🤖 AI Summary
This work addresses the high computational and memory overhead in long-context reinforcement learning caused by redundant recomputation of shared prompt prefixes across multiple trajectories. The authors propose a scheduler-level prefix reuse mechanism that decouples prefix and suffix computation: the prefix undergoes only a single forward and backward pass, with its key/value (K/V) states and corresponding gradients (gK/gV) cached for reuse by microbatches of suffixes. This approach is the first to achieve training semantics equivalent to the baseline at the scheduling level while remaining compatible with diverse parallelism strategies—including tensor, expert, pipeline, and data parallelism—and preserving MoE auxiliary losses. Experiments on Llama3-8B demonstrate up to a 4.395× speedup, a 59.1% reduction in peak Phase-B memory usage, and an increase in total token capacity from 17,920 to 29,696.
📝 Abstract
GRPO- and PPO-style LLM post-training commonly sample multiple trajectories from the same prompt and then train on the resulting group. In long-context RL workloads, this shared prompt-side prefix can contain retrieved passages, visual tokens, tool schemas, system instructions, or task context, while the full rollout group is still too large to pack into one training microbatch. Standard dense trainers therefore recompute the same prefix forward and backward for every trajectory. We present a schedule-level reuse mechanism that decouples prefix and suffix computation. The schedule runs prefix forward once, executes suffixes as ordinary microbatches while reading prefix K/V and accumulating prefix-side gK/gV , and then runs prefix backward once on the accumulated gradient cache. This reordered schedule is equivalent to baseline training over real arithmetic and aligns numerically within finite-precision tolerance. Because only K/V and gK/gV are hot during suffix computation, the approach offloads dormant prefix activations, integrates with TP/EP/CP/PP and DP-style placement at the execution level, and preserves aux-loss-based MoE router semantics through logical prefix-token accounting. On dense Llama3-8B, Qwen3-8B, and MoE Qwen3-MoE-30B-A3B configurations, the schedule matches optimizer updates across TP/CP/PP/EP combinations, aligns on a 100-step real RL trace replay, reaches up to 4.395x speedup (2.930x under a conservative compile-on comparison) as prefix ratio and rollout group size grow, and reduces Phase-B peak HBM by up to 59.1%, extending the Llama3-8B capacity frontier from 17,920 to 29,696 total tokens.