M4/Transformer Block
L15

Residual Connections and RMSNorm

14 min
Why do layers need residual connections?

A transformer layer transforms hidden states — but if each layer replaced the input entirely, information from earlier layers could be lost. Worse, gradients would have to pass through every transformation to reach early layers during training, making learning unstable.

A residual connection solves this: the layer's output is added to its input. The layer only needs to learn the change, not the entire new representation. Information passes through even if the layer does nothing useful.

But addition can cause scale to drift — after many layers, the vectors may grow without bound. RMSNorm (Root Mean Square Normalization) fixes this by rescaling vectors to a consistent magnitude before each sub-layer. In modern transformers, the norm is applied before the transform (pre-norm), not after.

The pattern inside a block is: norm → transform → add residual, repeated for both attention and FFN.

RMSNorm first normalizes a vector so its root-mean-square is 1, then applies a learned per-dimension scale. Given a vector x of dimension d:

RMS(x) = √((1/d) × ∑ xi²)
RMSNorm(x) = (x / RMS(x)) × γ

The division by RMS(x) normalizes the vector magnitude. The learned scale parameter γ (one per dimension) lets the model adjust the normalized output — it can re-amplify dimensions that matter and suppress ones that do not.

Why RMSNorm instead of the more common LayerNorm? LayerNorm also subtracts the mean before normalizing, which requires computing both mean and variance. RMSNorm skips the mean subtraction. In practice, this is slightly faster (one fewer pass over the data) and works equally well for transformer hidden states, which is why most modern LLMs use it.

The original transformer paper applied normalization after the sub-layer: transform → add residual → norm. Most modern LLMs use pre-norm: norm → transform → add residual.

The reason is training stability, specifically gradient flow. In post-norm, the output of each block is norm(x + transform(x)). During backpropagation, gradients from higher layers must pass through the norm layer to reach lower layers. The norm's Jacobian can scale gradients up or down unpredictably, causing training instability in deep networks.

In pre-norm, the output is x + transform(norm(x)). The residual connection provides a direct, unmodified path for gradients: they flow through the addition without passing through any norm or transform. Each sub-layer only needs to learn a useful delta to add — if a layer has nothing useful to contribute, it can learn weights close to zero and the gradient still flows cleanly through the residual. This makes training much more stable for 32+ layer models.

One token's hidden state through the attention sub-layer:

input: x = [1.0, 2.0, 3.0, 4.0]
step 1: RMS(x) = √((1+4+9+16)/4) = √7.5 ≈ 2.74
        x/RMS = [0.37, 0.73, 1.10, 1.46]   (× γ=[1,1,1,1] in this example)
step 2: delta = attention(x_norm) = [0.1, -0.3, 0.5, 0.2]
output: x + delta = [1.1, 1.7, 3.5, 4.2]

The original input x flows through unchanged and the layer's contribution is simply added on top.

Input: [n_tokens, d_model]
After RMSNorm: [n_tokens, d_model]   (same shape, rescaled values)
After transform + residual: [n_tokens, d_model]
Shape never changes. Norm and residual are elementwise operations.

The residual connection is a simple addition:

output = x + sublayer(RMSNorm(x))

RMSNorm rescales each vector by its root mean square:

RMSNorm(x)ᵢ = (xᵢ / RMS(x)) × gᵢ
RMS(x) = √(mean(x²))
g is a learned per-element scale parameter

In llama.cpp, the build_norm() function in src/llama-graph.cpp applies RMSNorm (or LayerNorm, depending on the model). You will see it called just before build_attn() and build_ffn() in every layer.

RMSNorm is cheap — it only needs a mean-of-squares, a square root, and an elementwise multiply. The residual addition is even cheaper. These operations are negligible compared to the matrix multiplies in attention and FFN, but they are critical for training stability and inference quality.

Check Yourself
conceptualQ1

Suppose you remove all residual connections from a 32-layer transformer. What would you expect to happen to the hidden states in the deepest layers?

shapeQ2

In a pre-norm transformer block, where does RMSNorm appear?