๐ค AI Summary
Pretraining large language models (LLMs) incurs prohibitive computational and memory overhead. Method: This paper proposes CoLA-Mโthe first framework to deeply integrate low-rank activation modeling into LLM pretraining. Instead of compressing model weights via conventional low-rank approximation, CoLA-M exploits the intrinsic low-rank structure of activation tensors. It jointly optimizes parameter count, FLOPs, and memory footprint through nonlinear coupling factorization of weight matrices, memory-aware gradient checkpointing, and structured sparse training. Contribution/Results: Evaluated on LLaMA variants (60Mโ7B parameters), CoLA-M achieves a 2ร reduction in computation, a 1.86ร increase in training throughput, and a 2ร compression in model sizeโwhile preserving full-rank performance. Inference latency and memory efficiency are also improved. CoLA-M establishes a novel paradigm for efficient LLM pretraining.
๐ Abstract
Large language models (LLMs) are revolutionizing many science and engineering fields. However, their huge model sizes impose extremely demanding needs of computational resources in the pre-training stage. Although low-rank factorizations can reduce model parameters, their direct application in LLM pre-training often lead to non-negligible performance loss. To address this fundamental challenge, we introduce CoLA and its memory-efficient implementation, CoLA-M. We leverage the low-rank structure observed widely in model activations, enforcing non-linear transformations between factorized weight matrices to reduce model size, boost model capacity and training efficiency. Experiments on LLaMA models with 60 million to 7 billion parameters show that CoLA reduces the computing cost by $f 2pmb{ imes}$ and improves training throughput by $f 1.86pmb{ imes}$ while maintaining full-rank level performance. CoLA-M further squeezes memory cost without sacrificing throughput, offering a pre-training approach with collectively superior parameter, computing, and memory efficiency. The LLMs produced are also $f 2pmb{ imes}$ smaller, enabling faster inference with lower memory cost on resource-constrained platforms