🤖 AI Summary
Large reasoning models (LRMs) face dual challenges in long-chain-of-thought generation: excessive KV cache memory pressure and degraded reasoning coherence due to sparse attention. Moreover, existing methods lack the capability to identify critical tokens dynamically during inference. To address this, we propose Multi-Pole Attention: it performs real-time token clustering based on semantic vectors, accurately modeling key tokens via cluster centroids while approximating remaining keys. This enables incremental key cache updates and efficient custom CUDA kernel implementation. The method requires no pre-processing and adapts online to newly generated tokens during inference. Evaluated on Qwen-8B, our approach preserves accuracy on complex reasoning tasks while accelerating attention computation by 4.5× and substantially reducing KV cache memory footprint. To the best of our knowledge, this is the first method achieving high-fidelity, low-overhead, fully online optimization for long-chain reasoning.
📝 Abstract
Large Reasoning Models (LRMs) have shown promising accuracy improvements on complex problem-solving tasks. While these models have attained high accuracy by leveraging additional computation at test time, they need to generate long chain-of-thought reasoning in order to think before answering, which requires generating thousands of tokens. While sparse attention methods can help reduce the KV cache pressure induced by this long autoregressive reasoning, these methods can introduce errors which disrupt the reasoning process. Additionally, prior methods often pre-process the input to make it easier to identify the important prompt tokens when computing attention during generation, and this pre-processing is challenging to perform online for newly generated reasoning tokens. Our work addresses these challenges by introducing Multipole Attention, which accelerates autoregressive reasoning by only computing exact attention for the most important tokens, while maintaining approximate representations for the remaining tokens. Our method first performs clustering to group together semantically similar key vectors, and then uses the cluster centroids both to identify important key vectors and to approximate the remaining key vectors in order to retain high accuracy. We design a fast cluster update process to quickly re-cluster the input and previously generated tokens, thereby allowing for accelerating attention to the previous output tokens. We evaluate our method using emerging LRMs such as Qwen-8B, demonstrating that our approach can maintain accuracy on complex reasoning tasks even with aggressive attention sparsity settings. We also provide kernel implementations to demonstrate the practical efficiency gains from our method, achieving up to 4.5$ imes$ speedup for attention in long-context reasoning applications. Our code is available at https://github.com/SqueezeAILab/MultipoleAttention.