# Multi-Head Self-Attention

> Source: https://aiwiki.ai/wiki/multi-head_self-attention
> Updated: 2026-06-21
> Categories: Deep Learning, Machine Learning, Model Architecture, Neural Networks, Transformer Models
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

**Multi-head self-attention** is the core sequence-mixing mechanism of the [Transformer](/wiki/transformer) architecture: it runs several [scaled dot-product attention](/wiki/attention) operations ("heads") in parallel over different learned projections of the same sequence, then concatenates and linearly projects their outputs. It was introduced by [Ashish Vaswani](/wiki/ashish_vaswani), [Noam Shazeer](/wiki/noam_shazeer), and colleagues in the 2017 paper ["Attention Is All You Need"](/wiki/attention_is_all_you_need_transformer), which used 8 parallel heads with a model dimension of d_model = 512 and a per-head dimension of d_k = d_v = 64.[1] Rather than computing a single attention function over the full model dimension, multi-head self-attention splits the queries, keys, and values into multiple parallel heads, each of which independently computes [self-attention](/wiki/self_attention) in a lower-dimensional subspace. The outputs of all heads are then concatenated and linearly projected to produce the final result. As the original paper states, "Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this."[1] Because each head operates at the reduced dimension d_k = d_model / h, "the total computational cost is similar to that of single-head attention with full dimensionality," letting the model capture diverse relationships such as syntax, semantics, positional patterns, and coreference at no extra cost.[1]

Multi-head self-attention sits at the heart of modern [deep learning](/wiki/deep_learning), powering [large language models](/wiki/large_language_model) like [GPT-4](/wiki/gpt-4), [Claude](/wiki/claude), [Gemini](/wiki/gemini), and [LLaMA](/wiki/llama), as well as vision transformers and multimodal systems. Variants such as [Multi-Query Attention](/wiki/multi_query_attention) (MQA), [Grouped-Query Attention](/wiki/grouped_query_attention) (GQA), and [Multi-Head Latent Attention](/wiki/multi_latent_attention) (MLA) have reshaped how heads share parameters, while engineering work such as [FlashAttention](/wiki/flash_attention) and [PagedAttention](/wiki/paged_attention) has reshaped how the same math is mapped onto modern hardware.

This article focuses on the multi-head structure specifically: why running attention in parallel heads matters, how the heads specialize, how variants modify the head layout to shrink the [KV cache](/wiki/kv_cache), and how production stacks implement it. For the general mechanics of attention as a sequence operation, see the parent article on [self-attention](/wiki/self_attention).

## When was multi-head attention introduced?

Before the Transformer, sequence modeling relied primarily on [recurrent neural networks](/wiki/recurrent_neural_network) (RNNs) and [long short-term memory](/wiki/long_short-term_memory_lstm) (LSTM) networks. Single-head attention, proposed by Bahdanau et al. (2014) for [machine translation](/wiki/machine_translation), let models focus on relevant parts of an input regardless of distance.[2] But a single attention function computes a weighted average that can only capture one set of relationships at a time.

Vaswani et al. addressed this in June 2017 by introducing multi-head attention: multiple attention operations running in parallel across different learned subspaces. As the original paper put it, "Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this."[1] Ablations in the original paper showed that an 8-head Transformer with d_k = 64 cleanly beat a 1-head Transformer with d_k = 512 on the WMT 2014 English-German benchmark, even though the two configurations used the same parameters and FLOPs. Going to too many heads (32 with d_k = 16) also hurt quality, suggesting head dimension has a lower bound below which heads stop being useful.[1]

## How does multi-head self-attention work?

### Scaled dot-product attention

The foundation of multi-head self-attention is the scaled dot-product attention function. Given a set of queries (Q), keys (K), and values (V), the mechanism computes:

```
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
```

where d_k is the dimension of the key vectors. The dot product Q K^T measures the similarity between each query and each key, producing an attention score matrix. The scaling factor 1/sqrt(d_k) prevents the dot products from growing too large in magnitude, which would push the [softmax](/wiki/softmax) function into regions with extremely small gradients.[1] After applying softmax, the resulting attention weights are used to compute a weighted sum of the value vectors.

The shape arithmetic is worth keeping in mind. For an input of length n with model dimension d_model, Q, K, and V each have shape (n, d_model). The attention score matrix has shape (n, n), and the output has shape (n, d_model). When extended to a batch of B sequences and h heads with d_k = d_model / h, the standard tensor layout is (B, h, n, d_k), and the score matrix is (B, h, n, n). Most of the optimization work in modern attention kernels is about computing these tensors efficiently without materializing the full (B, h, n, n) score matrix in HBM.

### Multi-head mechanism

In the multi-head variant, the input representations are linearly projected h times into different subspaces using separate learned weight matrices. Each projection produces a lower-dimensional set of queries, keys, and values that are processed by an independent attention head:

```
MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) W^O

where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
```

Here, W_i^Q, W_i^K, and W_i^V are the learned projection matrices for head i, each mapping from the full model dimension d_model to a head dimension d_k = d_model / h. The output projection matrix W^O maps the concatenated head outputs back to d_model.[1]

