🤖 AI Summary
In multi-task learning (MTL), conflicting task gradients—both in direction and magnitude—hinder model performance and generalization. To address this, we propose SAMO, a lightweight optimization method built upon the Sharpness-Aware Minimization (SAM) framework. SAMO innovatively integrates global and local perturbations: it enforces layer-wise normalization to constrain local perturbations and approximates the gradient correction via a single forward pass, drastically reducing computational and memory overhead. This design simultaneously mitigates loss surface sharpness and harmonizes conflicting task gradients, thereby enhancing generalization. Evaluated on multiple standard MTL benchmarks, SAMO achieves superior performance over state-of-the-art methods at significantly lower computational cost, striking a better trade-off between efficiency and effectiveness.
📝 Abstract
Multi-task learning (MTL) enables a joint model to capture commonalities across multiple tasks, reducing computation costs and improving data efficiency. However, a major challenge in MTL optimization is task conflicts, where the task gradients differ in direction or magnitude, limiting model performance compared to single-task counterparts. Sharpness-aware minimization (SAM) minimizes task loss while simultaneously reducing the sharpness of the loss landscape. Our empirical observations show that SAM effectively mitigates task conflicts in MTL. Motivated by these findings, we explore integrating SAM into MTL but face two key challenges. While both the average loss gradient and individual task gradients-referred to as global and local information-contribute to SAM, how to combine them remains unclear. Moreover, directly computing each task gradient introduces significant computational and memory overheads. To address these challenges, we propose SAMO, a lightweight extbf{S}harpness- extbf{A}ware extbf{M}ulti-task extbf{O}ptimization approach, that leverages a joint global-local perturbation. The local perturbations are approximated using only forward passes and are layerwise normalized to improve efficiency. Extensive experiments on a suite of multi-task benchmarks demonstrate both the effectiveness and efficiency of our method. Code is available at https://github.com/OptMN-Lab/SAMO.