# Gradient checkpointing

> Source: https://aiwiki.ai/wiki/gradient_checkpointing
> Updated: 2026-06-27
> Categories: Deep Learning, Training & Optimization
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

**Gradient checkpointing**, also called **activation checkpointing**, **activation recomputation**, or **rematerialization**, is a memory-saving technique for [training](/wiki/training) deep neural networks that trades extra compute for much lower peak memory: instead of storing every intermediate activation from the forward pass, it stores only a small subset (the "checkpoints") and recomputes the rest on demand during the backward pass.[^1][^2] The technique was introduced by Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin in the 2016 paper "Training Deep Nets with Sublinear Memory Cost" (arXiv:1604.06174), which showed that a network of n layers can be trained with O(sqrt(n)) activation memory at the cost of a single extra forward pass per mini-batch.[^1] In their experiments this cut the memory of a 1,000-layer residual network from roughly 48 GB to about 7 GB for around 30 percent additional running time.[^1] As the paper states, "we propose a novel method to trade computation for memory."[^1]

Gradient checkpointing is implemented in every major training framework, including [pytorch](/wiki/pytorch) (`torch.utils.checkpoint`), [jax](/wiki/jax) (`jax.checkpoint`, aliased to `jax.remat`), [tensorflow](/wiki/tensorflow) (`tf.recompute_grad`), [deepspeed](/wiki/deepspeed), and [fsdp](/wiki/fsdp) (`apply_activation_checkpointing`).[^2][^3][^4][^5] Higher-level libraries expose it as a single flag: Hugging Face `transformers` enables it with `gradient_checkpointing=True` on a model config or on `TrainingArguments`.[^2] It is a standard component of large language model training, used in [llama](/wiki/llama), [gpt-3](/wiki/gpt-3), and the [megatron lm](/wiki/megatron_lm) family among many others.[^6][^7]

## What is gradient checkpointing? (background)

Training a neural network with [backpropagation](/wiki/backpropagation) requires the gradient of the loss with respect to every parameter. Computing those gradients via the chain rule needs intermediate values from the forward pass, called activations, at every layer where the gradient is propagated through a non-linear or otherwise non-invertible operation.[^8] In the most direct implementation of reverse-mode [automatic differentiation](/wiki/automatic_differentiation), all activations are stored during the forward pass and read back during the backward pass. Memory consumption therefore grows linearly with the depth of the network and with the batch size, sequence length, and hidden dimension.[^8]

For a [transformer](/wiki/transformer) block, the activations that must be stored include the input to attention, the queries, keys, values, the attention probability matrix (of size sequence_length squared), the attention output, the input to the feed-forward layer, the intermediate feed-forward activation (typically four times the hidden dimension), and the layer-norm statistics. Korthikanti and colleagues derive an approximate per-layer formula and show that activation memory grows linearly in batch size, sequence length, hidden dimension, and number of attention heads.[^6] At a context length of 2,048 and hidden size of 12,288 (the GPT-3 configuration), per-layer activation memory exceeds the storage required for the layer's parameters by a wide margin, making activations the dominant memory cost as models grow taller and longer-context.[^6]

As neural networks grew in depth and width during the 2010s, this linear scaling became a bottleneck. Chen and colleagues observed in 2016 that on a 1,000-layer deep residual network the activations occupied roughly 48 GB, exceeding the memory of available GPUs.[^1] The same constraint reappeared at much larger scale with the [transformer](/wiki/transformer) architecture: when training a [gpt-3](/wiki/gpt-3) style 530-billion-parameter model, activation memory alone exceeded what a single NVIDIA [nvidia a100](/wiki/nvidia_a100) GPU with 80 GB of memory could hold even after applying tensor and pipeline parallelism.[^6] Backpropagation thus requires explicit strategies to control activation memory, of which gradient checkpointing is the most widely adopted.[^1][^6]

### How is it different from a model checkpoint?

Gradient checkpointing is often confused with model checkpointing, but the two are unrelated. A model [checkpoint](/wiki/checkpoint) periodically saves model weights (and optimizer state) to disk for fault tolerance and resumption during long training runs. Gradient checkpointing, by contrast, operates entirely within a single backward pass and concerns intermediate activations held in GPU memory, not parameter snapshots written to disk. One is about saving progress; the other is about saving memory during the gradient computation.

