Efficient Joint Prediction of Multiple Future Tokens

📅 2025-03-24
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Large language models (LLMs) suffer from impoverished hidden-state representations due to single-token autoregressive prediction, hindering long-range reasoning. To address this, we propose Joint Token Prediction (JTP), a lightweight extension to the standard next-token prediction framework. JTP introduces a compact representation bottleneck layer and a multi-objective joint loss function, enabling synchronized prediction of multiple future tokens via bottleneck-constrained teacher forcing—achieving short-horizon stable belief-state modeling with negligible training overhead. Crucially, JTP overcomes the longstanding instability and non-convergence issues plaguing conventional multi-step prediction methods, markedly enhancing both the predictiveness and generalizability of hidden states. Empirical evaluation on the synthetic StarGraph navigation task demonstrates that JTP consistently outperforms all existing baselines, validating its effectiveness and practical utility for structured reasoning tasks.

Technology Category

Application Category

📝 Abstract
In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.
Problem

Research questions and friction points this paper is trying to address.

Enhances hidden states by predicting multiple future tokens jointly
Uses teacher forcing with a bottleneck for efficient training
Improves performance on synthetic navigation tasks over existing methods
Innovation

Methods, ideas, or system contributions that make the work stand out.

Joint multi-token prediction (JTP) method
Teacher forcing with representation bottleneck
Short-horizon belief state representation
🔎 Similar Papers
No similar papers found.