In practice, the per-head projections are fused into one big matrix per role. A modern implementation has three weight matrices W^Q, W^K, W^V each of shape (d_model, d_model), produces Q, K, V each of shape (B, n, d_model), then reshapes to (B, n, h, d_k) and transposes to (B, h, n, d_k) before the attention kernel. Some libraries fuse W^Q, W^K, W^V into a single (d_model, 3 * d_model) matrix for a single matmul, which improves [GPU](/wiki/gpu) utilization at small batch sizes.

### Step-by-step process

1. **Linear projection**: The input embedding matrix X is multiplied by three separate weight matrices (W^Q, W^K, W^V) to produce the query, key, and value matrices.
2. **Splitting into heads**: Q, K, and V are split into h sets of lower-dimensional representations. If d_model = 512 and h = 8, each head operates on vectors of dimension d_k = 64.
3. **Independent attention**: Each head independently computes scaled dot-product attention. Because the heads operate in different learned subspaces, they can specialize in different types of relationships.
4. **Concatenation**: The outputs from all h heads are concatenated along the feature dimension, reconstructing a vector of dimension d_model.
5. **Output projection**: A final linear transformation via W^O produces the multi-head attention output, which is then passed to the next layer in the [neural network](/wiki/neural_network).

### Self-attention, cross-attention, and causal masking

When the queries, keys, and values all come from the same sequence, the mechanism is multi-head self-attention. This is the most common variant in encoder-only models (like [BERT](/wiki/bert)) and decoder-only models (like [GPT](/wiki/gpt_generative_pre-trained_transformer)). In encoder-decoder architectures, a sibling form called [cross-attention](/wiki/cross_attention) is also used, where the queries come from the decoder and the keys and values come from the encoder output. Both forms use the same multi-head plumbing but differ in the source of their inputs.[1]

In autoregressive models, multi-head self-attention is wrapped in [causal masking](/wiki/causal_attention): an upper-triangular mask added to the score matrix sets future positions to negative infinity so that, after softmax, position i can only attend to positions 1 through i. The mask is applied per head identically. Production attention kernels like FlashAttention skip the upper triangle entirely in the causal case, roughly doubling speed in long-context training compared to a naive masked softmax.

## Why does multi-head attention use multiple heads?

Language contains many simultaneous types of relationships. A model must track syntactic structure (subject-verb agreement), semantic context, positional patterns, and coreference at the same time. A single attention head computes a single weighted average, forcing all these different relationship types to compete for representation in the same set of attention weights.

Multiple heads let different heads specialize. Different heads track syntactic relations (for example, linking verbs to their direct objects), positional patterns (attending to the immediately preceding token), or coreference (connecting pronouns to their antecedents).[3][4] Each head operates in a different learned subspace, which discourages heads from collapsing onto identical patterns. The output projection W^O then learns how to pool the per-head opinions into one combined update to the residual stream.

A useful way to think about it is as an ensemble inside a single layer. Each head is a small attention module with its own opinion about which positions matter. Because 8 heads of dimension 64 outperformed 1 head of dimension 512 in the original ablations,[1] the parallel specialization clearly does meaningful work, even though the total compute is identical.

## How many attention heads does a Transformer use?

The number of attention heads h and the per-head dimension d_k = d_model / h are important architectural hyperparameters. The total computation and parameter count of multi-head attention remain the same as single-head attention (since d_k * h = d_model), but the distribution across heads affects model quality.

Typical configurations across well-known models are shown in the following table. Per-head dimension d_k for Q has stayed remarkably stable at 64 or 128 for most architectures, regardless of model scale.

| Model | d_model | Heads (h) | KV heads | Head dim (d_k) | Attention type | Params |
|---|---|---|---|---|---|---|
| Transformer (base) [1] | 512 | 8 | 8 | 64 | MHA | ~65M |
| Transformer (big) [1] | 1024 | 16 | 16 | 64 | MHA | ~213M |
| [BERT](/wiki/bert)-Base | 768 | 12 | 12 | 64 | MHA | 110M |
| [BERT](/wiki/bert)-Large | 1024 | 16 | 16 | 64 | MHA | 340M |
| [GPT-2](/wiki/gpt-2) Small | 768 | 12 | 12 | 64 | MHA | 117M |
| [GPT-2](/wiki/gpt-2) XL | 1600 | 25 | 25 | 64 | MHA | 1.5B |
| [GPT-3](/wiki/gpt-3) | 12288 | 96 | 96 | 128 | MHA | 175B |
| PaLM 540B | 18432 | 48 | 1 | 384 | MQA | 540B |
| Falcon 40B | 8192 | 64 | 8 | 128 | GQA (multi-group MQA) | 40B |
| [Mistral 7B](/wiki/mistral_7b) | 4096 | 32 | 8 | 128 | GQA + sliding window | 7B |
| [LLaMA](/wiki/llama) 2 7B | 4096 | 32 | 32 | 128 | MHA | 7B |
| [LLaMA](/wiki/llama) 2 70B | 8192 | 64 | 8 | 128 | GQA | 70B |
| [LLaMA](/wiki/llama) 3.1 8B | 4096 | 32 | 8 | 128 | GQA | 8B |
| [LLaMA](/wiki/llama) 3.1 70B | 8192 | 64 | 8 | 128 | GQA | 70B |
| [LLaMA](/wiki/llama) 3.1 405B | 16384 | 128 | 16 | 128 | GQA | 405B |
| Qwen2.5 72B | 8192 | 64 | 8 | 128 | GQA | 72B |
| DeepSeek-V2 [7] | 5120 | 128 | shared latent | 128 + 64 | MLA | 236B (21B active) |
| [DeepSeek-V3](/wiki/deepseek_v3) [12] | 7168 | 128 | shared latent | 128 + 64 | MLA | 671B (37B active) |

