🤖 AI Summary
This work addresses the computational bottleneck posed by the large output vocabulary in language models during on-device deployment, where the dense classification head incurs substantial parameter and inference costs. To circumvent this issue, the authors reformulate the classification task as an information retrieval problem and introduce a training-free, hardware-efficient alternative. The proposed approach features a balanced clustering structure, parallel multi-probe scoring, a full-vocabulary probability sampling mechanism, and selective quantization. Evaluated on Llama-3.2, Gemma-3, and Qwen-3, the method achieves up to 1.75× end-to-end speedup while preserving the original model’s output accuracy.
📝 Abstract
Language models are increasingly adopting smaller architectures optimized for consumer devices. In this setting, inference efficiency is the primary constraint. Meanwhile, vocabulary sizes continue to grow rapidly, making the classification head a critical bottleneck that accounts for up to 60\% of model parameters, and 50\% of inference compute. We introduce FlashHead, the first efficient drop-in replacement for the dense classification head that is training-free and hardware-friendly. FlashHead builds on principles from information retrieval, reframing that computation at the output head as a retrieval problem rather than a dense classification over the full vocabulary. FlashHead introduces four key innovations: (1) a balanced clustering scheme that structures vocabulary partitions into compact hardware-efficient tensors, (2) extending multiprobe retrieval to language model heads, enabling thousands of clusters to be scored in parallel, (3) a novel inference-time sampling mechanism that extends retrieval beyond top tokens, enabling probabilistic sampling across the full vocabulary, and (4) selective quantization, enabling effective low-bit computation in the head. Experiments on Llama-3.2, Gemma-3, and Qwen-3 show that FlashHead delivers model-level inference speedups of up to \textbf{1.75x} which maintaining output accuracy compared to the original head. By overcoming the classification head bottleneck, FlashHead establishes a new benchmark for efficient inference and removes a key barrier to developing smaller, capable models for consumer hardware.