JAX-AMG: A GPU-Accelerated Differentiable Sparse Linear Solver Library for JAX

📅 2026-06-07
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses a critical gap in the JAX ecosystem: the absence of a sparse linear solver that supports GPU-accelerated algebraic multigrid (AMG), automatic differentiation, and distributed multi-GPU execution. The authors present the first native integration of NVIDIA AmgX as a JAX primitive, unifying AMG with Krylov subspace methods within JAX’s computational framework. This integration enables just-in-time (JIT) compilation, reverse-mode automatic differentiation, batching, and MPI-based distributed execution across multiple GPUs. To enhance efficiency, the implementation incorporates a solver caching mechanism that mitigates repeated setup overhead. By delivering a high-performance, scalable sparse linear algebra layer, this work bridges a key tooling gap between differentiable simulation and scientific computing, enabling seamless incorporation into scientific machine learning pipelines—particularly for PDE-constrained optimization and inverse problems.
📝 Abstract
Sparse linear systems from PDE discretizations are central to scientific computing, yet no existing JAX-ecosystem solver simultaneously provides GPU-accelerated algebraic multigrid (AMG), automatic differentiation (AD), and distributed multi-GPU execution. JAX-AMG fills this gap by wrapping the Nvidia AmgX solver suite as a native JAX primitive, exposing AMG and Krylov methods with configurable preconditioners through a unified interface compatible with JIT compilation, reverse-mode AD via adjoint methods, batched solves, and MPI-based distributed execution. Solver caching amortizes setup costs across repeated solves, making JAX-AMG practical for PDE-constrained optimization and inverse problems. The result is a robust, scalable sparse linear algebra layer that integrates seamlessly into differentiable simulation and scientific machine learning pipelines.
Problem

Research questions and friction points this paper is trying to address.

sparse linear solver
algebraic multigrid
automatic differentiation
GPU acceleration
JAX
Innovation

Methods, ideas, or system contributions that make the work stand out.

differentiable solver
algebraic multigrid
GPU acceleration
JAX
sparse linear systems
🔎 Similar Papers
No similar papers found.