Delayed Attention Training Improves Length Generalization in Transformer--RNN Hybrids

📅 2025-09-30
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
This work addresses the limited length generalization of sequence models on composite tasks involving state tracking and associative memory. While RNNs excel at state tracking, their memory capacity is constrained; Transformers possess superior memory capabilities but struggle to generalize to sequences longer than those seen during training. To bridge this gap, we propose a hybrid RNN–Transformer architecture and introduce delayed attention training: attention layers are frozen during early training to prioritize optimization of RNN-based state tracking, then unfrozen and fine-tuned in later stages to mitigate Transformer reliance on short-range spurious shortcuts. This strategy enables synergistic integration of both mechanisms for the first time. Our approach achieves over 90% accuracy on sequences three times longer than the training length—demonstrating substantial improvements in modeling long-range dependencies and length generalization.

Technology Category

Application Category

📝 Abstract
We study length generalization in sequence models on a composite problem involving both state tracking and associative recall. Prior work finds that recurrent networks handle state tracking well but struggle with recall, whereas Transformers excel at recall yet fail to extend state-tracking capabilities to longer sequences. Motivated by the complementary strengths of these architectures, we construct hybrid models integrating recurrent and attention-based components, and train them on the combined task to evaluate whether both capabilities can be preserved. Our results reveal that, in such hybrids, the Transformer component tends to exploit shortcut solutions, leading to poor length generalization. We identify this shortcut reliance as a key obstacle and propose a simple yet effective training strategy -- delaying the training of the attention layers -- that mitigates this effect and significantly improves length generalization performance. Our experiments show that this approach enables hybrid models to achieve near-perfect accuracy ($>90%$) on hybrid sequences three times longer than those used during training.
Problem

Research questions and friction points this paper is trying to address.

Hybrid models struggle with length generalization due to Transformer shortcuts
Delayed attention training mitigates shortcut reliance in hybrid architectures
Enables near-perfect accuracy on sequences three times training length
Innovation

Methods, ideas, or system contributions that make the work stand out.

Delayed attention training in hybrid models
Combines recurrent and attention-based components
Improves length generalization on sequences
🔎 Similar Papers
No similar papers found.