### Predecessor ideas

The core idea (trade compute for memory by recomputing intermediates) predates the 2016 paper by decades. The [automatic differentiation](/wiki/automatic_differentiation) literature describes "checkpointing" or "revolve" schemes for adjoint computation in the context of ODE solvers and PDE-constrained optimisation, including Andreas Griewank's work in the 1990s on recursive binomial checkpointing. Chen et al. cite this earlier literature and adapt it to deep network training graphs, but their contribution is the specific O(sqrt(n)) recipe and an automated, framework-level implementation.[^1] The reframing for [deep learning](/wiki/deep_learning), and the practical demonstration on a 1,000-layer ResNet, opened the technique to widespread adoption in mainstream training frameworks.[^1]

## How does gradient checkpointing work? (algorithm)

The algorithm proposed by Chen et al. partitions an n-layer feed-forward network into segments and stores only the activation at the boundary of each segment.[^1] During the backward pass, when gradients are needed for a layer inside a segment, the forward pass is re-executed from the most recent stored checkpoint up to that layer, regenerating the missing activations on demand.[^1] Once the gradient for the layer has been computed, the recomputed activations are discarded. In the PyTorch implementation, as the documentation puts it, "the unsaved tensors are recomputed by re-invoking `function` in the backward pass as needed for gradient computation."[^2]

Formally, consider a network expressed as a composition f = f_L o ... o f_2 o f_1, where each f_i takes the activation a_(i-1) and produces a_i. Standard backpropagation stores all a_i during the forward pass. With checkpointing, the layers are grouped into segments; only the activations at segment boundaries are persisted. To compute the gradient da_i/dtheta_i inside a segment, the activations a_(j+1), ..., a_(i-1) (where a_j is the previous checkpoint) are recomputed by re-running the forward pass on the segment.[^1] The gradient is then accumulated along the chain rule exactly as in standard backpropagation; the only difference is when and how the activations are materialised.[^1]

Chen et al. observed that placing approximately sqrt(n) checkpoints in an n-layer network is the optimum of the simple memory-vs-compute tradeoff: each segment then holds at most sqrt(n) activations in memory at any one time, and the total recomputed work is one extra forward pass.[^1] Recursive application of the same strategy reduces memory further to O(log n) at the cost of O(n log n) extra forward computation, useful only in extreme regimes.[^1] In an implementation released alongside the paper, the authors reduced memory of a 1,000-layer residual network from 48 GB to 7 GB with 30 percent additional running time on the ImageNet problem.[^1]

A widely cited illustrative reimplementation by Tim Salimans and Yaroslav Bulatov demonstrated that a TensorFlow training script using their automatic graph rewrite could fit "more than 10x larger models" onto a single GPU at roughly 20 percent additional compute time, popularising the technique under the name "gradient checkpointing".[^9]

## How much memory does it save? (mathematical analysis)

Let n be the number of layers and c be the number of checkpoints. Activation memory after checkpointing scales as O(n/c + c), since at any instant the implementation holds c checkpoints plus the activations recomputed for the current segment, which contains n/c layers. Minimising n/c + c over c yields c = sqrt(n) and a memory cost of O(sqrt(n)).[^1] Compute, in contrast, scales as O(n) per backward pass plus O(n) for the additional segment-by-segment recomputation, so the asymptotic cost of the backward pass roughly doubles. In practice, because the backward pass already does about twice as much work as the forward pass under standard backpropagation, the total wall-clock cost of training with checkpointing is roughly 33 percent higher than without it, in line with the 30 percent number reported by Chen et al.[^1] The paper frames the result precisely: it gives "an algorithm that costs O(sqrt(n)) memory to train a n layer network, with only the computational cost of an extra forward pass per mini-batch."[^1]

A worked numerical example clarifies the magnitudes. Consider a 80-layer transformer with hidden size 8,192, sequence length 4,096, and batch size 4 microbatches per GPU. Without checkpointing, the activations per layer occupy roughly 0.6 GB in FP16, summing to nearly 50 GB for the layer stack alone, dominating optimiser and parameter memory on an 80 GB GPU.[^6] With full checkpointing at every layer boundary, the implementation stores only the layer input (about 0.6 GB) at each checkpoint plus the activations of one layer being recomputed; total activation memory drops below 10 GB.[^6] Korthikanti et al. depict this graphically across the 22B, 175B, 530B, and 1T model configurations, showing baseline activation memory above the 80 GB A100 ceiling in every case and the "present work" (sequence parallelism plus selective recomputation) falling below 30 GB.[^6]

