π€ AI Summary
To address the challenge of scaling reinforcement learning (RL) for large language model (LLM) inference in massive distributed environments, this paper proposes RLAXβa scalable RL framework designed for TPU clusters. RLAX employs a parameter-server architecture with preemptible training and dynamic fault recovery. It integrates three key innovations: efficient synthetic data construction, a multi-algorithm-compatible distributed RL training pipeline, and fine-grained model weight synchronization. Evaluated on 1,024 v5p TPUs, RLAX improves the pass@8 accuracy of QwQ-32B by 12.8% in just 12 hours and 48 minutes, achieving significantly accelerated convergence and high training robustness. RLAX delivers a system-level solution for efficient, stable, and large-scale RL-based alignment of LLMs.
π Abstract
Reinforcement learning (RL) has emerged as the de-facto paradigm for improving the reasoning capabilities of large language models (LLMs). We have developed RLAX, a scalable RL framework on TPUs. RLAX employs a parameter-server architecture. A master trainer periodically pushes updated model weights to the parameter server while a fleet of inference workers pull the latest weights and generates new rollouts. We introduce a suite of system techniques to enable scalable and preemptible RL for a diverse set of state-of-art RL algorithms. To accelerate convergence and improve model quality, we have devised new dataset curation and alignment techniques. Large-scale evaluations show that RLAX improves QwQ-32B's pass@8 accuracy by 12.8% in just 12 hours 48 minutes on 1024 v5p TPUs, while remaining robust to preemptions during training.