The DeepSeek rows reflect MLA's split into a content path (128 dims) and a decoupled [RoPE](/wiki/rope) path (64 dims). Instead of conventional per-head K and V tensors, MLA stores a single shared 512-dim latent per token, described in detail below.

Head dimension d_k has remained at 64 or 128 across most architectures, even as model sizes have scaled from millions to hundreds of billions of parameters. Scaling generally happens by increasing the number of heads and the number of layers, not by increasing d_k. This empirical regularity is one of the more underdiscussed conventions in modern LLM design.

## What is the computational complexity of multi-head self-attention?

The computational complexity of multi-head self-attention is O(n^2 * d_model), where n is the sequence length. Each head computes attention with complexity O(n^2 * d_k), and since h heads with d_k = d_model / h sum to d_model, the total matches single-head attention over the full dimension.[1] The quadratic dependence on sequence length is the primary bottleneck for long sequences, which motivates research into [sparse attention](/wiki/sparse_attention), [linear attention](/wiki/linear_attention), and FlashAttention.

The parameter count for a single multi-head attention layer is roughly 4 * d_model^2: the projection matrices for queries, keys, values, and output. Across an L-layer model, attention contributes 4 * L * d_model^2 parameters, with the position-wise feed-forward network contributing about twice that. As models have scaled, the cost of KV cache memory at inference has come to dominate, which is what motivates most of the variants discussed below.

## How does MHA differ from MQA, GQA, and MLA?

As [large language models](/wiki/large_language_model) have scaled, the cost of standard multi-head attention during inference has become a bottleneck. The cost is dominated by the KV cache: every decoding step loads every cached key and value tensor from HBM, and cache size grows linearly with sequence length and number of heads. Several architectural variants address this by changing how heads share parameters.

### Multi-Query Attention (MQA)

Multi-Query Attention, introduced by [Noam Shazeer](/wiki/noam_shazeer) (2019), shares a single set of K and V projections across all query heads.[5] Each head keeps its own query projection, but there is only one K and one V projection for the entire layer. This reduces KV cache size by a factor equal to the number of heads.

| Aspect | Standard MHA | Multi-Query Attention |
|---|---|---|
| Query heads | h independent | h independent |
| Key/Value heads | h independent | 1 shared each |
| KV cache size | O(n * h * d_k) | O(n * d_k) |
| KV cache reduction (8 heads) | Baseline | ~87.5% |
| Quality impact | Baseline | small perplexity increase, occasional training instability |
| Decoding speedup | Baseline | 1.8x to 2.4x typical |

MQA was adopted by Google's PaLM and several open models such as Falcon and StarCoder. The quality degradation is small on most benchmarks, but for the largest and most capable models even a small loss matters, which is why most teams now prefer GQA.

### Grouped-Query Attention (GQA)

Grouped-Query Attention, proposed by Joshua Ainslie and colleagues (2023), is a middle ground between MHA and MQA.[6] Query heads are divided into g groups, and each group shares a single KV head. When g = h, GQA is equivalent to MHA; when g = 1, it is equivalent to MQA. Intermediate values (for example g = 8 with h = 32) give most of the speed benefits of MQA while keeping quality close to MHA.

The Ainslie paper also showed that an existing MHA checkpoint could be "uptrained" into GQA using only about 5% of the original pretraining compute, which made the variant practical to retrofit. Meta used GQA in [LLaMA](/wiki/llama) 2 70B (July 2023) and retained it in LLaMA 3 across all sizes. [Mistral](/wiki/mistral) AI used GQA in [Mistral 7B](/wiki/mistral_7b) (September 2023). GQA has become the default attention mechanism for new LLMs since 2023.

### Multi-Head Latent Attention (MLA)

Multi-Head Latent Attention, introduced in the DeepSeek-V2 paper (2024) and reused in [DeepSeek-V3](/wiki/deepseek_v3) (2024-2025), takes a different angle on KV cache reduction.[7][12] Instead of sharing heads, MLA jointly compresses the key and value tensors of every head into a single low-dimensional latent vector through a learned down-projection. The model caches only this latent vector; at compute time, an up-projection reconstructs full per-head keys and values on the fly.

DeepSeek-V2 reported a 93.3% reduction in KV cache size compared to standard MHA on the same backbone, and a 5.76x increase in maximum generation throughput against DeepSeek 67B.[7] Because every head still receives its own reconstructed K and V, MLA can match or exceed the quality of standard MHA, unlike GQA which gives up some capacity by sharing per-group K and V.

MLA is non-trivial to combine with [RoPE](/wiki/rope), because rotary position embeddings act on full Q and K tensors and cannot be applied cleanly to a compressed latent. The DeepSeek authors solve this with a "decoupled RoPE" trick: each query and key is split into a content component (carried through the latent compression) and a small positional component (typically 64 dimensions per head) that bypasses the latent and carries RoPE directly. In DeepSeek-V3, the architecture uses 128 attention heads with a 128-dim content path plus a 64-dim RoPE path, and a 512-dim shared latent for K/V across all heads. The effective per-token cache is about 576 dims regardless of head count.[12]

