See also: attention, transformer, self-attention, GPU computing, softmax
Flash Attention is a family of IO-aware, exact attention algorithms designed to accelerate and reduce the memory footprint of the attention mechanism in transformer models. The original Flash Attention algorithm was introduced by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Re at Stanford University in 2022. Rather than approximating attention (as many prior methods attempted), Flash Attention computes mathematically exact attention while reordering the computation to minimize memory reads and writes between different levels of GPU memory.
The core insight behind Flash Attention is that standard attention implementations are bottlenecked not by arithmetic computation but by memory access. Modern GPUs can perform far more floating-point operations per second than they can read and write data from their main memory. By restructuring the attention computation to account for the GPU memory hierarchy, Flash Attention achieves wall-clock speedups of 2 to 7.6 times over standard attention while reducing memory usage from quadratic to linear in the sequence length.
Flash Attention has been widely adopted across the deep learning ecosystem. It is integrated into PyTorch as part of the scaled_dot_product_attention function, supported natively in Hugging Face Transformers, and used in the training and inference of models such as LLaMA, Mistral, GPT variants, and many others. The algorithm has gone through four major versions (Flash Attention, FlashAttention-2, FlashAttention-3, and FlashAttention-4), each targeting newer GPU hardware and achieving higher utilization.
Imagine you have a huge jigsaw puzzle spread across a large table in the garage (that is far away), and a small desk in front of you. The standard way to solve the puzzle is to keep running back and forth to the garage table to grab pieces, look at them, and run back again. You do this thousands of times because you can only carry a few pieces at once.
Flash Attention is like being smarter about it. Instead of running back and forth for every single piece, you grab a small group of pieces, bring them to your desk, figure out how they fit together, write down your progress, and then go get the next group. You still solve the exact same puzzle with the exact same answer, but you make far fewer trips to the garage. The small desk is the fast memory on a GPU (called SRAM), and the garage table is the large but slow main memory (called HBM). Flash Attention makes fewer trips between them, so everything goes faster.
The self-attention mechanism in transformers computes pairwise interactions between all positions in a sequence. Given a sequence of length N, the standard attention algorithm materializes an N x N attention score matrix, applies softmax normalization, and then uses the result to compute a weighted sum of value vectors. This leads to O(N^2) time and memory complexity, making it the primary bottleneck when scaling transformers to long sequences.
The standard multi-head attention computation proceeds as follows:
The intermediate matrices S and P are each of size N x N. For a sequence length of 4,096 with 16-bit floating-point precision, a single attention head's S matrix alone consumes 32 MB. With multiple heads and batch elements, this adds up quickly, often exceeding available GPU memory for longer sequences.
Modern GPUs have a hierarchical memory system with dramatically different capacities and bandwidths at each level:
| Memory level | Capacity (A100) | Bandwidth (A100) | Relative speed |
|---|---|---|---|
| HBM (high bandwidth memory) | 40 to 80 GB | 1.5 to 2.0 TB/s | 1x (baseline) |
| SRAM (on-chip, per SM) | 192 KB per SM (20 MB total across 108 SMs) | ~19 TB/s | ~10x faster |
| Registers | ~256 KB per SM | Even higher | Fastest |
Standard attention writes the full N x N matrices S and P to HBM, then reads them back, resulting in a large number of slow memory accesses. Flash Attention's key contribution is recognizing that these intermediate matrices do not need to be materialized in HBM at all.
The standard attention algorithm performs O(N^2 * d) floating-point operations. Its memory access (IO) complexity is O(N * d + N^2), where the N^2 term comes from reading and writing the attention matrix S and P.
Flash Attention performs the same O(N^2 * d) floating-point operations (and sometimes slightly more due to recomputation), but its IO complexity is O(N^2 * d^2 / M), where M is the SRAM size. For typical values where d is between 64 and 128 and M is around 100 KB, this works out to significantly fewer HBM accesses. Dao et al. proved that this IO complexity is optimal: no algorithm computing exact attention can asymptotically reduce the number of HBM accesses below this bound for all values of M.
Since modern GPUs are often memory-bandwidth-bound rather than compute-bound during attention (especially at moderate sequence lengths), reducing memory accesses yields direct wall-clock speedups even though the total FLOP count may be slightly higher.
Flash Attention avoids materializing the full N x N attention matrix by splitting Q, K, and V into blocks that fit into SRAM and computing attention one block at a time. The algorithm fuses all of the attention operations (matrix multiply, softmax, another matrix multiply) into a single GPU kernel, eliminating intermediate reads and writes to HBM.
The tiling proceeds as follows:
Block sizes B_r and B_c are chosen to ensure that the Q block, K block, V block, and the intermediate score block all fit within the available SRAM.
A major challenge in tiling attention is the softmax normalization, which in its standard form requires knowing the maximum value and the sum of exponentials across the entire row of the score matrix. Since Flash Attention processes the score matrix in blocks, it cannot see the full row at once.
Flash Attention solves this using an online softmax algorithm based on the work of Milakov and Gimelshein (2018). The technique maintains two running statistics for each row:
When a new block of scores arrives, the algorithm:
This rescaling ensures numerical stability and produces results identical to computing softmax over the entire row at once. The "telescoping" property of the exponential function guarantees that the correction factors compose correctly regardless of processing order.
In the forward pass, Flash Attention does not store the N x N attention matrix P. During backpropagation, the gradients with respect to Q, K, and V require access to P. Rather than storing P (which would negate the memory savings), Flash Attention recomputes it from the saved Q, K, V blocks and the softmax normalization statistics (m and l) during the backward pass.
This recomputation trades additional FLOPs for reduced memory. The key observation is that even with the extra arithmetic, the backward pass is faster in wall-clock time because it avoids the costly HBM reads of the large attention matrix. The only values saved from the forward pass are the final output O (size N x d) and the softmax statistics m and l (each of size N), giving a total memory footprint of O(N) instead of O(N^2).
| Property | Standard attention | Flash Attention |
|---|---|---|
| Computation | Exact | Exact (identical output) |
| Time complexity (FLOPs) | O(N^2 * d) | O(N^2 * d) (same, slightly more due to recomputation) |
| Memory complexity | O(N^2) | O(N) |
| HBM IO complexity | O(N * d + N^2) | O(N^2 * d^2 / M) |
| Intermediate matrices in HBM | S and P (N x N each) | None |
| Memory savings (seq len 2K) | Baseline | ~10x reduction |
| Memory savings (seq len 4K) | Baseline | ~20x reduction |
| Wall-clock speedup (attention only) | Baseline | 2 to 7.6x faster |
| Backward pass storage | O(N^2) for P | O(N) for (O, m, l) |
FlashAttention-2, published by Tri Dao in July 2023, improved on the original algorithm with three optimizations aimed at closing the gap between Flash Attention's throughput and the theoretical peak of the GPU hardware.
The original Flash Attention spent a meaningful fraction of its time on non-matrix-multiplication operations such as softmax rescaling, bound checking, and causal masking. These operations run on general-purpose CUDA cores rather than specialized Tensor Cores, which are 4 to 16 times faster for matrix operations. FlashAttention-2 restructured the algorithm to defer the final softmax rescaling to the end of the loop, reducing the number of intermediate rescaling steps.
The original Flash Attention parallelized work across the batch dimension and the number of attention heads. When the batch size or head count was small (common with long sequences), many GPU streaming multiprocessors (SMs) sat idle. FlashAttention-2 added parallelism over the sequence length dimension, distributing different blocks of Q to different thread blocks. This improved GPU occupancy, especially for long-context workloads with small batch sizes.
Within a thread block, the original Flash Attention used a "split-K" scheme: it distributed K and V across warps while keeping Q shared. This required all warps to write their partial results to shared memory, synchronize, and sum, adding overhead. FlashAttention-2 reversed this, splitting Q across warps instead. Since each warp now computes a separate portion of the output, no cross-warp synchronization or reduction is needed during the inner loop.
| Metric | FlashAttention (v1) | FlashAttention-2 |
|---|---|---|
| GPU utilization (A100) | 25 to 40% of theoretical max | 50 to 73% of theoretical max |
| Attention TFLOPs/s (A100) | ~120 TFLOPs/s | ~230 TFLOPs/s |
| End-to-end training (A100) | ~125 TFLOPs/s | Up to 225 TFLOPs/s (72% MFU) |
| Speed vs. standard PyTorch | 3x faster | Up to 9x faster |
| Max head dimension supported | 128 | 256 |
FlashAttention-2 also introduced support for multi-query attention (MQA) and grouped-query attention (GQA), which are used in many modern large language models for more efficient inference.
FlashAttention-3, published in July 2024 by Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, and Tri Dao, targeted NVIDIA's Hopper architecture (H100 GPUs). FlashAttention-2 achieved only about 35% utilization on the H100 because it did not exploit several new hardware features available on Hopper.
Hopper GPUs introduced the Tensor Memory Accelerator (TMA), a dedicated hardware unit that can load data from HBM to shared memory independently of the Tensor Cores. FlashAttention-3 exploits this by splitting warps into producer warps (which issue TMA loads) and consumer warps (which run matrix multiplications on the Tensor Cores). This overlaps data movement with computation, hiding memory latency.
FlashAttention-3 uses two warp groups that alternate ("ping-pong") between computing matrix multiplications: while one warp group computes GEMM on the current block, the other warp group's softmax and data loading for the next block proceed concurrently. This interleaving keeps both the Tensor Cores and the TMA busy simultaneously.
Hopper's Tensor Cores support FP8 arithmetic at double the throughput of FP16. However, naive FP8 quantization introduces significant numerical error, especially when the input contains outlier values. FlashAttention-3 addresses this with a technique called "incoherent processing": before quantization, the query and key matrices are multiplied by a random orthogonal matrix (implemented efficiently via a Hadamard transform with random signs). This "spreads out" outlier values across dimensions, reducing quantization error. The operation has O(d log d) complexity per vector and is negligible compared to the O(N^2 * d) attention cost.
FP8 FlashAttention-3 achieves 2.6 times lower numerical error than standard FP8 attention using per-tensor quantization.
| Metric | FlashAttention-2 (H100) | FlashAttention-3 FP16 (H100) | FlashAttention-3 FP8 (H100) |
|---|---|---|---|
| GPU utilization | ~35% | ~75% | Higher (FP8 doubles throughput) |
| TFLOPs/s | ~350 | ~740 | ~1,200 (1.2 PFLOPs/s) |
| Speedup vs. FA-2 | Baseline | 1.5 to 2.0x | Higher with FP8 |
FlashAttention-3 was presented at NeurIPS 2024.
FlashAttention-4 was presented by Tri Dao at Hot Chips in September 2025 and targets NVIDIA's Blackwell architecture (B200 GPUs). It is written in CuTe-DSL (CUDA Templates Domain-Specific Language) rather than raw CUDA C++.
FlashAttention-4 introduces a new online softmax variant that skips approximately 90% of output rescaling operations. In the standard online softmax, every time a new block produces a larger row maximum, all previously accumulated outputs must be rescaled. FlashAttention-4 only performs this rescaling when the new maximum is "sufficiently larger" than the current one, reducing overhead without affecting numerical correctness.
Additionally, FlashAttention-4 uses software simulation of the exponential function (via the MUFU.EX2 instruction) to better overlap the softmax computation with Tensor Core operations.
On B200 GPUs, FlashAttention-4 achieves approximately 1,600 TFLOPs/s (71% utilization), up to 2.7 times faster than previous methods. It is also reported to be 22% faster than the attention kernel in NVIDIA's cuDNN library on Blackwell hardware.
Flash Attention has become the de facto standard for attention computation in modern deep learning frameworks and models.
Starting with PyTorch 2.0, the torch.nn.functional.scaled_dot_product_attention (SDPA) function provides native access to Flash Attention kernels. The function automatically selects the most performant backend (Flash Attention, memory-efficient attention, or a math fallback) based on the input dimensions and hardware. PyTorch 2.2 integrated FlashAttention-2 into SDPA, providing roughly 2x additional speedup. Flash Attention is fully composable with torch.compile() for further optimization.
Hugging Face Transformers supports multiple attention backends: sdpa (which may dispatch to Flash Attention), flash_attention_2 (direct FlashAttention-2), flex_attention, and eager (standard matrix multiplication). Users can enable FlashAttention-2 directly via:
model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2"
)
Many widely used models were trained or are served with Flash Attention:
| Model | Flash Attention usage |
|---|---|
| LLaMA 2 and LLaMA 3 | Used during training; supports FA-2 for inference |
| Mistral 7B | Trained with Flash Attention and sliding window attention |
| Mixtral | Uses Flash Attention with grouped-query attention |
| GPT-2 replication studies | Demonstrated 3x end-to-end speedup |
| Stable Diffusion | Integrated via xFormers and SDPA |
| BERT training | 15% end-to-end speedup on MLPerf benchmark |
Flash Attention implementations exist for JAX (Kvax), AMD GPUs (via ROCm), and NVIDIA's Transformer Engine. The official repository is maintained at Dao-AILab/flash-attention on GitHub.
Flash Attention extends naturally to block-sparse attention patterns. In block-sparse mode, the algorithm skips entire blocks of the K/V matrices that are masked out, avoiding both the computation and the memory access for those blocks. This yields an approximate attention method that is 2 to 4 times faster than even dense Flash Attention and scales to sequence lengths of 64K and beyond.
Block-sparse Flash Attention supports common sparsity patterns including causal masks, local (sliding window) masks, and arbitrary block-sparse masks. This makes it compatible with architectures like Mistral that use sliding window attention.
Flash Attention's CUDA kernels require NVIDIA GPUs with compute capability 7.0 or higher (Volta, Turing, Ampere, Hopper, Blackwell). The FP8 features in FlashAttention-3 require Hopper (H100) or newer. AMD GPU support exists but may lag behind NVIDIA in feature parity.
The Flash Attention CUDA kernels historically required long compilation times (sometimes hours) when building from source, particularly for FlashAttention-2 and 3 which use heavy C++ template metaprogramming. FlashAttention-4 addresses this by using CuTe-DSL, reducing installation time significantly.
While Flash Attention supports causal masking, sliding window attention, and block-sparse patterns, it does not natively support arbitrary per-element attention masks with full efficiency. Complex masking patterns (for example, document-level masking within a packed batch) require workarounds. PyTorch's FlexAttention API was introduced to address this gap, allowing users to define custom attention masks that compile into fused kernels with Flash Attention-like performance.
Flash Attention does not change the O(N^2 * d) computational complexity of attention. It only reduces memory usage and memory access overhead. For applications that need truly sub-quadratic attention, methods such as sparse attention, linear attention, or other approximate techniques are necessary.
When sliding window attention is configured for causal self-attention, the window size parameter may be incorrectly applied to cross-attention layers as well, enforcing causal masking where bidirectional attention is intended. Users should verify attention mask behavior when combining different attention types in encoder-decoder architectures.
Although Flash Attention computes exact attention in theory, different floating-point operation orderings can produce small numerical differences compared to standard attention (on the order of 1e-6 for FP32 or 1e-3 for FP16). These differences are within the range of normal floating-point non-associativity and do not affect model quality.
| Version | Year | Authors | Target GPU | Peak utilization | Key technique |
|---|---|---|---|---|---|
| FlashAttention | 2022 | Dao, Fu, Ermon, Rudra, Re | A100 (Ampere) | 25 to 40% | Tiling, online softmax, recomputation |
| FlashAttention-2 | 2023 | Dao | A100 (Ampere) | 50 to 73% | Better parallelism, warp partitioning |
| FlashAttention-3 | 2024 | Shah, Bikshandi, Zhang, Thakkar, Ramani, Dao | H100 (Hopper) | ~75% (FP16) | Asynchrony, ping-pong, FP8 |
| FlashAttention-4 | 2025 | Dao et al. | B200 (Blackwell) | ~71% | Lazy rescaling, CuTe-DSL |
Flash Attention builds on and has inspired several related lines of work: