# Flash-Decoding

> Source: https://aiwiki.ai/wiki/flash_decoding
> Updated: 2026-05-19
> Categories: AI Inference, Algorithms
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

# Flash-Decoding

**Flash-Decoding** is an inference-time variant of the [FlashAttention](/wiki/flash_attention) algorithm that targets the decoding (autoregressive generation) phase of [large language model](/wiki/llm) inference, where each step processes only a single query token against a potentially very long key-value cache. It was introduced in October 2023 by [Tri Dao](/wiki/tri_dao), Daniel Haziza, Francisco Massa, and Grigory Sizov in a blog post co-published on the Stanford CRFM, Princeton NLP, PyTorch, and Together AI websites.[^1][^2][^3][^4] Flash-Decoding extends FlashAttention's I/O-aware tiling with an additional parallelization axis along the keys and values sequence length, enabling near-full utilization of modern [GPUs](/wiki/gpu) even at batch size 1 with long contexts. The original announcement reports up to 8x faster end-to-end token generation and up to 50x faster attention kernels at very long sequence lengths compared with FlashAttention v2.[^1][^3]

## Introduction

Modern transformer-based language models can be deployed with context windows of tens or hundreds of thousands of tokens. During autoregressive generation, the model produces one new token at a time, and at each step the new query token must attend to the full history of cached keys and values, the [KV cache](/wiki/kv_cache). As context lengths have grown from roughly 2,000 tokens in 2022 to 32,000 to 100,000 or more by late 2023, the [attention](/wiki/attention) operation has come to dominate the latency budget of decoding for long-context workloads.[^1][^3]

FlashAttention v1 and v2 were designed primarily to accelerate training, where each call to the attention kernel processes a long sequence of query tokens in parallel; their parallelization scheme distributes work over the batch dimension, the head dimension, and the query length. During single-step decoding the query length is one, which collapses one of the three axes and leaves only batch and head to occupy the streaming multiprocessors of the GPU. With batch size 1, this leaves an [NVIDIA A100](/wiki/nvidia_a100) running FlashAttention v2 at less than 1% of its compute capacity.[^1][^3] Flash-Decoding closes this gap by parallelizing across the keys and values sequence length itself, splitting the long KV cache into chunks that are processed in parallel and then combined with a single reduction.

## Background: FlashAttention and the decoding bottleneck

### FlashAttention v1 and v2

FlashAttention is an exact attention algorithm that fuses the matrix multiplications, scaling, masking, [softmax](/wiki/softmax), and dropout into a single I/O-aware kernel that streams blocks of queries, keys, and values between the GPU's [high-bandwidth memory](/wiki/hbm) (HBM) and on-chip SRAM. By tiling along the query dimension and using an iterative online softmax, FlashAttention avoids materializing the full attention matrix in HBM. FlashAttention v2 refined the kernel by rebalancing work and parallelism across thread blocks and warps.[^1][^3]

In both versions the natural unit of parallelism is a tile of query tokens. A typical kernel launch distributes work over the cross product of (batch, head, query-tile) thread blocks. During training, where queries are long, this fully occupies the GPU. During inference, the picture changes in two ways. First, the prefill phase still has a long query (the prompt) and remains GPU-bound on FlashAttention. Second, the decoding phase processes only one new query token per step, so the number of query tiles drops to one per (batch, head) combination.[^1][^3]

### Why decoding underutilizes the GPU

A modern GPU contains many streaming multiprocessors (SMs); an NVIDIA A100, for example, has 108 SMs. To saturate the device the kernel needs enough independent thread blocks to occupy all SMs at once. During decoding with batch size 1 and a typical attention configuration of 16 query heads with [grouped-query attention](/wiki/grouped_query_attention) (for instance the configuration used in the original Flash-Decoding micro-benchmarks: 16 query heads at 128 head dimension with 2 key/value heads), the kernel only launches on the order of 16 thread blocks, well below what the device can run in parallel. Worse, decoding is dominated by reading the KV cache from HBM, which becomes the bottleneck, but unless work is parallelized across that long KV cache dimension, the memory bandwidth itself is also underused because only a handful of SMs are actively streaming.[^1][^3] The blog post quantifies this gap with the observation that FlashAttention v2 "uses less than 1% of the GPU" at batch size 1 in long-context decoding scenarios.[^1]

