🤖 AI Summary
Large language models (LLMs) exhibit limited reasoning capabilities on complex mathematical tasks. Method: We propose SPRINT, a contrastive learning–based framework for dynamic, structured pruning of attention heads. During inference, SPRINT adaptively selects optimal head-layer combinations to enable task-driven Best-of-N path optimization. Contribution/Results: We empirically demonstrate—contrary to conventional wisdom—that selective head pruning can enhance, rather than degrade, LLM reasoning performance. SPRINT introduces attention-head embedding alignment and dynamic pruning to preserve semantic consistency and path validity post-pruning. Evaluated on MATH500 and GSM8K, SPRINT significantly outperforms standard Best-of-N sampling and random pruning, achieving up to a 4.2-percentage-point absolute accuracy gain. This validates that dynamically sparse inference paths effectively augment LLMs’ capacity for complex mathematical reasoning.
📝 Abstract
Model pruning in transformer-based language models, traditionally viewed as a means of achieving computational savings, can enhance the model's reasoning capabilities. In this work, we uncover a surprising phenomenon: the selective pruning of certain attention heads leads to improvements in reasoning performance, particularly on challenging tasks. Motivated by this observation, we propose SPRINT, a novel contrastive learning framework that dynamically selects the optimal head and layer to prune during inference. By aligning question embeddings with head embeddings, SPRINT identifies those pruned-head configurations that result in more accurate reasoning. Extensive experiments demonstrate that our method significantly outperforms traditional best-of-$N$ and random head selection strategies on the MATH500 and GSM8K datasets.