NeuronMM: High-Performance Matrix Multiplication for LLM Inference on AWS Trainium

📅 2025-10-29
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address suboptimal matrix multiplication performance in LLM inference on AWS Trainium’s systolic array architecture, this paper proposes a hardware-aware kernel optimization methodology: fusing computation and memory access operations, restructuring data layouts to eliminate explicit transposition, and designing a software-controllable multi-level caching strategy to maximize SRAM bandwidth utilization. The approach is deeply co-designed with Trainium’s dataflow and memory hierarchy, significantly reducing off-chip data movement overhead. Experimental evaluation across four mainstream large language models and nine datasets demonstrates an average 1.35× speedup (up to 2.22×) for matrix multiplication kernels and an average 1.66× end-to-end inference acceleration (up to 2.49×). This work establishes a reusable, low-level optimization paradigm for efficient LLM deployment on Trainium accelerators.

Technology Category

Application Category

📝 Abstract
AI accelerators, customized to AI workloads, provide cost-effective and high-performance solutions for training and inference. Trainium, an AI accelerator recently developed by Amazon Web Services (AWS), provides an attractive option for LLM training and inference through its heterogeneous architecture. However, leveraging Trainium architecture for high performance can be challenging because of its systolic array architecture and special requirement on data layout. In this paper, we design high-performance matrix multiplication (matmul), a critical compute kernel, for LLM inference on Trainium. We introduce a series of techniques customized to Trainium based on kernel fusion and novel caching strategies to reduce data movement across the software-managed memory hierarchy, maximize SRAM bandwidth, and avoid expensive matrix transpose. Evaluating with nine datasets and four recent LLMs, we show that our system largely outperforms the state-of-the-art matmul implemented by AWS on Trainium: at the level of matmul kernel, it achieves an average 1.35x speedup (up to 2.22x), which translates to an average 1.66x speedup (up to 2.49x) for end-to-end LLM inference.
Problem

Research questions and friction points this paper is trying to address.

Optimizing matrix multiplication for LLM inference on AWS Trainium
Addressing data layout challenges in systolic array architecture
Reducing data movement through kernel fusion and caching strategies
Innovation

Methods, ideas, or system contributions that make the work stand out.

Kernel fusion reduces data movement across memory
Novel caching strategies maximize SRAM bandwidth usage
Avoids expensive matrix transpose operations
🔎 Similar Papers
No similar papers found.
D
Dinghong Song
University of California, Merced
J
Jierui Xu
University of Wisconsin, Madison
W
Weichu Yang
University of Wisconsin, Madison
Pengfei Su
Pengfei Su
Assistant Professor of Computer Science and Engineering, University of California, Merced
Programming LanguagesProgram AnalysisHigh-performance ComputingParallel Programming
D
Dong Li
University of California, Merced