### Sliding window attention

Sliding window attention restricts each query to attend only to the W most recent keys for causal models. It bounds per-step KV memory at O(W * h * d_k) regardless of total sequence length, and it stacks with any of MHA, MQA, GQA, or MLA. The Longformer paper (Beltagy et al. 2020) popularized the idea for encoders by combining a window with a small set of global tokens.

[Mistral 7B](/wiki/mistral_7b) (September 2023) brought sliding window attention to a frontier-quality decoder-only model with a window of 4,096 tokens under a context of 8,192. Because each Transformer layer propagates a window's worth of context, an L-layer stack can route information across roughly L * W tokens even though each individual call sees only W. Mistral combined this with GQA and a custom FlashAttention kernel for the windowed mask.

[LLaMA](/wiki/llama) 4 (2025) uses a related design called iRoPE that interleaves three RoPE-equipped local-attention layers with one global-attention layer that drops position embeddings (NoPE) and uses a full causal mask. The pattern lets the model retain global reach while keeping most layers cheap, and it is part of how the LLaMA 4 Scout model claims a 10M-token context length.

### Comparison of head-sharing variants

| Variant | Year | KV heads per layer | KV cache size | Quality vs MHA | Notable adopters |
|---|---|---|---|---|---|
| MHA | 2017 | h (one per query head) | O(n * h * d_k) | Baseline | Original [Transformer](/wiki/transformer), [BERT](/wiki/bert), [GPT-2](/wiki/gpt-2), [GPT-3](/wiki/gpt-3) |
| [MQA](/wiki/multi_query_attention) | 2019 | 1 (shared) | O(n * d_k) | Slightly lower; sometimes unstable | PaLM, Falcon, StarCoder |
| [GQA](/wiki/grouped_query_attention) | 2023 | g (1 < g < h) | O(n * g * d_k) | Near baseline | [LLaMA](/wiki/llama) 2/3/4, [Mistral](/wiki/mistral), Qwen, Mixtral |
| Sliding window | 2020 | h or g, but only over W tokens | O(W * h * d_k) | Slight degradation past window | [Longformer](/wiki/longformer), [Mistral 7B](/wiki/mistral_7b), iRoPE in [LLaMA](/wiki/llama) 4 |
| [MLA](/wiki/multi_latent_attention) | 2024 | shared latent, decompressed at use | O(n * d_c), d_c << h * d_k | Comparable or better | DeepSeek-V2, [DeepSeek-V3](/wiki/deepseek_v3), DeepSeek-R1 |
| Cross-Layer KV sharing | 2024 | shared across multiple layers | proportional to layers shared | Small degradation | Character.AI's MQA-CLA, MiniCPM, recent research models |

Cross-Layer Attention (CLA), proposed in 2024, takes a different axis of compression by sharing the KV tensors across adjacent layers rather than across heads within a layer. It can be stacked on top of MQA or GQA for additional cache reduction at modest quality cost.

## FlashAttention and IO-aware kernels

None of the variants above change the asymptotic O(n^2) cost of dense attention. They reduce KV cache size or restrict the attention pattern. A complementary line of work has attacked the constants instead, by computing exact attention more efficiently on GPUs.

[FlashAttention](/wiki/flash_attention), introduced by [Tri Dao](/wiki/tri_dao) and colleagues at Stanford in 2022, is the most successful example.[8] On a modern [GPU](/wiki/gpu) like an A100 or H100, attention is bounded by memory bandwidth, not compute. The naive implementation materializes the full (n, n) score matrix in HBM, computes softmax, multiplies by V, and writes back, streaming O(n^2) data through HBM even though the matmul has plenty of arithmetic intensity. FlashAttention reorganizes the computation into tiled blocks that fit in on-chip SRAM. Each tile reads its slice of Q, K, V from HBM once, computes a local output and a running softmax statistic in SRAM, and writes only the final output back. Memory traffic drops from O(n^2) to O(n), the score matrix is never materialized, and the algorithm stays mathematically exact.

### FlashAttention timeline

| Version | Year | Lead author | Target hardware | Key contributions | Reported throughput |
|---|---|---|---|---|---|
| FlashAttention v1 [8] | May 2022 | Tri Dao | Ampere (A100) | IO-aware tiling, online softmax, kernel fusion, no n^2 materialization | 15% wall-clock speedup on BERT-Large; 3x on GPT-2 |
| FlashAttention v2 [9] | July 2023 | Tri Dao | Ampere (A100) | Parallelize over sequence as well as heads/batch, fewer non-matmul FLOPs | ~2x v1, reaches 50-73% of peak FLOPs on A100 |
| FlashAttention v3 [10] | July 2024 | Jay Shah, Tri Dao | Hopper (H100) | Warp-specialized scheduling, async TMA loads, GEMM/softmax overlap, FP8 | 1.5-2x v2 in FP16, ~740 TFLOPs/s; ~1.2 PFLOPs/s with FP8 |

