🤖 AI Summary
This work addresses the low computational efficiency and lack of hardware acceleration support for end-to-end differentiable linear programming (LP) in machine learning. We propose the first high-performance, JAX-based differentiable LP solver. Methodologically, we introduce the first JAX implementations of the restarted averaged and reflected Halpern-type primal–dual hybrid gradient methods (rAPDHG and rH-PDHG), enabling batched computation, automatic differentiation, and seamless GPU/TPU parallelization. Our contributions are threefold: (1) a fully differentiable, device-agnostic LP solver; (2) substantial improvements in convergence speed and throughput—outperforming state-of-the-art LP solvers across multiple benchmark suites; and (3) an open-source, plug-and-play implementation that facilitates joint modeling of machine learning and optimization.
📝 Abstract
This paper presents MPAX (Mathematical Programming in JAX), a versatile and efficient toolbox for integrating linear programming (LP) into machine learning workflows. MPAX implemented the state-of-the-art first-order methods, restarted average primal-dual hybrid gradient and reflected restarted Halpern primal-dual hybrid gradient, to solve LPs in JAX. This provides native support for hardware accelerations along with features like batch solving, auto-differentiation, and device parallelism. Extensive numerical experiments demonstrate the advantages of MPAX over existing solvers. The solver is available at https://github.com/MIT-Lu-Lab/MPAX.