Multi-Head Attention Runs Several Smaller Attentions in Parallel
A single attention head can only learn one pattern of "what to attend to." But language has many simultaneous relationships — syntactic (subject-verb), semantic (pronoun-referent), positional (adjacent words). One set of Q/K/V weights cannot capture all of these at once.
Multi-head attention splits the Q, K, and V projections into h independent heads, each operating in a smaller subspace. Head 1 might learn to attend to the previous word. Head 2 might track the subject of the sentence. Head 3 might focus on punctuation boundaries. They all run in parallel.
After each head computes its own attention output, the results are concatenated back into one vector and passed through a final output projection (W_O). This is not the same as running one big attention head — each head has its own learned Q/K/V weights and attends to different things.
In practice, the model does not run h separate attention operations. Instead, it performs one large projection and then reshapes the result:
- Project: Multiply hidden states [n_tokens, d_model] by WQ [d_model, d_model] to get a full Q tensor of shape [n_tokens, d_model].
- Reshape: View the [n_tokens, d_model] tensor as [n_tokens, n_heads, d_head]. This is a zero-cost operation — the data stays in the same memory, only the indexing changes.
- Per-head attention: Each head operates on its own [n_tokens, d_head] slice independently. The attention computation runs in parallel across all heads.
- Concat + project: Concatenate all head outputs back to [n_tokens, d_model] and multiply by WO [d_model, d_model].
The key realization: the "split into heads" is a reshape, not a copy or a separate computation. The columns of WQ are implicitly partitioned into groups of d_head, and each group corresponds to one head. This is why model code often shows one large matmul followed by a reshape, rather than h small matmuls.
d_model = 8, n_heads = 2, so d_head = 4:
Each head works in 4 dimensions instead of 8. The outputs are recombined to full width.
Each head runs independently, then outputs are concatenated and projected:
where d_head = d_model / n_heads (typically). The total Q/K/V projection parameters equal those of a single-head attention (since h × d_head = d_model), plus the W_O output projection. The difference is not parameter count — it is that each head learns independent attention patterns, letting the model attend to different relationships simultaneously.
In practice, multi-head attention is often implemented as a single large Q/K/V projection followed by a reshape into [n_tokens, n_heads, d_head]. The attention computation then runs per-head (often batched across heads for parallelism). The final concatenation is just a reshape back.
Multi-head attention does not increase total compute compared to a single large head — the same number of multiply-adds happens either way. But it does increase parallelism: all heads run independently, which maps well to GPU architectures. The smaller d_head reduces the size of per-head Q, K, and V tensors and the cost of each dot-product computation, but the score matrix itself is always [n_tokens, n_tokens] per head regardless of d_head.
Why is multi-head attention not the same as simply repeating one attention head multiple times?
If d_model = 512 and n_heads = 8, what is d_head? What is the shape of Q for one head with 10 tokens?