π€ AI Summary
This work uncovers the mechanisms and conditions under which neural networks undergo feature forgetting during prolonged training. Focusing on infinitely wide two-layer networks trained with large-batch stochastic gradient descent, we formulate a multiscale differential equation framework that, for the first time, integrates fastβslow dynamics with critical manifold theory. We demonstrate that feature forgetting is driven primarily by data nonlinearity and modulated by the initial scale of the second-layer weights. Leveraging tensor programs and singular perturbation theory, we derive a scaling law for forgetting and precisely characterize its onset: stronger leading nonlinear components intensify forgetting, whereas increasing the initial magnitude of second-layer weights effectively mitigates it. Numerical experiments corroborate our theoretical predictions, offering a novel perspective on dynamic forgetting in deep learning.
π Abstract
The dynamics of gradient-based training in neural networks often exhibit nontrivial structures; hence, understanding them remains a central challenge in theoretical machine learning. In particular, a concept of feature unlearning, in which a neural network progressively loses previously learned features over long training, has gained attention. In this study, we consider the infinite-width limit of a two-layer neural network updated with a large-batch stochastic gradient, then derive differential equations with different time scales, revealing the mechanism and conditions for feature unlearning to occur. Specifically, we utilize the fast-slow dynamics: while an alignment of first-layer weights develops rapidly, the second-layer weights develop slowly. The direction of a flow on a critical manifold, determined by the slow dynamics, decides whether feature unlearning occurs. We give numerical validation of the result, and derive theoretical grounding and scaling laws of the feature unlearning. Our results yield the following insights: (i) the strength of the primary nonlinear term in data induces the feature unlearning, and (ii) an initial scale of the second-layer weights mitigates the feature unlearning. Technically, our analysis utilizes Tensor Programs and the singular perturbation theory.