🤖 AI Summary
This work addresses the challenge of out-of-distribution (OOD) generalization when models face both correlation shifts across environments and diversity shifts caused by rare or difficult samples within environments. To tackle this, the paper proposes the ECTR framework, which unifies environment-level invariant learning with sample-level tail reweighting for the first time. Specifically, ECTR employs total-variation-based invariant risk minimization to capture features invariant across environments and introduces environment-conditioned tail reweighting to enhance robustness to hard examples. The framework is further extended to settings without environment labels by inferring latent environmental structures through a minimax optimization strategy. Extensive experiments across regression, tabular, time-series, and image classification benchmarks demonstrate that ECTR significantly improves worst-environment performance and overall OOD generalization.
📝 Abstract
Out-of-distribution (OOD) generalization remains challenging when models simultaneously encounter correlation shifts across environments and diversity shifts driven by rare or hard samples. Existing invariant risk minimization (IRM) methods primarily address spurious correlations at the environment level, but often overlook sample-level heterogeneity within environments, which can critically impact OOD performance. In this work, we propose Environment-Conditioned Tail Reweighting for Total Variation Invariant Risk Minimization (ECTR), a unified framework that augments TV-based invariant learning with environment-conditioned tail reweighting to jointly address both types of distribution shift. By integrating environment-level invariance with within-environment robustness, the proposed approach makes these two mechanisms complementary under mixed distribution shifts. We further extend the framework to scenarios without explicit environment annotations by inferring latent environments through a minimax formulation. Experiments across regression, tabular, time-series, and image classification benchmarks under mixed distribution shifts demonstrate consistent improvements in both worst-environment and average OOD performance.