SparDA: Sparse Decoupled Attention for Efficient Long-Context LLM Inference

📅 2026-06-03
📈 Citations: 0
Influential: 0
📄 PDF

career value

225K/year
🤖 AI Summary
This work addresses the memory bottleneck in long-context large language model inference caused by the growing KV cache and the O(T²) complexity of sparse attention selection. The authors propose SparDA, an architecture that decouples sparse attention selection into a lightweight, trainable Forecast module—introducing less than 0.5% parameter overhead—which predicts required KV blocks for the next layer to enable CPU-GPU prefetching overlapped with computation. By employing grouped shared Forecast heads, SparDA further reduces selection overhead. Evaluated on an 8B sparsely pretrained model, SparDA achieves up to 1.25× faster prefill and 1.7× faster decode speeds compared to the baseline, with a 5.3× increase in single-GPU decoding throughput while maintaining or slightly improving accuracy.
📝 Abstract
Sparse attention reduces compute and memory bandwidth for long-context LLM inference. However, two key challenges remain: (1) KV cache capacity still grows with sequence length, and offloading to CPU memory introduces a PCIe transfer bottleneck; (2) the sparse selection step itself retains $O(T^2)$ complexity and can dominate attention cost at long contexts. We propose SparDA, a decoupled sparse attention architecture that introduces a fourth per-layer projection, the Forecast, alongside Query, Key, and Value. The Forecast predicts the KV blocks needed by the next layer, enabling lookahead selection that overlaps CPU-to-GPU prefetch with current-layer execution. Because Forecast is decoupled from the attention query, our GQA implementation uses one Forecast head per GQA group, reducing selection overhead versus the original multi-head selector. SparDA adds $<$0.5% parameters and trains only the Forecast projections by matching the original selector's attention distribution. On two sparse-pretrained 8B models, SparDA matches or slightly improves accuracy and delivers up to 1.25$\times$ prefill speedup and 1.7$\times$ decode speedup over the sparse-attention offload baseline. By enabling larger feasible batch sizes on a single GPU, SparDA further reaches up to 5.3$\times$ higher decode throughput than the non-offload sparse baseline. Our source code is available at https://github.com/NVlabs/SparDA.
Problem

Research questions and friction points this paper is trying to address.

sparse attention
KV cache
long-context LLM inference
memory bandwidth
PCIe bottleneck
Innovation

Methods, ideas, or system contributions that make the work stand out.

Sparse Attention
Decoupled Forecast
KV Cache Prefetching
Long-Context LLM Inference
Grouped Query Attention