🤖 AI Summary
This paper investigates in-context learning (ICL) in Transformers applied to $d$-dimensional mixed linear regression. Methodologically, it integrates high-dimensional statistics, functional analysis, and optimization dynamics, employing population mean-squared-error loss and multi-scale signal-to-noise ratio (SNR) modeling. Theoretically, it establishes the first provable ICL prediction error bound of $O(sqrt{d/n})$ under high SNR and an excess risk bound of $O(L/sqrt{B})$, explicitly characterizing dependencies on dimension $d$, sample size $n$, number of layers $L$, and batch size $B$. It further proves global convergence of gradient flow for single-layer linear self-attention. Experimentally, the proposed mechanism significantly outperforms baselines—including EM—on synthetic data. Collectively, this work provides the first rigorous, quantifiable theoretical framework for understanding Transformers’ implicit modeling capability in ICL for mixed linear regression.
📝 Abstract
We investigate the in-context learning capabilities of transformers for the $d$-dimensional mixture of linear regression model, providing theoretical insights into their existence, generalization bounds, and training dynamics. Specifically, we prove that there exists a transformer capable of achieving a prediction error of order $mathcal{O}(sqrt{d/n})$ with high probability, where $n$ represents the training prompt size in the high signal-to-noise ratio (SNR) regime. Moreover, we derive in-context excess risk bounds of order $mathcal{O}(L/sqrt{B})$ for the case of two mixtures, where $B$ denotes the number of training prompts, and $L$ represents the number of attention layers. The dependence of $L$ on the SNR is explicitly characterized, differing between low and high SNR settings. We further analyze the training dynamics of transformers with single linear self-attention layers, demonstrating that, with appropriately initialized parameters, gradient flow optimization over the population mean square loss converges to a global optimum. Extensive simulations suggest that transformers perform well on this task, potentially outperforming other baselines, such as the Expectation-Maximization algorithm.