This bottleneck matters for practical deployments. Long-document summarization, large-codebase code completion, multi-turn chat with long history, and retrieval-augmented generation all stress the decoding phase with long KV caches. As the cache grows, attention latency at each generation step grows roughly linearly under naive implementations, becoming a significant fraction of end-to-end token latency.[^1][^3]

## The Flash-Decoding algorithm

Flash-Decoding adds a new parallelization axis to FlashAttention: the keys and values sequence length. Conceptually, the long KV cache is sliced into several chunks (the blog calls these "splits"), each chunk is attended to in parallel by an independent thread block, and a small reduction step at the end combines the partial outputs into the exact attention result.[^1][^3]

### Three-step procedure

The blog post describes the algorithm in three steps.[^1][^3]

1. **Split.** The keys and values are conceptually divided into smaller chunks along the sequence-length dimension. This is performed as a tensor view and incurs no GPU work; only the indexing changes.
2. **Parallel partial attention.** The single query is paired with each chunk, and FlashAttention is run independently on each (query, K-chunk, V-chunk) triplet. Each parallel kernel computes a partial attention output for its chunk and additionally writes one log-sum-exp scalar per row per chunk to HBM. The log-sum-exp value records the normalization factor of the partial softmax over that chunk.
3. **Reduction.** A small follow-up kernel reads the per-chunk partial outputs and per-chunk log-sum-exp values, and combines them into the final exact attention output by rescaling and summing. Because the log-sum-exp values capture the missing normalization, the result is numerically identical to a full softmax over the entire KV cache.

The reduction exploits the standard property of online softmax (which FlashAttention already uses internally to combine partial sums along the keys axis within a single thread block). Flash-Decoding lifts that same trick one level outward: instead of one thread block streaming through all keys, multiple thread blocks each stream through a chunk and the cross-chunk combination is performed at the end.[^1][^3]

### Why this is exact

The attention output is `softmax(Q K^T) V`. If the keys (and matched values) are partitioned into disjoint chunks, then the softmax over the concatenated logits can be reconstructed exactly from the softmax over each chunk and the maximum and sum of exponentials within each chunk. This is the same numerically stable identity that the original online-softmax formulation relies on. Storing one extra scalar per row per chunk, the log-sum-exp, is sufficient to perform the cross-chunk rescaling without ever materializing the full attention matrix.[^1][^3]

### Memory and compute footprint

The extra memory traffic introduced by Flash-Decoding is small: one float per (row, chunk) for the log-sum-exp, plus the partial attention outputs themselves, which are written to HBM and then read back during the reduction. For long contexts, this overhead is dwarfed by the reduction in idle SMs and by the much higher achieved memory bandwidth on the KV-cache read.[^1][^3]

## Implementation: parallelizing across KV cache chunks

The number of splits is chosen by a launch-time heuristic that balances two pressures. With too few splits the kernel still does not fill the GPU. With too many splits each partial attention computation becomes too small to amortize launch and overhead costs, and the reduction step itself becomes a meaningful fraction of work. The xFormers and FlashAttention implementations both expose this as a tunable parameter and select a default based on context length, head count, and device characteristics.[^1][^3][^5]

### Compatibility with the KV cache layout

Because Flash-Decoding only changes how the keys/values dimension is sliced, it composes naturally with techniques that affect how the KV cache is laid out in memory, such as grouped-query attention and multi-query attention, where the number of KV heads is smaller than the number of query heads. The original release added Flash-Decoding alongside experimental support for multi-query and grouped-query attention in xFormers v0.0.22.[^5] The blog also notes that the same parallel structure applies whether the KV cache is stored contiguously or as a more complex paged structure.[^1][^3]

### Attribution of the idea

The blog post credits Erich Elsen, Ashish Vaswani, and Michaël Benesty for originally suggesting the idea of splitting the KV cache loading across thread blocks. Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov implemented and benchmarked the technique and published the algorithm and code.[^1][^3]

## Speedups at long contexts

The original announcement presents both end-to-end and kernel-level benchmarks on an NVIDIA A100 in fp16.[^1][^3]

