🤖 AI Summary
Conventional wisdom holds that training language models with very small batch sizes—especially batch size 1—leads to unstable pretraining and fine-tuning, necessitating gradient accumulation to increase the effective batch size.
Method: This work challenges that paradigm by systematically investigating the feasibility and stability of training with minimal batch sizes (down to 1). We propose a principled Adam hyperparameter scaling rule that jointly adapts learning rate, β₁, and β₂ with batch size. Furthermore, we demonstrate, for the first time on standard language modeling tasks, stable pure SGD training without momentum or optimizer state.
Results: Our small-batch approach achieves superior or comparable performance at equal or lower computational cost, reduces memory footprint significantly, and exhibits greater hyperparameter robustness. Crucially, we show that gradient accumulation is unnecessary in single-device settings, establishing a new efficient, lightweight, accumulation-free small-batch training paradigm.
📝 Abstract
Conventional wisdom dictates that small batch sizes make language model pretraining and fine-tuning unstable, motivating gradient accumulation, which trades off the number of optimizer steps for a proportional increase in batch size. While it is common to decrease the learning rate for smaller batch sizes, other hyperparameters are often held fixed. In this work, we revisit small batch sizes all the way down to batch size one, and we propose a rule for scaling Adam hyperparameters to small batch sizes. We find that small batch sizes (1) train stably, (2) are consistently more robust to hyperparameter choices, (3) achieve equal or better per-FLOP performance than larger batch sizes, and (4) notably enable stable language model training with vanilla SGD, even without momentum, despite storing no optimizer state. Building on these results, we provide practical recommendations for selecting a batch size and setting optimizer hyperparameters. We further recommend against gradient accumulation unless training on multiple devices with multiple model replicas, bottlenecked by inter-device bandwidth.