Multi-head self-attention is a core component of the Transformer architecture, introduced by Vaswani et al. in 2017 in the landmark paper "Attention Is All You Need."[1] Rather than computing a single attention function over the full model dimension, multi-head self-attention splits the queries, keys, and values into multiple parallel "heads," each of which independently computes self-attention in a lower-dimensional subspace. The outputs of all heads are then concatenated and linearly projected to produce the final result. This design allows the model to jointly attend to information from different representation subspaces, capturing diverse linguistic and structural relationships such as syntax, semantics, positional patterns, and coreference.
Multi-head self-attention is used throughout modern deep learning, powering large language models like GPT, BERT, and LLaMA, as well as vision transformers and multimodal systems. Variants such as Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) have been developed to improve inference efficiency while preserving quality.
Before the Transformer, sequence modeling relied primarily on recurrent neural networks (RNNs) and long short-term memory (LSTM) networks, which processed tokens sequentially. This sequential processing created a bottleneck for parallelization and made it difficult to capture long-range dependencies efficiently. The attention mechanism, initially proposed by Bahdanau et al. (2014) for machine translation, offered a way for models to focus on relevant parts of an input sequence regardless of distance.[2] However, a single attention function computes a weighted average that can only capture one set of relationships at a time, limiting the model's ability to represent the rich, multi-faceted structure of language.
Vaswani et al. addressed this limitation by introducing multi-head attention. Instead of performing a single attention computation over the full model dimension, the mechanism runs multiple attention operations in parallel across different learned subspaces. As stated in the original paper: "Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this."[1] This architectural choice became one of the defining innovations of the Transformer.
The foundation of multi-head self-attention is the scaled dot-product attention function. Given a set of queries (Q), keys (K), and values (V), the mechanism computes:
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
where d_k is the dimension of the key vectors. The dot product Q K^T measures the similarity between each query and each key, producing an attention score matrix. The scaling factor 1/sqrt(d_k) prevents the dot products from growing too large in magnitude, which would push the softmax function into regions with extremely small gradients.[1] After applying softmax, the resulting attention weights are used to compute a weighted sum of the value vectors.
In the multi-head variant, the input representations are linearly projected h times into different subspaces using separate learned weight matrices. Each projection produces a lower-dimensional set of queries, keys, and values that are processed by an independent attention head:
MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) W^O
where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
Here, W_i^Q, W_i^K, and W_i^V are the learned projection matrices for head i, each mapping from the full model dimension d_model to a head dimension d_k = d_model / h. The output projection matrix W^O maps the concatenated head outputs back to d_model.[1]
When the queries, keys, and values all come from the same sequence, the mechanism is called self-attention. This is the most common variant in encoder-only models (like BERT) and decoder-only models (like GPT). In encoder-decoder architectures, a variant called cross-attention is also used, where the queries come from the decoder and the keys and values come from the encoder output. Both forms use the same multi-head mechanism, but they differ in the source of their inputs.[1]
Language contains many simultaneous types of relationships. When reading a sentence, a model must simultaneously track syntactic structure (subject-verb agreement), semantic relationships (meaning and context), positional patterns (proximity of related words), and coreference (what pronouns refer to). A single attention head computes a single weighted average, forcing all these different relationship types to compete for representation in the same set of attention weights.
Multiple heads solve this problem by allowing different heads to specialize in different aspects of the input. Empirical research has revealed several key advantages:
The number of attention heads h and the per-head dimension d_k = d_model / h are important architectural hyperparameters. The total computation and parameter count of multi-head attention remain the same as single-head attention (since d_k * h = d_model), but the distribution across heads affects model quality.
Typical configurations across well-known models are shown in the following table:
| Model | d_model | Number of Heads (h) | Head Dimension (d_k) | Attention Type | Parameters |
|---|---|---|---|---|---|
| Transformer (original) | 512 | 8 | 64 | MHA | 65M |
| BERT-Base | 768 | 12 | 64 | MHA | 110M |
| BERT-Large | 1024 | 16 | 64 | MHA | 340M |
| GPT-2 Small | 768 | 12 | 64 | MHA | 117M |
| GPT-2 XL | 1600 | 25 | 64 | MHA | 1.5B |
| GPT-3 | 12288 | 96 | 128 | MHA | 175B |
| LLaMA 3.1 8B | 4096 | 32 | 128 | GQA (8 KV heads) | 8B |
| LLaMA 3.1 405B | 16384 | 128 | 128 | GQA | 405B |
| DeepSeek-V2 | 5120 | 128 | 128 | MLA | 236B |
The head dimension d_k has remained remarkably stable at 64 or 128 across most architectures, even as model sizes have scaled from millions to hundreds of billions of parameters. Scaling is achieved primarily by increasing the number of heads and the number of layers, rather than by increasing d_k.
The computational complexity of multi-head self-attention is O(n^2 * d_model), where n is the sequence length and d_model is the model dimension. Each head computes attention with complexity O(n^2 * d_k), and since there are h heads with d_k = d_model / h, the total complexity sums to O(n^2 * d_model). This is the same as single-head attention over the full dimension.[1]
The quadratic dependence on sequence length n is the primary computational bottleneck, particularly for long sequences. This has motivated research into efficient attention variants such as sparse attention, linear attention, and FlashAttention, which aim to reduce the O(n^2) factor while preserving model quality.
The parameter count for a single multi-head attention layer consists of the projection matrices for queries, keys, values, and the output projection: 4 * d_model^2 parameters total (since each of W^Q, W^K, W^V, and W^O has shape d_model x d_model).
As large language models have scaled to billions of parameters, the memory and compute costs of standard multi-head attention during inference have become a significant bottleneck. Several architectural variants have been proposed to address this.
Multi-Query Attention, introduced by Shazeer (2019), shares a single set of key and value projections across all query heads.[5] While each head still has its own independent query projection, there is only one key projection and one value projection for the entire layer. This dramatically reduces the size of the key-value (KV) cache during autoregressive inference.
| Aspect | Standard MHA | Multi-Query Attention |
|---|---|---|
| Query heads | h independent | h independent |
| Key heads | h independent | 1 shared |
| Value heads | h independent | 1 shared |
| KV cache size | O(n * h * d_k) | O(n * d_k) |
| KV cache reduction | Baseline | ~87.5% (with 8 heads) |
| Quality impact | Baseline | ~2% perplexity increase |
| Inference speedup | Baseline | 1.8-2.4x |
MQA has been adopted by models such as Google's PaLM and Falcon. The quality degradation is small, but for the largest and most capable models, even a small loss can be significant.
Grouped-Query Attention, proposed by Ainslie et al. (2023), provides a middle ground between standard MHA and MQA.[6] Instead of sharing one KV head across all query heads (MQA) or maintaining one KV head per query head (MHA), GQA divides query heads into g groups, with each group sharing a single KV head.
GQA was rapidly adopted after its publication. Meta used GQA in LLaMA 2 (July 2023) and retained it in LLaMA 3. Mistral AI used GQA in the Mistral 7B model (September 2023). GQA has become the default attention mechanism for most new large language models.
Multi-Head Latent Attention, introduced in the DeepSeek-V2 paper (2024), takes a fundamentally different approach to KV cache reduction.[7] Instead of sharing heads, MLA compresses the key and value representations into a low-dimensional latent vector using low-rank factorization. During inference, the model caches only the compressed latent vector rather than the full key and value tensors. At computation time, the latent vector is decompressed back to full-dimensional keys and values.
This approach achieves dramatic cache reductions. DeepSeek-V2 reported a 93.3% reduction in KV cache size compared to standard MHA, while also boosting maximum generation throughput by 5.76x compared to DeepSeek 67B.[7] Unlike GQA, which sacrifices some model capacity by sharing heads, MLA can actually match or exceed the modeling performance of standard MHA because each head still receives unique key and value representations (decompressed from the shared latent).
MLA also introduces a "decoupled RoPE" technique to maintain compatibility with rotary position embeddings, splitting each attention head into a content component and a positional component.
| Variant | KV Heads per Layer | KV Cache Size | Quality vs. MHA | Key Adoption |
|---|---|---|---|---|
| Multi-Head Attention (MHA) | h (one per query head) | O(n * h * d_k) | Baseline | Original Transformer, BERT, GPT |
| Multi-Query Attention (MQA) | 1 (shared) | O(n * d_k) | Slightly lower | PaLM, Falcon |
| Grouped-Query Attention (GQA) | g (1 < g < h) | O(n * g * d_k) | Near baseline | LLaMA 2/3, Mistral |
| Multi-Head Latent Attention (MLA) | h (decompressed from latent) | O(n * d_c) where d_c << h*d_k | Comparable or better | DeepSeek-V2, DeepSeek-V3 |
Not all attention heads contribute equally to model performance. Research has shown that many heads are redundant and can be removed with minimal impact on quality.
Michel et al. demonstrated that a large percentage of attention heads can be pruned at test time without significantly degrading performance.[8] In BERT, up to 40% of heads could be removed with negligible impact. In machine translation models, 20% of heads could be removed while maintaining reasonable translation quality. The authors developed a greedy pruning algorithm based on sensitivity scores derived from the gradient of the model's loss with respect to each head. Pruning yielded up to a 17.5% increase in inference speed for BERT-based models.
Voita et al. conducted a detailed analysis of head specialization in the Transformer, finding that only a small subset of heads play important, linguistically interpretable roles.[4] These "specialized heads" handle functions such as attending to adjacent words, tracking syntactic dependencies, and resolving rare or positionally significant tokens. Using a differentiable relaxation of the L0 penalty for pruning, they demonstrated that specialized heads are consistently the last to be pruned. On an English-Russian translation task, 38 out of 48 encoder heads could be pruned with a drop of only 0.15 BLEU.
These findings have practical implications. They suggest that models may be over-parameterized in their number of heads, and that structured pruning of attention heads is a viable approach to model compression and faster inference.
Visualization of attention patterns has become an important tool for understanding what Transformer models learn. Tools such as BertViz allow researchers to inspect the attention weights of individual heads across layers.
Clark et al. (2019) conducted a systematic analysis of BERT's attention heads and identified several recurring specialization patterns:[3]
| Head Role | Description | Example |
|---|---|---|
| Positional heads | Attend to tokens at fixed relative positions (e.g., the previous or next token) | Head attends to the immediately preceding word |
| Syntactic heads | Track grammatical relationships | Head links verbs to their direct objects ("baked" to "cake") |
| Determiner-noun heads | Connect determiners to the nouns they modify | Head links "the" to "dog" |
| Prepositional heads | Attend to objects of prepositions | Head links "in" to "Paris" |
| Coreference heads | Link pronouns to their antecedents | Head connects "she" to "lawyer" |
| Delimiter heads | Attend to special tokens like [SEP] or [CLS] | Head focuses on sentence boundaries |
| Broad attention heads | Distribute attention roughly uniformly across the sequence | Head attends broadly for contextual mixing |
These findings demonstrate that multi-head attention does not simply provide redundant copies of the same computation. Instead, different heads learn qualitatively different functions that correspond to recognized linguistic categories.
Research has also revealed a hierarchical structure in how heads specialize across layers. Lower layers (layers 2-4 in BERT) tend to capture basic grammatical relationships such as noun-verb and determiner-noun links. Middle layers focus on more complex syntactic structures. Upper layers capture higher-level semantic relationships and task-specific patterns.[3]
Multi-head self-attention is a foundational building block in a wide range of applications:
Imagine you are in a classroom with lots of friends, and you need to figure out what everyone is talking about. Instead of trying to listen to everything yourself, you send out several helper listeners. Each helper pays attention to something different: one helper listens for who is friends with whom, another helper listens for what game everyone wants to play, and a third helper listens for who is sitting next to whom.
After listening, all the helpers come back and tell you what they heard. You combine all their reports to get a really complete picture of what is going on in the classroom.
That is what multi-head self-attention does in a computer. Each "head" is like a helper that focuses on a different kind of relationship between words. One head might notice grammar rules, another might notice word meanings, and another might notice which words are close together. By combining all their findings, the computer understands sentences much better than if it only had one helper listening to everything at once.