🤖 AI Summary
This work addresses the memory bottleneck and inefficient on-chip data movement that hinder inference in large-scale attention models, such as Mixture-of-Experts (MoE). The authors propose FlatAttention, the first approach to co-optimize collective communication primitives from on-chip interconnect networks with attention dataflow, enabling efficient support for diverse attention variants on tile-based accelerators. FlatAttention substantially reduces high-bandwidth memory (HBM) accesses, effectively mitigating the memory wall. On a 32×32 tile configuration, it achieves 86% compute utilization for compute-bound attention and 78% HBM bandwidth utilization for memory-bound workloads. Compared to FlashAttention-3, FlatAttention delivers a 4.1× speedup and reduces HBM traffic by 16×. In end-to-end DeepSeek-v3 inference, it improves throughput by 1.9× and reduces per-token latency for single users by 1.4×.
📝 Abstract
Attention accounts for an increasingly dominant fraction of total computation during inference for mixture-of-experts (MoE) models, making efficient acceleration critical. Emerging domain-specific accelerators for large model inference are shifting toward chip-scale and wafer-scale tile-based architectures. Tiles contain large matrix and vector engines and are connected through on-chip interconnects, which support tile-to-tile traffic to reduce the tile-to-main-memory traffic bottleneck. Hence, dataflow management is crucial to achieve high utilization. We propose FlatAttention, a dataflow for modern attention variants on tile-based accelerators. FlatAttention minimizes expensive high-bandwidth memory (HBM) accesses by exploiting collective primitives integrated into the on-chip network fabric, achieving up to 92.3% utilization, 4.1x speedup over FlashAttention-3, and 16x lower HBM traffic. On a 32x32 tile configuration with peak performance comparable to NVIDIA GH200, FlatAttention generalizes across multiple attention variants, achieving an average of 86% utilization for compute-bound attentions and 78% HBM bandwidth utilization for memory-bound ones, resulting in an average 1.9x speedup over attention implementations on GH200. Finally, we evaluate end-to-end DeepSeek-v3 FP8 decoding with FlatAttention on a wafer-scale multi-die system, achieving a 1.9x improvement in system throughput and a 1.4x reduction in per-user token output latency, despite operating with 1.5x lower peak system performance compared to the state-of-the-art solution.