### End-to-end generation

End-to-end measurements are reported for [CodeLlama](/wiki/code_llama) 34B at batch size 1 across context lengths from 512 to 64,000 tokens.[^1] At short contexts FlashAttention v2 and Flash-Decoding are roughly comparable, because at short contexts the attention kernel is not the bottleneck and parallelization across keys yields little benefit. As the context grows, FlashAttention v2's per-step attention time grows roughly linearly while Flash-Decoding's per-step attention time stays nearly flat. At 64,000 tokens the blog reports up to 8x faster end-to-end token generation versus the FlashAttention v2 baseline.[^1][^3]

### Kernel micro-benchmarks

For the attention kernel in isolation the blog presents micro-benchmarks with the configuration of 16 query heads at 128 head dimension with 2 key/value heads at batch size 1, sweeping sequence length from 256 to 131,072.[^1][^3] At 256 tokens, FlashAttention v2 takes 390.5 microseconds and Flash-Decoding takes 63.4 microseconds. At 4,096 tokens, FlashAttention v2 takes 401.7 microseconds versus 57 microseconds for Flash-Decoding. The gap widens dramatically at extreme lengths: at 65,536 tokens FlashAttention v2 takes 2,300.6 microseconds against 64.4 microseconds for Flash-Decoding, and at 131,072 tokens FlashAttention v2 takes 4,592.2 microseconds versus 106.6 microseconds for Flash-Decoding. The blog summarizes that the attention kernel itself can be up to 50x faster than FlashAttention v2 at the longest contexts tested.[^1][^3]

A useful way to read the table is that Flash-Decoding's runtime is nearly independent of context length up to roughly 32,000 tokens, only beginning to grow noticeably beyond that, because once the KV cache is large enough and the splits are numerous enough the GPU is saturated and additional work translates into linear scaling.[^1][^3]

## Adoption

### FlashAttention library

Flash-Decoding shipped in the FlashAttention library starting at version 2.2, exposed through the inference example interface; the same repository continues to evolve the technique alongside the rest of the FlashAttention family.[^1][^3]

### xFormers

The Meta-maintained xFormers library released Flash-Decoding in version 0.0.22 (tag published on September 27, 2023, with the public Flash-Decoding announcement following on October 12 to 13, 2023).[^5] The implementation is exposed through `xformers.ops.memory_efficient_attention`, with a dispatcher that automatically selects either the Flash-Decoding kernel or the standard FlashAttention kernel based on problem size, falling back to an efficient Triton kernel that implements the Flash-Decoding algorithm when the CUDA path is unsupported.[^1][^3][^5] The same v0.0.22 release added experimental support for [multi-head self-attention](/wiki/multi-head_self-attention) variants such as multi-query and grouped-query attention, a local attention bias, an efficient [Triton](/wiki/triton) implementation of [rotary position embeddings](/wiki/rope), and a minimal LLaMa inference example demonstrating the decoding path end-to-end.[^5]

### vLLM and PagedAttention

The [vLLM](/wiki/vllm) inference engine, which is built around [PagedAttention](/wiki/paged_attention) for managing KV cache memory at block granularity, did not initially use the Flash-Decoding kernel; its decoding path used a custom kernel in `csrc/attention/attention_kernels.cu` compatible with the paged KV layout. vLLM later began integrating FlashAttention as an alternative decoding backend; pull request #3648 ("[Kernel] Use flash-attn for decoding") was opened by skrider in 2024 and ultimately merged in May 2024 before being reverted shortly afterward due to a memory-access issue in the LoRA test suite.[^6] Subsequent vLLM releases added stable FlashAttention-family backends alongside the original PagedAttention kernel.

## Flash-Decoding++ and related work

In November 2023, a separate group published **FlashDecoding++** ("FlashDecoding++: Faster Large Language Model Inference on GPUs", arXiv:2311.01282) by Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Yuhan Dong, and Yu Wang, with author affiliations to Tsinghua University, Shanghai Jiao Tong University, Peking University, and Infinigence-AI.[^7][^8] Despite the suggestive name, FlashDecoding++ is best understood as an independent line of work rather than a direct successor; it targets several inefficiencies in LLM inference on both NVIDIA and AMD GPUs and is sometimes compared with the original Flash-Decoding as a baseline.

