🤖 AI Summary
Existing on-policy distillation (OPD) methods rely on token-level supervision signals, which struggle to distinguish genuine reasoning discrepancies from superficial surface-form differences, leading to insufficient alignment between student and teacher reasoning paths. This work proposes Trajectory-aware On-Policy Distillation (TOPD), which leverages near-future trajectory information to identify critical divergence points and extends supervision from individual tokens to multi-step future tokens, thereby achieving trajectory-level alignment. By integrating reverse KL correction with distribution shift detection, TOPD substantially improves upon standard OPD, increasing average accuracy from 47.8% to 52.2%. Notably, it achieves 63.3% on AIME24 (+3.3%) and 53.3% on AIME25 (+6.6%).
📝 Abstract
On-Policy Distillation (OPD) improves large language model reasoning by training a student model on trajectories sampled from its own policy under teacher supervision. Although OPD operates on trajectories, its learning signal remains token-level: it identifies deviations through high-loss tokens and repairs them through local reverse-KL correction. We show that this "trajectory-sampled but token-learned" mechanism cannot reliably bridge student trajectories toward teacher trajectories. About 30% of high-loss tokens fall into the low-divergence regime, indicating that many are surface-form mismatches rather than real reasoning forks. Moreover, even truly divergent tokens are difficult to repair with isolated token-level supervision, since reasoning failures often unfold as short-horizon distributional drift. We propose Trajectory-aware OPD (TOPD), which uses near-future trajectory information to identify real divergent states and distribute guidance across multiple future tokens. Experiments show that suppressing non-divergent high-loss tokens improves standard OPD from 47.8% to 48.2% average accuracy, while TOPD further improves performance to 52.2%, with gains on AIME24 from 60.0% to 63.3% and AIME25 from 46.7% to 53.3%.