TNT: Improving Chunkwise Training for Test-Time Memorization

📅 2025-11-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
Existing deep test-time memory RNNs (e.g., Titans, TTT) offer linear scalability but suffer from prohibitively slow training and low hardware utilization, limiting practical deployment. Their parallelization is fundamentally constrained by a trade-off inherent in block size selection: large blocks accelerate training yet degrade model performance, whereas small blocks preserve accuracy at the cost of severe computational inefficiency. This work introduces the TNT training paradigm, featuring a novel two-stage decoupled design. Stage I enables context-parallel long-range modeling via a hierarchical memory architecture with periodic local state resets. Stage II separates pretraining and fine-tuning: global coarse-grained processing with large blocks is followed by high-resolution fine-tuning with small blocks. Evaluated on Titans and TTT, TNT achieves up to 17× training speedup while simultaneously improving accuracy—marking a significant breakthrough in the scalability of RNN-based models.

Technology Category

Application Category

📝 Abstract
Recurrent neural networks (RNNs) with deep test-time memorization modules, such as Titans and TTT, represent a promising, linearly-scaling paradigm distinct from Transformers. While these expressive models do not yet match the peak performance of state-of-the-art Transformers, their potential has been largely untapped due to prohibitively slow training and low hardware utilization. Existing parallelization methods force a fundamental conflict governed by the chunksize hyperparameter: large chunks boost speed but degrade performance, necessitating a fixed, suboptimal compromise. To solve this challenge, we introduce TNT, a novel training paradigm that decouples training efficiency from inference performance through a two-stage process. Stage one is an efficiency-focused pre-training phase utilizing a hierarchical memory. A global module processes large, hardware-friendly chunks for long-range context, while multiple parallel local modules handle fine-grained details. Crucially, by periodically resetting local memory states, we break sequential dependencies to enable massive context parallelization. Stage two is a brief fine-tuning phase where only the local memory modules are adapted to a smaller, high-resolution chunksize, maximizing accuracy with minimal overhead. Evaluated on Titans and TTT models, TNT achieves a substantial acceleration in training speed-up to 17 times faster than the most accurate baseline configuration - while simultaneously improving model accuracy. This improvement removes a critical scalability barrier, establishing a practical foundation for developing expressive RNNs and facilitating future work to close the performance gap with Transformers.
Problem

Research questions and friction points this paper is trying to address.

Improving slow training speed and low hardware utilization in RNNs
Resolving trade-off between large chunks for speed and small chunks for accuracy
Enabling parallelization while maintaining model performance in chunkwise training
Innovation

Methods, ideas, or system contributions that make the work stand out.

Decouples training efficiency from inference performance
Uses hierarchical memory with global and local modules
Enables massive parallelization through periodic memory resets
🔎 Similar Papers
No similar papers found.