FlashAttention-3 (arXiv 2407.08608, NeurIPS 2024) targets the [H100](/wiki/h100) specifically.[10] On Hopper, the Tensor Memory Accelerator (TMA) and warp specialization let the kernel overlap data movement with computation in ways that were not possible on Ampere. FP8 FlashAttention-3 also uses incoherent processing (block-wise random rotations) to reduce quantization error, achieving 2.6x lower error than a baseline FP8 attention implementation.

FlashAttention has become the default attention kernel in most production [LLM](/wiki/large_language_model) training and inference stacks. It powers PyTorch's `torch.nn.functional.scaled_dot_product_attention`, [vLLM](/wiki/vllm), TensorRT-LLM, and the [Hugging Face](/wiki/hugging_face) Transformers library.

### Distributed attention: Ring and Star

Ring Attention (Liu et al. 2023) distributes exact attention across many devices.[18] Each device holds a slice of the sequence and the K and V tensors rotate around a logical ring, so each device sees every other device's K and V exactly once. The math stays exact, per-device memory drops to O(n / devices), and total throughput grows roughly linearly with device count. Ring Attention is now standard for training million-token-context models and is typically combined with FlashAttention as the kernel inside each ring step.

[NVIDIA](/wiki/nvidia)'s Star Attention (November 2024) targets long-context inference rather than training.[13] It splits a long input into blocks and prepends a shared "anchor block" to each shard, so blocks can be encoded in parallel across many hosts in a first phase before a second phase performs global query-side attention on a single host. The anchor blocks soak up the [attention sinks](/wiki/attention_sink) that would otherwise emerge at the start of each independent block, preserving accuracy. Star Attention reports 97-100% of the accuracy of full global attention on RULER and BABILong with substantial inference speedups on Llama 3.1 8B and 70B at sequence lengths up to 1M tokens.

## KV cache and inference

During autoregressive generation, every new token requires attending over the entire prefix. Recomputing keys and values at each step would be O(n^2). The standard remedy is the [KV cache](/wiki/kv_cache): keys and values are computed once per token and stored in GPU memory, so each step only computes the new token's Q and reads cached K and V.

With MHA, the cache for one layer at sequence length n is roughly 2 * h * d_k * n bytes per dtype element (so 4 KB per token per layer in FP16 with h * d_k = 1024). Across L layers and B concurrent requests, a 70B GQA model serving a 32K-token request easily uses several gigabytes of HBM just for K and V. The cache also turns generation into a memory-bandwidth problem rather than a compute problem: each decoding step reads the entire cache from HBM, runs a small QK^T and softmax, and writes one new K and V. This is the core reason MQA, GQA, MLA, sliding window, and CLA exist: they all shrink the cache and therefore raise effective tokens per second.

### PagedAttention and vLLM

[PagedAttention](/wiki/paged_attention), introduced by Woosuk Kwon and colleagues at UC Berkeley at SOSP 2023, treats the KV cache like virtual memory.[11] Instead of allocating a contiguous block of HBM per request (which fragments and over-reserves when generation is shorter than the maximum), PagedAttention stores K and V in fixed-size blocks scattered across HBM and maintains a per-request page table mapping logical positions to physical blocks. The attention kernel walks the page table when computing scores.

PagedAttention powers [vLLM](/wiki/vllm), the most widely used open-source LLM serving system. The Kwon paper reports that prior systems wasted 60-80% of KV cache memory to fragmentation; vLLM's paging brings waste close to zero. Combined with continuous batching (dynamically adding and removing requests from a running batch), vLLM reports 2-4x throughput improvements over earlier systems like FasterTransformer and Orca at the same latency.[11] PagedAttention is now also implemented in TensorRT-LLM, SGLang, and other production stacks.

### KV cache compression and quantization

Several lower-level optimizations reduce KV cache size without changing the architecture:

- **Quantization**: Storing K and V in INT8 or FP8 instead of FP16 halves or quarters cache size. Production stacks routinely run with FP8 KV cache on Hopper hardware. Research methods (KIVI, AWQ-derived schemes) push to 4-bit or even 2-bit KV with small quality loss.
- **Eviction and compression**: Methods like H2O (Heavy Hitter Oracle), Scissorhands, and StreamingLLM with attention sinks keep only a subset of cached tokens, exploiting the observation that most tokens contribute little to future attention beyond a few "heavy hitter" positions plus an initial sink.
- **Cross-layer sharing (CLA)**: Sharing KV across consecutive layers.

### Speculative decoding

[Speculative decoding](/wiki/speculative_decoding) is not a multi-head variant but is closely tied to attention because it reuses the KV cache across two models. A small draft model generates several candidate tokens per step; the larger target model verifies them in parallel using the same multi-head attention machinery. Each accepted draft token saves a full forward pass of the target model. Frontier deployments routinely combine speculative decoding with FlashAttention and PagedAttention for end-to-end speedups in the 1.5-3x range.

## Recent variants and active research

Multi-head attention is not a finished design. Several recent variants change the head structure itself rather than just the sharing pattern.

### Differential attention

[Differential Transformer](/wiki/differential_transformer), introduced by Microsoft Research and Tsinghua in October 2024 (arXiv 2410.05258, ICLR 2025), modifies each head to compute the difference of two softmax attention maps rather than a single softmax.[14] Each query and key is split in half to feed two parallel softmax operations, and the second is subtracted from the first with a learnable scale. The construction acts like noise cancellation: shared low-amplitude patterns cancel out, and only the meaningful signal survives.

