๐ค AI Summary
To address the high inference latency and computational overhead caused by autoregressive encoding of long retrieval contexts in Retrieval-Augmented Generation (RAG), this paper introduces Block Attentionโa novel attention mechanism that partitions retrieved documents into independent blocks and recomputes only the key-value (KV) states for the final block while reusing cached KV states from prior blocks. This work presents the first cross-block KV state reuse scheme in RAG, integrating block-wise document segmentation, positional recoding, and LLM fine-tuning with explicit block awareness to jointly optimize efficiency and generation quality. Evaluated on four standard RAG benchmarks, our method matches or exceeds the performance of full self-attention (e.g., 68.4% vs. 67.9% on Llama3; 62.8% vs. 59.6% on Mistral), reduces first-token latency to just 45 ms for 32K-length sequences, and cuts FLOPs by 99.8%.
๐ Abstract
We introduce Block-Attention, an attention mechanism designed to address the increased inference latency and cost in Retrieval-Augmented Generation (RAG) scenarios. Traditional approaches often encode the entire context. Instead, Block-Attention divides retrieved documents into discrete blocks, with each block independently calculating key-value (KV) states except for the final block. In RAG scenarios, by defining each passage as a block, Block-Attention enables us to reuse the KV states of passages that have been seen before, thereby significantly reducing the latency and the computation overhead during inference. The implementation of Block-Attention involves block segmentation, position re-encoding, and fine-tuning the LLM to adapt to the Block-Attention mechanism. Experiments on four RAG benchmarks demonstrate that after block fine-tuning, the Block-Attention model achieves performance comparable to self-attention models (68.4% vs 67.9% on Llama3) or even superior performance (62.8% vs 59.6% on Mistral). Notably, Block-Attention significantly reduces the time to first token (TTFT) and floating point operations (FLOPs) to a very low level. It only takes 45 ms to output the first token for an input sequence with a total length of 32K. Compared to the self-attention models, the time consumption and corresponding FLOPs are reduced by 98.7% and 99.8%, respectively.