TokenWeave: Efficient Compute-Communication Overlap for Distributed LLM Inference

📅 2025-05-16
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
In distributed large language model (LLM) inference, inefficient overlap between NVLink communication and computation incurs up to 20% additional overhead. To address this, we propose a wave-aware token-splitting strategy and a layer normalization–communication co-optimization framework. First, we introduce a novel dynamic token-splitting mechanism guided by execution wave profiling. Second, we design an RMSNorm–AllReduce fused kernel that concurrently executes normalization and collective communication using only 2–8 streaming multiprocessors (SMs). Third, we integrate Multimem instruction optimization with Hopper-architecture-specific scheduling. Experiments on NVLink-connected systems demonstrate up to 29% end-to-end latency reduction and 26% throughput improvement; notably, in certain configurations, performance surpasses the ideal zero-communication baseline.

Technology Category

Application Category

📝 Abstract
Distributed inference of large language models (LLMs) can introduce overheads of up to 20% even over GPUs connected via high-speed interconnects such as NVLINK. Multiple techniques have been proposed to mitigate these overheads by decomposing computations into finer-grained tasks and overlapping communication with sub-tasks as they complete. However, fine-grained decomposition of a large computation into many smaller computations on GPUs results in overheads. Further, the communication itself uses many streaming multiprocessors (SMs), adding to the overhead. We present TokenWeave to address these challenges. TokenWeave proposes a Token-Splitting technique that divides the tokens in the inference batch into two approximately equal subsets in a wave-aware manner. The computation of one subset is then overlapped with the communication of the other. In addition, TokenWeave optimizes the order of the layer normalization computation with respect to communication operations and implements a novel fused AllReduce-RMSNorm kernel carefully leveraging Multimem instruction support available on NVIDIA Hopper GPUs. These optimizations allow TokenWeave to perform communication and RMSNorm using only 2-8 SMs. Moreover, our kernel enables the memory bound RMSNorm to be overlapped with the other batch's computation, providing additional gains. Our evaluations demonstrate up to 29% latency gains and up to 26% throughput gains across multiple models and workloads. In several settings, TokenWeave results in better performance compared to an equivalent model with all communication removed.
Problem

Research questions and friction points this paper is trying to address.

Reduces overhead in distributed LLM inference via compute-communication overlap
Optimizes layer normalization and communication order for efficiency
Enhances performance with fused AllReduce-RMSNorm kernel on Hopper GPUs
Innovation

Methods, ideas, or system contributions that make the work stand out.

Token-Splitting technique for overlapping compute-communication
Optimized layer normalization with communication operations
Fused AllReduce-RMSNorm kernel leveraging Hopper GPUs
🔎 Similar Papers
No similar papers found.