The paper reports that Diff Transformer outperforms standard Transformer baselines on language modeling at matched FLOPs, with notable gains on long-context retrieval, in-context learning, and hallucination mitigation. The authors keep the parameter and FLOP budget close to vanilla Transformer by halving the head dimension feeding each softmax. Differential attention is a candidate replacement for the inner softmax in any of MHA, GQA, or MLA.

### Attention sinks and length extrapolation

Guangxuan Xiao and colleagues (StreamingLLM, ICLR 2024) found that the first few tokens of a sequence accumulate disproportionately large attention weights, even when they carry no semantic content.[15] They named this the [attention sink](/wiki/attention_sink) and showed that preserving the first few tokens in the KV cache, alongside a sliding window of recent tokens, lets a frozen LLM generate over essentially unbounded contexts (over 4M tokens with Llama-2 in their experiments). Several recent open-weight models train with explicit sink tokens to make long-context streaming inference more stable.

### Position encoding interactions

Multi-head attention is permutation-equivariant, so position information must be added separately. The dominant choice in modern LLMs is [Rotary Position Embedding](/wiki/rope) (RoPE), which rotates each query and key vector by an angle that depends on its absolute position. Every head applies the same rotation to its own d_k-dimensional Q and K. MLA had to invent decoupled RoPE (a separate small head dimension reserved for position) precisely because applying RoPE inside a low-rank latent compression is mathematically inconvenient. iRoPE in LLaMA 4 alternates RoPE-equipped local-attention layers with NoPE global-attention layers as a way to manage long contexts without degrading short-range behavior. [ALiBi](/wiki/alibi), used in MPT and BLOOM, adds a fixed linear bias to attention scores per head instead of rotating Q and K, with a different slope per head. See the [self-attention](/wiki/self_attention) parent article for a fuller treatment of position encoding choices.

## Head pruning and importance

Not all attention heads contribute equally. Research has shown that many heads are redundant and can be removed with minimal impact on quality.

Michel et al. ("Are sixteen heads really better than one?", 2019) demonstrated that a large fraction of attention heads can be pruned at test time without significantly degrading performance.[16] In BERT, up to 40% of heads could be removed with negligible impact; in machine translation, 20% of heads could be removed while maintaining reasonable quality. The authors used a greedy pruning algorithm based on gradient-derived sensitivity scores. Pruning yielded up to 17.5% inference speedup on BERT-based models.

Voita et al. (2019) found that only a small subset of heads play important, linguistically interpretable roles.[4] These "specialized heads" attend to adjacent words, track syntactic dependencies, or resolve rare or positionally significant tokens. Using a differentiable L0 relaxation, they showed that specialized heads were the last to be pruned. On an English-Russian translation task, 38 out of 48 encoder heads could be pruned with a drop of only 0.15 BLEU.

These findings suggest production models may be over-parameterized in their number of heads, which helped motivate the head-sharing variants above: if many heads are redundant, sharing parameters across them is a sensible architectural prior.

## Attention head visualization and interpretability

Tools such as BertViz let researchers inspect the attention weights of individual heads across layers, which has become an important way of understanding what Transformer models learn.

### What do individual attention heads learn?

Clark et al. (2019) analyzed [BERT](/wiki/bert)'s attention heads and identified several recurring specialization patterns:[3]

| Head role | Description | Example |
|---|---|---|
| Positional heads | Attend to tokens at fixed relative positions (for example, the previous or next token) | Head attends to the immediately preceding word |
| Syntactic heads | Track grammatical relationships | Head links verbs to their direct objects ("baked" to "cake") |
| Determiner-noun heads | Connect determiners to the nouns they modify | Head links "the" to "dog" |
| Prepositional heads | Attend to objects of prepositions | Head links "in" to "Paris" |
| Coreference heads | Link pronouns to their antecedents | Head connects "she" to "lawyer" |
| Delimiter heads | Attend to special tokens like [SEP] or [CLS] | Head focuses on sentence boundaries |
| Broad attention heads | Distribute attention roughly uniformly across the sequence | Head attends broadly for contextual mixing |

These findings demonstrate that multi-head attention does not simply provide redundant copies of the same computation. Different heads learn qualitatively different functions that correspond to recognized linguistic categories.

### Hierarchical organization and induction heads

Heads also specialize hierarchically across layers. Lower layers in BERT (layers 2-4) capture basic grammatical relationships such as noun-verb and determiner-noun links; middle layers handle complex syntactic structures; upper layers capture higher-level semantic relationships and task-specific patterns.[3]

Research from [Anthropic](/wiki/anthropic) on the Transformer Circuits Thread identified the "induction head," a two-layer circuit in which an early head writes positional information into the residual stream and a later head uses that information to attend back to a previous occurrence of the current token and copy what came after it. Induction heads emerge in a sharp phase transition during training and closely track the appearance of in-context learning ability in small attention-only models.[17] They give a concrete mechanistic example of how the multi-head structure supports computation that goes beyond simple weighted averaging.

## Hardware considerations

Multi-head self-attention is increasingly understood through a hardware lens, because the dominant cost on modern accelerators is not arithmetic but data movement.

### Memory hierarchy

