Flash-Decoding
Last reviewed
May 19, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 · 2,587 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 19, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 · 2,587 words
Add missing citations, update stale details, or suggest a clearer explanation.
Flash-Decoding is an inference-time variant of the FlashAttention algorithm that targets the decoding (autoregressive generation) phase of large language model inference, where each step processes only a single query token against a potentially very long key-value cache. It was introduced in October 2023 by Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov in a blog post co-published on the Stanford CRFM, Princeton NLP, PyTorch, and Together AI websites.[^1][^2][^3][^4] Flash-Decoding extends FlashAttention's I/O-aware tiling with an additional parallelization axis along the keys and values sequence length, enabling near-full utilization of modern GPUs even at batch size 1 with long contexts. The original announcement reports up to 8x faster end-to-end token generation and up to 50x faster attention kernels at very long sequence lengths compared with FlashAttention v2.[^1][^3]
Modern transformer-based language models can be deployed with context windows of tens or hundreds of thousands of tokens. During autoregressive generation, the model produces one new token at a time, and at each step the new query token must attend to the full history of cached keys and values, the KV cache. As context lengths have grown from roughly 2,000 tokens in 2022 to 32,000 to 100,000 or more by late 2023, the attention operation has come to dominate the latency budget of decoding for long-context workloads.[^1][^3]
FlashAttention v1 and v2 were designed primarily to accelerate training, where each call to the attention kernel processes a long sequence of query tokens in parallel; their parallelization scheme distributes work over the batch dimension, the head dimension, and the query length. During single-step decoding the query length is one, which collapses one of the three axes and leaves only batch and head to occupy the streaming multiprocessors of the GPU. With batch size 1, this leaves an NVIDIA A100 running FlashAttention v2 at less than 1% of its compute capacity.[^1][^3] Flash-Decoding closes this gap by parallelizing across the keys and values sequence length itself, splitting the long KV cache into chunks that are processed in parallel and then combined with a single reduction.
FlashAttention is an exact attention algorithm that fuses the matrix multiplications, scaling, masking, softmax, and dropout into a single I/O-aware kernel that streams blocks of queries, keys, and values between the GPU's high-bandwidth memory (HBM) and on-chip SRAM. By tiling along the query dimension and using an iterative online softmax, FlashAttention avoids materializing the full attention matrix in HBM. FlashAttention v2 refined the kernel by rebalancing work and parallelism across thread blocks and warps.[^1][^3]
In both versions the natural unit of parallelism is a tile of query tokens. A typical kernel launch distributes work over the cross product of (batch, head, query-tile) thread blocks. During training, where queries are long, this fully occupies the GPU. During inference, the picture changes in two ways. First, the prefill phase still has a long query (the prompt) and remains GPU-bound on FlashAttention. Second, the decoding phase processes only one new query token per step, so the number of query tiles drops to one per (batch, head) combination.[^1][^3]
A modern GPU contains many streaming multiprocessors (SMs); an NVIDIA A100, for example, has 108 SMs. To saturate the device the kernel needs enough independent thread blocks to occupy all SMs at once. During decoding with batch size 1 and a typical attention configuration of 16 query heads with grouped-query attention (for instance the configuration used in the original Flash-Decoding micro-benchmarks: 16 query heads at 128 head dimension with 2 key/value heads), the kernel only launches on the order of 16 thread blocks, well below what the device can run in parallel. Worse, decoding is dominated by reading the KV cache from HBM, which becomes the bottleneck, but unless work is parallelized across that long KV cache dimension, the memory bandwidth itself is also underused because only a handful of SMs are actively streaming.[^1][^3] The blog post quantifies this gap with the observation that FlashAttention v2 "uses less than 1% of the GPU" at batch size 1 in long-context decoding scenarios.[^1]
This bottleneck matters for practical deployments. Long-document summarization, large-codebase code completion, multi-turn chat with long history, and retrieval-augmented generation all stress the decoding phase with long KV caches. As the cache grows, attention latency at each generation step grows roughly linearly under naive implementations, becoming a significant fraction of end-to-end token latency.[^1][^3]
Flash-Decoding adds a new parallelization axis to FlashAttention: the keys and values sequence length. Conceptually, the long KV cache is sliced into several chunks (the blog calls these "splits"), each chunk is attended to in parallel by an independent thread block, and a small reduction step at the end combines the partial outputs into the exact attention result.[^1][^3]
The blog post describes the algorithm in three steps.[^1][^3]
The reduction exploits the standard property of online softmax (which FlashAttention already uses internally to combine partial sums along the keys axis within a single thread block). Flash-Decoding lifts that same trick one level outward: instead of one thread block streaming through all keys, multiple thread blocks each stream through a chunk and the cross-chunk combination is performed at the end.[^1][^3]
The attention output is softmax(Q K^T) V. If the keys (and matched values) are partitioned into disjoint chunks, then the softmax over the concatenated logits can be reconstructed exactly from the softmax over each chunk and the maximum and sum of exponentials within each chunk. This is the same numerically stable identity that the original online-softmax formulation relies on. Storing one extra scalar per row per chunk, the log-sum-exp, is sufficient to perform the cross-chunk rescaling without ever materializing the full attention matrix.[^1][^3]
The extra memory traffic introduced by Flash-Decoding is small: one float per (row, chunk) for the log-sum-exp, plus the partial attention outputs themselves, which are written to HBM and then read back during the reduction. For long contexts, this overhead is dwarfed by the reduction in idle SMs and by the much higher achieved memory bandwidth on the KV-cache read.[^1][^3]
The number of splits is chosen by a launch-time heuristic that balances two pressures. With too few splits the kernel still does not fill the GPU. With too many splits each partial attention computation becomes too small to amortize launch and overhead costs, and the reduction step itself becomes a meaningful fraction of work. The xFormers and FlashAttention implementations both expose this as a tunable parameter and select a default based on context length, head count, and device characteristics.[^1][^3][^5]
Because Flash-Decoding only changes how the keys/values dimension is sliced, it composes naturally with techniques that affect how the KV cache is laid out in memory, such as grouped-query attention and multi-query attention, where the number of KV heads is smaller than the number of query heads. The original release added Flash-Decoding alongside experimental support for multi-query and grouped-query attention in xFormers v0.0.22.[^5] The blog also notes that the same parallel structure applies whether the KV cache is stored contiguously or as a more complex paged structure.[^1][^3]
The blog post credits Erich Elsen, Ashish Vaswani, and Michaël Benesty for originally suggesting the idea of splitting the KV cache loading across thread blocks. Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov implemented and benchmarked the technique and published the algorithm and code.[^1][^3]
The original announcement presents both end-to-end and kernel-level benchmarks on an NVIDIA A100 in fp16.[^1][^3]
End-to-end measurements are reported for CodeLlama 34B at batch size 1 across context lengths from 512 to 64,000 tokens.[^1] At short contexts FlashAttention v2 and Flash-Decoding are roughly comparable, because at short contexts the attention kernel is not the bottleneck and parallelization across keys yields little benefit. As the context grows, FlashAttention v2's per-step attention time grows roughly linearly while Flash-Decoding's per-step attention time stays nearly flat. At 64,000 tokens the blog reports up to 8x faster end-to-end token generation versus the FlashAttention v2 baseline.[^1][^3]
For the attention kernel in isolation the blog presents micro-benchmarks with the configuration of 16 query heads at 128 head dimension with 2 key/value heads at batch size 1, sweeping sequence length from 256 to 131,072.[^1][^3] At 256 tokens, FlashAttention v2 takes 390.5 microseconds and Flash-Decoding takes 63.4 microseconds. At 4,096 tokens, FlashAttention v2 takes 401.7 microseconds versus 57 microseconds for Flash-Decoding. The gap widens dramatically at extreme lengths: at 65,536 tokens FlashAttention v2 takes 2,300.6 microseconds against 64.4 microseconds for Flash-Decoding, and at 131,072 tokens FlashAttention v2 takes 4,592.2 microseconds versus 106.6 microseconds for Flash-Decoding. The blog summarizes that the attention kernel itself can be up to 50x faster than FlashAttention v2 at the longest contexts tested.[^1][^3]
A useful way to read the table is that Flash-Decoding's runtime is nearly independent of context length up to roughly 32,000 tokens, only beginning to grow noticeably beyond that, because once the KV cache is large enough and the splits are numerous enough the GPU is saturated and additional work translates into linear scaling.[^1][^3]
Flash-Decoding shipped in the FlashAttention library starting at version 2.2, exposed through the inference example interface; the same repository continues to evolve the technique alongside the rest of the FlashAttention family.[^1][^3]
The Meta-maintained xFormers library released Flash-Decoding in version 0.0.22 (tag published on September 27, 2023, with the public Flash-Decoding announcement following on October 12 to 13, 2023).[^5] The implementation is exposed through xformers.ops.memory_efficient_attention, with a dispatcher that automatically selects either the Flash-Decoding kernel or the standard FlashAttention kernel based on problem size, falling back to an efficient Triton kernel that implements the Flash-Decoding algorithm when the CUDA path is unsupported.[^1][^3][^5] The same v0.0.22 release added experimental support for multi-head self-attention variants such as multi-query and grouped-query attention, a local attention bias, an efficient Triton implementation of rotary position embeddings, and a minimal LLaMa inference example demonstrating the decoding path end-to-end.[^5]
The vLLM inference engine, which is built around PagedAttention for managing KV cache memory at block granularity, did not initially use the Flash-Decoding kernel; its decoding path used a custom kernel in csrc/attention/attention_kernels.cu compatible with the paged KV layout. vLLM later began integrating FlashAttention as an alternative decoding backend; pull request #3648 ("[Kernel] Use flash-attn for decoding") was opened by skrider in 2024 and ultimately merged in May 2024 before being reverted shortly afterward due to a memory-access issue in the LoRA test suite.[^6] Subsequent vLLM releases added stable FlashAttention-family backends alongside the original PagedAttention kernel.
In November 2023, a separate group published FlashDecoding++ ("FlashDecoding++: Faster Large Language Model Inference on GPUs", arXiv:2311.01282) by Ke Hong, Guohao Dai, Jiaming Xu, Qiuli Mao, Xiuhong Li, Jun Liu, Kangdi Chen, Yuhan Dong, and Yu Wang, with author affiliations to Tsinghua University, Shanghai Jiao Tong University, Peking University, and Infinigence-AI.[^7][^8] Despite the suggestive name, FlashDecoding++ is best understood as an independent line of work rather than a direct successor; it targets several inefficiencies in LLM inference on both NVIDIA and AMD GPUs and is sometimes compared with the original Flash-Decoding as a baseline.
The paper identifies three problems and proposes one fix for each.[^7][^8]
The paper reports up to 4.86x speedup on NVIDIA GPUs and up to 3.93x speedup on AMD GPUs versus the Hugging Face baseline, and an average 1.37x speedup versus state-of-the-art LLM inference engines including the original Flash-Decoding.[^7][^8]
Other related work in the long-context decoding space includes tree-attention variants (such as DeFT, which adapts Flash-Decoding's split structure to tree-structured speculative decoding), continued kernel work in the FlashAttention family including FlashAttention 3 which targets the Hopper architecture and adds asynchrony, and integrations of the split-KV idea into the FP8 KV-cache work in vLLM.
Flash-Decoding addressed what was, by late 2023, the principal kernel-level bottleneck for long-context LLM serving at small batch sizes: a single-query attention against a many-thousand-token KV cache leaving most of the GPU idle. By adding one parallelization axis and a tiny reduction step, the algorithm preserves the I/O-awareness and numerical exactness of FlashAttention while making the kernel scale gracefully into the regime of contexts of 100,000 tokens and beyond.[^1][^3] Together with the broader move toward grouped-query and multi-query attention (which reduce KV cache bandwidth per token) and toward block-managed KV caches such as PagedAttention (which reduce memory waste), Flash-Decoding became one of the standard building blocks of long-context inference stacks during 2023 and 2024.[^1][^3][^5]