🤖 AI Summary
This work addresses the computational and memory bottlenecks in accurately approximating the Hessian matrix in modern deep networks. The authors propose a novel approach that constructs a structured Hessian approximation from a single gradient by analytically averaging over the symmetry group actions on the weight space that leave the loss invariant. This approximation enables efficient estimation, storage, and inversion, and—by explicitly incorporating weight symmetries into curvature modeling—provides a unified perspective that subsumes existing methods such as Shampoo and Muon. Empirical evaluations demonstrate that the framework consistently enhances second-order optimization across diverse architectures and small-scale language models, offering controllable accuracy-computation trade-offs and extensibility to downstream tasks including uncertainty quantification and continual learning.
📝 Abstract
Many machine learning techniques rely on approximating a loss function's curvature, but this is notoriously hard to do at the scale of modern deep networks. Surprisingly, no previous work has exploited the curvature constraints that arise from well known weight-space symmetries in loss landscapes. By analytically averaging over group actions that leave the loss invariant, we construct structured Hessian approximations from single gradients that can be tractably estimated, stored, and inverted. The choice of user-specified symmetry group directly governs the trade-off between approximation accuracy and computational cost. Moreover, our framework provides a unifying theoretical lens for viewing existing methods; in particular, a specific choice of symmetry group recovers Shampoo/Muon-like curvature estimates. We validate our method on a range of network architectures, and deploy it to second-order optimization benchmarks, including a small language model. Our curvature estimation framework might find applications in other machine learning problems such as uncertainty estimation, continual learning, compression/pruning, training data attribution, and more.