🤖 AI Summary
Large language models (LLMs) suffer from prohibitive parameter count and computational overhead in their output layers due to the large vocabulary size (V), where the final linear projection is V × d_model. Existing acceleration techniques—e.g., hierarchical softmax—introduce architectural complexity and hinder end-to-end optimization. This work proposes the first application of vector quantization (VQ) to compress the LLM output layer: a compact, shared codebook of size K ≪ V replaces the full vocabulary projection, enabling codebook-driven logits prediction and differentiable scattering for fully end-to-end training—without modifying the model backbone or adding auxiliary structures. Evaluated on WikiText-103 and C4, our method achieves 99% parameter reduction in the output layer and 6× speedup in logits computation, with only ~4% perplexity degradation. Ablation studies confirm the robustness of both codebook design and learning strategy.
📝 Abstract
Large Language Models (LLMs) have achieved remarkable success but face significant computational and memory challenges, particularly due to their extensive output vocabularies. The final linear projection layer, mapping hidden states to vocabulary-sized logits, often constitutes a substantial portion of the model's parameters and computational cost during inference. Existing methods like adaptive softmax or hierarchical softmax introduce structural complexities. In this paper, we propose VQ-Logits, a novel approach that leverages Vector Quantization (VQ) to drastically reduce the parameter count and computational load of the LLM output layer. VQ-Logits replaces the large V * dmodel output embedding matrix with a small, shared codebook of K embedding vectors (K<<V ). Each token in the vocabulary is mapped to one of these K codebook vectors. The LLM predicts logits over this compact codebook, which are then efficiently"scattered"to the full vocabulary space using the learned or preassigned mapping. We demonstrate through extensive experiments on standard language modeling benchmarks (e.g., WikiText-103, C4) that VQ-Logits can achieve up to 99% parameter reduction in the output layer and 6x speedup in logit computation, with only a marginal 4% increase in perplexity compared to full softmax baselines. We further provide detailed ablation studies on codebook size, initialization, and learning strategies, showcasing the robustness and effectiveness of our approach.