🤖 AI Summary
To address the challenges of large parameter count, low sample efficiency, and high computational cost in model-based reinforcement learning (MBRL), this paper introduces the first world model framework built upon the Mamba state-space model. Leveraging Mamba’s linear-time complexity and strong long-range dependency modeling capability, we design an adaptive trajectory sampling mechanism to mitigate error accumulation in early predictions, and integrate a lightweight training architecture with enhanced model predictive control (MPC). With only 7 million parameters, our approach achieves state-of-the-art performance on standard benchmarks, attaining normalized scores competitive with leading MBRL methods. Remarkably, the world model can be trained on a conventional laptop, drastically lowering hardware requirements. Our core contribution is the pioneering integration of Mamba into world model construction—uniquely reconciling high sample efficiency with exceptionally low computational overhead.
📝 Abstract
Model-based reinforcement learning (RL) offers a solution to the data inefficiency that plagues most model-free RL algorithms. However, learning a robust world model often demands complex and deep architectures, which are expensive to compute and train. Within the world model, dynamics models are particularly crucial for accurate predictions, and various dynamics-model architectures have been explored, each with its own set of challenges. Currently, recurrent neural network (RNN) based world models face issues such as vanishing gradients and difficulty in capturing long-term dependencies effectively. In contrast, use of transformers suffers from the well-known issues of self-attention mechanisms, where both memory and computational complexity scale as $O(n^2)$, with $n$ representing the sequence length. To address these challenges we propose a state space model (SSM) based world model, specifically based on Mamba, that achieves $O(n)$ memory and computational complexity while effectively capturing long-term dependencies and facilitating the use of longer training sequences efficiently. We also introduce a novel sampling method to mitigate the suboptimality caused by an incorrect world model in the early stages of training, combining it with the aforementioned technique to achieve a normalised score comparable to other state-of-the-art model-based RL algorithms using only a 7 million trainable parameter world model. This model is accessible and can be trained on an off-the-shelf laptop. Our code is available at https://github.com/realwenlongwang/Drama.git