🤖 AI Summary
Transformer attention’s quadratic time and memory complexity severely hinders efficient long-context modeling. To address this, we propose an **adaptive token retention mechanism** that learns, end-to-end under a strict global memory budget *M*, which tokens to retain. Our core contributions are: (i) a **learnable Bernoulli gating** with **Hard-Concrete relaxation**, enabling differentiable training and deterministic top-*M* token selection at inference; and (ii) a **probabilistic inter-layer selection strategy**, leveraging variational relaxation for gradient propagation—achieving plug-and-play integration without modifying the original attention architecture. Experiments demonstrate that retaining only 30–50% of tokens preserves over 95% of model performance, reduces peak memory by 35–45%, and improves throughput by 1.8×. The method significantly enhances the efficiency–accuracy trade-off across text classification, extractive question answering, and long-document summarization tasks.
📝 Abstract
Transformer attention scales quadratically with sequence length O(n^2), limiting long-context use. We propose Adaptive Retention, a probabilistic, layer-wise token selection mechanism that learns which representations to keep under a strict global budget M. Retention is modeled with Bernoulli gates trained via a Hard-Concrete/variational relaxation and enforced with a simple top-M rule at inference, making the method differentiable and drop-in for standard encoders. Across classification, extractive QA, and long-document summarization, keeping only 30-50% of tokens preserves >= 95% of full-model performance while cutting peak memory by ~35-45% and improving throughput by up to ~1.8x. This architecture-agnostic approach delivers practical long-context efficiency without modifying base attention or task heads.