Adaptive Batch Size Schedules for Distributed Training of Language Models with Data and Model Parallelism

📅 2024-12-30
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
To address the efficiency–generalization trade-off in selecting batch sizes for distributed large language model training, this paper proposes an adaptive batch-size scheduling method compatible with both data and model parallelism. Methodologically, it introduces a general, scalable dynamic scheduling mechanism that overcomes limitations of conventional fixed or warmup-based batch-size strategies. Theoretically, it establishes the first convergence guarantee for adaptive batch-size schemes with the Adam optimizer under non-convex objectives. System-wise, it implements coordinated sharding of parameters, gradients, and optimizer states using PyTorch’s Fully Sharded Data Parallel (FSDP). Empirical evaluation on Llama-series models (≤3B) during pretraining demonstrates substantial improvements in convergence speed and final generalization performance over constant and warmup baselines.

Technology Category

Application Category

📝 Abstract
An appropriate choice of batch sizes in large-scale model training is crucial, yet it involves an intrinsic yet inevitable dilemma: large-batch training improves training efficiency in terms of memory utilization, while generalization performance often deteriorates due to small amounts of gradient noise. Despite this dilemma, the common practice of choosing batch sizes in language model training often prioritizes training efficiency -- employing either constant large sizes with data parallelism or implementing batch size warmup schedules. However, such batch size schedule designs remain heuristic and often fail to adapt to training dynamics, presenting the challenge of designing adaptive batch size schedules. Given the abundance of available datasets and the data-hungry nature of language models, data parallelism has become an indispensable distributed training paradigm, enabling the use of larger batch sizes for gradient computation. However, vanilla data parallelism requires replicas of model parameters, gradients, and optimizer states at each worker, which prohibits training larger models with billions of parameters. To optimize memory usage, more advanced parallelism strategies must be employed. In this work, we propose general-purpose and theoretically principled adaptive batch size schedules compatible with data parallelism and model parallelism. We develop a practical implementation with PyTorch Fully Sharded Data Parallel, facilitating the pretraining of language models of different sizes. We empirically demonstrate that our proposed approaches outperform constant batch sizes and heuristic batch size warmup schedules in the pretraining of models in the Llama family, with particular focus on smaller models with up to 3 billion parameters. We also establish theoretical convergence guarantees for such adaptive batch size schedules with Adam for general smooth nonconvex objectives.
Problem

Research questions and friction points this paper is trying to address.

Automatic Batch Size Adjustment
Large-scale Language Model Training
Distributed Training Strategies
Innovation

Methods, ideas, or system contributions that make the work stand out.

Batch Adaptation
Parallel Strategies
Adam Optimizer Convergence
🔎 Similar Papers
No similar papers found.
Tim Tsz-Kit Lau
Tim Tsz-Kit Lau
University of Pennsylvania
Machine LearningOptimizationStatisticsArtificial Intelligence
W
Weijian Li
Department of Computer Science, Northwestern University, Evanston, IL 60208, USA
Chenwei Xu
Chenwei Xu
Northwestern University
Deep LearningMachine Learning
H
Han Liu
Department of Computer Science, Northwestern University, Evanston, IL 60208, USA; Department of Statistics and Data Science, Northwestern University, Evanston, IL 60208, USA
Mladen Kolar
Mladen Kolar
University of Southern California
Machine learningStatistics