🤖 AI Summary
In deep Transformers, improper weight initialization induces rank collapse (representation homogenization) and entropy collapse (excessive attention concentration) in self-attention layers, leading to training failure. This work establishes the first unified analytical theory characterizing signal propagation through a full Transformer block—including LayerNorm and residual connections—under initialization, and rigorously derives critical conditions for both collapses. We identify a novel entropy collapse phenomenon under high-variance initialization and propose a theoretically grounded framework based on a random energy model. This enables construction of an analytically tractable trainability phase diagram and yields quantitative design principles for weight scaling and residual coefficients. Empirical validation on TinyStories demonstrates significantly improved training stability for BERT-style models, with theoretical predictions deviating by less than 5% from observed behavior.
📝 Abstract
Finding the right initialisation for neural networks is crucial to ensure smooth training and good performance. In transformers, the wrong initialisation can lead to one of two failure modes of self-attention layers: rank collapse, where all tokens collapse into similar representations, and entropy collapse, where highly concentrated attention scores lead to training instability. While the right initialisation has been extensively studied in feed-forward networks, an exact description of signal propagation through a full transformer block has so far been lacking. Here, we provide an analytical theory of signal propagation through vanilla transformer blocks with self-attention layers, layer normalisation, skip connections and ReLU MLP. To treat the self-attention layer, we draw on a formal parallel with the Random Energy Model from statistical physics. We identify and characterise two regimes governed by the variance of the query and key initialisations: a low-variance regime, where we recover the known rank collapse behaviour; and a previously unexplored high-variance regime, where signal is preserved but extit{entropy collapse} occurs. In the low-variance regime, we calculate the critical strength for the residual connection to ensure signal propagation. Our theory yields trainability diagrams that identify the correct choice of initialisation hyper-parameters for a given architecture. Experiments with BERT-style models trained on TinyStories validate our predictions. Our theoretical framework gives a unified perspective on the two failure modes of self-attention and gives quantitative predictions on the scale of both weights and residual connections that guarantees smooth training.