🤖 AI Summary
This work investigates the impact of training task diversity on the generalization and optimization dynamics of in-context learning (ICL), particularly addressing the lack of theoretical understanding in nonlinear function classes. To this end, we develop an analytical framework based on a low-rank Gaussian mixture model, where task vectors are modeled as distributions within a low-dimensional subspace. By integrating linear attention with subspace analysis, we theoretically demonstrate for the first time how task diversity shortens the ICL plateau phase and enhances out-of-distribution generalization. Our framework not only unifies the mechanistic explanation of how diversity improves ICL performance but also, through both theoretical analysis and empirical validation, establishes its effectiveness in nonlinear Transformers and complex function classes.
📝 Abstract
The transformer's emergent ability to perform in-context learning (ICL) has sparked a wide range of studies designed to understand its underlying mechanisms. Existing works often study how training task diversity, defined either as the number of ICL training task vectors or as the number of function classes from which the task vectors are drawn, shapes both the learning dynamics and generalization capabilities of ICL. While both definitions have uncovered many interesting phenomena, many observations under the latter definition remain theoretically unexplained. This paper presents a minimal analytical model under which these phenomena provably emerge from the properties of the training data. By modeling the training task vectors as a mixture of low-rank Gaussians, we show how training task diversity, defined by the number of non-overlapping columns between subspaces that parameterize the covariance matrices, improves both the generalization and optimization trajectory of ICL with linear attention. In particular, we show that our model can explain (i) why training with task diversity shortens the ICL plateau and (ii) why ICL appears to achieve out-of-distribution generalization. We conclude by empirically demonstrating how our results extend to nonlinear transformers and nonlinear function classes. Overall, our work presents a tractable framework to unify existing observations.