🤖 AI Summary
To address high computational latency and poor scalability in state dimension and prediction horizon for real-time model predictive control (MPC) of legged robots, this paper proposes a GPU-accelerated parallel MPC framework. The core innovation is the first integration of parallel associative scan into the iterative linear-quadratic regulator (iLQR) framework, enabling concurrent parallelization across both time steps and state variables, coupled with an efficient primal-dual interior-point method for solving the Karush–Kuhn–Tucker (KKT) system. Leveraging JAX’s automatic differentiation and native GPU support, the framework enables a single controller to coordinate up to 16 robots with MPC solve times under 25 ms, and natively supports MPC-in-the-loop end-to-end reinforcement learning. Experiments demonstrate speedups of 60% over acados and 700% over crocoddyl on WB-MPC and SRBD-MPC benchmarks, respectively, while exhibiting strong scalability in state dimension and enabling large-scale parallel training across diverse environments.
📝 Abstract
This paper introduces a novel Model Predictive Control (MPC) implementation for legged robot locomotion that leverages GPU parallelization. Our approach enables both temporal and state-space parallelization by incorporating a parallel associative scan to solve the primal-dual Karush-Kuhn-Tucker (KKT) system. In this way, the optimal control problem is solved in $mathcal{O}(nlog{N} + m)$ complexity, instead of $mathcal{O}(N(n + m)^3)$, where $n$, $m$, and $N$ are the dimension of the system state, control vector, and the length of the prediction horizon. We demonstrate the advantages of this implementation over two state-of-the-art solvers (acados and crocoddyl), achieving up to a 60% improvement in runtime for Whole Body Dynamics (WB)-MPC and a 700% improvement for Single Rigid Body Dynamics (SRBD)-MPC when varying the prediction horizon length. The presented formulation scales efficiently with the problem state dimensions as well, enabling the definition of a centralized controller for up to 16 legged robots that can be computed in less than 25 ms. Furthermore, thanks to the JAX implementation, the solver supports large-scale parallelization across multiple environments, allowing the possibility of performing learning with the MPC in the loop directly in GPU.