For more general computation graphs, the optimisation problem becomes a graph partition: pick a subset S of edges (the "cut") such that every node has an ancestor in S, minimise the maximum size of a connected component of (graph minus S), and minimise total recomputation cost. The general problem is NP-hard, but heuristics close to the original sqrt-n approach work well for the layer-stacked computation graphs typical of feed-forward networks and [transformer](/wiki/transformer) models.[^1] More sophisticated dynamic programming solutions exist for chain-structured graphs and are sometimes called Treeverse or revolve algorithms in the [automatic differentiation](/wiki/automatic_differentiation) literature.

A second important consideration is the size of each saved activation. Two checkpoints of equal "count" can have very different memory footprints if the network's hidden dimension changes across layers (e.g., U-Nets or models with bottlenecks). Implementations therefore use byte-size-weighted variants of the cut problem in practice. The original Chen et al. analysis assumes uniform layer cost; this matches deep ResNets and isotropic transformers well but is less appropriate for encoder-decoder architectures with varying feature-map sizes.[^1]

## What variants of gradient checkpointing exist?

### Selective activation recomputation

Vijay Korthikanti and collaborators at NVIDIA introduced **selective activation recomputation** in the May 2022 paper "Reducing Activation Recomputation in Large Transformer Models".[^6] Their observation: in a [transformer](/wiki/transformer) block, certain operations (notably the softmax and dropout inside attention, and certain GeLU and dropout activations in the MLP block) consume a disproportionate amount of activation memory relative to the compute they require, while other operations (large matrix multiplications) are the opposite.[^6] The selective strategy stores activations for the compute-expensive operations and recomputes only the memory-heavy, compute-light operations.[^6]

Combined with sequence parallelism, this technique reduces activation memory by 5x while keeping the compute overhead of recomputation under 10 percent of the original training step.[^6] When training a 530B GPT-3 style model on 2,240 A100 GPUs, the authors achieved 54.2 percent Model FLOPs Utilisation, a 29 percent improvement over the 42.1 percent obtained with full activation recomputation.[^6] The implementation is part of the open-source [megatron lm](/wiki/megatron_lm) codebase.[^6]

### Asynchronous offload

A second variant moves checkpointed activations from GPU memory to either CPU host memory or even disk between the forward and backward passes, prefetching them back asynchronously when needed.[^4] This adds a third axis to the tradeoff: GPU memory, CPU memory, and compute. [deepspeed](/wiki/deepspeed) supports CPU offloading via the `checkpoint_in_cpu` parameter inside its activation checkpointing configuration, which only operates when activations are partitioned across [model parallelism](/wiki/model_parallelism) ranks.[^4][^10] The [fsdp](/wiki/fsdp) `checkpoint_wrapper` similarly supports an `offload_to_cpu` option that moves preserved inputs to host RAM and prefetches them back to the GPU immediately before recomputation.[^5]

### Partitioned activation checkpointing

In [tensor parallelism](/wiki/tensor_parallelism) setups (as used in [megatron lm](/wiki/megatron_lm)), each GPU holds only a shard of the model. DeepSpeed's `partition_activations` option distributes the activation checkpoints themselves across the tensor parallel ranks, dividing activation memory per GPU by the tensor parallel size.[^4] When combined with `contiguous_memory_optimization`, the partitioned checkpoints are copied into a single contiguous buffer, reducing memory fragmentation.[^4]

### Compiler-managed rematerialisation

In [jax](/wiki/jax), the `jax.checkpoint` decorator (alias `jax.remat`) is implemented as a transformation that interacts with the XLA compiler.[^3] Recent versions add a `policy` argument that controls which intermediate values are saved versus recomputed, with built-in policies such as `dots_with_no_batch_dims_saveable`, `save_only_these_names`, and `offload_dot_with_no_batch_dims`.[^11] The JAX documentation notes that when functions are compiled with `jax.jit`, XLA may apply rematerialisation transforms automatically, so an explicit `jax.checkpoint` is most useful around staged control flow such as `jax.lax.scan`, which is common in [transformer](/wiki/transformer) implementations.[^11]

