Gradient checkpointing
Last reviewed
May 25, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 ยท 3,602 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 25, 2026
Sources
No citations yet
Review status
Needs citations
Revision
v1 ยท 3,602 words
Add missing citations, update stale details, or suggest a clearer explanation.
Gradient checkpointing, also called activation checkpointing or activation recomputation, is a memory-saving technique for training deep neural networks that trades additional compute for reduced peak memory consumption.[1] Rather than storing every intermediate activation produced during the forward pass, the technique stores only a small subset of activations (the "checkpoints") and recomputes the others on demand during the backward pass.[1] The technique was introduced by Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin in the 2016 paper "Training Deep Nets with Sublinear Memory Cost", which showed that a network of n layers can be trained with O(sqrt(n)) memory at the cost of a single extra forward pass per mini-batch.[1] Gradient checkpointing is implemented in every major training framework, including pytorch (torch.utils.checkpoint), jax (jax.checkpoint, aliased to jax.remat), tensorflow (tf.recompute_grad), deepspeed, and fsdp (apply_activation_checkpointing).[2][3][4][5] It is a standard component of large language model training, used in llama, gpt-3, and the megatron lm family among many others.[6][7]
Training a neural network with backpropagation requires the gradient of the loss with respect to every parameter. Computing those gradients via the chain rule needs intermediate values from the forward pass, called activations, at every layer where the gradient is propagated through a non-linear or otherwise non-invertible operation.[8] In the most direct implementation of reverse-mode automatic differentiation, all activations are stored during the forward pass and read back during the backward pass. Memory consumption therefore grows linearly with the depth of the network and with the batch size, sequence length, and hidden dimension.[8]
For a transformer block, the activations that must be stored include the input to attention, the queries, keys, values, the attention probability matrix (of size sequence_length squared), the attention output, the input to the feed-forward layer, the intermediate feed-forward activation (typically four times the hidden dimension), and the layer-norm statistics. Korthikanti and colleagues derive an approximate per-layer formula and show that activation memory grows linearly in batch size, sequence length, hidden dimension, and number of attention heads.[6] At a context length of 2,048 and hidden size of 12,288 (the GPT-3 configuration), per-layer activation memory exceeds the storage required for the layer's parameters by a wide margin, making activations the dominant memory cost as models grow taller and longer-context.[6]
As neural networks grew in depth and width during the 2010s, this linear scaling became a bottleneck. Chen and colleagues observed in 2016 that on a 1,000-layer deep residual network the activations occupied roughly 48 GB, exceeding the memory of available GPUs.[1] The same constraint reappeared at much larger scale with the transformer architecture: when training a gpt-3 style 530-billion-parameter model, activation memory alone exceeded what a single NVIDIA nvidia a100 GPU with 80 GB of memory could hold even after applying tensor and pipeline parallelism.[6] backpropagation thus requires explicit strategies to control activation memory, of which gradient checkpointing is the most widely adopted.[1][6]
The technique is sometimes confused with model checkpointing, which periodically saves model weights to disk for fault tolerance during long training runs. They are unrelated. Gradient checkpointing operates within a single backward pass and concerns intermediate activations, not parameter snapshots.
The core idea (trade compute for memory by recomputing intermediates) predates the 2016 paper by decades. The automatic differentiation literature describes "checkpointing" or "revolve" schemes for adjoint computation in the context of ODE solvers and PDE-constrained optimisation, including Andreas Griewank's work in the 1990s on recursive binomial checkpointing. Chen et al. cite this earlier literature and adapt it to deep network training graphs, but their contribution is the specific O(sqrt(n)) recipe and an automated, framework-level implementation.[1] The reframing for deep learning, and the practical demonstration on a 1,000-layer ResNet, opened the technique to widespread adoption in mainstream training frameworks.[1]
The algorithm proposed by Chen et al. partitions an n-layer feed-forward network into segments and stores only the activation at the boundary of each segment.[1] During the backward pass, when gradients are needed for a layer inside a segment, the forward pass is re-executed from the most recent stored checkpoint up to that layer, regenerating the missing activations on demand.[1] Once the gradient for the layer has been computed, the recomputed activations are discarded.
Formally, consider a network expressed as a composition f = f_L o ... o f_2 o f_1, where each f_i takes the activation a_(i-1) and produces a_i. Standard backpropagation stores all a_i during the forward pass. With checkpointing, the layers are grouped into segments; only the activations at segment boundaries are persisted. To compute the gradient da_i/dtheta_i inside a segment, the activations a_(j+1), ..., a_(i-1) (where a_j is the previous checkpoint) are recomputed by re-running the forward pass on the segment.[1] The gradient is then accumulated along the chain rule exactly as in standard backpropagation; the only difference is when and how the activations are materialised.[1]
Chen et al. observed that placing approximately sqrt(n) checkpoints in an n-layer network is the optimum of the simple memory-vs-compute tradeoff: each segment then holds at most sqrt(n) activations in memory at any one time, and the total recomputed work is one extra forward pass.[1] Recursive application of the same strategy reduces memory further to O(log n) at the cost of O(n log n) extra forward computation, useful only in extreme regimes.[1] In an implementation released alongside the paper, the authors reduced memory of a 1,000-layer residual network from 48 GB to 7 GB with 30 percent additional running time on the ImageNet problem.[1]
A widely cited illustrative reimplementation by Tim Salimans and Yaroslav Bulatov demonstrated that a TensorFlow training script using their automatic graph rewrite could fit "more than 10x larger models" onto a single GPU at roughly 20 percent additional compute time, popularising the technique under the name "gradient checkpointing".[9]
Let n be the number of layers and c be the number of checkpoints. Activation memory after checkpointing scales as O(n/c + c), since at any instant the implementation holds c checkpoints plus the activations recomputed for the current segment, which contains n/c layers. Minimising n/c + c over c yields c = sqrt(n) and a memory cost of O(sqrt(n)).[1] Compute, in contrast, scales as O(n) per backward pass plus O(n) for the additional segment-by-segment recomputation, so the asymptotic cost of the backward pass roughly doubles. In practice, because the backward pass already does about twice as much work as the forward pass under standard backpropagation, the total wall-clock cost of training with checkpointing is roughly 33 percent higher than without it, in line with the 30 percent number reported by Chen et al.[1]
A worked numerical example clarifies the magnitudes. Consider a 80-layer transformer with hidden size 8,192, sequence length 4,096, and batch size 4 microbatches per GPU. Without checkpointing, the activations per layer occupy roughly 0.6 GB in FP16, summing to nearly 50 GB for the layer stack alone, dominating optimiser and parameter memory on an 80 GB GPU.[6] With full checkpointing at every layer boundary, the implementation stores only the layer input (about 0.6 GB) at each checkpoint plus the activations of one layer being recomputed; total activation memory drops below 10 GB.[6] Korthikanti et al. depict this graphically across the 22B, 175B, 530B, and 1T model configurations, showing baseline activation memory above the 80 GB A100 ceiling in every case and the "present work" (sequence parallelism plus selective recomputation) falling below 30 GB.[6]
For more general computation graphs, the optimisation problem becomes a graph partition: pick a subset S of edges (the "cut") such that every node has an ancestor in S, minimise the maximum size of a connected component of (graph minus S), and minimise total recomputation cost. The general problem is NP-hard, but heuristics close to the original sqrt-n approach work well for the layer-stacked computation graphs typical of feed-forward networks and transformer models.[1] More sophisticated dynamic programming solutions exist for chain-structured graphs and are sometimes called Treeverse or revolve algorithms in the automatic differentiation literature.
A second important consideration is the size of each saved activation. Two checkpoints of equal "count" can have very different memory footprints if the network's hidden dimension changes across layers (e.g., U-Nets or models with bottlenecks). Implementations therefore use byte-size-weighted variants of the cut problem in practice. The original Chen et al. analysis assumes uniform layer cost; this matches deep ResNets and isotropic transformers well but is less appropriate for encoder-decoder architectures with varying feature-map sizes.[1]
Vijay Korthikanti and collaborators at NVIDIA introduced selective activation recomputation in the May 2022 paper "Reducing Activation Recomputation in Large Transformer Models".[6] Their observation: in a transformer block, certain operations (notably the softmax and dropout inside attention, and certain GeLU and dropout activations in the MLP block) consume a disproportionate amount of activation memory relative to the compute they require, while other operations (large matrix multiplications) are the opposite.[6] The selective strategy stores activations for the compute-expensive operations and recomputes only the memory-heavy, compute-light operations.[6]
Combined with sequence parallelism, this technique reduces activation memory by 5x while keeping the compute overhead of recomputation under 10 percent of the original training step.[6] When training a 530B GPT-3 style model on 2,240 A100 GPUs, the authors achieved 54.2 percent Model FLOPs Utilisation, a 29 percent improvement over the 42.1 percent obtained with full activation recomputation.[6] The implementation is part of the open-source megatron lm codebase.[6]
A second variant moves checkpointed activations from GPU memory to either CPU host memory or even disk between the forward and backward passes, prefetching them back asynchronously when needed.[4] This adds a third axis to the tradeoff: GPU memory, CPU memory, and compute. deepspeed supports CPU offloading via the checkpoint_in_cpu parameter inside its activation checkpointing configuration, which only operates when activations are partitioned across model parallelism ranks.[4][10] The fsdp checkpoint_wrapper similarly supports an offload_to_cpu option that moves preserved inputs to host RAM and prefetches them back to the GPU immediately before recomputation.[5]
In tensor parallelism setups (as used in megatron lm), each GPU holds only a shard of the model. DeepSpeed's partition_activations option distributes the activation checkpoints themselves across the tensor parallel ranks, dividing activation memory per GPU by the tensor parallel size.[4] When combined with contiguous_memory_optimization, the partitioned checkpoints are copied into a single contiguous buffer, reducing memory fragmentation.[4]
In jax, the jax.checkpoint decorator (alias jax.remat) is implemented as a transformation that interacts with the XLA compiler.[3] Recent versions add a policy argument that controls which intermediate values are saved versus recomputed, with built-in policies such as dots_with_no_batch_dims_saveable, save_only_these_names, and offload_dot_with_no_batch_dims.[11] The JAX documentation notes that when functions are compiled with jax.jit, XLA may apply rematerialisation transforms automatically, so an explicit jax.checkpoint is most useful around staged control flow such as jax.lax.scan, which is common in transformer implementations.[11]
pytorch exposes the technique through torch.utils.checkpoint.[2] The primary entry point is torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, ...), which runs function(*args) without saving its intermediate activations and arranges for them to be recomputed during the backward pass.[2] A checkpoint_sequential(functions, segments, input) helper applies the technique to a nn.Sequential model by splitting it into a chosen number of segments.[2]
The library supports two modes selected by the use_reentrant flag. The original, reentrant implementation invokes a nested backward pass through the autograd engine; it has limitations around inputs that require gradients and around hooks. The non-reentrant implementation (use_reentrant=False) was added later, supports more functionality, and is the recommended choice; PyTorch documentation states that the parameter must be explicitly specified and that future versions will raise an exception if it is omitted.[2] Checkpointed functions preserve the random number generator state by default so that operations such as dropout produce the same values on the recomputation pass; this can be disabled with preserve_rng_state=False at the cost of correctness for stochastic operations.[2]
For distributed training, PyTorch ships apply_activation_checkpointing in torch.distributed.algorithms._checkpoint.checkpoint_wrapper, which walks a model and wraps submodules selected by a user-supplied check_fn or auto_wrap_policy.[5] For fsdp training, the recommended pattern is to wrap each transformer block (attention plus feed-forward) and to use non-reentrant checkpointing, since the reentrant variant interacts poorly with FSDP's state synchronisation.[5] The wrapper accepts an offload_to_cpu argument that moves preserved inputs to host RAM and prefetches them before recomputation.[5]
jax provides jax.checkpoint, aliased to jax.remat, with the signature jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=(), concrete=Deprecated).[3] Applying it as a decorator marks the wrapped function so that during reverse-mode autodiff its intermediates are recomputed rather than stored.[3] As noted above, JAX exposes several policies for fine-grained control over what is checkpointed.[11]
tensorflow exposes the technique through tf.recompute_grad(f), a decorator that wraps a function so that its forward computation is repeated during the backward pass instead of caching intermediate tensors.[12] The early popular implementation of the technique on TensorFlow was the cybertronai/gradient-checkpointing library by Tim Salimans and Yaroslav Bulatov, which automatically rewrote a TensorFlow computation graph to insert checkpoints; the project is now in maintenance mode and the official tf.recompute_grad API is the recommended replacement.[9][12]
deepspeed provides a dedicated activation checkpointing API in the deepspeed.checkpointing module. Key functions include deepspeed.checkpointing.configure() to set options, deepspeed.checkpointing.checkpoint() to checkpoint a block, deepspeed.checkpointing.is_configured(), deepspeed.checkpointing.reset() to clear buffers between forward passes, and deepspeed.checkpointing.model_parallel_cuda_manual_seed() to initialise RNG state for model-parallel runs.[4] The configuration options include partition_activations (split checkpoints across tensor-parallel ranks), contiguous_memory_optimization (copy checkpoints into a contiguous buffer, requires partitioning to be enabled), checkpoint_in_cpu (offload checkpoints to CPU memory, requires partitioning), and profile (log forward and backward time per checkpoint invocation).[4][10]
PyTorch's fsdp integrates activation checkpointing through apply_activation_checkpointing from torch.distributed.algorithms._checkpoint.checkpoint_wrapper.[5] The user supplies a checkpoint_wrapper_fn callable that wraps modules and a check_fn or auto_wrap_policy that decides which submodules to wrap. The typical pattern is to wrap each transformer layer, mirroring the strategy used by megatron lm and DeepSpeed.[5]
The original megatron lm codebase implements activation checkpointing at the granularity of transformer layers via the checkpoint-activations flag and the related checkpoint-num-layers.[13] Setting checkpoint-num-layers to 1 (checkpoint at every layer boundary) is a common configuration; the GPT-NeoX-20B model was trained with exactly these settings.[13] Since 2022, Megatron-LM also implements the selective recomputation strategy of Korthikanti et al. as a configurable mode alongside the original full recomputation path.[6]
Activation checkpointing is effectively universal in training large language models. Public training reports from the major labs document its use as follows.
The original llama paper, "LLaMA: Open and Efficient Foundation Language Models" by Touvron et al., released by Meta in February 2023, states that "To further improve training efficiency, we reduced the amount of activations that are recomputed during the backward pass with checkpointing. More precisely, we save the activations that are expensive to compute, such as the outputs of linear layers."[7] The paper explicitly cites Korthikanti et al. and reports that, in conjunction with sequence and model parallelism, the resulting training pipeline processed roughly 380 tokens per second per GPU on 2,048 A100 GPUs with 80 GB of memory, training the 65B-parameter model in approximately 21 days over 1.4T tokens.[7] LLaMA's implementers also note that they "manually implemented the backward function for the transformer layers, instead of relying on the PyTorch autograd", a workaround that allowed them to keep selective recomputation across the boundaries between attention and feed-forward operations.[7] This pattern of bypassing the framework's default gradient pipeline to gain finer-grained checkpointing control reappears in numerous large-scale training projects.
The 175B-parameter GPT-3 was trained using model parallelism across V100 GPUs; activations exceeded single-GPU capacity, requiring memory-saving strategies. The 530B-parameter Megatron-Turing NLG model, trained jointly by NVIDIA and Microsoft on 2,240 A100 GPUs, used a 3D parallelism strategy with activation checkpointing applied to every transformer block.[6][14] Korthikanti et al. measure that for that model the activation memory required for the baseline (no recomputation) approach is well above 80 GB per GPU, exceeding the A100 HBM budget; full activation recomputation brings it below the budget at the cost of 30 to 40 percent execution time overhead, while their selective scheme achieves a 5x reduction with under 10 percent overhead.[6]
Google's PaLM, GPT-NeoX-20B, BLOOM, and most subsequent open-weight LLMs above 1B parameters use some form of activation checkpointing.[13] The standard configuration in the open-source megatron lm and DeepSpeed stacks is to checkpoint at every transformer layer with non-reentrant checkpointing and either CPU offload or partitioned activations as needed.[4][5] GPT-NeoX-20B was trained with checkpoint-activations: True and checkpoint-num-layers: 1, the per-layer setting that maximises memory savings at the cost of more recomputation work.[13]
While the discussion above focuses on pre-training, activation checkpointing is also pervasive in fine-tuning. Hugging Face's transformers library exposes a gradient_checkpointing=True flag on most model configurations and on the TrainingArguments object, which routes the model's forward through torch.utils.checkpoint.[2] LoRA and full fine-tuning pipelines on consumer hardware (24 GB to 80 GB GPUs) routinely enable it; combined with mixed precision training in BF16 and 8-bit optimisers, it makes fine-tuning of 7B-parameter models feasible on a single GPU.[7]
Activation checkpointing addresses one specific source of training memory: intermediate activations. Other techniques target different sources.
| Technique | Target | Memory savings | Compute overhead |
|---|---|---|---|
| Gradient checkpointing | Activations | O(sqrt(n)) of original | ~20 to 30 percent[9][15] |
| mixed precision training | Activations + weights + grads | ~2x via FP16/BF16 | Slight speedup |
| ZeRO-1/2/3 | Optimiser state, grads, params | up to ~Nx with N data-parallel ranks | Communication overhead |
| Optimiser offload | Optimiser state | Moves to CPU/NVMe | Communication overhead |
| flash attention | Attention activations only | Attention reduced from O(N^2) to O(N) | None (often a speedup) |
These approaches are complementary; the standard recipe for training a multi-billion-parameter LLM combines activation checkpointing, mixed precision, ZeRO or FSDP sharding, and FlashAttention.[4][5][6][7] Korthikanti et al. note that ZeRO-style optimisations and offload techniques "are complementary to the techniques presented here and could be additionally employed for even greater memory savings."[6]
The choice between full and selective activation checkpointing depends on architecture and scale. For pre-transformer architectures (ResNets, large CNNs), the original Chen et al. sqrt-n strategy or its per-block analogue suffices. For transformers above ~1B parameters, selective recomputation or flash attention is usually preferable because the attention softmax and dropout layers are exactly the memory-heavy, compute-light operations the selective strategy targets.[6]
Activation checkpointing's main cost is wall-clock time. Korthikanti et al. measured 30 to 40 percent execution time overhead from full activation recomputation on their training runs.[6] PyTorch's distributed documentation gives a similar 20 to 30 percent range as a rule of thumb.[5] Cybertronai's TensorFlow implementation reported around 20 percent.[9] The exact overhead depends on the architecture (the ratio of compute to activation memory in each block), the chosen checkpoint granularity, and whether selective recomputation is in use.[6]
A subtler limitation concerns randomness. Operations such as dropout produce different outputs on each forward pass; if the second (recomputation) pass uses a different RNG state from the first, the gradient is incorrect. PyTorch's torch.utils.checkpoint preserves the RNG state by default but emits warnings that operations involving randomness across multiple devices may still produce incorrect gradients.[2] Disabling RNG-state preservation via preserve_rng_state=False is an option only when the wrapped function contains no stochastic operations.[2]
The reentrant implementation in PyTorch, the original variant, places additional restrictions: at least one input must have requires_grad=True, certain backward hooks can interfere with the recursive autograd call, and it interacts poorly with fsdp's state synchronisation, potentially producing deadlocks or wrong gradients.[5] These issues motivated the non-reentrant implementation, which the PyTorch documentation strongly recommends.[2]
Activation checkpointing also has no effect on memory consumed by parameters, gradients, or optimiser state. For these, sharding techniques such as ZeRO and fsdp are required. In practice, parameter and optimiser memory often dominate at the smaller end of the LLM range; activation memory dominates as sequence length and batch size grow, which is why checkpointing becomes essential at long-context training.[6]
Finally, the technique can interact awkwardly with certain compilation paths. PyTorch issues track interactions between torch.compile, FSDP, and the various checkpoint wrappers; the recommended workaround is to use non-reentrant checkpointing and to wrap modules at the transformer block boundary.[5]