π€ AI Summary
This work addresses a critical limitation in existing policy distillation methods, wherein the use of stop-gradient operations introduces bias into advantage estimation, thereby impeding effective knowledge transfer from teacher to student models. Building upon the framework of f-divergences, the authors develop a unified optimization perspective that elucidates the origin of this bias under general f-divergence formulations. They propose OPD+, an unbiased advantage estimation method that accommodates a variety of divergence choices and employs a reinforcement learningβstyle objective function. Empirical evaluations demonstrate that OPD+ significantly outperforms conventional KL-based distillation on benchmarks involving mathematical reasoning and tool usage, leading to substantial improvements in student model performance.
π Abstract
On-policy distillation (OPD) is a widely used technique to transfer capabilities from capable teacher language models to the base student models, and can be formulated in a reinforcement learning style objective using student generated rollouts. Yet, despite the divergence reward being dependent on student model likelihood, existing works usually adopt a stop gradient design primarily for stability, which makes the resulting advantage estimation questionable. In this work, we provide a generic optimization framework based on f-divergence between the student and teacher, and mathematically revisit whether such design space is valid. We prove that general stop-gradient operation would lead to biased estimates of the reward objective and corresponding gradient for general divergence functions. We propose OPD+, the corrected version of OPD that demonstrates improved performance over the baseline KL approach and also supports the choice of various f-divergence. We validate our findings on mathematical reasoning and tool-use benchmarks.