## How is gradient checkpointing implemented in PyTorch, JAX, and TensorFlow?

### PyTorch

[pytorch](/wiki/pytorch) exposes the technique through `torch.utils.checkpoint`.[^2] The PyTorch documentation summarises the idea in one line: "Activation checkpointing is a technique that trades compute for memory."[^2] The primary entry point is `torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, ...)`, which runs `function(*args)` without saving its intermediate activations and arranges for them to be recomputed during the backward pass.[^2] A `checkpoint_sequential(functions, segments, input)` helper applies the technique to a `nn.Sequential` model by splitting it into a chosen number of segments.[^2]

The library supports two modes selected by the `use_reentrant` flag. The original, reentrant implementation invokes a nested backward pass through the autograd engine; it has limitations around inputs that require gradients and around hooks. The non-reentrant implementation (`use_reentrant=False`) was added later, supports more functionality, and is the recommended choice; PyTorch documentation states that the parameter must be explicitly specified and that future versions will raise an exception if it is omitted.[^2] Checkpointed functions preserve the random number generator state by default so that operations such as dropout produce the same values on the recomputation pass; this can be disabled with `preserve_rng_state=False` at the cost of correctness for stochastic operations.[^2]

For distributed training, PyTorch ships `apply_activation_checkpointing` in `torch.distributed.algorithms._checkpoint.checkpoint_wrapper`, which walks a model and wraps submodules selected by a user-supplied `check_fn` or `auto_wrap_policy`.[^5] For [fsdp](/wiki/fsdp) training, the recommended pattern is to wrap each transformer block (attention plus feed-forward) and to use non-reentrant checkpointing, since the reentrant variant interacts poorly with FSDP's state synchronisation.[^5] The wrapper accepts an `offload_to_cpu` argument that moves preserved inputs to host RAM and prefetches them before recomputation.[^5]

The Hugging Face `transformers` library exposes the same PyTorch mechanism through a single switch: setting `gradient_checkpointing=True` on a model configuration or on the `TrainingArguments` object routes the model's forward through `torch.utils.checkpoint`, which is why the flag is the most common way practitioners turn the technique on for fine-tuning.[^2]

### JAX

[jax](/wiki/jax) provides `jax.checkpoint`, aliased to `jax.remat`, with the signature `jax.checkpoint(fun, *, prevent_cse=True, policy=None, static_argnums=(), concrete=Deprecated)`.[^3] Applying it as a decorator marks the wrapped function so that during reverse-mode autodiff its intermediates are recomputed rather than stored.[^3] As noted above, JAX exposes several policies for fine-grained control over what is checkpointed.[^11]

### TensorFlow

[tensorflow](/wiki/tensorflow) exposes the technique through `tf.recompute_grad(f)`, a decorator that wraps a function so that its forward computation is repeated during the backward pass instead of caching intermediate tensors.[^12] The early popular implementation of the technique on TensorFlow was the `cybertronai/gradient-checkpointing` library by Tim Salimans and Yaroslav Bulatov, which automatically rewrote a TensorFlow computation graph to insert checkpoints; the project is now in maintenance mode and the official `tf.recompute_grad` API is the recommended replacement.[^9][^12]

### DeepSpeed

[deepspeed](/wiki/deepspeed) provides a dedicated activation checkpointing API in the `deepspeed.checkpointing` module. Key functions include `deepspeed.checkpointing.configure()` to set options, `deepspeed.checkpointing.checkpoint()` to checkpoint a block, `deepspeed.checkpointing.is_configured()`, `deepspeed.checkpointing.reset()` to clear buffers between forward passes, and `deepspeed.checkpointing.model_parallel_cuda_manual_seed()` to initialise RNG state for model-parallel runs.[^4] The configuration options include `partition_activations` (split checkpoints across tensor-parallel ranks), `contiguous_memory_optimization` (copy checkpoints into a contiguous buffer, requires partitioning to be enabled), `checkpoint_in_cpu` (offload checkpoints to CPU memory, requires partitioning), and `profile` (log forward and backward time per checkpoint invocation).[^4][^10]

### FSDP

