FlashAttention is an IO-aware exact attention algorithm that accelerates transformer training and inference by reducing memory reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. Developed by Tri Dao and collaborators, FlashAttention computes standard, mathematically exact attention without approximation while achieving significant wall-clock speedups and reducing memory usage from quadratic to linear in sequence length. Since its introduction in 2022, the algorithm has become a foundational component in virtually every major large language model implementation.
The self-attention mechanism in transformers computes pairwise interactions between all tokens in a sequence. For a sequence of length N, standard attention forms an N x N score matrix S = QK^T, applies softmax to produce a probability matrix P, and multiplies by the value matrix V to obtain the output. This process requires O(N^2) memory to store the intermediate matrices S and P, and the quadratic scaling creates severe bottlenecks as sequence lengths grow.
Conventionally, researchers focused on reducing the computational complexity of attention through approximate methods such as sparse attention and linear attention. These approaches traded model quality for lower arithmetic cost. However, Tri Dao and colleagues observed that the true bottleneck on modern GPUs is not the number of floating-point operations (FLOPs) but rather the movement of data between different levels of the memory hierarchy. Modern GPUs like the NVIDIA A100 can perform far more arithmetic operations per second than they can move data to and from HBM. Attention is therefore memory-bandwidth-bound, not compute-bound, for typical sequence lengths encountered in practice [1].
This insight motivated an algorithm designed around IO-awareness: explicitly accounting for the cost of reading and writing data between slow HBM (with high capacity but limited bandwidth) and fast on-chip SRAM (with limited capacity but very high bandwidth).
The original FlashAttention paper, titled "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," was authored by Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Re. Dao, Fu, Ermon, and Re were affiliated with Stanford University, while Rudra was at the University at Buffalo, SUNY. The paper was first released as an arXiv preprint in May 2022 and published at NeurIPS 2022 on the Main Conference Track [1].
FlashAttention uses three core techniques to eliminate the memory bottleneck: tiling (also called blocking), kernel fusion, and selective recomputation.
Rather than computing the full N x N attention matrix at once, FlashAttention divides the Q, K, and V matrices into smaller blocks that fit within the GPU's on-chip SRAM. The algorithm loads one block of K and V from HBM into SRAM, then iterates over blocks of Q. For each pair of blocks, it computes the local attention scores, applies softmax incrementally, and accumulates the partial output. The critical advantage is that the large intermediate matrices S and P are never materialized in HBM. They exist only temporarily in SRAM during each block's computation, then are discarded.
A significant challenge with tiling attention is that softmax is not straightforwardly decomposable across blocks. Computing softmax(QK^T) normally requires knowing the maximum value across the entire row of scores to ensure numerical stability, which would seem to require materializing the full score matrix first.
FlashAttention overcomes this through the online softmax trick, originally proposed by Milakov and Gimelshein (2018). The algorithm maintains running statistics (the row-wise maximum value m and the sum of exponentials l) that are updated incrementally as each new block of K is processed. When a new block produces a larger maximum value, previously accumulated results are rescaled accordingly. This allows softmax to be computed in a single streaming pass over the K blocks, without ever needing to see all scores simultaneously. This incremental approach is arguably the key algorithmic innovation that makes FlashAttention possible [1].
During the backward pass for computing gradients, standard implementations would need the stored N x N attention matrix. FlashAttention avoids this by storing only the output O and the compact softmax normalization statistics (m and l) from the forward pass. During backpropagation, it recomputes the attention scores on the fly from blocks of Q, K, and V that are loaded back into SRAM. This can be understood as a selective form of gradient checkpointing, but unlike standard checkpointing (which trades speed for memory), FlashAttention's recomputation is actually faster than reading the stored matrices from HBM because the recomputation happens entirely in fast SRAM [1].
Standard attention requires O(N^2) HBM accesses to read and write the score and probability matrices. FlashAttention requires O(N^2 d^2 M^-1) HBM accesses, where d is the head dimension and M is the SRAM size. Since d is typically 64 or 128 and M is on the order of hundreds of kilobytes, this represents a substantial reduction. The authors further proved that FlashAttention is asymptotically optimal in the number of HBM accesses among all exact attention algorithms for certain ranges of SRAM sizes [1].
The memory footprint drops from O(N^2) to O(N), since no N x N intermediate matrices are stored in HBM. In practice, this translated to 10x memory savings at sequence length 2K and 20x savings at sequence length 4K compared to standard attention baselines [1].
| Benchmark | Speedup | Notes |
|---|---|---|
| BERT-large (seq. 512) | 15% end-to-end | Compared to MLPerf 1.1 training speed record |
| GPT-2 (seq. 1K) | Up to 3x | Wall-clock training speedup |
| Long-range arena (seq. 1K-4K) | 2.4x | Speedup over standard attention |
| GPT-2 (general) | Up to 7.6x | Kernel-level speedup |
| Long sequences (8K) | 2.2x-2.7x | Compared to PyTorch and Megatron-LM |
FlashAttention-2 was released by Tri Dao in July 2023 in a paper titled "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning" [2]. While the first version achieved substantial improvements over standard attention, it still reached only 25-35% of the theoretical maximum throughput of the A100 GPU. FlashAttention-2 closed much of this gap through three key optimizations.
The first optimization involved restructuring the algorithm to minimize operations that are not matrix multiplications. On modern GPUs, tensor cores execute matrix multiply-accumulate operations at extremely high throughput (for example, 312 TFLOPS for FP16 matmul on the A100, versus only 19.5 TFLOPS for non-matmul FP32 operations). FlashAttention-2 reorganized the online softmax computation to perform fewer rescaling operations, shifting a larger fraction of the total work into matmul operations that tensor cores handle efficiently [2].
FlashAttention-1 parallelized across batch size and number of attention heads, leaving the sequence length dimension serial. This limited occupancy when batch sizes were small or heads were few, a common scenario with long sequences. FlashAttention-2 introduced additional parallelism over the sequence length dimension in the forward pass, distributing different blocks of the output across thread blocks. For long sequences with small batch sizes, this yielded a major occupancy improvement [2].
FlashAttention-2 also optimized how work is divided among warps (groups of 32 threads) within each thread block. FlashAttention-1 split K and V across warps, requiring costly shared memory synchronization to combine partial results. FlashAttention-2 instead split Q across warps while having each warp access the full K and V, which eliminated most inter-warp communication and reduced shared memory reads and writes [2].
FlashAttention-2 was rewritten from scratch using NVIDIA's CUTLASS 3.x library and its CuTe abstraction layer. The result was roughly a 2x speedup over FlashAttention-1, reaching 50-73% of the theoretical maximum FLOPS on the A100 GPU. End-to-end training of GPT-style models achieved up to 225 TFLOPS per A100 GPU, corresponding to 72% model FLOPS utilization [2]. FlashAttention-2 also added native support for multi-query attention (MQA) and grouped-query attention (GQA), two widely used techniques that reduce the size of the key-value cache during inference.
FlashAttention-3, titled "FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision," was authored by Jay Shah and Ganesh Bikshandi (Colfax Research), Ying Zhang (Meta), Vijay Thakkar and Pradeep Ramani (NVIDIA), and Tri Dao (Princeton University and Together AI). The paper was released as an arXiv preprint in July 2024 and published at NeurIPS 2024 [3].
Despite FlashAttention-2's strong results on the A100, it achieved only about 35% utilization on NVIDIA's newer H100 Hopper GPU. The Hopper architecture introduced new hardware features, including the Tensor Memory Accelerator (TMA) for asynchronous data movement, asynchronous execution of tensor core operations, and native FP8 support. FlashAttention-3 was designed specifically to exploit these capabilities [3].
FlashAttention-3 introduced three main techniques:
Warp Specialization and Asynchronous Pipelining. The algorithm uses warp specialization to overlap data movement (via TMA) with tensor core computation. Dedicated "producer" warps handle loading data from HBM into shared memory, while "consumer" warps execute the matmul and softmax operations. This overlapping hides much of the memory latency that previously limited performance [3].
Interleaved Block-wise Matmul and Softmax. On Hopper, the WGMMA (Warpgroup Matrix Multiply-Accumulate) instructions execute asynchronously, meaning the issuing warp can perform other work while the tensor cores are busy. FlashAttention-3 exploits this by interleaving softmax computation on one block with the matmul of the next block, effectively overlapping the two operations rather than executing them sequentially [3].
FP8 with Block Quantization and Incoherent Processing. FlashAttention-3 leverages the H100's native FP8 tensor cores, which offer roughly double the throughput of FP16. To maintain accuracy with the reduced precision, the algorithm uses block-wise quantization (scaling each block independently) and incoherent processing (multiplying by a random orthogonal matrix to spread outlier values more evenly). These techniques reduce numerical error by 2.6x compared to a naive FP8 attention baseline [3].
| Configuration | Throughput | Utilization |
|---|---|---|
| FP16 on H100 | Up to 740 TFLOPS | 75% of theoretical max |
| FP8 on H100 | Close to 1.2 PFLOPS | Near peak FP8 throughput |
| vs FlashAttention-2 on H100 | 1.5-2.0x speedup | Improved from 35% to 75% utilization |
In late 2025 and early 2026, FlashAttention-4 emerged as the next iteration, targeting NVIDIA's Blackwell architecture (e.g., B200 GPUs) alongside continued Hopper optimization. A major change is the use of CuTeDSL, a Python domain-specific language released by the NVIDIA CUTLASS team for writing high-performance CUDA kernels. This replaced the lower-level C++ CUTLASS code used in previous versions [4].
The shift to CuTeDSL was partly necessitated by architectural changes: NVIDIA's Blackwell SM100 architecture replaced the WGMMA instructions used on Hopper with new TCGEN05 tensor core instructions, meaning FlashAttention-3 could not run on Blackwell hardware without a rewrite [4].
As of March 2026, FlashAttention-4 serves as a backend for PyTorch's FlexAttention framework, enabling users to define custom attention variants (such as causal masking, sliding window attention, or document masking) that are automatically compiled into efficient FlashAttention-4 kernels. On Blackwell GPUs, the Flash backend achieves 1.6-3.2x speedup over the Triton implementation for forward passes and 1.85-2.3x for backward passes [4].
| Feature | FlashAttention-1 (2022) | FlashAttention-2 (2023) | FlashAttention-3 (2024) | FlashAttention-4 (2025-2026) |
|---|---|---|---|---|
| Primary GPU target | A100 | A100 | H100 (Hopper) | H100, B200 (Blackwell) |
| Max utilization achieved | 25-35% | 50-73% | 75% | Improved further |
| FP8 support | No | No | Yes | Yes |
| Implementation language | CUDA | CUTLASS 3.x / CuTe | CUTLASS 3.x / CuTe | CuTeDSL (Python) |
| Key innovation | IO-aware tiling | Better parallelism | Asynchronous pipelining | Blackwell support, FlexAttention |
| MQA/GQA support | Limited | Native | Native | Native |
One of FlashAttention's most consequential practical effects has been enabling training and inference with much longer sequences. Because standard attention requires O(N^2) memory, doubling the sequence length quadruples memory consumption, quickly exceeding GPU memory limits. FlashAttention's linear memory scaling fundamentally changes this equation.
With FlashAttention, training with 8K context length is only 7% less hardware-efficient than training with 2K context length, a dramatic improvement over implementations like Megatron-LM where such an increase caused severe degradation [5]. This efficiency has been instrumental in the trend toward longer context windows in modern language models. Models such as GPT-4, LLaMA, and Qwen have all been trained with context lengths of 8K, 32K, 128K, or beyond.
Research has consistently shown that models trained with longer context outperform those with shorter context on both pretraining perplexity and downstream tasks. FlashAttention made this practical by removing the memory barrier that previously constrained context length choices.
FlashAttention is complementary to, rather than competing with, most other attention optimization techniques. It is helpful to distinguish between approaches that change what is computed and approaches that change how it is computed.
Multi-Query Attention (MQA), introduced by Shazeer (2019), reduces the number of key and value heads to one, so all query heads share a single set of keys and values. This dramatically reduces KV cache size during inference. Grouped-Query Attention (GQA), proposed by Ainslie et al. (2023), is an interpolation where groups of query heads share key-value heads. Both MQA and GQA reduce the amount of computation and memory needed for the KV projections. FlashAttention-2 and later versions natively support both MQA and GQA, meaning these techniques can be combined for compounding benefits [2].
Sparse Attention methods, such as those in Longformer (Beltagy et al., 2020) and BigBird (Zaheer et al., 2020), restrict each token to attend only to a subset of other tokens (e.g., local windows plus global tokens). This reduces both computation and memory from O(N^2) to O(N) but at the cost of potentially missing long-range dependencies.
Linear Attention methods, including Katharopoulos et al. (2020), replace the softmax with a kernel function that allows the attention computation to be reformulated as a linear recurrence, reducing complexity to O(N). However, these methods are approximate and have generally not matched standard attention in model quality for language tasks.
FlashAttention falls squarely in this category. It computes mathematically exact standard attention but reorganizes the computation to be memory-efficient. This means it is fully compatible with any attention variant: standard multi-head attention, MQA, GQA, causal attention, sliding window attention, or cross-attention. Any model using any of these patterns can benefit from FlashAttention's IO-aware tiling.
Meta's xFormers library independently developed memory-efficient attention implementations in a similar spirit. The xFormers "cutlass" attention backend uses tiling and kernel fusion comparable to FlashAttention, and the library dynamically dispatches to whichever backend is available and fastest. In practice, FlashAttention-2 demonstrated roughly 2x speedup over the xFormers cutlass implementation, and the Dao-AILab implementation has become the dominant backend [6].
PyTorch integrated FlashAttention directly into its core API starting with PyTorch 2.0 via torch.nn.functional.scaled_dot_product_attention() (SDPA). This function automatically dispatches to FlashAttention when the inputs and hardware are compatible, making adoption seamless for PyTorch users. The SDPA interface also supports falling back to a memory-efficient C++ implementation or a standard math implementation when FlashAttention is unavailable [7].
FlashInfer is another notable implementation, developed as a library for LLM serving that includes attention kernels optimized for variable-length sequences and paged KV caches. It has been adopted by serving systems including SGLang and MLC-LLM [8].
NVIDIA's cuDNN library also includes its own fused attention implementation, which competes with FlashAttention on throughput for certain configurations. TensorRT-LLM, NVIDIA's inference optimization toolkit, integrates both cuDNN attention and FlashAttention kernels.
FlashAttention's adoption has been remarkably broad. By 2024, it was used in virtually every major open-source LLM implementation. Notable adopters include:
Major models trained with FlashAttention include the LLaMA family (Meta), Falcon (TII), MPT (MosaicML), Mistral, and many others. The algorithm's influence extends beyond language models to vision transformers, diffusion models for image generation, and audio processing models.
As of early 2026, the FlashAttention ecosystem continues to evolve rapidly. FlashAttention-4, written in CuTeDSL, represents the state of the art for Hopper and Blackwell GPUs. The integration with PyTorch's FlexAttention framework means that custom attention patterns (causal masks, sliding windows, document boundaries, and arbitrary user-defined modifications) can be expressed at a high level and compiled into optimized FlashAttention kernels automatically.
Tri Dao, who is an Assistant Professor of Computer Science at Princeton University and Co-founder and Chief Scientist of Together AI, continues to lead the development. The open-source repository at Dao-AILab/flash-attention on GitHub remains the reference implementation.
The broader lesson of FlashAttention, that hardware-aware algorithm design can yield order-of-magnitude improvements without sacrificing mathematical exactness, has influenced work well beyond attention. Researchers have applied similar IO-aware principles to other operations in deep learning, including state-space models and convolution layers. FlashAttention demonstrated that the gap between theoretical peak hardware performance and actual achieved performance is often large, and that closing this gap through careful algorithm-hardware co-design can be as impactful as architectural innovation.