Multi-Head Self-Attention
Last reviewed
May 8, 2026
Sources
24 citations
Review status
Source-backed
Revision
v4 · 6,677 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 8, 2026
Sources
24 citations
Review status
Source-backed
Revision
v4 · 6,677 words
Add missing citations, update stale details, or suggest a clearer explanation.
Multi-head self-attention is a core component of the Transformer architecture, introduced by Ashish Vaswani, Noam Shazeer, and colleagues in 2017 in the paper "Attention Is All You Need".[1] Rather than computing a single attention function over the full model dimension, multi-head self-attention splits the queries, keys, and values into multiple parallel "heads," each of which independently computes self-attention in a lower-dimensional subspace. The outputs of all heads are then concatenated and linearly projected to produce the final result. This design lets the model jointly attend to information from different representation subspaces, capturing diverse linguistic and structural relationships such as syntax, semantics, positional patterns, and coreference.
Multi-head self-attention sits at the heart of modern deep learning, powering large language models like GPT-4, Claude, Gemini, and LLaMA, as well as vision transformers and multimodal systems. Variants such as Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) have reshaped how heads share parameters, while engineering work such as FlashAttention and PagedAttention has reshaped how the same math is mapped onto modern hardware.
This article focuses on the multi-head structure specifically: why running attention in parallel heads matters, how the heads specialize, how variants modify the head layout to shrink the KV cache, and how production stacks implement it. For the general mechanics of attention as a sequence operation, see the parent article on self-attention.
Before the Transformer, sequence modeling relied primarily on recurrent neural networks (RNNs) and long short-term memory (LSTM) networks. Single-head attention, proposed by Bahdanau et al. (2014) for machine translation, let models focus on relevant parts of an input regardless of distance.[2] But a single attention function computes a weighted average that can only capture one set of relationships at a time.
Vaswani et al. addressed this by introducing multi-head attention: multiple attention operations running in parallel across different learned subspaces. As the original paper put it, "Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this."[1] Ablations in the original paper showed that an 8-head Transformer with d_k = 64 cleanly beat a 1-head Transformer with d_k = 512 on the WMT 2014 English-German benchmark, even though the two configurations used the same parameters and FLOPs. Going to too many heads (32 with d_k = 16) also hurt quality, suggesting head dimension has a lower bound below which heads stop being useful.[1]
The foundation of multi-head self-attention is the scaled dot-product attention function. Given a set of queries (Q), keys (K), and values (V), the mechanism computes:
Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
where d_k is the dimension of the key vectors. The dot product Q K^T measures the similarity between each query and each key, producing an attention score matrix. The scaling factor 1/sqrt(d_k) prevents the dot products from growing too large in magnitude, which would push the softmax function into regions with extremely small gradients.[1] After applying softmax, the resulting attention weights are used to compute a weighted sum of the value vectors.
The shape arithmetic is worth keeping in mind. For an input of length n with model dimension d_model, Q, K, and V each have shape (n, d_model). The attention score matrix has shape (n, n), and the output has shape (n, d_model). When extended to a batch of B sequences and h heads with d_k = d_model / h, the standard tensor layout is (B, h, n, d_k), and the score matrix is (B, h, n, n). Most of the optimization work in modern attention kernels is about computing these tensors efficiently without materializing the full (B, h, n, n) score matrix in HBM.
In the multi-head variant, the input representations are linearly projected h times into different subspaces using separate learned weight matrices. Each projection produces a lower-dimensional set of queries, keys, and values that are processed by an independent attention head:
MultiHead(Q, K, V) = Concat(head_1, head_2, ..., head_h) W^O
where head_i = Attention(Q W_i^Q, K W_i^K, V W_i^V)
Here, W_i^Q, W_i^K, and W_i^V are the learned projection matrices for head i, each mapping from the full model dimension d_model to a head dimension d_k = d_model / h. The output projection matrix W^O maps the concatenated head outputs back to d_model.[1]
In practice, the per-head projections are fused into one big matrix per role. A modern implementation has three weight matrices W^Q, W^K, W^V each of shape (d_model, d_model), produces Q, K, V each of shape (B, n, d_model), then reshapes to (B, n, h, d_k) and transposes to (B, h, n, d_k) before the attention kernel. Some libraries fuse W^Q, W^K, W^V into a single (d_model, 3 * d_model) matrix for a single matmul, which improves GPU utilization at small batch sizes.
When the queries, keys, and values all come from the same sequence, the mechanism is multi-head self-attention. This is the most common variant in encoder-only models (like BERT) and decoder-only models (like GPT). In encoder-decoder architectures, a sibling form called cross-attention is also used, where the queries come from the decoder and the keys and values come from the encoder output. Both forms use the same multi-head plumbing but differ in the source of their inputs.[1]
In autoregressive models, multi-head self-attention is wrapped in causal masking: an upper-triangular mask added to the score matrix sets future positions to negative infinity so that, after softmax, position i can only attend to positions 1 through i. The mask is applied per head identically. Production attention kernels like FlashAttention skip the upper triangle entirely in the causal case, roughly doubling speed in long-context training compared to a naive masked softmax.
Language contains many simultaneous types of relationships. A model must track syntactic structure (subject-verb agreement), semantic context, positional patterns, and coreference at the same time. A single attention head computes a single weighted average, forcing all these different relationship types to compete for representation in the same set of attention weights.
Multiple heads let different heads specialize. Different heads track syntactic relations (for example, linking verbs to their direct objects), positional patterns (attending to the immediately preceding token), or coreference (connecting pronouns to their antecedents).[3][4] Each head operates in a different learned subspace, which discourages heads from collapsing onto identical patterns. The output projection W^O then learns how to pool the per-head opinions into one combined update to the residual stream.
A useful way to think about it is as an ensemble inside a single layer. Each head is a small attention module with its own opinion about which positions matter. Because 8 heads of dimension 64 outperformed 1 head of dimension 512 in the original ablations,[1] the parallel specialization clearly does meaningful work, even though the total compute is identical.
The number of attention heads h and the per-head dimension d_k = d_model / h are important architectural hyperparameters. The total computation and parameter count of multi-head attention remain the same as single-head attention (since d_k * h = d_model), but the distribution across heads affects model quality.
Typical configurations across well-known models are shown in the following table. Per-head dimension d_k for Q has stayed remarkably stable at 64 or 128 for most architectures, regardless of model scale.
| Model | d_model | Heads (h) | KV heads | Head dim (d_k) | Attention type | Params |
|---|---|---|---|---|---|---|
| Transformer (base) [1] | 512 | 8 | 8 | 64 | MHA | ~65M |
| Transformer (big) [1] | 1024 | 16 | 16 | 64 | MHA | ~213M |
| BERT-Base | 768 | 12 | 12 | 64 | MHA | 110M |
| BERT-Large | 1024 | 16 | 16 | 64 | MHA | 340M |
| GPT-2 Small | 768 | 12 | 12 | 64 | MHA | 117M |
| GPT-2 XL | 1600 | 25 | 25 | 64 | MHA | 1.5B |
| GPT-3 | 12288 | 96 | 96 | 128 | MHA | 175B |
| PaLM 540B | 18432 | 48 | 1 | 384 | MQA | 540B |
| Falcon 40B | 8192 | 64 | 8 | 128 | GQA (multi-group MQA) | 40B |
| Mistral 7B | 4096 | 32 | 8 | 128 | GQA + sliding window | 7B |
| LLaMA 2 7B | 4096 | 32 | 32 | 128 | MHA | 7B |
| LLaMA 2 70B | 8192 | 64 | 8 | 128 | GQA | 70B |
| LLaMA 3.1 8B | 4096 | 32 | 8 | 128 | GQA | 8B |
| LLaMA 3.1 70B | 8192 | 64 | 8 | 128 | GQA | 70B |
| LLaMA 3.1 405B | 16384 | 128 | 16 | 128 | GQA | 405B |
| Qwen2.5 72B | 8192 | 64 | 8 | 128 | GQA | 72B |
| DeepSeek-V2 [7] | 5120 | 128 | shared latent | 128 + 64 | MLA | 236B (21B active) |
| DeepSeek-V3 [12] | 7168 | 128 | shared latent | 128 + 64 | MLA | 671B (37B active) |
The DeepSeek rows reflect MLA's split into a content path (128 dims) and a decoupled RoPE path (64 dims). Instead of conventional per-head K and V tensors, MLA stores a single shared 512-dim latent per token, described in detail below.
Head dimension d_k has remained at 64 or 128 across most architectures, even as model sizes have scaled from millions to hundreds of billions of parameters. Scaling generally happens by increasing the number of heads and the number of layers, not by increasing d_k. This empirical regularity is one of the more underdiscussed conventions in modern LLM design.
The computational complexity of multi-head self-attention is O(n^2 * d_model), where n is the sequence length. Each head computes attention with complexity O(n^2 * d_k), and since h heads with d_k = d_model / h sum to d_model, the total matches single-head attention over the full dimension.[1] The quadratic dependence on sequence length is the primary bottleneck for long sequences, which motivates research into sparse attention, linear attention, and FlashAttention.
The parameter count for a single multi-head attention layer is roughly 4 * d_model^2: the projection matrices for queries, keys, values, and output. Across an L-layer model, attention contributes 4 * L * d_model^2 parameters, with the position-wise feed-forward network contributing about twice that. As models have scaled, the cost of KV cache memory at inference has come to dominate, which is what motivates most of the variants discussed below.
As large language models have scaled, the cost of standard multi-head attention during inference has become a bottleneck. The cost is dominated by the KV cache: every decoding step loads every cached key and value tensor from HBM, and cache size grows linearly with sequence length and number of heads. Several architectural variants address this by changing how heads share parameters.
Multi-Query Attention, introduced by Noam Shazeer (2019), shares a single set of K and V projections across all query heads.[5] Each head keeps its own query projection, but there is only one K and one V projection for the entire layer. This reduces KV cache size by a factor equal to the number of heads.
| Aspect | Standard MHA | Multi-Query Attention |
|---|---|---|
| Query heads | h independent | h independent |
| Key/Value heads | h independent | 1 shared each |
| KV cache size | O(n * h * d_k) | O(n * d_k) |
| KV cache reduction (8 heads) | Baseline | ~87.5% |
| Quality impact | Baseline | small perplexity increase, occasional training instability |
| Decoding speedup | Baseline | 1.8x to 2.4x typical |
MQA was adopted by Google's PaLM and several open models such as Falcon and StarCoder. The quality degradation is small on most benchmarks, but for the largest and most capable models even a small loss matters, which is why most teams now prefer GQA.
Grouped-Query Attention, proposed by Joshua Ainslie and colleagues (2023), is a middle ground between MHA and MQA.[6] Query heads are divided into g groups, and each group shares a single KV head. When g = h, GQA is equivalent to MHA; when g = 1, it is equivalent to MQA. Intermediate values (for example g = 8 with h = 32) give most of the speed benefits of MQA while keeping quality close to MHA.
The Ainslie paper also showed that an existing MHA checkpoint could be "uptrained" into GQA using only about 5% of the original pretraining compute, which made the variant practical to retrofit. Meta used GQA in LLaMA 2 70B (July 2023) and retained it in LLaMA 3 across all sizes. Mistral AI used GQA in Mistral 7B (September 2023). GQA has become the default attention mechanism for new LLMs since 2023.
Multi-Head Latent Attention, introduced in the DeepSeek-V2 paper (2024) and reused in DeepSeek-V3 (2024-2025), takes a different angle on KV cache reduction.[7][12] Instead of sharing heads, MLA jointly compresses the key and value tensors of every head into a single low-dimensional latent vector through a learned down-projection. The model caches only this latent vector; at compute time, an up-projection reconstructs full per-head keys and values on the fly.
DeepSeek-V2 reported a 93.3% reduction in KV cache size compared to standard MHA on the same backbone, and a 5.76x increase in maximum generation throughput against DeepSeek 67B.[7] Because every head still receives its own reconstructed K and V, MLA can match or exceed the quality of standard MHA, unlike GQA which gives up some capacity by sharing per-group K and V.
MLA is non-trivial to combine with RoPE, because rotary position embeddings act on full Q and K tensors and cannot be applied cleanly to a compressed latent. The DeepSeek authors solve this with a "decoupled RoPE" trick: each query and key is split into a content component (carried through the latent compression) and a small positional component (typically 64 dimensions per head) that bypasses the latent and carries RoPE directly. In DeepSeek-V3, the architecture uses 128 attention heads with a 128-dim content path plus a 64-dim RoPE path, and a 512-dim shared latent for K/V across all heads. The effective per-token cache is about 576 dims regardless of head count.[12]
Sliding window attention restricts each query to attend only to the W most recent keys for causal models. It bounds per-step KV memory at O(W * h * d_k) regardless of total sequence length, and it stacks with any of MHA, MQA, GQA, or MLA. The Longformer paper (Beltagy et al. 2020) popularized the idea for encoders by combining a window with a small set of global tokens.
Mistral 7B (September 2023) brought sliding window attention to a frontier-quality decoder-only model with a window of 4,096 tokens under a context of 8,192. Because each Transformer layer propagates a window's worth of context, an L-layer stack can route information across roughly L * W tokens even though each individual call sees only W. Mistral combined this with GQA and a custom FlashAttention kernel for the windowed mask.
LLaMA 4 (2025) uses a related design called iRoPE that interleaves three RoPE-equipped local-attention layers with one global-attention layer that drops position embeddings (NoPE) and uses a full causal mask. The pattern lets the model retain global reach while keeping most layers cheap, and it is part of how the LLaMA 4 Scout model claims a 10M-token context length.
| Variant | Year | KV heads per layer | KV cache size | Quality vs MHA | Notable adopters |
|---|---|---|---|---|---|
| MHA | 2017 | h (one per query head) | O(n * h * d_k) | Baseline | Original Transformer, BERT, GPT-2, GPT-3 |
| MQA | 2019 | 1 (shared) | O(n * d_k) | Slightly lower; sometimes unstable | PaLM, Falcon, StarCoder |
| GQA | 2023 | g (1 < g < h) | O(n * g * d_k) | Near baseline | LLaMA 2/3/4, Mistral, Qwen, Mixtral |
| Sliding window | 2020 | h or g, but only over W tokens | O(W * h * d_k) | Slight degradation past window | Longformer, Mistral 7B, iRoPE in LLaMA 4 |
| MLA | 2024 | shared latent, decompressed at use | O(n * d_c), d_c << h * d_k | Comparable or better | DeepSeek-V2, DeepSeek-V3, DeepSeek-R1 |
| Cross-Layer KV sharing | 2024 | shared across multiple layers | proportional to layers shared | Small degradation | Character.AI's MQA-CLA, MiniCPM, recent research models |
Cross-Layer Attention (CLA), proposed in 2024, takes a different axis of compression by sharing the KV tensors across adjacent layers rather than across heads within a layer. It can be stacked on top of MQA or GQA for additional cache reduction at modest quality cost.
None of the variants above change the asymptotic O(n^2) cost of dense attention. They reduce KV cache size or restrict the attention pattern. A complementary line of work has attacked the constants instead, by computing exact attention more efficiently on GPUs.
FlashAttention, introduced by Tri Dao and colleagues at Stanford in 2022, is the most successful example.[8] On a modern GPU like an A100 or H100, attention is bounded by memory bandwidth, not compute. The naive implementation materializes the full (n, n) score matrix in HBM, computes softmax, multiplies by V, and writes back, streaming O(n^2) data through HBM even though the matmul has plenty of arithmetic intensity. FlashAttention reorganizes the computation into tiled blocks that fit in on-chip SRAM. Each tile reads its slice of Q, K, V from HBM once, computes a local output and a running softmax statistic in SRAM, and writes only the final output back. Memory traffic drops from O(n^2) to O(n), the score matrix is never materialized, and the algorithm stays mathematically exact.
| Version | Year | Lead author | Target hardware | Key contributions | Reported throughput |
|---|---|---|---|---|---|
| FlashAttention v1 [8] | May 2022 | Tri Dao | Ampere (A100) | IO-aware tiling, online softmax, kernel fusion, no n^2 materialization | 15% wall-clock speedup on BERT-Large; 3x on GPT-2 |
| FlashAttention v2 [9] | July 2023 | Tri Dao | Ampere (A100) | Parallelize over sequence as well as heads/batch, fewer non-matmul FLOPs | ~2x v1, reaches 50-73% of peak FLOPs on A100 |
| FlashAttention v3 [10] | July 2024 | Jay Shah, Tri Dao | Hopper (H100) | Warp-specialized scheduling, async TMA loads, GEMM/softmax overlap, FP8 | 1.5-2x v2 in FP16, ~740 TFLOPs/s; ~1.2 PFLOPs/s with FP8 |
FlashAttention-3 (arXiv 2407.08608, NeurIPS 2024) targets the H100 specifically.[10] On Hopper, the Tensor Memory Accelerator (TMA) and warp specialization let the kernel overlap data movement with computation in ways that were not possible on Ampere. FP8 FlashAttention-3 also uses incoherent processing (block-wise random rotations) to reduce quantization error, achieving 2.6x lower error than a baseline FP8 attention implementation.
FlashAttention has become the default attention kernel in most production LLM training and inference stacks. It powers PyTorch's torch.nn.functional.scaled_dot_product_attention, vLLM, TensorRT-LLM, and the Hugging Face Transformers library.
Ring Attention (Liu et al. 2023) distributes exact attention across many devices.[18] Each device holds a slice of the sequence and the K and V tensors rotate around a logical ring, so each device sees every other device's K and V exactly once. The math stays exact, per-device memory drops to O(n / devices), and total throughput grows roughly linearly with device count. Ring Attention is now standard for training million-token-context models and is typically combined with FlashAttention as the kernel inside each ring step.
NVIDIA's Star Attention (November 2024) targets long-context inference rather than training.[13] It splits a long input into blocks and prepends a shared "anchor block" to each shard, so blocks can be encoded in parallel across many hosts in a first phase before a second phase performs global query-side attention on a single host. The anchor blocks soak up the attention sinks that would otherwise emerge at the start of each independent block, preserving accuracy. Star Attention reports 97-100% of the accuracy of full global attention on RULER and BABILong with substantial inference speedups on Llama 3.1 8B and 70B at sequence lengths up to 1M tokens.
During autoregressive generation, every new token requires attending over the entire prefix. Recomputing keys and values at each step would be O(n^2). The standard remedy is the KV cache: keys and values are computed once per token and stored in GPU memory, so each step only computes the new token's Q and reads cached K and V.
With MHA, the cache for one layer at sequence length n is roughly 2 * h * d_k * n bytes per dtype element (so 4 KB per token per layer in FP16 with h * d_k = 1024). Across L layers and B concurrent requests, a 70B GQA model serving a 32K-token request easily uses several gigabytes of HBM just for K and V. The cache also turns generation into a memory-bandwidth problem rather than a compute problem: each decoding step reads the entire cache from HBM, runs a small QK^T and softmax, and writes one new K and V. This is the core reason MQA, GQA, MLA, sliding window, and CLA exist: they all shrink the cache and therefore raise effective tokens per second.
PagedAttention, introduced by Woosuk Kwon and colleagues at UC Berkeley at SOSP 2023, treats the KV cache like virtual memory.[11] Instead of allocating a contiguous block of HBM per request (which fragments and over-reserves when generation is shorter than the maximum), PagedAttention stores K and V in fixed-size blocks scattered across HBM and maintains a per-request page table mapping logical positions to physical blocks. The attention kernel walks the page table when computing scores.
PagedAttention powers vLLM, the most widely used open-source LLM serving system. The Kwon paper reports that prior systems wasted 60-80% of KV cache memory to fragmentation; vLLM's paging brings waste close to zero. Combined with continuous batching (dynamically adding and removing requests from a running batch), vLLM reports 2-4x throughput improvements over earlier systems like FasterTransformer and Orca at the same latency.[11] PagedAttention is now also implemented in TensorRT-LLM, SGLang, and other production stacks.
Several lower-level optimizations reduce KV cache size without changing the architecture:
Speculative decoding is not a multi-head variant but is closely tied to attention because it reuses the KV cache across two models. A small draft model generates several candidate tokens per step; the larger target model verifies them in parallel using the same multi-head attention machinery. Each accepted draft token saves a full forward pass of the target model. Frontier deployments routinely combine speculative decoding with FlashAttention and PagedAttention for end-to-end speedups in the 1.5-3x range.
Multi-head attention is not a finished design. Several recent variants change the head structure itself rather than just the sharing pattern.
Differential Transformer, introduced by Microsoft Research and Tsinghua in October 2024 (arXiv 2410.05258, ICLR 2025), modifies each head to compute the difference of two softmax attention maps rather than a single softmax.[14] Each query and key is split in half to feed two parallel softmax operations, and the second is subtracted from the first with a learnable scale. The construction acts like noise cancellation: shared low-amplitude patterns cancel out, and only the meaningful signal survives.
The paper reports that Diff Transformer outperforms standard Transformer baselines on language modeling at matched FLOPs, with notable gains on long-context retrieval, in-context learning, and hallucination mitigation. The authors keep the parameter and FLOP budget close to vanilla Transformer by halving the head dimension feeding each softmax. Differential attention is a candidate replacement for the inner softmax in any of MHA, GQA, or MLA.
Guangxuan Xiao and colleagues (StreamingLLM, ICLR 2024) found that the first few tokens of a sequence accumulate disproportionately large attention weights, even when they carry no semantic content.[15] They named this the attention sink and showed that preserving the first few tokens in the KV cache, alongside a sliding window of recent tokens, lets a frozen LLM generate over essentially unbounded contexts (over 4M tokens with Llama-2 in their experiments). Several recent open-weight models train with explicit sink tokens to make long-context streaming inference more stable.
Multi-head attention is permutation-equivariant, so position information must be added separately. The dominant choice in modern LLMs is Rotary Position Embedding (RoPE), which rotates each query and key vector by an angle that depends on its absolute position. Every head applies the same rotation to its own d_k-dimensional Q and K. MLA had to invent decoupled RoPE (a separate small head dimension reserved for position) precisely because applying RoPE inside a low-rank latent compression is mathematically inconvenient. iRoPE in LLaMA 4 alternates RoPE-equipped local-attention layers with NoPE global-attention layers as a way to manage long contexts without degrading short-range behavior. ALiBi, used in MPT and BLOOM, adds a fixed linear bias to attention scores per head instead of rotating Q and K, with a different slope per head. See the self-attention parent article for a fuller treatment of position encoding choices.
Not all attention heads contribute equally. Research has shown that many heads are redundant and can be removed with minimal impact on quality.
Michel et al. ("Are sixteen heads really better than one?", 2019) demonstrated that a large fraction of attention heads can be pruned at test time without significantly degrading performance.[16] In BERT, up to 40% of heads could be removed with negligible impact; in machine translation, 20% of heads could be removed while maintaining reasonable quality. The authors used a greedy pruning algorithm based on gradient-derived sensitivity scores. Pruning yielded up to 17.5% inference speedup on BERT-based models.
Voita et al. (2019) found that only a small subset of heads play important, linguistically interpretable roles.[4] These "specialized heads" attend to adjacent words, track syntactic dependencies, or resolve rare or positionally significant tokens. Using a differentiable L0 relaxation, they showed that specialized heads were the last to be pruned. On an English-Russian translation task, 38 out of 48 encoder heads could be pruned with a drop of only 0.15 BLEU.
These findings suggest production models may be over-parameterized in their number of heads, which helped motivate the head-sharing variants above: if many heads are redundant, sharing parameters across them is a sensible architectural prior.
Tools such as BertViz let researchers inspect the attention weights of individual heads across layers, which has become an important way of understanding what Transformer models learn.
Clark et al. (2019) analyzed BERT's attention heads and identified several recurring specialization patterns:[3]
| Head role | Description | Example |
|---|---|---|
| Positional heads | Attend to tokens at fixed relative positions (for example, the previous or next token) | Head attends to the immediately preceding word |
| Syntactic heads | Track grammatical relationships | Head links verbs to their direct objects ("baked" to "cake") |
| Determiner-noun heads | Connect determiners to the nouns they modify | Head links "the" to "dog" |
| Prepositional heads | Attend to objects of prepositions | Head links "in" to "Paris" |
| Coreference heads | Link pronouns to their antecedents | Head connects "she" to "lawyer" |
| Delimiter heads | Attend to special tokens like [SEP] or [CLS] | Head focuses on sentence boundaries |
| Broad attention heads | Distribute attention roughly uniformly across the sequence | Head attends broadly for contextual mixing |
These findings demonstrate that multi-head attention does not simply provide redundant copies of the same computation. Different heads learn qualitatively different functions that correspond to recognized linguistic categories.
Heads also specialize hierarchically across layers. Lower layers in BERT (layers 2-4) capture basic grammatical relationships such as noun-verb and determiner-noun links; middle layers handle complex syntactic structures; upper layers capture higher-level semantic relationships and task-specific patterns.[3]
Research from Anthropic on the Transformer Circuits Thread identified the "induction head," a two-layer circuit in which an early head writes positional information into the residual stream and a later head uses that information to attend back to a previous occurrence of the current token and copy what came after it. Induction heads emerge in a sharp phase transition during training and closely track the appearance of in-context learning ability in small attention-only models.[17] They give a concrete mechanistic example of how the multi-head structure supports computation that goes beyond simple weighted averaging.
Multi-head self-attention is increasingly understood through a hardware lens, because the dominant cost on modern accelerators is not arithmetic but data movement.
A modern GPU like an H100 has roughly 80 GB of HBM3 memory at about 3 TB/s and 256 KB of register file plus shared memory (SRAM) per streaming multiprocessor at bandwidth far exceeding HBM. Naive attention round-trips Q, K, V, and the full (n, n) score matrix through HBM; FlashAttention keeps the inner loop entirely in SRAM. The same logic applies on NVIDIA's Blackwell GPUs and on AMD's MI300.
All modern accelerators provide specialized matmul units (Tensor Cores on NVIDIA) that run mixed-precision matrix multiplication an order of magnitude faster than general-purpose ALUs. Multi-head attention's main cost is two batched matmuls per head (QK^T and softmax(QK^T)V), so kernels work hard to keep these on Tensor Cores in FP16, BF16, or FP8. The non-matmul parts (softmax, dropout, masking) run on regular ALUs in higher precision and can become the bottleneck once the matmul is fully utilized; one of the contributions of FlashAttention-2 was to reduce these non-matmul FLOPs.
A single LLM inference call has two phases. Prefill runs attention over the whole prompt at once with high arithmetic intensity and looks like a training forward pass. Decoding generates tokens one at a time and is bounded by HBM bandwidth: each step reads the entire KV cache to compute attention over a single new query vector. Decoding throughput per GPU is approximately (HBM bandwidth) / (KV cache size per token times number of layers), which is why all the head-sharing variants pay off so directly: cutting cache size by 4x roughly raises decoding throughput by 4x.
For models too large for a single GPU, attention is typically tensor-parallelized along the head dimension: each GPU holds a subset of heads and computes attention for those heads, then heads are concatenated through an all-gather. This works cleanly with MHA, GQA, and MQA. MLA is more delicate because the latent is shared across heads, so DeepSeek uses a custom sharding strategy. For very long contexts, sequence parallelism (Ring Attention) supplements tensor parallelism by sharding along the sequence dimension.
Multi-head attention is the most heavily optimized component in the deep learning software stack. The default options as of 2026 include:
| Stack | Function or class | Notes |
|---|---|---|
| PyTorch | torch.nn.functional.scaled_dot_product_attention | Dispatches to FlashAttention v2/v3, memory-efficient attention, math, or cuDNN backends |
| flash-attn (Dao-AILab) | flash_attn_func and friends | Reference implementation of FlashAttention 1/2/3 in CUDA and Triton |
| JAX / Flax | flax.linen.MultiHeadDotProductAttention | Used in Google's PaLM/Gemini training stacks via Pallas kernels |
| Hugging Face Transformers | per-model attention classes plus AttentionInterface | Lets users swap between SDPA, FlashAttention, eager, and custom backends per model |
| vLLM | Attention op with PagedAttention kernels | Production inference; handles MHA, GQA, MQA, MLA, sliding window, FP8 KV |
| TensorRT-LLM | fused MHA/GQA plugins | NVIDIA's optimized inference stack with PagedAttention and continuous batching |
| SGLang / FlashInfer | RadixAttention, modular attention kernels | Alternative inference stack with prefix caching; FlashInfer kernels also used by vLLM |
None of these stacks expose MHA, MQA, GQA, or MLA differently at the user level. The variant is determined by how the model defines its head shapes, and the kernel adapts.
Multi-head self-attention is a foundational building block in a wide range of applications:
For the broader sequence operation that multi-head self-attention extends, see the parent article on self-attention. For history and the original paper, see Attention Is All You Need.
Imagine you are in a classroom with lots of friends, and you need to figure out what everyone is talking about. Instead of trying to listen to everything yourself, you send out several helper listeners. Each helper pays attention to something different: one helper listens for who is friends with whom, another listens for what game everyone wants to play, and a third listens for who is sitting next to whom.
After listening, all the helpers come back and tell you what they heard. You combine all their reports to get a complete picture of what is going on in the classroom.
That is what multi-head self-attention does in a computer. Each "head" is like a helper that focuses on a different kind of relationship between words. One head might notice grammar rules, another might notice word meanings, and another might notice which words are close together. By combining all their findings, the computer understands sentences much better than if it only had one helper listening to everything at once.