🤖 AI Summary
This work addresses the limitations of conventional Bayesian neural networks, which employ mean-field Gaussian posteriors with O(mn) parameters, neglect structural dependencies among weights, and incur high computational costs. The authors propose a singular Bayesian posterior that parameterizes the weight matrix in a low-rank form W = AB^T, concentrating the posterior on a low-rank manifold to explicitly model shared latent factors and structural dependencies. This approach is the first to integrate low-rank structure into the PAC-Bayes framework, yielding a generalization bound dependent solely on the rank r, derived via the Eckart–Young–Mirsky theorem and Gaussian complexity analysis. Experiments on MLPs, LSTMs, and Transformers demonstrate that the method achieves predictive performance comparable to a 5-member deep ensemble using up to 15× fewer parameters, while significantly improving out-of-distribution detection and calibration.
📝 Abstract
Bayesian neural networks promise calibrated uncertainty but require $O(mn)$ parameters for standard mean-field Gaussian posteriors. We argue this cost is often unnecessary, particularly when weight matrices exhibit fast singular value decay. By parameterizing weights as $W = AB^{\top}$ with $A \in \mathbb{R}^{m \times r}$, $B \in \mathbb{R}^{n \times r}$, we induce a posterior that is singular with respect to the Lebesgue measure, concentrating on the rank-$r$ manifold. This singularity captures structured weight correlations through shared latent factors, geometrically distinct from mean-field's independence assumption. We derive PAC-Bayes generalization bounds whose complexity term scales as $\sqrt{r(m+n)}$ instead of $\sqrt{m n}$, and prove loss bounds that decompose the error into optimization and rank-induced bias using the Eckart-Young-Mirsky theorem. We further adapt recent Gaussian complexity bounds for low-rank deterministic networks to Bayesian predictive means. Empirically, across MLPs, LSTMs, and Transformers on standard benchmarks, our method achieves predictive performance competitive with 5-member Deep Ensembles while using up to $15\times$ fewer parameters. Furthermore, it substantially improves OOD detection and often improves calibration relative to mean-field and perturbation baselines.