🤖 AI Summary
This work addresses the significant performance degradation of large language models under 4-bit quantization due to severe precision loss. To mitigate this, the authors propose a quantization-aware subspace joint learning framework that leverages learnable low-rank and residual decompositions, jointly optimizing structural and quantization parameters over the Stiefel manifold and the general linear group. The approach further incorporates distribution shaping and dynamic range equalization to minimize quantization error. Additionally, an on-chip efficient inference kernel integrating a dual-branch architecture is designed to accelerate computation. Experiments on LLaMA3 and Qwen3 demonstrate that the method achieves near-FP16 accuracy at 4-bit precision, with end-to-end inference speedups of up to 1.8×.
📝 Abstract
4-bit quantization reduces the memory footprint and latency of large language model inference, but its aggressive precision reduction can severely degrade accuracy. Prior methods address this by decomposing each weight matrix into two components (e.g., via singular value decomposition) and quantizing them separately, assigning the bulk of values to a low-precision residual component while handling outliers with a high-precision low-rank component. However, such decompositions are designed to minimize the real-valued energy of the residual, rather than the post-quantization error of the residual and low-rank components. We propose TwinQuant, a 4-bit quantization framework that learns quantization-friendly decomposed subspaces and jointly reshapes both the low-rank and residual components. TwinQuant learns component-specific transformations via a joint optimization over the Stiefel and general linear manifolds, flattening their distributions and reducing dynamic-range imbalance. To enable efficient end-to-end execution, we further design a fused dual-component kernel that pipelines the two-stage low-rank computation on-chip and merges both components with a single epilogue, avoiding intermediate global-memory traffic. Across LLaMA3 and Qwen3 models, TwinQuant preserves near-FP16 accuracy and delivers up to $1.8\times$ end-to-end speedup over an FP16 baseline.