A modern [GPU](/wiki/gpu) like an [H100](/wiki/h100) has roughly 80 GB of HBM3 memory at about 3 TB/s and 256 KB of register file plus shared memory (SRAM) per streaming multiprocessor at bandwidth far exceeding HBM. Naive attention round-trips Q, K, V, and the full (n, n) score matrix through HBM; FlashAttention keeps the inner loop entirely in SRAM. The same logic applies on NVIDIA's Blackwell GPUs and on AMD's MI300.

### Tensor cores and precision

All modern accelerators provide specialized matmul units (Tensor Cores on NVIDIA) that run mixed-precision matrix multiplication an order of magnitude faster than general-purpose ALUs. Multi-head attention's main cost is two batched matmuls per head (QK^T and softmax(QK^T)V), so kernels work hard to keep these on Tensor Cores in FP16, BF16, or FP8. The non-matmul parts (softmax, dropout, masking) run on regular ALUs in higher precision and can become the bottleneck once the matmul is fully utilized; one of the contributions of FlashAttention-2 was to reduce these non-matmul FLOPs.

### Decoding versus prefill

A single LLM inference call has two phases. Prefill runs attention over the whole prompt at once with high arithmetic intensity and looks like a training forward pass. Decoding generates tokens one at a time and is bounded by HBM bandwidth: each step reads the entire KV cache to compute attention over a single new query vector. Decoding throughput per GPU is approximately (HBM bandwidth) / (KV cache size per token times number of layers), which is why all the head-sharing variants pay off so directly: cutting cache size by 4x roughly raises decoding throughput by 4x.

### Multi-GPU and distributed attention

For models too large for a single GPU, attention is typically tensor-parallelized along the head dimension: each GPU holds a subset of heads and computes attention for those heads, then heads are concatenated through an all-gather. This works cleanly with MHA, GQA, and MQA. MLA is more delicate because the latent is shared across heads, so DeepSeek uses a custom sharding strategy. For very long contexts, sequence parallelism (Ring Attention) supplements tensor parallelism by sharding along the sequence dimension.

## Implementations

Multi-head attention is the most heavily optimized component in the deep learning software stack. The default options as of 2026 include:

| Stack | Function or class | Notes |
|---|---|---|
| [PyTorch](/wiki/pytorch) | `torch.nn.functional.scaled_dot_product_attention` | Dispatches to FlashAttention v2/v3, memory-efficient attention, math, or cuDNN backends |
| flash-attn (Dao-AILab) | `flash_attn_func` and friends | Reference implementation of FlashAttention 1/2/3 in CUDA and Triton |
| [JAX](/wiki/jax) / [Flax](/wiki/flax) | `flax.linen.MultiHeadDotProductAttention` | Used in Google's PaLM/Gemini training stacks via Pallas kernels |
| [Hugging Face](/wiki/hugging_face) Transformers | per-model attention classes plus `AttentionInterface` | Lets users swap between SDPA, FlashAttention, eager, and custom backends per model |
| [vLLM](/wiki/vllm) | `Attention` op with PagedAttention kernels | Production inference; handles MHA, GQA, MQA, MLA, sliding window, FP8 KV |
| TensorRT-LLM | fused MHA/GQA plugins | NVIDIA's optimized inference stack with PagedAttention and continuous batching |
| SGLang / FlashInfer | RadixAttention, modular attention kernels | Alternative inference stack with prefix caching; FlashInfer kernels also used by vLLM |

None of these stacks expose MHA, MQA, GQA, or MLA differently at the user level. The variant is determined by how the model defines its head shapes, and the kernel adapts.

## Applications

Multi-head self-attention is a foundational building block in a wide range of applications:

- **Natural language processing**: [Machine translation](/wiki/machine_translation), [text summarization](/wiki/text_summarization), [question answering](/wiki/question_answering), [sentiment analysis](/wiki/sentiment_analysis), and language modeling.
- **Computer vision**: [Vision Transformers](/wiki/vision_transformer) (ViT) apply multi-head self-attention to sequences of image patches for classification, detection, and segmentation.
- **Speech processing**: Models like [Whisper](/wiki/whisper) use it for [speech recognition](/wiki/speech_recognition).
- **Multimodal models**: Systems like [CLIP](/wiki/clip), [GPT-4](/wiki/gpt-4), [Claude](/wiki/claude), and [Gemini](/wiki/gemini) use multi-head attention to align text, images, audio, and video.
- **Protein structure prediction**: AlphaFold 2 and AlphaFold 3 use it in their Evoformer and Pairformer modules.
- **Reinforcement learning and robotics**: Decision Transformer and Robotics Transformer (RT-2) use it over trajectories of states, actions, and observations.
- **Code models**: Codex, AlphaCode, and modern coding assistants like [Claude Opus 4.7](/wiki/claude_opus_4_7) and [GPT-5](/wiki/gpt-5) rely on multi-head attention as their core sequence-mixing primitive.

## What is not multi-head self-attention

- **Multi-head attention** without the "self": the same plumbing applied to cross-attention, where Q comes from one source and K/V from another. The math is identical, but it relates two different sequences rather than a single sequence to itself.
- **Multi-headed mixture of experts**: each token is routed to one or a few of many feed-forward experts. Independent of the attention layer.
- **Convolutional attention**: methods like Conformer mix multi-head attention with [convolutional](/wiki/convolutional_neural_network) layers in a single block but still use standard multi-head attention internally.
- **Linear and kernel attention**: variants like [Linformer](/wiki/linformer), [Performer](/wiki/performer), and Linear Transformer drop the softmax to obtain O(n) complexity. They keep the multi-head structure but change the per-head computation.

