🤖 AI Summary
This work addresses the challenge that non-autoregressive draft generation suffers from severe distributional divergence from the true autoregressive distribution as model depth increases, a problem exacerbated in tree-structured decoding due to shared marginal distributions across branches. To mitigate this, the authors propose TreeFlash, a novel approach that introduces MLP-based conditional modeling leveraging the drafter’s hidden states and the preceding token to design a two-stage distribution approximation mechanism. This method effectively captures autoregressive dependencies while preserving single-pass parallel generation with O(1) decoding complexity. TreeFlash significantly enhances both accuracy and efficiency in tree-based speculative decoding, achieving state-of-the-art performance across diverse tasks and models. Compared to existing marginal-based tree drafting methods, it improves block efficiency by 12% and achieves a 9% higher speedup ratio.
📝 Abstract
One-shot block drafters for speculative decoding generate the full draft in a single forward pass, achieving strong throughput by eliminating sequential token generation. However, they predict each draft token conditioned only on the prefix context, with no dependence on previously drafted tokens. This non-autoregressive conditioning causes the drafter's distribution to diverge from the verifier's true autoregressive distribution as draft depth grows. This problem becomes more severe in tree-based drafting, where distinct branches are forced to share the same marginal distribution for subsequent tokens. We propose TreeFlash, which addresses this by incorporating an MLP layer conditioned on the drafter's hidden state and the previous token to approximate an autoregressive distribution. TreeFlash retains the $\mathcal{O}(1)$ decoding time complexity of one-shot drafters by employing a two-stage approximation mechanism. TreeFlash achieves state-of-the-art performance across a variety of tasks and models, improving over marginal tree drafting by $12\%$ higher block efficiency and $9\%$ higher speedup.