PyTorch's [fsdp](/wiki/fsdp) integrates activation checkpointing through `apply_activation_checkpointing` from `torch.distributed.algorithms._checkpoint.checkpoint_wrapper`.[^5] The user supplies a `checkpoint_wrapper_fn` callable that wraps modules and a `check_fn` or `auto_wrap_policy` that decides which submodules to wrap. The typical pattern is to wrap each transformer layer, mirroring the strategy used by [megatron lm](/wiki/megatron_lm) and DeepSpeed.[^5]

### Megatron-LM

The original [megatron lm](/wiki/megatron_lm) codebase implements activation checkpointing at the granularity of transformer layers via the `checkpoint-activations` flag and the related `checkpoint-num-layers`.[^13] Setting `checkpoint-num-layers` to 1 (checkpoint at every layer boundary) is a common configuration; the GPT-NeoX-20B model was trained with exactly these settings.[^13] Since 2022, Megatron-LM also implements the selective recomputation strategy of Korthikanti et al. as a configurable mode alongside the original full recomputation path.[^6]

## Where is gradient checkpointing used in modern training?

Activation checkpointing is effectively universal in training large language models. Public training reports from the major labs document its use as follows.

### LLaMA

The original [llama](/wiki/llama) paper, "LLaMA: Open and Efficient Foundation Language Models" by Touvron et al., released by Meta in February 2023, states that "To further improve training efficiency, we reduced the amount of activations that are recomputed during the backward pass with checkpointing. More precisely, we save the activations that are expensive to compute, such as the outputs of linear layers."[^7] The paper explicitly cites Korthikanti et al. and reports that, in conjunction with sequence and model parallelism, the resulting training pipeline processed roughly 380 tokens per second per GPU on 2,048 A100 GPUs with 80 GB of memory, training the 65B-parameter model in approximately 21 days over 1.4T tokens.[^7] LLaMA's implementers also note that they "manually implemented the backward function for the transformer layers, instead of relying on the PyTorch autograd", a workaround that allowed them to keep selective recomputation across the boundaries between attention and feed-forward operations.[^7] This pattern of bypassing the framework's default gradient pipeline to gain finer-grained checkpointing control reappears in numerous large-scale training projects.

### GPT-3 and Megatron-Turing NLG

The 175B-parameter GPT-3 was trained using model parallelism across V100 GPUs; activations exceeded single-GPU capacity, requiring memory-saving strategies. The 530B-parameter Megatron-Turing NLG model, trained jointly by NVIDIA and Microsoft on 2,240 A100 GPUs, used a 3D parallelism strategy with activation checkpointing applied to every transformer block.[^6][^14] Korthikanti et al. measure that for that model the activation memory required for the baseline (no recomputation) approach is well above 80 GB per GPU, exceeding the A100 HBM budget; full activation recomputation brings it below the budget at the cost of 30 to 40 percent execution time overhead, while their selective scheme achieves a 5x reduction with under 10 percent overhead.[^6]

### PaLM and others

Google's PaLM, GPT-NeoX-20B, BLOOM, and most subsequent open-weight LLMs above 1B parameters use some form of activation checkpointing.[^13] The standard configuration in the open-source [megatron lm](/wiki/megatron_lm) and DeepSpeed stacks is to checkpoint at every transformer layer with non-reentrant checkpointing and either CPU offload or partitioned activations as needed.[^4][^5] GPT-NeoX-20B was trained with `checkpoint-activations: True` and `checkpoint-num-layers: 1`, the per-layer setting that maximises memory savings at the cost of more recomputation work.[^13]

### Training versus fine-tuning

While the discussion above focuses on pre-training, activation checkpointing is also pervasive in fine-tuning. Hugging Face's `transformers` library exposes a `gradient_checkpointing=True` flag on most model configurations and on the `TrainingArguments` object, which routes the model's forward through `torch.utils.checkpoint`.[^2] LoRA and full fine-tuning pipelines on consumer hardware (24 GB to 80 GB GPUs) routinely enable it; combined with [mixed precision training](/wiki/mixed_precision_training) in BF16 and 8-bit optimisers, it makes fine-tuning of 7B-parameter models feasible on a single GPU.[^7]

## How does it compare to other memory-saving techniques?

Activation checkpointing addresses one specific source of training memory: intermediate activations. Other techniques target different sources.

