π€ AI Summary
To address the scalability limitations of large-scale spatio-temporal graph neural networks (ST-GNNs) under GPU memory constraints and the lack of spatio-temporal awareness in existing distributed training frameworks, this paper proposes PGT-Iβthe first efficient distributed training framework specifically designed for ST-GNNs. Methodologically, PGT-I introduces (i) index-based batching and distributed index batching to reduce memory overhead and communication cost, and (ii) a dynamic runtime snapshot construction technique, enabling end-to-end ST-GNN training on the full PeMS dataset without graph partitioning for the first time. Built upon PyTorch Geometric Temporal, it integrates distributed data parallelism with spatio-temporal locality optimization. Evaluated on 128 GPUs, PGT-I achieves a 13.1Γ speedup over standard DDP and reduces peak GPU memory consumption by 89%, significantly alleviating the scalability bottleneck in large-scale ST-GNN training.
π Abstract
Spatiotemporal graph neural networks (ST-GNNs) are powerful tools for modeling spatial and temporal data dependencies. However, their applications have been limited primarily to small-scale datasets because of memory constraints. While distributed training offers a solution, current frameworks lack support for spatiotemporal models and overlook the properties of spatiotemporal data. Informed by a scaling study on a large-scale workload, we present PyTorch Geometric Temporal Index (PGT-I), an extension to PyTorch Geometric Temporal that integrates distributed data parallel training and two novel strategies: index-batching and distributed-index-batching. Our index techniques exploit spatiotemporal structure to construct snapshots dynamically at runtime, significantly reducing memory overhead, while distributed-index-batching extends this approach by enabling scalable processing across multiple GPUs. Our techniques enable the first-ever training of an ST-GNN on the entire PeMS dataset without graph partitioning, reducing peak memory usage by up to 89% and achieving up to a 13.1x speedup over standard DDP with 128 GPUs.