π€ AI Summary
To address accelerator memory and I/O bandwidth bottlenecks in training high-resolution, long-horizon global weather forecasting AI models, this paper proposes Jigsaw, a parallel training framework that unifies domain parallelism and tensor parallelism to eliminate redundant memory consumption. Jigsaw enables strong scaling under compute-bound regimes and superscalar weak scaling under I/O-bound regimes. Integrated with the lightweight MLP-based architecture WeatherMixer and a customized model-parallel strategy, it achieves efficient training on a 256-GPU cluster. Experiments demonstrate sustained peak performance of 9β11 PFLOPs (23%β28% of theoretical peak), with scaling efficiency improved to 68%β72% over baseline approaches without model parallelism. This represents a significant breakthrough in scalability for large-scale meteorological AI training, overcoming key limitations of conventional architectures.
π Abstract
AI-based methods have revolutionized atmospheric forecasting, with recent successes in medium-range forecasting spurring the development of climate foundation models. Accurate modeling of complex atmospheric dynamics at high spatial resolutions and longer lead times requires large neural networks and gigabyte-sized data samples, making accelerator memory and I/O-bandwidth the bottlenecks for model training. We introduce WeatherMixer, a multi-layer-perceptron-based architecture whose workload scales linearly with input size, allowing the model to learn global weather phenomena at accuracies similar to numerical weather prediction. To cope with the computational demand, we propose Jigsaw, a novel model parallelization scheme that employs both domain and tensor parallelism, eliminating memory redundancy. Jigsaw exceeds state-of-the-art performance in strong scaling in compute-communication-limited systems and achieves superscalar weak scaling in I/O-bandwidth-limited systems. We scale training to 256 GPUs, reaching peak performances of 9 and 11 PFLOPs, 23% and 28% of theoretical peaks, achieving 68% and 72% scaling efficiency versus 51% without model parallelism.