| Technique | Target | Memory savings | Compute overhead |
|---|---|---|---|
| Gradient checkpointing | Activations | O(sqrt(n)) of original | ~20 to 30 percent[^9][^15] |
| [mixed precision training](/wiki/mixed_precision_training) | Activations + weights + grads | ~2x via FP16/BF16 | Slight speedup |
| ZeRO-1/2/3 | Optimiser state, grads, params | up to ~Nx with N data-parallel ranks | Communication overhead |
| Optimiser offload | Optimiser state | Moves to CPU/NVMe | Communication overhead |
| [flash attention](/wiki/flash_attention) | Attention activations only | Attention reduced from O(N^2) to O(N) | None (often a speedup) |

These approaches are complementary; the standard recipe for training a multi-billion-parameter LLM combines activation checkpointing, mixed precision, ZeRO or FSDP sharding, and FlashAttention.[^4][^5][^6][^7] Korthikanti et al. note that ZeRO-style optimisations and offload techniques "are complementary to the techniques presented here and could be additionally employed for even greater memory savings."[^6]

The choice between full and selective activation checkpointing depends on architecture and scale. For pre-transformer architectures (ResNets, large CNNs), the original Chen et al. sqrt-n strategy or its per-block analogue suffices. For [transformer](/wiki/transformer)s above ~1B parameters, selective recomputation or [flash attention](/wiki/flash_attention) is usually preferable because the attention softmax and dropout layers are exactly the memory-heavy, compute-light operations the selective strategy targets.[^6]

## What are the limitations of gradient checkpointing?

Activation checkpointing's main cost is wall-clock time. Korthikanti et al. measured 30 to 40 percent execution time overhead from full activation recomputation on their training runs.[^6] PyTorch's distributed documentation gives a similar 20 to 30 percent range as a rule of thumb.[^5] Cybertronai's TensorFlow implementation reported around 20 percent.[^9] The exact overhead depends on the architecture (the ratio of compute to activation memory in each block), the chosen checkpoint granularity, and whether selective recomputation is in use.[^6]

A subtler limitation concerns randomness. Operations such as dropout produce different outputs on each forward pass; if the second (recomputation) pass uses a different RNG state from the first, the gradient is incorrect. PyTorch's `torch.utils.checkpoint` preserves the RNG state by default but emits warnings that operations involving randomness across multiple devices may still produce incorrect gradients.[^2] Disabling RNG-state preservation via `preserve_rng_state=False` is an option only when the wrapped function contains no stochastic operations.[^2]

The reentrant implementation in PyTorch, the original variant, places additional restrictions: at least one input must have `requires_grad=True`, certain backward hooks can interfere with the recursive autograd call, and it interacts poorly with [fsdp](/wiki/fsdp)'s state synchronisation, potentially producing deadlocks or wrong gradients.[^5] These issues motivated the non-reentrant implementation, which the PyTorch documentation strongly recommends.[^2]

Activation checkpointing also has no effect on memory consumed by parameters, gradients, or optimiser state. For these, sharding techniques such as ZeRO and [fsdp](/wiki/fsdp) are required. In practice, parameter and optimiser memory often dominate at the smaller end of the LLM range; activation memory dominates as sequence length and batch size grow, which is why checkpointing becomes essential at long-context training.[^6]

Finally, the technique can interact awkwardly with certain compilation paths. PyTorch issues track interactions between `torch.compile`, FSDP, and the various checkpoint wrappers; the recommended workaround is to use non-reentrant checkpointing and to wrap modules at the [transformer](/wiki/transformer) block boundary.[^5]

## ELI5: gradient checkpointing in plain terms

Imagine solving a long math problem and writing down every single intermediate step so you can check your work later. That uses a lot of paper. Gradient checkpointing is like keeping only a few key intermediate answers and, when you need a step you erased, quickly redoing the small piece of arithmetic that produces it. You spend a little more time recomputing, but you need far less paper (memory). For a deep network, the trick is to keep about sqrt(n) of the n intermediate results, which is why a model that would not fit in a GPU's memory suddenly does, in exchange for roughly one extra forward pass of compute.[^1]

## See also

