RSAM: Learning on manifolds with Riemannian Sharpness-aware Minimization

📅 2023-09-29
🏛️ arXiv.org
📈 Citations: 3
Influential: 0
📄 PDF
🤖 AI Summary
This paper addresses the limited robustness and generalization of models on Riemannian manifolds. It presents the first extension of Sharpness-Aware Minimization (SAM) to non-Euclidean geometric spaces. The key contributions are: (1) a curvature-aware sharpness definition on manifolds, yielding a tighter generalization error bound; (2) the Riemannian SAM (RSAM) algorithm, which rigorously adapts SAM to the Riemannian optimization framework by integrating manifold geometry with sharpness-aware gradient updates; and (3) empirical validation showing substantial improvements in generalization across image classification and contrastive learning tasks on CIFAR-10/100 and FGVC Aircraft. RSAM consistently outperforms baseline methods under both standard and adversarial training settings, demonstrating enhanced model stability and transferability. The implementation is publicly available.
📝 Abstract
Nowadays, understanding the geometry of the loss landscape shows promise in enhancing a model's generalization ability. In this work, we draw upon prior works that apply geometric principles to optimization and present a novel approach to improve robustness and generalization ability for constrained optimization problems. Indeed, this paper aims to generalize the Sharpness-Aware Minimization (SAM) optimizer to Riemannian manifolds. In doing so, we first extend the concept of sharpness and introduce a novel notion of sharpness on manifolds. To support this notion of sharpness, we present a theoretical analysis characterizing generalization capabilities with respect to manifold sharpness, which demonstrates a tighter bound on the generalization gap, a result not known before. Motivated by this analysis, we introduce our algorithm, Riemannian Sharpness-Aware Minimization (RSAM). To demonstrate RSAM's ability to enhance generalization ability, we evaluate and contrast our algorithm on a broad set of problems, such as image classification and contrastive learning across different datasets, including CIFAR100, CIFAR10, and FGVCAircraft. Our code is publicly available at url{https://t.ly/RiemannianSAM}.
Problem

Research questions and friction points this paper is trying to address.

Enhancing model generalization via sharpness-aware teleportation
Leveraging Riemannian manifolds for loss landscape geometry
Reducing generalization gap between population and empirical loss
Innovation

Methods, ideas, or system contributions that make the work stand out.

Sharpness-aware teleportation on Riemannian manifolds
Decomposes iteration into teleportation and sharpness-aware steps
Leverages Riemannian quotient manifold for generalization
🔎 Similar Papers
No similar papers found.