Learning Transformer-based World Models with Contrastive Predictive Coding

📅 2025-03-06
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing Transformer-based world models are constrained by conventional next-state prediction paradigms, limiting their capacity to learn rich, high-level representations. To address this, we propose TWISTER—a novel framework that introduces action-conditional contrastive predictive coding (CPC) into Transformer world model training for the first time. TWISTER enables end-to-end, search-free modeling of long-horizon, high-order temporal features, breaking away from RNN-dominated paradigms such as Dreamer. Our method jointly optimizes self-supervised temporal representation learning and world-model-based reinforcement learning. Evaluated on the Atari 100k benchmark, TWISTER achieves a human-normalized mean score of 162%, establishing a new state-of-the-art among lookahead-free methods. This advance significantly enhances both generalization capability and long-range temporal modeling performance.

Technology Category

Application Category

📝 Abstract
The DreamerV3 algorithm recently obtained remarkable performance across diverse environment domains by learning an accurate world model based on Recurrent Neural Networks (RNNs). Following the success of model-based reinforcement learning algorithms and the rapid adoption of the Transformer architecture for its superior training efficiency and favorable scaling properties, recent works such as STORM have proposed replacing RNN-based world models with Transformer-based world models using masked self-attention. However, despite the improved training efficiency of these methods, their impact on performance remains limited compared to the Dreamer algorithm, struggling to learn competitive Transformer-based world models. In this work, we show that the next state prediction objective adopted in previous approaches is insufficient to fully exploit the representation capabilities of Transformers. We propose to extend world model predictions to longer time horizons by introducing TWISTER (Transformer-based World model wIth contraSTivE Representations), a world model using action-conditioned Contrastive Predictive Coding to learn high-level temporal feature representations and improve the agent performance. TWISTER achieves a human-normalized mean score of 162% on the Atari 100k benchmark, setting a new record among state-of-the-art methods that do not employ look-ahead search.
Problem

Research questions and friction points this paper is trying to address.

Improving Transformer-based world models for better performance.
Extending world model predictions to longer time horizons.
Enhancing agent performance using Contrastive Predictive Coding.
Innovation

Methods, ideas, or system contributions that make the work stand out.

Transformer-based world model with Contrastive Predictive Coding
Extends predictions to longer time horizons
Achieves record performance on Atari 100k benchmark
🔎 Similar Papers
No similar papers found.
Maxime Burchi
Maxime Burchi
University of Würzburg
Computer VisionSpeech RecognitionReinforcement Learning
R
R. Timofte
Computer Vision Lab, CAIDAS & IFI, University of Würzburg, Germany