PiKE: Adaptive Data Mixing for Multi-Task Learning Under Low Gradient Conflicts

📅 2025-02-10
📈 Citations: 0
Influential: 0
📄 PDF
🤖 AI Summary
In multi-task learning, static or heuristic data mixing strategies prove inefficient under weak gradient conflict—e.g., multilingual or multi-domain large language model (LLM) pretraining. To address this, we propose GradMix: an adaptive task sampling method grounded in forward gradient interaction. Its core innovation lies in the first use of gradient inner products to estimate *forward synergistic effects* among tasks, enabling online task weight optimization. We provide theoretical convergence guarantees and formal fairness assurances across tasks. GradMix incurs zero additional computational overhead and requires no architectural modifications. Extensive experiments on large-scale LLM pretraining demonstrate that GradMix significantly accelerates convergence and consistently outperforms mainstream baselines—achieving higher average downstream task performance.

Technology Category

Application Category

📝 Abstract
Modern machine learning models are trained on diverse datasets and tasks to improve generalization. A key challenge in multitask learning is determining the optimal data mixing and sampling strategy across different data sources. Prior research in this multi-task learning setting has primarily focused on mitigating gradient conflicts between tasks. However, we observe that many real-world multitask learning scenarios-such as multilingual training and multi-domain learning in large foundation models-exhibit predominantly positive task interactions with minimal or no gradient conflict. Building on this insight, we introduce PiKE (Positive gradient interaction-based K-task weights Estimator), an adaptive data mixing algorithm that dynamically adjusts task contributions throughout training. PiKE optimizes task sampling to minimize overall loss, effectively leveraging positive gradient interactions with almost no additional computational overhead. We establish theoretical convergence guarantees for PiKE and demonstrate its superiority over static and non-adaptive mixing strategies. Additionally, we extend PiKE to promote fair learning across tasks, ensuring balanced progress and preventing task underrepresentation. Empirical evaluations on large-scale language model pretraining show that PiKE consistently outperforms existing heuristic and static mixing strategies, leading to faster convergence and improved downstream task performance.
Problem

Research questions and friction points this paper is trying to address.

Optimize data mixing in multitask learning
Leverage positive gradient interactions
Ensure fair task representation in training
Innovation

Methods, ideas, or system contributions that make the work stand out.

Adaptive data mixing algorithm
Leverages positive gradient interactions
Ensures fair learning across tasks
🔎 Similar Papers
No similar papers found.