🤖 AI Summary
This work identifies a fundamental limitation of Sharpness-Aware Minimization (SAM): in the infinite-width limit, SAM degenerates to effective perturbations only on the final layer, hindering scalability with model width. To address this, we propose μP² (Maximal Update and Perturbation Parameterization), the first parameterization that achieves layer-wise balanced scaling of both update magnitudes and perturbation strengths—ensuring all layers actively contribute to robust optimization in wide networks. Grounded in Tensor Programs theory, μP² integrates layer-adaptive perturbations, Adaptive SAM, and the generalized SAM-ON framework. Experiments across MLPs, ResNets, and Vision Transformers demonstrate that μP² significantly improves generalization and enables stable cross-model transfer of jointly tuned hyperparameters—specifically learning rate and perturbation radius—across varying model scales. Thus, μP² provides both a theoretically grounded, scalable framework and a practical solution for robust training of large-scale neural networks.
📝 Abstract
Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behaviour is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a stable parameterization with layerwise perturbation scaling, which we call $ extit{Maximal Update and Perturbation Parameterization}$ ($mu$P$^2$), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that $mu$P$^2$ achieves hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales. Moreover, we provide an intuitive condition to derive $mu$P$^2$ for other perturbation rules like Adaptive SAM and SAM-ON, also ensuring balanced perturbation effects across all layers.