🤖 AI Summary
This work investigates the optimization dynamics of linear bilingual models for next-token prediction under Zipf-distributed token frequencies (frequency ∝ 1/k^α), focusing on gradient descent (GD) and sign gradient descent (SignGD)—a simplified proxy for Adam—with varying tail exponent α. Methodologically, it integrates power-law modeling, spectral analysis of feature matrices, and deterministic first-order optimization theory. Theoretically, it establishes the first α-explicit optimization scaling laws: when α ≤ 1—characteristic of real-world heavy-tailed text distributions—GD exhibits near-linear iteration complexity in vocabulary dimension d, leading to severe slowdown in high dimensions; in contrast, SignGD achieves O(√d) convergence, substantially mitigating the heavy-tail challenge. Crucially, α = 1 is identified as the worst-case for GD, whereas SignGD remains robust and accelerates convergence by over an order of magnitude in large-vocabulary settings. These results provide foundational insights into optimizer selection for language modeling under realistic data distributions.
📝 Abstract
Recent works have highlighted optimization difficulties faced by gradient descent in training the first and last layers of transformer-based language models, which are overcome by optimizers such as Adam. These works suggest that the difficulty is linked to the heavy-tailed distribution of words in text data, where the frequency of the $k$th most frequent word $pi_k$ is proportional to $1/k$, following Zipf's law. To better understand the impact of the data distribution on training performance, we study a linear bigram model for next-token prediction when the tokens follow a power law $pi_k propto 1/k^alpha$ parameterized by the exponent $alpha>0$. We derive optimization scaling laws for deterministic gradient descent and sign descent as a proxy for Adam as a function of the exponent $alpha$. Existing theoretical investigations in scaling laws assume that the eigenvalues of the data decay as a power law with exponent $alpha>1$. This assumption effectively makes the problem ``finite dimensional'' as most of the loss comes from a few of the largest eigencomponents. In comparison, we show that the problem is more difficult when the data have heavier tails. The case $alpha = 1$ as found in text data is ``worst-case'' for gradient descent, in that the number of iterations required to reach a small relative error scales almost linearly with dimension. While the performance of sign descent also depends on the dimension, for Zipf-distributed data the number of iterations scales only with the square-root of the dimension, leading to a large improvement for large vocabularies.