🤖 AI Summary
This work addresses the high memory overhead imposed by KV caching in large language models during long-context reasoning. To mitigate this, the authors propose GradMem, a compressed memory mechanism that optimizes a small set of prefix memory tokens via sample-level gradient descent at test time—while keeping model weights frozen—using a self-supervised context reconstruction loss to iteratively refine memory contents. This approach pioneers the integration of test-time gradient-based optimization into memory encoding, enabling loss-guided error correction and substantially enhancing both memory capacity and generalization. Experiments demonstrate that GradMem outperforms forward-pass memory methods under equivalent memory budgets on associative key-value retrieval tasks and achieves competitive performance on natural language benchmarks such as bAbI and SQuAD variants, successfully answering multi-turn queries using only compressed memory representations.
📝 Abstract
Many large language model applications require conditioning on long contexts. Transformers typically support this by storing a large per-layer KV-cache of past activations, which incurs substantial memory overhead. A desirable alternative is ompressive memory: read a context once, store it in a compact state, and answer many queries from that state. We study this in a context removal setting, where the model must generate an answer without access to the original context at inference time. We introduce GradMem, which writes context into memory via per-sample test-time optimization. Given a context, GradMem performs a few steps of gradient descent on a small set of prefix memory tokens while keeping model weights frozen. GradMem explicitly optimizes a model-level self-supervised context reconstruction loss, resulting in a loss-driven write operation with iterative error correction, unlike forward-only methods. On associative key--value retrieval, GradMem outperforms forward-only memory writers with the same memory size, and additional gradient steps scale capacity much more effectively than repeated forward writes. We further show that GradMem transfers beyond synthetic benchmarks: with pretrained language models, it attains competitive results on natural language tasks including bAbI and SQuAD variants, relying only on information encoded in memory.