🤖 AI Summary
This work addresses the significant redundancy in causal Transformer batch inference, where shared prefixes—such as system prompts or few-shot examples—are redundantly processed through repeated MLP activations. To mitigate this, the authors propose a stateless, prefix-Trie-based compressed inference method that constructs a Trie structure over shared input segments, enabling their unified representation in a single forward pass and restoring outputs only at attention boundaries. This approach uniquely integrates prefix deduplication with position-aware MLPs, LayerNorm, and embedding layers, achieving intra-batch redundancy elimination without caching or maintaining state. Compatible with standard Transformer architectures, the method yields end-to-end speedups of 1.44–1.59× on MS MARCO v1.1 using Qwen3 models (0.6B–8B parameters), with synthetic data demonstrating up to 5× acceleration.
📝 Abstract
Batch inference workloads for causal transformer models frequently process sequences that share common prefixes, such as system prompts, few-shot examples, or shared queries. Standard inference engines treat each sequence independently, redundantly recomputing identical MLP activations for every copy of the shared prefix. We introduce RadixMLP, a technique that exploits the position-wise nature of MLPs, LayerNorms, linear projections, and embeddings to eliminate this redundancy. RadixMLP dynamically maps batches to a prefix trie, gathering shared segments into a compressed representation for position-wise computation and scattering results back only at attention boundaries. RadixMLP is stateless and operates within a single forward pass. In end-to-end serving benchmarks on MS~MARCO v1.1 with Qwen3 models (0.6B to 8B parameters), RadixMLP achieves 1.44-1.59$\times$ speedups in realistic reranking workloads, with up to $5\times$ speedups on synthetic benchmarks with longer shared prefixes. Our code is available at https://github.com/michaelfeil/radix-mlp.