🤖 AI Summary
Existing adaptive-step probabilistic ODE solvers suffer from unbounded memory growth during long-horizon simulations, frequently crashing due to out-of-memory errors and struggling to reconcile adaptivity with full-trajectory probabilistic inference. This work introduces the first adaptive probabilistic ODE solver with strictly bounded memory consumption. Building on a robust state estimation framework, we design a memory-deterministic propagation mechanism that—uniquely—enables adaptive step-size selection under a hard memory budget. Leveraging JAX’s just-in-time compilation and probabilistic numerical methods, our solver natively supports efficient, scalable scientific computing. Experiments demonstrate elimination of memory overflow risks, up to an order-of-magnitude speedup in simulation time, and full preservation of probabilistic outputs—including calibrated uncertainty quantification and complete trajectory inference—without compromising accuracy or adaptivity.
📝 Abstract
Despite substantial progress in recent years, probabilistic solvers with adaptive step sizes can still not solve memory-demanding differential equations -- unless we care only about a single point in time (which is far too restrictive; we want the whole time series). Counterintuitively, the culprit is the adaptivity itself: Its unpredictable memory demands easily exceed our machine's capabilities, making our simulations fail unexpectedly and without warning. Still, dropping adaptivity would abandon years of progress, which can't be the answer. In this work, we solve this conundrum. We develop an adaptive probabilistic solver with fixed memory demands building on recent developments in robust state estimation. Switching to our method (i) eliminates memory issues for long time series, (ii) accelerates simulations by orders of magnitude through unlocking just-in-time compilation, and (iii) makes adaptive probabilistic solvers compatible with scientific computing in JAX.