🤖 AI Summary
This work addresses the instability and potential collapse in training caused by the mismatch between low-precision inference (e.g., FP8) and high-precision training (e.g., BF16), which introduces policy gradient bias. To mitigate this issue, the authors propose Adaptive Importance Sampling (AIS), the first method to incorporate a dynamic mixing coefficient into importance sampling. AIS continuously diagnoses weight reliability, distribution shift, and variance amplification during training to adaptively adjust the strength of gradient correction, thereby suppressing quantization-induced bias while preserving exploration capability. Implemented within the GRPO framework and combining FP8 rollouts with a BF16 trainer, the approach integrates a diagnostic module and an importance-weighted gradient interpolation mechanism. Experiments on LLaMA-8B-Instruct, Qwen3-8B, and Qwen3.5-9B demonstrate that AIS matches the performance of BF16 baselines while achieving 1.5–2.76× speedup in rollouts.
📝 Abstract
Reinforcement learning (RL) for large language models (LLMs) is dominated by the cost of rollout generation, which has motivated the use of low-precision rollouts (e.g., FP8) paired with a BF16 trainer to improve throughput and reduce memory pressure. This introduces a rollout-training mismatch that biases the policy gradient and can cause training to collapse outright on reasoning benchmarks. We show that the mismatch is non-stationary and acts as a double-edged sword: early in training it provides a stochastic exploration bonus, exposing the gradient to trajectories the trainer would otherwise under-sample, but the same perturbation transitions into a destabilizing source of bias as the policy concentrates.
To solve this, we propose Adaptive Importance Sampling (AIS), a correction framework that adjusts the strength of its intervention on a per-batch basis. AIS combines three real-time diagnostics, namely weight reliability, divergence severity, and variance amplification, into a single mixing coefficient that interpolates between the uncorrected and fully importance-weighted gradients, suppressing the destabilizing component of the mismatch while preserving its exploratory benefit. We integrate AIS into GRPO and evaluate it on the diffusion-based LLaDA-8B-Instruct and the autoregressive Qwen3-8B and Qwen3.5-9B across mathematical reasoning and planning benchmarks. AIS matches the BF16 baseline on most tasks while retaining the 1.5 to 2.76x rollout speedup of FP8.