Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

πŸ“… 2025-06-09
πŸ“ˆ Citations: 0
✨ Influential: 0
πŸ“„ PDF
πŸ€– AI Summary
Large language models (LLMs) suffer from spurious correlations in training data during long-context reasoning, causing attention misallocation toward irrelevant tokens, redundant inference, and erroneous responses. To address this, we propose a causal-driven two-stage intervention framework: (1) gradient-based contrastive analysis to automatically identify spurious-correlation-inducing tokens; and (2) differentiable token-level pruning coupled with teacher-student attention distribution alignment for causal-aware attention distillation. This work pioneers the integration of causal intervention into attention mechanism optimization, establishing the first end-to-end closed loopβ€”from discovery of data-level spurious correlations to interpretable, inference-time attention correction. Evaluated on mathematical reasoning and code generation benchmarks, our method achieves significant absolute performance gains, effectively suppresses attention to confounding tokens, and enhances both model interpretability and response reliability.

Technology Category

Application Category

πŸ“ Abstract
Large language models (LLMs) have demonstrated significant improvements in contextual understanding. However, their ability to attend to truly critical information during long-context reasoning and generation still falls behind the pace. Specifically, our preliminary experiments reveal that certain distracting patterns can misdirect the model's attention during inference, and removing these patterns substantially improves reasoning accuracy and generation quality. We attribute this phenomenon to spurious correlations in the training data, which obstruct the model's capacity to infer authentic causal instruction-response relationships. This phenomenon may induce redundant reasoning processes, potentially resulting in significant inference overhead and, more critically, the generation of erroneous or suboptimal responses. To mitigate this, we introduce a two-stage framework called Learning to Focus (LeaF) leveraging intervention-based inference to disentangle confounding factors. In the first stage, LeaF employs gradient-based comparisons with an advanced teacher to automatically identify confounding tokens based on causal relationships in the training corpus. Then, in the second stage, it prunes these tokens during distillation to enact intervention, aligning the student's attention with the teacher's focus distribution on truly critical context tokens. Experimental results demonstrate that LeaF not only achieves an absolute improvement in various mathematical reasoning and code generation benchmarks but also effectively suppresses attention to confounding tokens during inference, yielding a more interpretable and reliable reasoning model.
Problem

Research questions and friction points this paper is trying to address.

Improving attention to critical information in long-context reasoning
Reducing spurious correlations that misdirect model attention
Enhancing reasoning accuracy and generation quality via token pruning
Innovation

Methods, ideas, or system contributions that make the work stand out.

Gradient-guided token pruning for attention
Two-stage causal attention distillation framework
Intervention-based inference to reduce spurious correlations
πŸ”Ž Similar Papers
No similar papers found.
Y
Yiju Guo
Gaoling School of Artificial Intelligence, Renmin University of China
Wenkai Yang
Wenkai Yang
Renmin University of China
Natural Language ProcessingMachine Learning
Zexu Sun
Zexu Sun
Renmin University of China
Causal inferenceReinforcement learningLarge language model
N
Ning Ding
Department of Computer Science and Technology, Tsinghua University
Z
Zhiyuan Liu
Department of Computer Science and Technology, Tsinghua University
Yankai Lin
Yankai Lin
Associate Professor (Tenure Track), Gaoling School of AI, Renmin University of China
Natural Language ProcessingLarge Language Models