Fully Sharded Data Parallel (FSDP)
Last reviewed
Apr 30, 2026
Sources
25 citations
Review status
Source-backed
Revision
v2 · 3,717 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
Apr 30, 2026
Sources
25 citations
Review status
Source-backed
Revision
v2 · 3,717 words
Add missing citations, update stale details, or suggest a clearer explanation.
Fully Sharded Data Parallel (FSDP) is a distributed training technique implemented in PyTorch that shards a model's parameters, gradients, and optimizer states across data-parallel workers, allowing models with billions or trillions of parameters to be trained on commodity GPU clusters without resorting to complex tensor or pipeline parallelism. FSDP was first developed at Meta AI and Facebook AI Research (initially as part of the FairScale library in 2021) and was integrated into core PyTorch as a beta feature in PyTorch 1.11, released on March 10, 2022 [1][2][3].
FSDP is the PyTorch counterpart to Microsoft DeepSpeed's ZeRO (Zero Redundancy Optimizer) Stage 3, sharing the same memory-saving idea of partitioning model state across data-parallel ranks while preserving the data-parallel programming model [4]. Since its release, FSDP has become the de-facto choice for large-scale training of foundation models in PyTorch and is used for large language models such as the Llama and Mistral families and diffusion models like Stable Diffusion. It is integrated into higher-level frameworks including Hugging Face Accelerate, PyTorch Lightning, and Mosaic Composer [5][6][7]. In July 2024, PyTorch 2.4 introduced FSDP2, a redesigned API at torch.distributed._composable.fsdp.fully_shard that uses per-parameter sharding via DTensor instead of the original FlatParameter design, providing better composability with tensor parallelism and clearer state semantics [8].
Training ever-larger neural networks on multiple GPUs has historically relied on three families of parallelism, each with distinct trade-offs.
Data parallelism, exemplified by PyTorch's DistributedDataParallel (DDP), replicates the entire model on each worker and synchronizes gradients via an all-reduce after every backward pass. DDP is simple, but every rank holds a full copy of parameters, gradients, and optimizer states. For an Adam-style optimizer in mixed precision the per-parameter footprint is roughly 16 bytes (2 bytes for fp16 parameters, 2 bytes for fp16 gradients, 4 bytes for the fp32 master copy, and 8 bytes for the two Adam moments), so a 7-billion-parameter model already needs more than 100 GB just for state, before activations [4][9]. DDP alone cannot train modern foundation models even on H100-class hardware.
Model parallelism splits individual layers across GPUs. Its modern incarnation, tensor parallelism as popularized by Megatron-LM, partitions matrix multiplications along an input or output dimension and requires all-reduce operations inside the forward and backward of each linear layer. Tensor parallelism scales well intra-node where NVLink bandwidth is high, but it is fragile across nodes and demands rewrites of attention and feed-forward layers. Pipeline parallelism splits the layer stack into stages, each placed on a different device, and feeds micro-batches through the pipeline; it avoids the bandwidth requirements of tensor parallelism but introduces pipeline bubbles and complicates gradient accumulation.
In 2019 and 2020, researchers at Microsoft led by Samyam Rajbhandari proposed ZeRO (Zero Redundancy Optimizer) as a way to recover the memory savings of model parallelism while keeping the data-parallel programming model. ZeRO observed that DDP's full replication of parameters, gradients, and optimizer states is wasteful: in principle each rank only needs the slice it is updating during the optimizer step. By sharding optimizer states (Stage 1), gradients (Stage 2), and parameters themselves (Stage 3) across the data-parallel group, ZeRO reduces per-rank memory by a factor proportional to the world size at the cost of extra collective communication [4]. ZeRO-Infinity and ZeRO-Offload extended this idea by spilling sharded state to CPU memory or NVMe storage [10].
In July 2021, Meta's FairScale team published an initial PyTorch implementation of ZeRO-3 called Fully Sharded Data Parallel, written by Myle Ott and colleagues for use in fairseq's training of large language models [1]. That work was ported into core PyTorch and shipped as torch.distributed.fsdp.FullyShardedDataParallel in the 1.11 release on March 10, 2022 [2][3]. A 2023 VLDB paper by Zhao et al. documents the design choices and lessons learned from scaling FSDP to large foundation-model workloads at Meta [9].
FSDP organizes a model into a tree of FSDP units, where each unit is a subtree of the nn.Module graph wrapped with FullyShardedDataParallel (FSDP1) or registered via the fully_shard API (FSDP2). Each unit's parameters are flattened, concatenated, and split into equal shards across the data-parallel ranks. At the start of training, every rank holds only its 1/N slice of every unit's parameters.
During the forward pass, FSDP walks the model unit by unit. Just before each unit's forward is invoked, FSDP issues an all_gather collective so that every rank temporarily reconstructs the full parameter tensor for that unit. The forward computation runs on the full parameters as if the model were replicated. Once the forward returns, FSDP frees the gathered parameters and each rank again holds only its shard. Peak parameter memory at any moment is therefore the size of the largest single unit's full parameters, plus the sharded copies of all other units, rather than the full model.
During the backward pass, FSDP walks the units in reverse. Before each unit's backward, FSDP again all-gathers the unit's parameters because the backward formula for many layers needs the weights. After the backward computes the gradient with respect to those parameters, FSDP issues a reduce_scatter collective: this both averages the gradient across ranks (the equivalent of DDP's all-reduce) and scatters the result so that each rank ends up holding only the gradient slice corresponding to its parameter shard. The full gradient is never materialized on any single rank. The gathered parameters are then freed.
During the optimizer step, each rank updates only its local shard of parameters using its local shard of gradients and optimizer states. No cross-rank communication is required for the step itself. Optimizer states are effectively sharded as a side effect of sharding parameters and gradients: each rank only ever instantiates the moments for the slice it owns.
A core performance technique is overlapping these collectives with computation. FSDP supports prefetching the all-gather for the next unit while the current unit's forward or backward is still running on the GPU, controlled by the forward_prefetch and backward_prefetch arguments [11]. With suitable wrapping granularity and prefetching, the all-gather and reduce-scatter time can be largely hidden behind compute, so FSDP achieves throughput comparable to DDP at much lower memory.
FSDP exposes a ShardingStrategy enum that lets users dial back the degree of sharding, mirroring the ZeRO stages. The most common strategies are summarized below [11][12].
| Strategy | Sharded state | ZeRO equivalent | Typical use |
|---|---|---|---|
FULL_SHARD | Parameters, gradients, optimizer states | ZeRO-3 | Default. Maximum memory savings, used for large models. |
SHARD_GRAD_OP | Gradients and optimizer states (parameters replicated) | ZeRO-2 | When parameters fit on each GPU but optimizer state does not. Less collective overhead than full shard. |
NO_SHARD | Nothing (full replication) | DDP | Equivalent to DDP. Useful for debugging or comparison runs. |
HYBRID_SHARD | Sharded inside each node, replicated across nodes | ZeRO-3 within node | Reduces inter-node bandwidth. Good fit for clusters where intra-node NVLink is much faster than InfiniBand. |
_HYBRID_SHARD_ZERO2 | Like HYBRID_SHARD with ZeRO-2 inside each node | ZeRO-2 within node | Tuning bandwidth vs. memory. |
HYBRID_SHARD is important at large scale. With it, FSDP shards parameters across GPUs inside a single node and replicates the sharded view across nodes, so most heavy collectives stay on intra-node NVLink while a single all-reduce handles cross-node synchronization [9].
FSDP only saves memory if the model is split into multiple units. If the entire model is wrapped as a single unit, the full parameter tensor is gathered every forward and the only savings come from sharded optimizer state (essentially ZeRO-1 behavior) [11]. Choosing the right wrapping granularity is therefore one of the most important tuning decisions when adopting FSDP. FSDP1 ships several auto-wrap policies in torch.distributed.fsdp.wrap:
size_based_auto_wrap_policy traverses the module tree and wraps any submodule whose parameter count exceeds a min_num_params threshold (typically tens or hundreds of millions). It is a sensible default for unfamiliar architectures.transformer_auto_wrap_policy wraps each instance of a specified set of transformer-block classes (such as LlamaDecoderLayer, GPT2Block, or T5Block). This produces one FSDP unit per transformer layer, the standard layout for large language models because every block has roughly the same parameter count and compute time, so all-gather overlap is uniform across the model.lambda_auto_wrap_policy accepts a user-defined callable for fine-grained control over which submodules become units.Users can also wrap submodules manually by calling FSDP(submodule, ...) directly. Manual wrapping is common when combining FSDP with tensor parallelism, where certain layers need a specific wrapping order. FSDP2's fully_shard API replaces wrapping with a function call that registers a module as a sharded unit, avoiding the FlatParameter machinery that FSDP1 used to glue many nn.Parameter objects into a single contiguous shard [8].
FSDP integrates mixed precision training through the MixedPrecision configuration object, which exposes three independent dtypes [11][13]:
param_dtype is used when parameters are gathered for forward and backward computation. Setting this to bfloat16 or float16 cuts the all-gather payload in half versus fp32 and lets matrix multiplications run on tensor cores at higher throughput.reduce_dtype is used during the reduce-scatter of gradients. Many large-model recipes set this to float32 even when param_dtype is bf16, because reductions over many ranks are sensitive to numeric error.buffer_dtype controls the dtype of non-parameter buffers (such as batch-norm running statistics).A common recipe for large language models is to keep parameters in bfloat16 and reduce gradients in float32, while the optimizer holds master parameters and Adam moments in float32. This delivers nearly the throughput of pure bf16 with stability comparable to full fp32. FSDP performs the dtype casts internally during all-gather and reduce-scatter, so users do not need to manually cast tensors in the model code.
Sharding parameters, gradients, and optimizer states attacks a substantial portion of training memory, but for transformers the activations stored for backward can still dwarf parameter memory at sequence lengths of 4k or more. Activation checkpointing, also called gradient checkpointing, recomputes activations during the backward pass instead of storing them, trading roughly one third extra forward-pass compute for a large activation-memory reduction [14].
FSDP composes with activation checkpointing through torch.distributed.algorithms._checkpoint.checkpoint_wrapper. The standard recipe wraps each transformer block first with checkpoint_wrapper and then with FSDP, so that the same block is both sharded and recomputed. Recent PyTorch releases also support selective activation checkpointing, which keeps cheap-to-store activations and recomputes only the expensive ones [9]. Combined with FULL_SHARD and bf16 mixed precision, activation checkpointing is what makes it possible to train 70-billion-parameter models on clusters of 64 to 512 NVIDIA H100 GPUs without resorting to tensor or pipeline parallelism.
For extreme cases where even a sharded model does not fit on the GPU, FSDP can offload sharded parameters and gradients to CPU memory between forward and backward, mirroring DeepSpeed ZeRO-Offload [10]. This is enabled with cpu_offload=CPUOffload(offload_params=True) [11]. The optimizer step then runs on the CPU, with shards copied back to GPU during the next iteration's all-gather. The cost is significant: every iteration pays for two PCIe transfers per parameter, and CPU optimizer steps are slow compared to GPU. CPU offload is most useful for fine-tuning very large models on a small number of NVIDIA A100 or H100 GPUs, where the alternative would be no training at all.
Serializing an FSDP-trained model requires special care because no rank holds the full parameter tensor at rest. PyTorch supports several state-dict modes via FSDP.state_dict_type [11][15]:
FULL_STATE_DICT materializes the unsharded parameter tensors and writes them out, by default only on rank 0. Convenient for inference checkpoints loaded outside the training cluster, but for very large models the full tensors may not fit in a single rank's CPU memory.SHARDED_STATE_DICT writes each rank's shard separately, producing a checkpoint shaped like the runtime layout. This avoids the all-gather and rank-0 memory pressure of the full mode and is recommended for distributed checkpointing of foundation models.LOCAL_STATE_DICT exposes the raw FlatParameter shards; deprecated in favor of SHARDED_STATE_DICT.Since PyTorch 2.0, the recommended path is the Distributed Checkpoint API (DCP), exposed as torch.distributed.checkpoint. DCP saves and loads sharded state dicts in a format that decouples the save layout from the load layout: a checkpoint saved on 32 GPUs can be reloaded on 64 GPUs with a different parallelism strategy without manual re-sharding [16]. DCP is the default checkpoint backend for both FSDP1 and FSDP2 in modern PyTorch.
In July 2024, PyTorch 2.4 introduced a redesigned FSDP API called FSDP2, exposed as torch.distributed._composable.fsdp.fully_shard [8]. The most fundamental change is the move from FlatParameter to per-parameter sharding via DTensor. FSDP1 concatenated all parameters of a unit into a one-dimensional FlatParameter and split that flat buffer evenly across ranks. The flat-buffer approach was communication-efficient but caused several pain points: parameters smaller than the world size could not be cleanly sharded, introspection required reasoning about the flat layout, and integration with tensor parallelism was clumsy because parameters lost their original shape until unflattened.
FSDP2 instead represents each parameter as a DTensor carrying an explicit sharding spec, sharded along its leading dimension across the data-parallel mesh, and parameters retain their original logical shape throughout training. This makes FSDP2 compose naturally with tensor parallelism, also implemented on DTensor: a parameter can be tensor-parallel-sharded along one mesh dimension and FSDP-sharded along another in the same spec. This 2D parallelism, exemplified by TorchTitan, is the modern PyTorch recipe for training models above 100 billion parameters across many nodes [8]. Other FSDP2 improvements include lazy initialization, clearer mixed-precision semantics, and a more explicit lifecycle for the all-gather buffers. FSDP2 is the recommended choice for new projects; FSDP1 remains supported for backward compatibility.
FSDP sits in a family of distributed-training approaches that solve overlapping but distinct problems [4][9][17][18][19].
| Technique | Primary memory savings | Programming model | Typical scale |
|---|---|---|---|
| DDP | None (full replication) | Data parallel | Models that fit on one GPU. |
| FSDP / ZeRO-3 | Parameters, gradients, optimizer states | Data parallel | 10M to 100B+ parameters. |
| DeepSpeed | Same as FSDP plus ZeRO-Infinity, ZeRO-Offload, MoE | Custom engine | Same as FSDP, plus extreme offload. |
| Megatron-LM | Tensor and pipeline parallelism | Custom framework | 1B to 1T parameter dense LLMs. |
| Megatron-DeepSpeed | Tensor + pipeline + ZeRO | Custom hybrid | Trillion-parameter dense and MoE. |
| Colossal-AI | ZeRO + tensor + pipeline | Custom framework | Open-source large-model training. |
JAX pjit / shard_map | Per-tensor SPMD sharding | Functional SPMD | TPU-first foundation models. |
FSDP shards storage but keeps compute data-parallel, while tensor and pipeline parallelism shard compute itself. FSDP can be combined with both: 2D and 3D parallelism layouts where FSDP shards across one mesh dimension and tensor or pipeline parallelism handles another are now standard at trillion-parameter scale.
FSDP has been adopted across academia and industry as the dominant PyTorch path to large-model training.
Meta has reported using FSDP variants for parts of the Llama training pipeline. The Llama 2 paper (July 2023) describes training on Meta's Research Super Cluster and internal production clusters with PyTorch-based infrastructure that includes FSDP for the data-parallel dimension [20]. The Llama 3 herd-of-models paper (2024) describes a 4D parallelism stack (tensor, context, pipeline, data) for the 405B model, where FSDP-style data-parallel sharding remains a component [21]. The PyTorch FSDP VLDB 2023 paper itself reports scaling experiments at Meta with models up to 1 trillion parameters [9].
The Hugging Face ecosystem integrates FSDP through Accelerate, which exposes FSDP configuration through a YAML file or CLI prompts and handles wrapping policy, mixed precision, and state-dict serialization automatically. Hugging Face also documents FSDP support in the Transformers Trainer class for fine-tuning models like Llama, Mistral, Falcon, and Mixtral on multi-GPU nodes [22]. PyTorch Lightning ships an FSDPStrategy that selects FSDP as the distributed backend with one keyword argument [23]. Mosaic Composer, the training library behind MosaicML's MPT and Databricks's DBRX, defaults to FSDP for large-model runs and was instrumental in popularizing the HYBRID_SHARD strategy for cost-efficient training [24].
Research labs including Stanford CRFM, the Allen Institute for AI (AI2), and EleutherAI use FSDP in their open-source training stacks. AI2's OLMo training code, released in 2024, is built on FSDP with transformer_auto_wrap_policy and DCP checkpointing [25]. TorchTitan, a Meta-led reference repository released in 2024, demonstrates FSDP2 plus tensor parallelism plus pipeline parallelism on the Llama architecture as a canonical example of modern PyTorch large-model training. FSDP is also used outside language modeling: Stable Diffusion XL and Stable Diffusion 3 fine-tuners routinely use FSDP through Hugging Face's diffusers and Accelerate for full-parameter fine-tuning of the larger U-Net or DiT backbones.
FSDP is significantly more complex than DDP, and several pitfalls trip up new users [9][11].
All-gather overhead can dominate at small batch sizes or when wrapping granularity is too fine. If every layer is its own unit, the per-step number of collectives explodes; if too few units are used, peak parameter memory rises because more parameters are held in fully gathered form simultaneously. Tuning the wrapping policy along with forward_prefetch and backward_prefetch is often the difference between FSDP being faster or slower than DDP at a given memory budget.
Gradient accumulation requires the no_sync() context manager to be correct. Without it, FSDP reduce-scatters gradients on every micro-batch backward, defeating the point of accumulation. Inside no_sync(), FSDP keeps gradients local until the final micro-batch, then performs a single reduce-scatter. The trade-off is that during no_sync() each rank holds the unsharded gradient for the units it is currently updating, raising peak memory.
Mixed precision configuration is subtle. Setting reduce_dtype to fp16 (rather than bf16 or fp32) is a common cause of training divergence in large runs, because fp16 has too little dynamic range for averaged gradients across many ranks. Buffers (especially batch-norm statistics) often need to be left in fp32. State-dict serialization is also non-trivial: FULL_STATE_DICT requires assembling a full unsharded copy on rank 0, which can run that rank out of CPU memory at trillion-parameter scale. SHARDED_STATE_DICT plus DCP is the recommended path, but it produces a directory of files rather than a single .pt file, sometimes complicating downstream tooling.
FSDP1 has long-standing FlatParameter quirks for parameters whose first dimension is smaller than the world size: the flat buffer is padded to be divisible by the world size, and that padding leaks into introspection and mixed-precision casts. FSDP2 addresses this by sharding parameters individually as DTensors, though not every third-party library has migrated.
Cross-node communication can be the bottleneck even with the best wrapping policy. On clusters where nodes are connected by 200 Gbps InfiniBand but each node has eight GPUs on NVLink at terabits per second, inter-node FULL_SHARD collectives can stall the step. HYBRID_SHARD mitigates this at the cost of higher per-rank parameter memory. The choice depends on model size relative to per-node GPU memory and on network topology, and is one of the most consequential knobs in large FSDP runs.
Finally, FSDP composes with torch.compile, but the integration was not always seamless: early PyTorch 2.x releases had graph breaks at FSDP boundaries that undid much of the compile gain. As of PyTorch 2.4 and FSDP2 this has improved, but FSDP1 plus torch.compile should still be profiled carefully.