M4/Transformer Block
L18A

Multi-Head Attention Runs Several Smaller Attentions in Parallel

16 min
Why split attention into multiple heads?

A single attention head can only learn one pattern of "what to attend to." But language has many simultaneous relationships — syntactic (subject-verb), semantic (pronoun-referent), positional (adjacent words). One set of Q/K/V weights cannot capture all of these at once.

Multi-head attention splits the Q, K, and V projections into h independent heads, each operating in a smaller subspace. Head 1 might learn to attend to the previous word. Head 2 might track the subject of the sentence. Head 3 might focus on punctuation boundaries. They all run in parallel.

After each head computes its own attention output, the results are concatenated back into one vector and passed through a final output projection (W_O). This is not the same as running one big attention head — each head has its own learned Q/K/V weights and attends to different things.

In practice, the model does not run h separate attention operations. Instead, it performs one large projection and then reshapes the result:

  1. Project: Multiply hidden states [n_tokens, d_model] by WQ [d_model, d_model] to get a full Q tensor of shape [n_tokens, d_model].
  2. Reshape: View the [n_tokens, d_model] tensor as [n_tokens, n_heads, d_head]. This is a zero-cost operation — the data stays in the same memory, only the indexing changes.
  3. Per-head attention: Each head operates on its own [n_tokens, d_head] slice independently. The attention computation runs in parallel across all heads.
  4. Concat + project: Concatenate all head outputs back to [n_tokens, d_model] and multiply by WO [d_model, d_model].

The key realization: the "split into heads" is a reshape, not a copy or a separate computation. The columns of WQ are implicitly partitioned into groups of d_head, and each group corresponds to one head. This is why model code often shows one large matmul followed by a reshape, rather than h small matmuls.

d_model = 8, n_heads = 2, so d_head = 4:

hidden state per token: [_, _, _, _, _, _, _, _]   (8 dims)
head 1: Q₁, K₁, V₁ each [n_tokens, 4] → attn output₁ [n_tokens, 4]
head 2: Q₂, K₂, V₂ each [n_tokens, 4] → attn output₂ [n_tokens, 4]
concat: [output₁ | output₂] = [n_tokens, 8]
project: concat × W_O → [n_tokens, 8]

Each head works in 4 dimensions instead of 8. The outputs are recombined to full width.

Before split:
Q, K, V projections: [n_tokens, d_model] → [n_tokens, n_heads × d_head]
Per head:
Qᵢ, Kᵢ, Vᵢ: [n_tokens, d_head]
Scores: [n_tokens, n_tokens] per head
Head output: [n_tokens, d_head]
After merge:
Concatenated: [n_tokens, n_heads × d_head]
After W_O: [n_tokens, d_model]

Each head runs independently, then outputs are concatenated and projected:

headᵢ = attention(Qᵢ, Kᵢ, Vᵢ)   // [n_tokens, d_head]
concat = [head₁ | head₂ | ... | headₕ]   // [n_tokens, h × d_head]
output = concat ⋅ W_O   // [n_tokens, d_model]

where d_head = d_model / n_heads (typically). The total Q/K/V projection parameters equal those of a single-head attention (since h × d_head = d_model), plus the W_O output projection. The difference is not parameter count — it is that each head learns independent attention patterns, letting the model attend to different relationships simultaneously.

In practice, multi-head attention is often implemented as a single large Q/K/V projection followed by a reshape into [n_tokens, n_heads, d_head]. The attention computation then runs per-head (often batched across heads for parallelism). The final concatenation is just a reshape back.

Multi-head attention does not increase total compute compared to a single large head — the same number of multiply-adds happens either way. But it does increase parallelism: all heads run independently, which maps well to GPU architectures. The smaller d_head reduces the size of per-head Q, K, and V tensors and the cost of each dot-product computation, but the score matrix itself is always [n_tokens, n_tokens] per head regardless of d_head.

Check Yourself
conceptualQ1

Why is multi-head attention not the same as simply repeating one attention head multiple times?

shapeQ2

If d_model = 512 and n_heads = 8, what is d_head? What is the shape of Q for one head with 10 tokens?