🤖 AI Summary
Large language models (LLMs) suffer from impoverished hidden-state representations due to single-token autoregressive prediction, hindering long-range reasoning. To address this, we propose Joint Token Prediction (JTP), a lightweight extension to the standard next-token prediction framework. JTP introduces a compact representation bottleneck layer and a multi-objective joint loss function, enabling synchronized prediction of multiple future tokens via bottleneck-constrained teacher forcing—achieving short-horizon stable belief-state modeling with negligible training overhead. Crucially, JTP overcomes the longstanding instability and non-convergence issues plaguing conventional multi-step prediction methods, markedly enhancing both the predictiveness and generalizability of hidden states. Empirical evaluation on the synthetic StarGraph navigation task demonstrates that JTP consistently outperforms all existing baselines, validating its effectiveness and practical utility for structured reasoning tasks.
📝 Abstract
In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.