JaxSGMC: Modular stochastic gradient MCMC in JAX

📅 2024-05-01
🏛️ SoftwareX
📈 Citations: 2
Influential: 0
📄 PDF
🤖 AI Summary
To address the scalability limitations of uncertainty quantification (UQ) in Bayesian deep learning for large-scale, high-dimensional settings, this work introduces the first modular stochastic gradient Markov chain Monte Carlo (SG-MCMC) framework designed natively for JAX. Our method features a highly decoupled architecture that enables plug-and-play composition and automatic differentiation compatibility across core components—including gradient noise models, preconditioners, and numerical integrators. Leveraging JAX’s functional programming paradigm, we achieve efficient cross-device (CPU/GPU/TPU) parallel sampling via just-in-time compilation (jit), vectorization (vmap), automatic differentiation (grad), and multi-process mapping (pmap). Empirical evaluation on multiple Bayesian neural network benchmarks demonstrates a 3–5× speedup in sampling throughput and a 40% reduction in memory consumption compared to PyTorch- and TensorFlow-based SG-MCMC implementations, substantially lowering the barrier to deploying SG-MCMC in trustworthy deep learning systems.

Technology Category

Application Category

Problem

Research questions and friction points this paper is trying to address.

Develops JaxSGMC for scalable Bayesian deep learning
Reduces barriers to adopt SG-MCMC sampling methods
Enables modular custom sampler design for SG-MCMC
Innovation

Methods, ideas, or system contributions that make the work stand out.

Modular SG-MCMC library in JAX
State-of-the-art SG-MCMC samplers
Custom samplers from standard blocks
🔎 Similar Papers