🤖 AI Summary
Circuit discovery aims to identify the minimal subnetwork within large language models (LLMs) responsible for executing a specific task. Existing approaches rely on iterative edge pruning, operating at coarse granularity—typically at the level of attention heads or MLP blocks—entailing high computational overhead and overestimating neuron importance. This paper proposes a multi-granularity node-level pruning framework that enables, for the first time, joint sparse optimization across hierarchical modules (e.g., layers, attention heads, neurons). It introduces learnable masks and hierarchical sparsity regularization, supporting end-to-end, single-phase fine-tuning for fine-grained circuit discovery without caching intermediate activations. Experiments demonstrate that the discovered circuits achieve comparable task performance with significantly fewer nodes—reducing memory footprint by 5–10×—and reveal that many neurons deemed “important” by coarse-grained methods are in fact redundant and safely removable.
📝 Abstract
Circuit discovery aims to identify minimal subnetworks that are responsible for specific behaviors in large language models (LLMs). Existing approaches primarily rely on iterative edge pruning, which is computationally expensive and limited to coarse-grained units such as attention heads or MLP blocks, overlooking finer structures like individual neurons. We propose a node-level pruning framework for circuit discovery that addresses both scalability and granularity limitations. Our method introduces learnable masks across multiple levels of granularity, from entire blocks to individual neurons, within a unified optimization objective. Granularity-specific sparsity penalties guide the pruning process, allowing a comprehensive compression in a single fine-tuning run. Empirically, our approach identifies circuits that are smaller in nodes than those discovered by prior methods; moreover, we demonstrate that many neurons deemed important by coarse methods are actually irrelevant, while still maintaining task performance. Furthermore, our method has a significantly lower memory footprint, 5-10x, as it does not require keeping intermediate activations in the memory to work.