The paper identifies three problems and proposes one fix for each.[^7][^8]

1. **Asynchronized softmax with a unified maximum value.** Standard partial-softmax schemes require synchronization between partial computations to share the running maximum; the paper proposes a unified maximum across partial blocks plus a recomputation fallback, removing roughly 20% overhead attributable to synchronization.
2. **Flat GEMM optimization with double buffering.** Decoding-phase matrix multiplications are extremely flat (small in the M dimension), which leaves tensor cores idle when the M dimension is padded up to 64. The paper reduces padding to 8 and uses double buffering to hide memory latency, addressing more than 50% performance loss from zero-padding in the flat case.
3. **Heuristic dataflow with hardware resource adaptation.** Because the optimal kernel for a given matrix shape and hardware mix varies, the system selects between implementations (for example FastGEMV on CUDA cores versus tensor-core kernels) using offline profiling, avoiding the up to 50% degradation observed from a single static dataflow.

The paper reports up to 4.86x speedup on NVIDIA GPUs and up to 3.93x speedup on AMD GPUs versus the Hugging Face baseline, and an average 1.37x speedup versus state-of-the-art LLM inference engines including the original Flash-Decoding.[^7][^8]

Other related work in the long-context decoding space includes tree-attention variants (such as DeFT, which adapts Flash-Decoding's split structure to tree-structured speculative decoding), continued kernel work in the FlashAttention family including [FlashAttention 3](/wiki/flash_attention_3) which targets the Hopper architecture and adds asynchrony, and integrations of the split-KV idea into the FP8 KV-cache work in vLLM.

## Significance

Flash-Decoding addressed what was, by late 2023, the principal kernel-level bottleneck for long-context LLM serving at small batch sizes: a single-query attention against a many-thousand-token KV cache leaving most of the GPU idle. By adding one parallelization axis and a tiny reduction step, the algorithm preserves the I/O-awareness and numerical exactness of FlashAttention while making the kernel scale gracefully into the regime of contexts of 100,000 tokens and beyond.[^1][^3] Together with the broader move toward grouped-query and multi-query attention (which reduce KV cache bandwidth per token) and toward block-managed KV caches such as PagedAttention (which reduce memory waste), Flash-Decoding became one of the standard building blocks of long-context inference stacks during 2023 and 2024.[^1][^3][^5]

## References

[^1]: Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. "Flash-Decoding for long-context inference." Stanford CRFM blog, October 12, 2023. https://crfm.stanford.edu/2023/10/12/flashdecoding.html. Accessed 2026-05-19.

[^2]: Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. "Flash-Decoding for Long-Context Inference." Princeton NLP Group blog, October 2023. https://princeton-nlp.github.io/flash-decoding/. Accessed 2026-05-19.

[^3]: Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. "Flash-Decoding for long-context inference." PyTorch blog, October 13, 2023 (updated November 16, 2024). https://pytorch.org/blog/flash-decoding/. Accessed 2026-05-19.

[^4]: Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov. "Flash-Decoding for long-context inference." Together AI blog. https://www.together.ai/blog/flash-decoding-for-long-context-inference. Accessed 2026-05-19.

[^5]: Facebook Research. "Faster LLM inference with Flash-Decoding, Local attention," xFormers v0.0.22 release notes, September 27, 2023. https://github.com/facebookresearch/xformers/releases/tag/v0.0.22. Accessed 2026-05-19.

[^6]: vLLM Project. "[Kernel] Use flash-attn for decoding" (Pull Request #3648), merged and subsequently reverted in May 2024. https://github.com/vllm-project/vllm/pull/3648. Accessed 2026-05-19.

[^7]: Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Yuhan Dong, Yu Wang. "FlashDecoding++: Faster Large Language Model Inference on GPUs." arXiv:2311.01282, November 2, 2023 (last revised January 5, 2024). https://arxiv.org/abs/2311.01282. Accessed 2026-05-19.

[^8]: Ke Hong et al. "FlashDecoding++: Faster Large Language Model Inference on GPUs." HTML version on arXiv. https://arxiv.org/html/2311.01282. Accessed 2026-05-19.

