🤖 AI Summary
This work addresses a critical gap in the JAX ecosystem: the absence of a sparse linear solver that supports GPU-accelerated algebraic multigrid (AMG), automatic differentiation, and distributed multi-GPU execution. The authors present the first native integration of NVIDIA AmgX as a JAX primitive, unifying AMG with Krylov subspace methods within JAX’s computational framework. This integration enables just-in-time (JIT) compilation, reverse-mode automatic differentiation, batching, and MPI-based distributed execution across multiple GPUs. To enhance efficiency, the implementation incorporates a solver caching mechanism that mitigates repeated setup overhead. By delivering a high-performance, scalable sparse linear algebra layer, this work bridges a key tooling gap between differentiable simulation and scientific computing, enabling seamless incorporation into scientific machine learning pipelines—particularly for PDE-constrained optimization and inverse problems.
📝 Abstract
Sparse linear systems from PDE discretizations are central to scientific computing, yet no existing JAX-ecosystem solver simultaneously provides GPU-accelerated algebraic multigrid (AMG), automatic differentiation (AD), and distributed multi-GPU execution. JAX-AMG fills this gap by wrapping the Nvidia AmgX solver suite as a native JAX primitive, exposing AMG and Krylov methods with configurable preconditioners through a unified interface compatible with JIT compilation, reverse-mode AD via adjoint methods, batched solves, and MPI-based distributed execution. Solver caching amortizes setup costs across repeated solves, making JAX-AMG practical for PDE-constrained optimization and inverse problems. The result is a robust, scalable sparse linear algebra layer that integrates seamlessly into differentiable simulation and scientific machine learning pipelines.