Flash Attention
Last reviewed
May 9, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v8 · 6,149 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 9, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v8 · 6,149 words
Add missing citations, update stale details, or suggest a clearer explanation.
See also: attention, transformer, self-attention, GPU computing, softmax, Tri Dao, Christopher Ré, H100, Hopper architecture, PyTorch, Triton, vLLM, Together AI
Flash Attention is a family of IO-aware, exact attention algorithms designed to accelerate and reduce the memory footprint of the attention mechanism in transformer models. The original Flash Attention algorithm was introduced by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré at Stanford University in May 2022. Rather than approximating attention (as many prior methods attempted), Flash Attention computes mathematically exact attention while reordering the computation to minimize memory reads and writes between different levels of GPU memory.
The core insight behind Flash Attention is that standard attention implementations are bottlenecked not by arithmetic computation but by memory access. Modern GPUs can perform far more floating-point operations per second than they can read and write data from their main memory. By restructuring the attention computation to account for the GPU memory hierarchy, Flash Attention achieves wall-clock speedups of 2 to 7.6 times over standard attention while reducing memory usage from quadratic to linear in the sequence length.
Flash Attention has been widely adopted across the deep learning ecosystem. It is integrated into PyTorch as part of the scaled_dot_product_attention function, supported natively in Hugging Face Transformers, and used in the training and inference of models such as LLaMA, LLaMA 3, Mistral, Mixtral, Gemma, Phi, Qwen, DBRX, StarCoder, and many others. The algorithm has gone through four major versions: FlashAttention (2022), FlashAttention-2 (2023), FlashAttention-3 (2024), and FlashAttention-4 (2025), each targeting newer GPU hardware and achieving higher utilization. The official open-source implementation is maintained at the Dao-AILab/flash-attention repository on GitHub and was awarded the Inaugural Stanford Open Source Software Prize in 2024.
Imagine you have a huge jigsaw puzzle spread across a large table in the garage (that is far away), and a small desk in front of you. The standard way to solve the puzzle is to keep running back and forth to the garage table to grab pieces, look at them, and run back again. You do this thousands of times because you can only carry a few pieces at once.
Flash Attention is like being smarter about it. Instead of running back and forth for every single piece, you grab a small group of pieces, bring them to your desk, figure out how they fit together, write down your progress, and then go get the next group. You still solve the exact same puzzle with the exact same answer, but you make far fewer trips to the garage. The small desk is the fast memory on a GPU (called SRAM), and the garage table is the large but slow main memory (called HBM). Flash Attention makes fewer trips between them, so everything goes faster.
The self-attention mechanism in transformers computes pairwise interactions between all positions in a sequence. Given a sequence of length N, the standard attention algorithm materializes an N x N attention score matrix, applies softmax normalization, and then uses the result to compute a weighted sum of value vectors. This leads to O(N^2) time and memory complexity, making it the primary bottleneck when scaling transformers to long sequences.
The quadratic memory cost is particularly punishing. For a sequence of 32,768 tokens with 32 attention heads at 16-bit precision, the per-layer attention matrices alone consume tens of gigabytes, often more than the available memory on a single GPU. Prior to Flash Attention, this bottleneck forced researchers to either truncate context length, reduce batch size, or rely on approximate attention methods that traded model quality for memory savings.
The standard multi-head attention computation proceeds as follows:
The intermediate matrices S and P are each of size N x N. For a sequence length of 4,096 with 16-bit floating-point precision, a single attention head's S matrix alone consumes 32 MB. With multiple heads and batch elements, this adds up quickly, often exceeding available GPU memory for longer sequences.
In pseudo-code:
# Standard attention (PyTorch-like)
S = Q @ K.transpose(-2, -1) / math.sqrt(d) # write S to HBM, O(N^2)
P = softmax(S, dim=-1) # read S, write P, O(N^2)
O = P @ V # read P, write O
Each line above involves a separate GPU kernel launch and a full pass over an N x N matrix in HBM. The intermediate matrix never needs to be inspected by the user; it exists only to feed the next operation.
Modern GPUs have a hierarchical memory system with dramatically different capacities and bandwidths at each level:
| Memory level | Capacity (A100) | Bandwidth (A100) | Relative speed |
|---|---|---|---|
| HBM (high bandwidth memory) | 40 to 80 GB | 1.5 to 2.0 TB/s | 1x (baseline) |
| SRAM (on-chip, per SM) | 192 KB per SM (about 20 MB total across 108 SMs) | about 19 TB/s | about 10x faster |
| Registers | about 256 KB per SM | even higher | fastest |
Standard attention writes the full N x N matrices S and P to HBM, then reads them back, resulting in a large number of slow memory accesses. Flash Attention's key contribution is recognizing that these intermediate matrices do not need to be materialized in HBM at all. Instead, they can be computed incrementally inside SRAM and discarded after use, saving both memory capacity and bandwidth.
The H100 introduced larger SRAM (228 KB per SM) and the Tensor Memory Accelerator (TMA), and the Blackwell B200 added a separate Tensor Memory (TMEM) of 256 KB per SM. Each generation increases the relative gap between on-chip SRAM bandwidth and off-chip HBM bandwidth, making IO-aware algorithms like Flash Attention more important over time, not less.
The standard attention algorithm performs O(N^2 d) floating-point operations. Its memory access (IO) complexity is O(N d + N^2), where the N^2 term comes from reading and writing the attention matrix S and P.
Flash Attention performs the same O(N^2 d) floating-point operations (and sometimes slightly more due to recomputation), but its IO complexity is O(N^2 d^2 / M), where M is the SRAM size. For typical values where d is between 64 and 128 and M is around 100 KB, this works out to significantly fewer HBM accesses. Dao et al. proved that this IO complexity is optimal: no algorithm computing exact attention can asymptotically reduce the number of HBM accesses below this bound for all values of M.
Since modern GPUs are often memory-bandwidth-bound rather than compute-bound during attention (especially at moderate sequence lengths), reducing memory accesses yields direct wall-clock speedups even though the total FLOP count may be slightly higher.
Flash Attention was not the first effort to reduce attention's memory footprint. Approximate methods such as Reformer (Kitaev et al., 2020), Performer (Choromanski et al., 2021), Linformer (Wang et al., 2020), and Longformer (Beltagy et al., 2020) reduced the asymptotic complexity but at a cost in model quality and often without delivering wall-clock speedups on GPUs. A separate line of work by Markus Rabe and Charles Staats ("Self-attention Does Not Need O(n^2) Memory", 2021) showed that exact attention could be computed with O(log N) auxiliary memory, but the proposed implementation did not yield substantial speedups in practice. Flash Attention combined the recomputation idea from this line with the online softmax of Milakov and Gimelshein (2018), and added the crucial step of fusing all operations into a single GPU kernel that is fully aware of the SRAM/HBM split.
Flash Attention avoids materializing the full N x N attention matrix by splitting Q, K, and V into blocks that fit into SRAM and computing attention one block at a time. The algorithm fuses all of the attention operations (matrix multiply, softmax, another matrix multiply) into a single GPU kernel, eliminating intermediate reads and writes to HBM.
The tiling proceeds as follows:
Block sizes B_r and B_c are chosen to ensure that the Q block, K block, V block, and the intermediate score block all fit within the available SRAM. Typical choices on A100 are B_r = B_c = 128 for head dimension 64.
A major challenge in tiling attention is the softmax normalization, which in its standard form requires knowing the maximum value and the sum of exponentials across the entire row of the score matrix. Since Flash Attention processes the score matrix in blocks, it cannot see the full row at once.
Flash Attention solves this using an online softmax algorithm based on the work of Maxim Milakov and Natalia Gimelshein (NVIDIA, 2018). The technique maintains two running statistics for each row:
When a new block of scores arrives, the algorithm:
Mathematically, the rescaling exploits the identity:
exp(x_i - m_new) / Sum_j exp(x_j - m_new) = exp(x_i - m_old) exp(m_old - m_new) / Sum_j exp(x_j - m_old) exp(m_old - m_new)
The correction factors cancel from numerator and denominator, ensuring numerical stability and producing results identical to computing softmax over the entire row at once. The telescoping property of the exponential function guarantees that the correction factors compose correctly regardless of processing order.
A cleaner derivation walks Zihao Ye's note "From Online Softmax to FlashAttention" (University of Washington, 2023), which is a standard reference among practitioners.
In the forward pass, Flash Attention does not store the N x N attention matrix P. During backpropagation, the gradients with respect to Q, K, and V require access to P. Rather than storing P (which would negate the memory savings), Flash Attention recomputes it from the saved Q, K, V blocks and the softmax normalization statistics (m and l) during the backward pass.
This recomputation trades additional FLOPs for reduced memory. The key observation is that even with the extra arithmetic, the backward pass is faster in wall-clock time because it avoids the costly HBM reads of the large attention matrix. The only values saved from the forward pass are the final output O (size N x d) and a single combined logsumexp statistic L per row (size N), giving a total memory footprint of O(N) instead of O(N^2).
A critical implementation detail is that the entire attention computation runs in a single CUDA kernel. In standard PyTorch, each of Q @ K^T, softmax, and ... @ V launches a separate kernel and writes its result to HBM. Flash Attention fuses these into one kernel that holds intermediate results in registers and SRAM. Kernel fusion eliminates the HBM round-trips for S and P and avoids three separate kernel launch overheads.
| Property | Standard attention | Flash Attention |
|---|---|---|
| Computation | Exact | Exact (identical output up to FP rounding) |
| Time complexity (FLOPs) | O(N^2 d) | O(N^2 d) (same, slightly more due to recomputation) |
| Memory complexity | O(N^2) | O(N) |
| HBM IO complexity | O(N d + N^2) | O(N^2 d^2 / M) |
| Intermediate matrices in HBM | S and P (N x N each) | None |
| Memory savings (seq len 2K) | Baseline | about 10x reduction |
| Memory savings (seq len 4K) | Baseline | about 20x reduction |
| Wall-clock speedup (attention only) | Baseline | 2 to 7.6x faster |
| Backward pass storage | O(N^2) for P | O(N) for (O, L) |
| Number of kernel launches | 3 or more | 1 |
The original 2022 paper reported a 15% end-to-end wall-clock speedup on BERT-large (sequence length 512) over the MLPerf 1.1 record, a 3x speedup on GPT-2 at sequence length 1K, and a 2.4x speedup on the Long-Range Arena benchmark at sequence lengths 1K to 4K. Quality also improved: a 0.7-point perplexity reduction on GPT-2 modeling, 6.4-point lift on long-document classification, and the first model to achieve better-than-chance accuracy on the Path-X (16K) and Path-256 (64K) tasks.
FlashAttention-2, published by Tri Dao in July 2023, improved on the original algorithm with three optimizations aimed at closing the gap between Flash Attention's throughput and the theoretical peak of the GPU hardware. The paper reported approximately 2x speedup over the original FlashAttention.
The original Flash Attention spent a meaningful fraction of its time on non-matrix-multiplication operations such as softmax rescaling, bound checking, and causal masking. These operations run on general-purpose CUDA cores rather than specialized Tensor Cores, which are 4 to 16 times faster for matrix operations on the A100 (and a much larger ratio on newer hardware). FlashAttention-2 restructured the algorithm to defer the final softmax rescaling to the end of the loop, reducing the number of intermediate rescaling steps and allowing the inner loop to be dominated by Tensor Core matmuls.
The original Flash Attention parallelized work across the batch dimension and the number of attention heads. When the batch size or head count was small (common with long sequences), many GPU streaming multiprocessors (SMs) sat idle. FlashAttention-2 added parallelism over the sequence length dimension, distributing different blocks of Q to different thread blocks. This improved GPU occupancy, especially for long-context workloads with small batch sizes.
Within a thread block, the original Flash Attention used a "split-K" scheme: it distributed K and V across warps while keeping Q shared. This required all warps to write their partial results to shared memory, synchronize, and sum, adding overhead. FlashAttention-2 reversed this, splitting Q across warps instead. Since each warp now computes a separate portion of the output, no cross-warp synchronization or reduction is needed during the inner loop.
| Metric | FlashAttention (v1) | FlashAttention-2 |
|---|---|---|
| GPU utilization (A100) | 25 to 40% of theoretical max | 50 to 73% of theoretical max |
| Attention TFLOPs/s (A100, FP16) | about 120 TFLOPs/s | about 230 TFLOPs/s |
| End-to-end training (A100) | about 125 TFLOPs/s | up to 225 TFLOPs/s (72% MFU) |
| Speed vs. standard PyTorch | 3x faster | up to 9x faster |
| Max head dimension supported | 128 | 256 |
FlashAttention-2 also introduced support for multi-query attention (MQA) and grouped-query attention (GQA), which are used in many modern large language models for more efficient inference, and added optional ALiBi positional biases.
FlashAttention-3, published in July 2024 by Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao, targeted NVIDIA's Hopper architecture (H100 GPUs). FlashAttention-2 achieved only about 35% utilization on the H100 because it did not exploit several new hardware features available on Hopper, including the Tensor Memory Accelerator (TMA), Warpgroup MMA (WGMMA) instructions, and FP8 Tensor Cores.
Hopper GPUs introduced the Tensor Memory Accelerator (TMA), a dedicated hardware unit that can load data from HBM to shared memory independently of the Tensor Cores. FlashAttention-3 exploits this by splitting warps into producer warps (which issue TMA loads) and consumer warps (which run matrix multiplications on the Tensor Cores). This overlaps data movement with computation, hiding memory latency. The technique is sometimes called persistent kernels with warp specialization.
FlashAttention-3 uses two warp groups that alternate ("ping-pong") between computing matrix multiplications: while warp group 1 performs its GEMMs on the current block, warp group 2 performs its softmax and data loading for the next block, and the roles then swap. This interleaving keeps both the Tensor Cores and the special function units (used for the exponential in softmax) busy simultaneously. With ping-pong scheduling alone, FA-3 achieves roughly 620 TFLOPs/s in FP16 forward at head dimension 128.
A further optimization runs softmax operations concurrently with GEMM computations within a single warp group, raising FP16 forward throughput to 640 to 660 TFLOPs/s at the cost of higher register pressure.
Hopper's Tensor Cores support FP8 arithmetic at double the throughput of FP16. However, naive FP8 quantization introduces significant numerical error, especially when the input contains outlier values. FlashAttention-3 addresses this with a technique called "incoherent processing": before quantization, the query and key matrices are multiplied by a random orthogonal matrix (implemented efficiently via a Hadamard transform with random signs). This spreads outlier values across dimensions, reducing quantization error. The operation has O(d log d) complexity per vector and is negligible compared to the O(N^2 d) attention cost.
FP8 FlashAttention-3 achieves 2.6 times lower numerical error than baseline FP8 attention using per-tensor quantization.
| Metric | FlashAttention-2 (H100) | FlashAttention-3 FP16 (H100) | FlashAttention-3 FP8 (H100) |
|---|---|---|---|
| GPU utilization | about 35% | about 75% | higher (FP8 doubles peak) |
| TFLOPs/s | about 350 | about 740 | close to 1,200 (1.2 PFLOPs/s) |
| Speedup vs. FA-2 | Baseline | 1.5 to 2.0x | higher with FP8 |
| Numerical error vs. baseline FP8 | not applicable | not applicable | 2.6x lower |
FlashAttention-3 was presented at NeurIPS 2024.
FlashAttention-4 was presented by Tri Dao at Hot Chips in August 2025 and targets NVIDIA's Blackwell architecture (B200 and B300 GPUs). It is written in CuTe-DSL (CUTLASS' Python kernel domain-specific language) rather than raw CUDA C++, which reduces compilation time by roughly 20 to 30 times compared to traditional template metaprogramming.
The motivating problem for FlashAttention-4 is what Dao calls "asymmetric hardware scaling": from Hopper to Blackwell, Tensor Core throughput nearly tripled while shared memory bandwidth and special function unit (SFU) throughput stayed roughly flat. As a result, the softmax exponential and the rescaling steps became proportionally more expensive in each new generation. The kernel must aggressively overlap these non-Tensor-Core operations with matmuls or the Tensor Cores will sit idle.
FlashAttention-4 introduces a new online softmax variant that skips approximately 90% of output rescaling operations. In standard online softmax, every time a new block produces a larger row maximum, all previously accumulated outputs must be rescaled by exp(m_old - m_new). FlashAttention-4 only performs this rescaling when the new maximum is sufficiently larger than the current one to affect numerical precision in BF16, and otherwise carries the previous statistics forward. The result is bit-equivalent or near-equivalent to ordinary online softmax in BF16 but with about 10x fewer correction operations.
To avoid bottlenecking on the SFU, FlashAttention-4 splits the exponential between the hardware MUFU.EX2 instruction and a software polynomial approximation evaluated on the FMA pipeline. For small head dimensions, the software path uses a cubic polynomial of the form 0.07711909 r^3 + 0.22756439 r^2 + 0.69514614 r + 1.0 selected with the Sollya tool. This shifts work off the SFU and onto the much more abundant FMA units.
The Blackwell kernel orchestrates five specialized warp roles in a producer-consumer pipeline:
On B200 GPUs, FlashAttention-4 achieves approximately 1,605 to 1,613 TFLOPs/s in BF16 (about 71% of theoretical peak). It is roughly 1.3x faster than NVIDIA's cuDNN 9.13 attention kernel and up to 2.7x faster than Triton implementations on Blackwell.
FlashDecoding is an inference-time variant introduced by Tri Dao, Daniel Haziza, Francisco Massa, and Grigory Sizov in October 2023. The original Flash Attention parallelizes over batch and heads, which is fine during training (large batches, many heads) but inefficient during single-query autoregressive decoding. With a query length of 1 and a batch size of 1, FlashAttention can occupy less than 1% of the GPU.
FlashDecoding addresses this by adding a third axis of parallelism over the K/V sequence dimension. It splits the keys and values into chunks, processes each chunk in parallel on different SMs, and then combines partial results using log-sum-exp normalization. Reported speedups during long-context decoding reach up to 8x for very long sequences. FlashDecoding shipped in xFormers starting at version 0.0.22, where the dispatcher selects between FlashAttention and FlashDecoding based on workload shape.
FlashDecoding++ (Hong et al., 2024, MLSys 2024) generalizes the idea to handle dynamic softmax statistics across K/V chunks more robustly and adds optimizations for flat-GEMM-bound decoding shapes. It provides further speedups on top of FlashDecoding for inference workloads.
Flash Attention has become the de facto standard for attention computation in modern deep learning frameworks and models.
Starting with PyTorch 2.0, the torch.nn.functional.scaled_dot_product_attention (SDPA) function provides native access to Flash Attention kernels. The function automatically selects the most performant backend (Flash Attention, memory-efficient attention by Rabe and Staats, or a math fallback) based on the input dimensions, dtype, and hardware. PyTorch 2.2 integrated FlashAttention-2 into SDPA, providing roughly 2x additional speedup. Flash Attention is fully composable with torch.compile() for further optimization.
import torch.nn.functional as F
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
Hugging Face Transformers supports multiple attention backends: sdpa (which may dispatch to Flash Attention), flash_attention_2 (direct FlashAttention-2), flex_attention, and eager (standard matrix multiplication). Users can enable FlashAttention-2 directly with one keyword argument:
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2"
)
| Library / framework | Status | Notes |
|---|---|---|
| flash-attn (PyPI, Dao-AILab) | reference | Official package by Tri Dao; CUDA kernels for FA-2/FA-3/FA-4 |
| PyTorch SDPA | shipped | Auto-selects Flash Attention when applicable |
| Hugging Face Transformers | shipped | attn_implementation="flash_attention_2" |
| Triton tutorial kernel | reference | Educational; basis for many forks |
| FlexAttention (PyTorch) | shipped (2.5+) | User-defined score_mod; compiles to fused Triton |
| xFormers (Meta) | shipped | Hosts FlashAttention and FlashDecoding |
| JAX/Pallas | shipped | TPU-friendly Pallas kernels; Kvax for GPU |
| NVIDIA cuDNN | shipped | Closed-source attention kernel; Flash-style |
| NVIDIA Transformer Engine | shipped | Includes FA-2/FA-3 kernels with FP8 paths |
| AMD ROCm CK Flash Attention | shipped | Composable Kernel port for MI200/MI250/MI300/MI355 |
| vLLM | shipped | Combines Flash Attention with PagedAttention |
| TensorRT-LLM | shipped | NVIDIA inference stack with FA-derived kernels |
| GPU family | Architecture | FA-1 | FA-2 | FA-3 | FA-4 | Notes |
|---|---|---|---|---|---|---|
| V100 | Volta | yes | partial | no | no | sm_70; FA-2 limited |
| T4, RTX 20-series | Turing | yes | partial | no | no | sm_75 |
| A100, A800 | Ampere | yes | yes | no | no | sm_80 / sm_86 |
| RTX 30-series | Ampere | yes | yes | no | no | consumer-class |
| RTX 40-series | Ada Lovelace | yes | yes | no | no | sm_89 |
| H100, H200, H800 | Hopper | yes | yes | yes | partial | TMA, WGMMA, FP8 |
| B100, B200, B300 | Blackwell | yes | yes | yes | yes | TMEM, 5th-gen Tensor Cores |
| MI200, MI250 | CDNA 2 | yes | yes | n/a | n/a | via ROCm Composable Kernel |
| MI300, MI355 | CDNA 3 | yes | yes | n/a | n/a | ROCm Composable Kernel |
Flash Attention is used in essentially every major open-weight large language model trained from late 2022 onward, and is widely understood to be used in the training of frontier closed models as well, although exact training recipes are not always public.
| Model | Flash Attention usage |
|---|---|
| LLaMA, LLaMA 2, LLaMA 3, LLaMA 3.1 | trained with FA; FA-2 supported for inference |
| Mistral 7B | trained with FA; uses sliding window attention |
| Mixtral 8x7B / 8x22B | uses FA with grouped-query attention |
| Falcon | trained with FA |
| Gemma, Gemma 2, Gemma 3 | trained with FA-2 |
| Phi, Phi-2, Phi-3 | trained with FA-2 |
| Qwen, Qwen 2, Qwen 2.5 | trained with FA |
| DBRX (Databricks) | trained with FA |
| DeepSeek, DeepSeek-V2, V3 | trained with FA |
| StarCoder, StarCoder 2 | trained with FA |
| StableLM, OLMo, Granite (IBM) | trained with FA |
| BERT (replication studies) | 15% end-to-end speedup on MLPerf |
| GPT-2 replication | demonstrated 3x end-to-end speedup |
| Stable Diffusion | integrated via xFormers and SDPA |
Flash Attention extends naturally to block-sparse attention patterns. In block-sparse mode, the algorithm skips entire blocks of the K/V matrices that are masked out, avoiding both the computation and the memory access for those blocks. This yields an approximate attention method that is 2 to 4 times faster than even dense Flash Attention and scales to sequence lengths of 64K and beyond.
Block-sparse Flash Attention supports common sparsity patterns including causal masks, local (sliding window) masks, and arbitrary block-sparse masks. This makes it compatible with architectures like Mistral that use sliding window attention. The Mistral team reported a 2x speed improvement at sequence length 16K with a 4K window after collaborating on FlashAttention and xFormers integration.
The official flash-attn package on PyPI provides Python bindings to the CUDA kernels. As of mid-2026 the recommended setup is:
Supported features in the PyPI package include forward and backward passes with variable sequence lengths, MQA and GQA, causal masking, sliding window local attention, paged KV cache, rotary embeddings, ALiBi, attention dropout, deterministic backward pass, softcapping, and an FP8 forward pass on Hopper.
A minimal example:
from flash_attn import flash_attn_func
# q, k, v: (batch, seq_len, num_heads, head_dim)
out = flash_attn_func(q, k, v, causal=True)
Flash Attention's CUDA kernels require NVIDIA GPUs with compute capability 7.0 or higher (Volta, Turing, Ampere, Ada, Hopper, Blackwell). The FP8 features in FlashAttention-3 require Hopper (H100) or newer. AMD GPU support exists via the ROCm Composable Kernel port and covers MI200, MI250, MI300, and MI355 series, although feature parity with the NVIDIA backend can lag.
The Flash Attention CUDA kernels historically required long compilation times (sometimes hours) when building from source, particularly for FlashAttention-2 and 3, which use heavy C++ template metaprogramming. FlashAttention-4 addresses this by using CuTe-DSL, reducing installation time by roughly 20 to 30x.
While Flash Attention supports causal masking, sliding window attention, and block-sparse patterns, it does not natively support arbitrary per-element attention masks with full efficiency. Complex masking patterns (for example, document-level masking within a packed batch) require workarounds. PyTorch's FlexAttention API was introduced in 2024 to address this gap, allowing users to define custom score_mod and mask_mod functions that compile into fused Triton kernels with Flash-Attention-like performance (about 90% of FA-2 forward and 85% of FA-2 backward on causal attention).
Flash Attention does not change the O(N^2 d) computational complexity of attention. It only reduces memory usage and memory access overhead. For applications that need truly sub-quadratic attention, methods such as sparse attention, linear attention, Mamba state-space models, or other approximate techniques are necessary.
When sliding window attention is configured for causal self-attention, the window size parameter may be incorrectly applied to cross-attention layers as well, enforcing causal masking where bidirectional attention is intended. Users should verify attention mask behavior when combining different attention types in encoder-decoder architectures.
Although Flash Attention computes exact attention in theory, different floating-point operation orderings can produce small numerical differences compared to standard attention (on the order of 1e-6 for FP32 or 1e-3 for FP16). These differences are within the range of normal floating-point non-associativity and do not affect model quality. FlashAttention-4's selective rescaling can introduce slightly larger differences but they remain bit-equivalent or near-equivalent at BF16 precision.
| Version | Year | Authors | Target GPU | Peak utilization | Key technique |
|---|---|---|---|---|---|
| FlashAttention | 2022 | Dao, Fu, Ermon, Rudra, Ré | A100 (Ampere) | 25 to 40% | Tiling, online softmax, recomputation |
| FlashAttention-2 | 2023 | Dao | A100 (Ampere) | 50 to 73% | Better parallelism, warp partitioning |
| FlashAttention-3 | 2024 | Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao | H100 (Hopper) | about 75% (FP16) | Asynchrony, ping-pong, FP8, incoherent processing |
| FlashAttention-4 | 2025 | Dao et al. | B200 (Blackwell) | about 71% (BF16) | Selective rescaling, software-emulated exp, CuTe-DSL |
Tri Dao is the principal author of all four versions and the central figure in the project. He completed his PhD at Stanford under Christopher Ré, focusing on hardware-aware machine learning algorithms. He is now an Assistant Professor of Computer Science at Princeton University, where he directs the Dao AI Lab, and is co-founder and Chief Scientist of Together AI. His other widely cited contributions include the Mamba state-space model (with Albert Gu), Monarch matrices for structured linear layers, and the Hyena operator. Recent honors include the Schmidt Sciences AI2050 Fellowship, a Google ML and Systems Junior Faculty Award, and Outstanding Paper or Best Paper recognition at MLSys, COLM, and ICML.
The original 2022 paper's coauthors include Daniel Y. Fu (Stanford PhD, then ML researcher), Stefano Ermon (Stanford), Atri Rudra (University at Buffalo), and Christopher Ré (Stanford, MacArthur Fellow). FlashAttention-3's lead author Jay Shah is at Colfax Research, with collaborators from NVIDIA (Vijay Thakkar, Pradeep Ramani) and Meta (Ying Zhang).
Flash Attention builds on and has inspired several related lines of work:
Flash Attention is widely regarded as one of the most important systems contributions to modern deep learning. By making long-context attention practical at scale, it enabled the rapid expansion of context windows in large language models from a few thousand tokens in 2022 to hundreds of thousands of tokens by 2024 and beyond. The algorithm's IO-aware design principle has influenced subsequent work on efficient kernels for normalization layers, embedding lookups, and routing in mixture-of-experts models.
The project's open-source nature, fast iteration cycle, and willingness to track new GPU generations made it a reference example of how academic research can drive industrial deep learning infrastructure. The original repository was awarded the inaugural Stanford Open Source Software Prize in 2024.