For the broader sequence operation that multi-head self-attention extends, see the parent article on [self-attention](/wiki/self_attention). For history and the original paper, see [Attention Is All You Need](/wiki/attention_is_all_you_need_transformer).

## Explain like I am 5 (ELI5)

Imagine you are in a classroom with lots of friends, and you need to figure out what everyone is talking about. Instead of trying to listen to everything yourself, you send out several helper listeners. Each helper pays attention to something different: one helper listens for who is friends with whom, another listens for what game everyone wants to play, and a third listens for who is sitting next to whom.

After listening, all the helpers come back and tell you what they heard. You combine all their reports to get a complete picture of what is going on in the classroom.

That is what multi-head self-attention does in a computer. Each "head" is like a helper that focuses on a different kind of relationship between words. One head might notice grammar rules, another might notice word meanings, and another might notice which words are close together. By combining all their findings, the computer understands sentences much better than if it only had one helper listening to everything at once.

## References

1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). "Attention Is All You Need." *Advances in Neural Information Processing Systems 30 (NeurIPS 2017)*. arXiv:1706.03762. https://arxiv.org/abs/1706.03762
2. Bahdanau, D., Cho, K., & Bengio, Y. (2014). "Neural Machine Translation by Jointly Learning to Align and Translate." *arXiv preprint arXiv:1409.0473*. https://arxiv.org/abs/1409.0473
3. Clark, K., Khandelwal, U., Levy, O., & Manning, C.D. (2019). "What Does BERT Look At? An Analysis of BERT's Attention." *Proceedings of the 2019 ACL Workshop BlackboxNLP*. arXiv:1906.04341.
4. Voita, E., Talbot, D., Moiseev, F., Sennrich, R., & Titov, I. (2019). "Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned." *Proceedings of the 57th Annual Meeting of the ACL*. arXiv:1905.09418.
5. Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need." *arXiv preprint arXiv:1911.02150*. https://arxiv.org/abs/1911.02150
6. Ainslie, J., Lee-Thorp, J., de Jong, M., Zemlyanskiy, Y., Lebron, F., & Sanghai, S. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints." *EMNLP 2023*. arXiv:2305.13245.
7. DeepSeek-AI. (2024). "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model." *arXiv preprint arXiv:2405.04434*. https://arxiv.org/abs/2405.04434
8. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Re, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." *NeurIPS 2022*. arXiv:2205.14135.
9. Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning." arXiv:2307.08691.
10. Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision." *NeurIPS 2024*. arXiv:2407.08608. https://arxiv.org/abs/2407.08608
11. Kwon, W., Li, Z., Zhuang, S., Sheng, Y., Zheng, L., Yu, C.H., Gonzalez, J.E., Zhang, H., & Stoica, I. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention." *SOSP 2023*. arXiv:2309.06180.
12. DeepSeek-AI. (2024). "DeepSeek-V3 Technical Report." *arXiv preprint arXiv:2412.19437*. https://arxiv.org/abs/2412.19437
13. Acharya, S., Jia, F., & Ginsburg, B. (2024). "Star Attention: Efficient LLM Inference over Long Sequences." *arXiv preprint arXiv:2411.17116*.
14. Ye, T., Dong, L., Xia, Y., Sun, Y., Zhu, Y., Huang, G., & Wei, F. (2024). "Differential Transformer." *ICLR 2025*. arXiv:2410.05258. https://arxiv.org/abs/2410.05258
15. Xiao, G., Tian, Y., Chen, B., Han, S., & Lewis, M. (2023). "Efficient Streaming Language Models with Attention Sinks." *ICLR 2024*. arXiv:2309.17453.
16. Michel, P., Levy, O., & Neubig, G. (2019). "Are Sixteen Heads Really Better than One?" *NeurIPS 2019*. arXiv:1905.10650.
17. Olsson, C., Elhage, N., Nanda, N., Joseph, N., et al. (2022). "In-context Learning and Induction Heads." *Transformer Circuits Thread*, Anthropic.
18. Liu, H., Zaharia, M., & Abbeel, P. (2023). "Ring Attention with Blockwise Transformers for Near-Infinite Context." arXiv:2310.01889.
19. Beltagy, I., Peters, M.E., & Cohan, A. (2020). "Longformer: The Long-Document Transformer." arXiv:2004.05150.
20. Jiang, A.Q., Sablayrolles, A., Mensch, A., Bamford, C., Chaplot, D.S., et al. (2023). "Mistral 7B." arXiv:2310.06825.
21. Vig, J. (2019). "A Multiscale Visualization of Attention in the Transformer Model." *ACL 2019: System Demonstrations*. arXiv:1906.05714.
22. Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., & Liu, Y. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding." arXiv:2104.09864.
23. Press, O., Smith, N.A., & Lewis, M. (2021). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation." *ICLR 2022*. arXiv:2108.12409.
24. Dive into Deep Learning. (2023). "Multi-Head Attention." https://d2l.ai/chapter_attention-mechanisms-and-transformers/multihead-attention.html
