🤖 AI Summary
To address the scalability limitations of uncertainty quantification (UQ) in Bayesian deep learning for large-scale, high-dimensional settings, this work introduces the first modular stochastic gradient Markov chain Monte Carlo (SG-MCMC) framework designed natively for JAX. Our method features a highly decoupled architecture that enables plug-and-play composition and automatic differentiation compatibility across core components—including gradient noise models, preconditioners, and numerical integrators. Leveraging JAX’s functional programming paradigm, we achieve efficient cross-device (CPU/GPU/TPU) parallel sampling via just-in-time compilation (jit), vectorization (vmap), automatic differentiation (grad), and multi-process mapping (pmap). Empirical evaluation on multiple Bayesian neural network benchmarks demonstrates a 3–5× speedup in sampling throughput and a 40% reduction in memory consumption compared to PyTorch- and TensorFlow-based SG-MCMC implementations, substantially lowering the barrier to deploying SG-MCMC in trustworthy deep learning systems.