🤖 AI Summary
To address numerical instability and suboptimal convergence efficiency in low-precision training of large language models (LLMs), this paper proposes a BF16 mixed-precision training framework augmented with stochastic rounding (SR). We establish, for the first time, the implicit regularization effect and theoretical convergence guarantees of SR within the Adam optimizer. Furthermore, we seamlessly extend BF16+SR to distributed training, enabling automatic, robust large-scale deployment. Empirical evaluation on 6.7B-model pretraining demonstrates that our method achieves significantly lower perplexity compared to standard BF16 and FP32 baselines, while improving throughput by 1.54× and reducing GPU memory consumption by 30%. This work introduces a new paradigm for LLM training—rigorously grounded in theory and validated in practice—that simultaneously delivers high accuracy, high efficiency, and low computational overhead.
📝 Abstract
As the parameters of Large Language Models (LLMs) have scaled to hundreds of billions, the demand for efficient training methods -- balancing faster computation and reduced memory usage without sacrificing accuracy -- has become more critical than ever. In recent years, various mixed precision strategies, which involve different precision levels for optimization components, have been proposed to increase training speed with minimal accuracy degradation. However, these strategies often require manual adjustments and lack theoretical justification. In this work, we leverage stochastic rounding (SR) to address numerical errors of training with low-precision representation. We provide theoretical analyses of implicit regularization and convergence under the Adam optimizer when SR is utilized. With the insights from these analyses, we extend previous BF16 + SR strategy to be used in distributed settings, enhancing the stability and performance for large scale training. Empirical results from pre-training models with up to 6.7B parameters, for the first time, demonstrate that our BF16 with SR strategy outperforms (BF16, FP32) mixed precision strategies, achieving better validation perplexity, up to $1.54 imes$ higher throughput, and $30%$ less memory usage.