🤖 AI Summary
This work addresses gradient conflicts and high communication overhead arising from data heterogeneity in large-scale instruction tuning by proposing the MERIT framework. MERIT uniquely integrates conflict-aware task partitioning with parameter-space weighted merging: it estimates dataset-level gradient conflicts via PCA alignment, partitions tasks along principal component directions to enable communication-free parallel fine-tuning, and merges models through token-weighted averaging under a local quadratic approximation. Theoretical analysis reveals that this merging mechanism induces curvature-weighted variance reduction, spectral filtering, and implicit regularization. Experiments demonstrate that MERIT improves the average score across 136 vision tasks on Qwen2.5-VL-3B from 54.3 to 57.0, and scales effectively to a 7B model trained on 176 heterogeneous data sources, matching or surpassing centralized training performance.
📝 Abstract
Instruction tuning aligns large language models, including multimodal ones, with diverse user intents, but scaling to heterogeneous mixtures is hindered by gradient interference and bandwidth-heavy synchronization. We ask whether these two bottlenecks can be addressed jointly by training parts of the mixture independently and reconciling them once in parameter space. We develop a local quadratic theory inside a shared flat basin that yields three results: weight merging produces a curvature-weighted variance reduction; PCA-aligned conflict splitting maximizes this gain along high-curvature directions; and merging additionally acts as spectral filtering with implicit norm regularization. These results directly motivate MERIT, a decentralized merge-ready instruction-tuning pipeline that estimates dataset-level gradient conflicts, partitions the mixture along the top PCA conflict axes, fine-tunes each partition independently with no inter-partition communication, and merges once via token-weighted averaging. On Qwen2.5-VL-3B with 136 Vision-FLAN tasks, MERIT improves the 8-benchmark average from 54.3 (joint training) to 57.0. The same recipe scales to a 7B model on a 1.6M-example, 176-source mixture -- matching or exceeding centralized joint training with minimal cost overhead -- and transfers to text-only FLAN. Our code is available at https://github.com/naver-ai/merit.