🤖 AI Summary
This work addresses the challenges of approximating causal attention in long-context language modeling, where resource bottlenecks arise from prefilling, KV cache compression, and memory- or compute-constrained decoding. The authors propose the first causal attention approximation method that simultaneously offers theoretical error guarantees and practical efficiency. By leveraging the Express framework to reformulate non-causal approximations into equivalent causal forms, and integrating the Thinformer algorithm with Triton-based optimizations, the method achieves an approximation error of $\log^{3/2}(n)/s$ while requiring only $O(s)$ memory and $O(s^2 \log^2 n)$ compression overhead. Empirical evaluations demonstrate substantial performance gains over FlashAttention-2 across multiple long-sequence tasks.
📝 Abstract
We introduce a new tool, Express, for converting a non-causal attention approximation into a causal approximation with matching approximation guarantees. When combined with the state-of-the-art Thinformer approximation, Express improves upon the best known causal attention guarantees, delivering $\log^{3/2}(n)/s$ approximation error with only $O(s)$ memory and $O(s^2 \log^2(n))$ compression overhead for a sequence of length $n$. We pair these developments with an efficient I/O-aware Triton implementation, demonstrate substantial speedups over FlashAttention 2, and use Express to overcome four resource bottlenecks in the language modeling pipeline: long-context prefill, KV cache compression, long-form memory-constrained decoding, and long-form compute-constrained decoding.