See also: Attention (machine learning), Transformer, KV cache
Grouped-query attention (GQA) is a variant of the multi-head attention mechanism used in Transformer models. It was introduced by Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebron, and Sumit Sanghai in their 2023 paper "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints," published at EMNLP 2023. GQA sits between standard multi-head attention (MHA) and multi-query attention (MQA) on a spectrum of key-value sharing strategies. It divides the query heads into groups, where each group shares a single set of key and value projections. This design reduces the size of the key-value (KV) cache during autoregressive inference, lowering memory consumption and increasing throughput, while preserving model quality much better than the more aggressive MQA approach.
Since its publication, GQA has been adopted as the default attention mechanism in nearly every major open-weight large language model, including Llama 2 and Llama 3, Mistral 7B, Gemma 2, Qwen 2, and Falcon 40B/180B.
During autoregressive text generation, a language model produces one token at a time. At each step, the model computes attention over all previously generated tokens. To avoid redundantly recomputing key and value projections for past tokens, modern implementations store these projections in a KV cache. The cache grows linearly with sequence length, model dimension, number of layers, and batch size.
For large models deployed at long context lengths, the KV cache can consume tens of gigabytes of GPU memory. For example, a model with 80 layers, 64 attention heads, a head dimension of 128, and a sequence length of 8,192 tokens requires storing 80 x 64 x 128 x 8,192 x 2 (keys and values) x 2 bytes (FP16) of data, which comes to roughly 20 GB for a single sequence.
More importantly, autoregressive decoding at small batch sizes is memory-bandwidth bound rather than compute bound. At each decoding step, the model must load the full KV cache from GPU high-bandwidth memory (HBM) into on-chip SRAM to compute attention scores against the single new query. The speed of token generation is therefore limited by how quickly the GPU can read this data, not by how fast it can perform arithmetic. Reducing the KV cache size directly translates to faster inference because there is less data to transfer.
Standard multi-head attention, introduced by Vaswani et al. (2017) in the original Transformer paper "Attention Is All You Need," uses h independent attention heads. Each head i has its own learned projection matrices for queries, keys, and values:
Q_i = X W_i^Q, K_i = X W_i^K, V_i = X W_i^V
head_i = softmax(Q_i K_i^T / sqrt(d_k)) V_i
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W^O
In MHA, each head maintains its own KV projections. The total KV cache size per layer is proportional to h (the number of heads) times d_k (the head dimension) times the sequence length.
Multi-query attention was proposed by Noam Shazeer in 2019 in the paper "Fast Transformer Decoding: One Write-Head Is All You Need." MQA keeps h independent query heads but collapses all keys and values into a single shared head. Every query head attends to the same set of keys and values:
K = X W^K, V = X W^V (shared across all heads)
Q_i = X W_i^Q (unique per head)
head_i = softmax(Q_i K^T / sqrt(d_k)) V
This reduces the KV cache by a factor of h compared to MHA. For a model with 64 heads, MQA stores only 1/64th of the key-value data. The result is a large speedup during inference. However, MQA comes with a quality trade-off: sharing a single KV representation across all heads limits the model's representational capacity, and empirical studies have shown measurable quality degradation compared to MHA on tasks like summarization, translation, and question answering.
GQA generalizes both MHA and MQA by introducing a parameter G (the number of KV head groups). The h query heads are partitioned into G equal-sized groups. Within each group, all query heads share a single key head and a single value head:
For a model with h = 32 query heads and G = 8 KV groups, each group contains 32 / 8 = 4 query heads that share one key-value pair. The KV cache is reduced by a factor of h / G = 4 compared to MHA.
Let h denote the number of query heads, G the number of KV groups, and d_k the head dimension. For each group g (where g = 1, ..., G):
K_g = X W_g^K, V_g = X W_g^V
For each query head i in group g:
Q_i = X W_i^Q
head_i = softmax(Q_i K_g^T / sqrt(d_k)) V_g
The outputs of all heads are concatenated and projected through a learned output matrix W^O:
GQA(X) = Concat(head_1, ..., head_h) W^O
The asymptotic computational complexity of GQA remains O(n^2 * d) during the full forward pass, the same as MHA and MQA. The benefit is in the reduced memory footprint and memory bandwidth requirements during inference.
The KV cache size per layer can be expressed as:
KV cache per layer = 2 x n_kv_heads x d_k x seq_len x bytes_per_element
where the factor of 2 accounts for both keys and values. The following table shows how KV cache size scales under each attention variant for a model with h = 64 query heads, d_k = 128, sequence length 8,192, and FP16 storage (2 bytes per element):
| Attention variant | KV heads | KV cache per layer | Reduction vs. MHA |
|---|---|---|---|
| MHA (G = h = 64) | 64 | 64 x 128 x 8192 x 2 x 2 = 256 MB | 1x (baseline) |
| GQA-8 (G = 8) | 8 | 8 x 128 x 8192 x 2 x 2 = 32 MB | 8x |
| GQA-4 (G = 4) | 4 | 4 x 128 x 8192 x 2 x 2 = 16 MB | 16x |
| MQA (G = 1) | 1 | 1 x 128 x 8192 x 2 x 2 = 4 MB | 64x |
For a model with 80 layers (like Llama 2 70B), multiply these per-layer values by 80 to get the total KV cache size. An 8x reduction (GQA-8) would save roughly 18 GB of memory for a single sequence at this scale.
One of the paper's key contributions is a practical recipe for converting an existing MHA model into a GQA model without training from scratch. This procedure is called uptraining and requires only about 5% of the original pre-training compute (approximately 600 TPUv3 chip-days for a T5-XXL scale model).
The conversion from MHA to GQA involves restructuring the key and value projection matrices. Ainslie et al. evaluated three methods for initializing the GQA key-value heads from the original MHA heads:
Mean pooling (best): The key and value projection matrices of the original heads within each group are averaged. For a group containing heads {i_1, i_2, ..., i_k}, the group's key projection is W_g^K = (1/k) * sum(W_{i_j}^K). This preserves the most information from the pretrained checkpoint.
First head selection (middle): The projection matrices of the first head in each group are used, and the rest are discarded.
Random initialization (worst): New key and value projections are initialized randomly, discarding all pretrained KV information.
The paper found that mean pooling consistently outperformed the other two methods, achieving the best quality after uptraining. This makes sense because averaging the learned projections retains information from all original heads rather than discarding most of it.
After checkpoint conversion, the model is trained for an additional alpha fraction of the original pre-training steps (where alpha = 0.05, i.e., 5%). The paper found that:
The original GQA paper evaluated the approach on T5-XXL (a model with approximately 11 billion parameters) across summarization, translation, and question-answering benchmarks. All GQA models were uptrained from the same MHA-XXL checkpoint.
| Model | Inference time (s) | CNN/DM (R1) | arXiv (R1) | PubMed (R1) | MediaSum (R1) | MultiNews (R1) | WMT (BLEU) | TriviaQA (F1) |
|---|---|---|---|---|---|---|---|---|
| MHA-Large | 0.37 | 46.0 | 42.9 | 44.6 | 46.2 | 35.5 | 46.6 | 78.2 |
| MHA-XXL | 1.51 | 47.2 | 43.8 | 45.6 | 47.5 | 36.4 | 46.9 | 81.9 |
| MQA-XXL | 0.24 | 46.6 | 43.0 | 45.0 | 46.9 | 36.1 | 46.5 | 81.3 |
| GQA-8-XXL | 0.28 | 47.1 | 43.5 | 45.4 | 47.7 | 36.3 | 47.2 | 81.6 |
Key observations:
The paper also measured how inference time changes with the number of groups G. Going from G = 1 (MQA) to G = 8 adds only modest overhead. Increasing G beyond 8 starts approaching MHA inference costs. G = 8 was identified as a practical sweet spot that balances quality and speed for models with 32 to 128 query heads.
The three attention variants form a spectrum of KV sharing strategies, trading off model expressiveness against inference efficiency:
| Property | MHA | GQA | MQA |
|---|---|---|---|
| KV heads | h (one per query head) | G (one per group) | 1 (shared) |
| KV cache size | Largest | Intermediate | Smallest |
| Inference speed | Slowest | Near MQA | Fastest |
| Model quality | Best | Near MHA | Slightly degraded |
| Representational capacity | Full | Near-full | Limited |
| Parameters (KV projections) | h x d_k x d_model x 2 | G x d_k x d_model x 2 | 1 x d_k x d_model x 2 |
| Uptraining from MHA? | N/A | Yes (5% compute) | Yes (5% compute) |
GQA occupies the best trade-off point for most practical applications: it achieves quality within a fraction of a point of MHA while running at nearly the same speed as MQA.
In practice, GQA is implemented by computing Q, K, and V projections with h query heads and G KV heads, then expanding the KV heads to match the number of query heads before computing attention. This expansion is performed by a function commonly called repeat_kv, which duplicates each KV head h/G times along the head dimension:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Expand KV heads to match query head count."""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x # Already MHA, no expansion needed
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
The key insight is that while the KV cache stores only G sets of key-value vectors (saving memory), the actual attention computation still operates with h effective KV heads by repeating each cached head. This means the forward pass arithmetic is identical to MHA; the savings come entirely from the reduced cache storage and the reduced amount of data that must be loaded from memory during decoding.
As of PyTorch 2.5+, the built-in torch.nn.functional.scaled_dot_product_attention function supports GQA directly through an enable_gqa=True parameter. This eliminates the need for explicit repeat_kv calls in user code and allows the kernel to handle the head expansion internally, potentially with optimized memory access patterns.
FlashAttention 2 and 3 natively support GQA. When the number of KV heads differs from the number of query heads, FlashAttention automatically handles the head grouping without requiring the user to expand the KV tensors. This combination of GQA with FlashAttention provides both the memory savings from reduced KV heads and the IO-efficiency of tiled attention computation.
GQA has been widely adopted across nearly all major open-weight and proprietary language models released since mid-2023. The following table summarizes GQA configurations in notable models:
| Model | Release | Parameters | Query heads | KV heads (G) | Head dim | Queries per group |
|---|---|---|---|---|---|---|
| Llama 2 7B | Jul 2023 | 7B | 32 | 32 (MHA) | 128 | 1 |
| Llama 2 13B | Jul 2023 | 13B | 40 | 40 (MHA) | 128 | 1 |
| Llama 2 70B | Jul 2023 | 70B | 64 | 8 | 128 | 8 |
| Llama 3 8B | Apr 2024 | 8B | 32 | 8 | 128 | 4 |
| Llama 3 70B | Apr 2024 | 70B | 64 | 8 | 128 | 8 |
| Llama 3.1 405B | Jul 2024 | 405B | 128 | 8 | 128 | 16 |
| Mistral 7B | Sep 2023 | 7B | 32 | 8 | 128 | 4 |
| Mixtral 8x7B | Dec 2023 | 46.7B (MoE) | 32 | 8 | 128 | 4 |
| Gemma 2B | Feb 2024 | 2B | 8 | 1 (MQA) | 256 | 8 |
| Gemma 2 9B | Jun 2024 | 9B | 16 | 8 | 256 | 2 |
| Gemma 2 27B | Jun 2024 | 27B | 32 | 16 | 128 | 2 |
| Falcon 7B | May 2023 | 7B | 71 | 1 (MQA) | 64 | 71 |
| Falcon 40B | May 2023 | 40B | 128 | 8 | 64 | 16 |
| Falcon 180B | Sep 2023 | 180B | 232 | 8 | 64 | 29 |
| Qwen 2.5 7B | Sep 2024 | 7B | 28 | 4 | 128 | 7 |
| Qwen 2.5 14B | Sep 2024 | 14B | 40 | 8 | 128 | 5 |
| Qwen 2.5 72B | Sep 2024 | 72B | 64 | 8 | 128 | 8 |
A few patterns stand out from this table:
Multi-head latent attention, introduced in DeepSeek-V2 (2024), takes a different approach to KV cache reduction. Instead of reducing the number of KV heads, MLA compresses the key and value representations into a low-dimensional latent vector before storing them in the cache. At inference time, the compressed representation is projected back to produce unique keys and values for each head.
MLA achieved a 93.3% reduction in KV cache size compared to standard MHA in DeepSeek-V2 while maintaining or even improving model quality. Unlike GQA, which slightly underperforms MHA, MLA has been reported to match or exceed MHA quality in ablation studies by the DeepSeek team. However, MLA adds computational overhead through the decompression step and is more complex to implement. GQA remains easier to implement, train, and integrate with existing attention kernels like FlashAttention.
| Property | GQA | MLA |
|---|---|---|
| KV cache reduction method | Fewer KV heads | Low-rank compression |
| KV cache reduction (typical) | 4x to 8x | Up to 15x |
| Quality vs. MHA | Slightly below | Comparable or better |
| Implementation complexity | Low (simple head sharing) | Higher (factorized projections) |
| Kernel support | Mature (FlashAttention, PyTorch) | Growing |
| Adopted by | Llama, Mistral, Gemma, Qwen, Falcon | DeepSeek-V2, DeepSeek-V3, DeepSeek-R1 |
Sliding window attention restricts each token to attend only to a fixed window of w neighboring tokens rather than the full context. Mistral 7B combines both GQA (G = 8) and sliding window attention (w = 4,096), stacking the two optimizations for a combined 8x reduction from GQA plus an additional 2x from the windowed cache, totaling roughly a 16x reduction in KV cache size compared to a full-MHA model at 8K context length.
GQA can be combined with KV cache quantization, which stores cached keys and values in lower numerical precision (e.g., FP8 or INT4 instead of FP16). These two techniques are orthogonal and multiplicative in their savings: GQA reduces the number of cached vectors, while quantization reduces the size of each vector.
PagedAttention, used in serving frameworks like vLLM, manages KV cache memory in non-contiguous pages (similar to virtual memory in operating systems) to reduce memory fragmentation. GQA and PagedAttention work together naturally: GQA reduces the total cache size, and PagedAttention ensures that the allocated memory is used efficiently.
Selecting the right value of G involves balancing several considerations:
In practice, G = 8 has become a de facto standard for models with 32 or more query heads. Some smaller models use G = 4, while larger models (like Llama 3.1 405B with 128 query heads) still use G = 8, resulting in 16 query heads per group.
GQA is not without drawbacks:
Imagine you are in a classroom and the teacher hands out answer sheets. In regular attention (MHA), every student gets their own personal copy of the answer sheet. That works well, but it uses a lot of paper. In multi-query attention (MQA), the teacher prints just one answer sheet and everyone has to share it. That saves a lot of paper, but some students might not find the answers they need because there is only one sheet.
Grouped-query attention is the middle ground. The teacher puts students into small groups of four or five. Each group gets its own answer sheet. This way, you still save a lot of paper (you need only a few sheets instead of one per student), but each group has an answer sheet that is more relevant to their questions. Everyone gets their work done almost as quickly as with one shared sheet, and the quality of answers is almost as good as when everyone had their own.