🤖 AI Summary
To address suboptimal matrix multiplication performance in LLM inference on AWS Trainium’s systolic array architecture, this paper proposes a hardware-aware kernel optimization methodology: fusing computation and memory access operations, restructuring data layouts to eliminate explicit transposition, and designing a software-controllable multi-level caching strategy to maximize SRAM bandwidth utilization. The approach is deeply co-designed with Trainium’s dataflow and memory hierarchy, significantly reducing off-chip data movement overhead. Experimental evaluation across four mainstream large language models and nine datasets demonstrates an average 1.35× speedup (up to 2.22×) for matrix multiplication kernels and an average 1.66× end-to-end inference acceleration (up to 2.49×). This work establishes a reusable, low-level optimization paradigm for efficient LLM deployment on Trainium accelerators.
📝 Abstract
AI accelerators, customized to AI workloads, provide cost-effective and high-performance solutions for training and inference. Trainium, an AI accelerator recently developed by Amazon Web Services (AWS), provides an attractive option for LLM training and inference through its heterogeneous architecture. However, leveraging Trainium architecture for high performance can be challenging because of its systolic array architecture and special requirement on data layout. In this paper, we design high-performance matrix multiplication (matmul), a critical compute kernel, for LLM inference on Trainium. We introduce a series of techniques customized to Trainium based on kernel fusion and novel caching strategies to reduce data movement across the software-managed memory hierarchy, maximize SRAM bandwidth, and avoid expensive matrix transpose. Evaluating with nine datasets and four recent LLMs, we show that our system largely outperforms the state-of-the-art matmul implemented by AWS on Trainium: at the level of matmul kernel, it achieves an average 1.35x speedup (up to 2.22x), which translates to an average 1.66x speedup (up to 2.49x) for end-to-end LLM inference.