- [backpropagation](/wiki/backpropagation)
- [automatic differentiation](/wiki/automatic_differentiation)
- [mixed precision training](/wiki/mixed_precision_training)
- [flash attention](/wiki/flash_attention)
- [fsdp](/wiki/fsdp)
- [megatron lm](/wiki/megatron_lm)
- [deepspeed](/wiki/deepspeed)
- [data parallelism](/wiki/data_parallelism)
- [tensor parallelism](/wiki/tensor_parallelism)
- [pipeline parallelism](/wiki/pipeline_parallelism)
- [model parallelism](/wiki/model_parallelism)
- [pytorch](/wiki/pytorch)
- [jax](/wiki/jax)
- [tensorflow](/wiki/tensorflow)
- [checkpoint](/wiki/checkpoint)
- [deep learning](/wiki/deep_learning)
- [training](/wiki/training)

## References

[^1]: Tianqi Chen, Bing Xu, Chiyuan Zhang, Carlos Guestrin, "Training Deep Nets with Sublinear Memory Cost", arXiv, 2016-04-21. https://arxiv.org/abs/1604.06174. Accessed 2026-06-27.
[^2]: PyTorch documentation, "torch.utils.checkpoint", PyTorch Foundation, 2026. https://docs.pytorch.org/docs/main/checkpoint.html. Accessed 2026-06-27.
[^3]: JAX documentation, "jax.checkpoint", JAX project, 2026. https://docs.jax.dev/en/latest/_autosummary/jax.checkpoint.html. Accessed 2026-06-27.
[^4]: DeepSpeed documentation, "Activation Checkpointing", DeepSpeed project, 2026. https://deepspeed.readthedocs.io/en/latest/activation-checkpointing.html. Accessed 2026-06-27.
[^5]: PyTorch source, "checkpoint_wrapper.py", PyTorch GitHub, 2026. https://github.com/pytorch/pytorch/blob/main/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py. Accessed 2026-06-27.
[^6]: Vijay Korthikanti, Jared Casper, Sangkug Lym, Lawrence McAfee, Michael Andersch, Mohammad Shoeybi, Bryan Catanzaro, "Reducing Activation Recomputation in Large Transformer Models", arXiv, 2022-05-10. https://arxiv.org/abs/2205.05198. Accessed 2026-06-27.
[^7]: Hugo Touvron et al., "LLaMA: Open and Efficient Foundation Language Models", arXiv, 2023-02-27. https://arxiv.org/abs/2302.13971. Accessed 2026-06-27.
[^8]: David E. Rumelhart, Geoffrey E. Hinton, Ronald J. Williams, "Learning representations by back-propagating errors", Nature, 1986-10-09. https://www.nature.com/articles/323533a0. Accessed 2026-06-27.
[^9]: Tim Salimans, Yaroslav Bulatov, "cybertronai/gradient-checkpointing: Make huge neural nets fit in memory", GitHub, 2018. https://github.com/cybertronai/gradient-checkpointing. Accessed 2026-06-27.
[^10]: DeepSpeed documentation, "DeepSpeed Configuration JSON", DeepSpeed project, 2026. https://www.deepspeed.ai/docs/config-json/. Accessed 2026-06-27.
[^11]: JAX documentation, "Gradient checkpointing with jax.checkpoint (aka jax.remat)", JAX project, 2026. https://docs.jax.dev/en/latest/gradient-checkpointing.html. Accessed 2026-06-27.
[^12]: TensorFlow documentation, "tf.recompute_grad", Google, 2026. https://www.tensorflow.org/api_docs/python/tf/recompute_grad. Accessed 2026-06-27.
[^13]: Sid Black et al., "GPT-NeoX-20B: An Open-Source Autoregressive Language Model", arXiv, 2022-04-14. https://arxiv.org/abs/2204.06745. Accessed 2026-06-27.
[^14]: Shaden Smith et al., "Using DeepSpeed and Megatron to Train Megatron-Turing NLG 530B, A Large-Scale Generative Language Model", arXiv, 2022-01-28. https://arxiv.org/abs/2201.11990. Accessed 2026-06-27.
[^15]: ApxML, "Activation Checkpointing Mechanics", apxml.com distributed training course, 2025. https://apxml.com/courses/distributed-training-pytorch-fsdp/chapter-3-mixed-precision-memory-optimization/activation-checkpointing-mechanics. Accessed 2026-06-27.

