See also: Data parallelism, Model parallelism, Pipeline parallelism, Tensor parallelism, Distributed computing
A partitioning strategy in distributed deep learning is the plan that decides how a model and its training data are split across multiple accelerators (typically GPUs or TPUs) so that the workload can be executed in parallel. The choice of strategy determines whether a network can fit in device memory at all, how fast each step runs, how much network traffic the cluster generates, and how well the system scales when more devices are added. For modern large language models with hundreds of billions of parameters, a single GPU cannot store even the parameters of one transformer block in full precision, so a carefully chosen partitioning strategy is the gating factor that decides whether training is feasible at all.
The term is also used in classical machine learning to describe how a dataset is divided into training, validation, and test subsets. That second meaning is covered briefly at the end of this article. Most of the page is about the parallelism sense, since that is where the term is encountered most often in 2024 era systems work.
Three forces push researchers toward partitioning their training jobs:
A partitioning strategy decides which of these axes (parameters, gradients, optimizer state, activations, the batch, or the sequence) is split, where the splits are placed, and how the resulting fragments communicate during each training step.
Data parallelism is the simplest and most widely used strategy. Each worker holds a complete copy of the model. The global mini batch is sliced across workers, every worker computes gradients on its shard, and an all-reduce step averages the gradients before the optimizer updates the weights. Because every replica applies the same update, the replicas stay bitwise identical between steps. PyTorch exposes this through torch.nn.parallel.DistributedDataParallel, TensorFlow through tf.distribute.MirroredStrategy, and JAX through jax.pmap or jax.jit with sharded inputs. The downside is obvious: every worker must store the full model, so data parallelism on its own cannot scale to models that exceed device memory.
Tensor parallelism splits the weight matrices of individual layers across devices. The technique was popularized by Megatron-LM (Shoeybi et al., 2019), which showed how a transformer feedforward block could be sharded along its hidden dimension by column splitting the first linear layer and row splitting the second. The matrix multiplications are computed in parallel and stitched back together with two collective operations per block, an all-reduce on the forward pass and another on the backward pass. Attention is sharded similarly by partitioning the heads across devices. Tensor parallelism keeps activations small per device but adds heavy intra layer communication, so it is normally confined to GPUs that share a single node and can talk over NVLink. NVIDIA's reference Megatron-LM implementation typically uses tensor parallel groups of size 8, matching the eight GPUs in a DGX style node. See tensor parallelism and Megatron-LM for details.
Pipeline parallelism splits the model along its depth. Each device hosts a contiguous block of layers (a stage), and mini batches are chopped into smaller micro batches that flow through the stages in a pipeline. GPipe (Huang et al., 2018) introduced synchronous micro batched pipelining, in which all forward passes complete before all backward passes start. PipeDream (Narayanan et al., 2019) added the 1F1B (one forward, one backward) schedule that interleaves the two passes to keep all stages busy and shrink the pipeline bubble. The interleaved 1F1B schedule used in modern Megatron-LM splits each stage into multiple virtual stages to reduce bubble time further. See pipeline parallelism for the schedule diagrams. Pipeline parallelism uses only point to point send and receive operations between adjacent stages, so it tolerates slower interconnects and is the natural way to scale across nodes.
The Zero Redundancy Optimizer (Rajbhandari et al., 2020) attacks the memory waste in plain data parallelism, where every replica stores its own copy of the optimizer state, gradients, and parameters. ZeRO defines three progressively aggressive stages:
ZeRO-3 was first shipped in DeepSpeed and then re-implemented in PyTorch as Fully Sharded Data Parallel (FSDP) by Zhao et al. in 2023. FSDP wraps a model so that each parameter is stored in a flat sharded form on one rank, then materialized just in time for the layer that needs it. The communication pattern of ZeRO-3 and FSDP looks like reduce-scatter on gradients followed by all-gather on parameters. Total bytes moved per step are similar to a normal data parallel all-reduce, but the memory savings are large enough that single node training of multi billion parameter models becomes possible.
Mixture of Experts (MoE) models contain many feedforward blocks per layer, only a few of which are activated for any given token. Expert parallelism places different experts on different devices and routes each token to the device that owns its assigned expert. The Switch Transformer (Fedus, Zoph, and Shazeer, 2022) demonstrated trillion parameter models with this approach by routing each token to a single expert, which simplifies routing and lowers communication. Mixtral 8x7B (Jiang et al., 2024) uses eight experts per layer with top 2 routing, giving 47 billion total parameters but only about 13 billion active per token. The all-to-all collective is the dominant communication pattern in expert parallelism, since tokens have to be shuffled to their experts and the results shuffled back. See expert parallelism.
Sequence parallelism splits activations along the sequence (token) dimension within a tensor parallel group. Korthikanti et al. (2022) introduced it for the LayerNorm and dropout operations of a transformer, where standard tensor parallelism would otherwise replicate large activation tensors across the tensor parallel ranks. By sharding these regions across the same group, activation memory is reduced by roughly a factor equal to the tensor parallel size, with no extra communication beyond what tensor parallelism already pays. The technique combined with selective activation recomputation cut activation memory by 5x and the recomputation overhead by more than 90% on a 530 billion parameter run on 2,240 A100 GPUs.
For very long contexts (tens of thousands or millions of tokens), even a sharded activation can blow past device memory. Ring Attention (Liu, Zaharia, and Abbeel, 2023) arranges devices in a ring, gives each device a chunk of the query and key/value tensors, and rotates the key/value chunks around the ring while computing blockwise attention. Communication is overlapped with compute, so the per device memory footprint scales down with the number of ring members rather than with sequence length. Meta's Llama 3 405B run used a ring based context parallel scheme with size 16 to support 128k token training without exhausting HBM (Llama 3 Herd of Models, 2024).
| Strategy | What is split | Per-step communication | Memory savings | Typical placement | Reference |
|---|---|---|---|---|---|
| Data parallelism | Batch | All-reduce on gradients | None on weights | Across nodes | Krizhevsky 2014 |
| Tensor parallelism | Weight matrices within a layer | All-reduce per matmul | Weights, activations | Within a node, NVLink | Shoeybi 2019 |
| Pipeline parallelism | Layers (depth) | Send/recv between stages | Weights | Across nodes | Huang 2018, Narayanan 2019 |
| ZeRO-1 | Optimizer state | All-reduce | 4x | Within data parallel group | Rajbhandari 2020 |
| ZeRO-2 | Optimizer state, gradients | Reduce-scatter, all-reduce | 8x | Within data parallel group | Rajbhandari 2020 |
| ZeRO-3 / FSDP | Optimizer state, gradients, parameters | All-gather, reduce-scatter | Linear in rank | Within data parallel group | Rajbhandari 2020, Zhao 2023 |
| Expert parallelism | MoE experts | All-to-all on tokens | Active params per device | Per MoE layer group | Fedus 2022 |
| Sequence parallelism | LayerNorm/dropout activations | Reuses tensor parallel collectives | Activations | Same group as TP | Korthikanti 2022 |
| Context parallelism (ring attention) | Query and KV along sequence | Ring of point-to-point sends | Attention activations | Across nodes | Liu 2023 |
Real world LLM training rarely uses one strategy alone. The standard recipe today is 3D parallelism, which composes data, tensor, and pipeline parallelism into a single device mesh. The mesh is sized so that the slowest collective (tensor parallel all-reduce) lives on the fastest interconnect (NVLink within a node), pipeline send and receive crosses InfiniBand between nodes, and data parallel gradient synchronization runs in the background overlapped with compute. ZeRO-1 or FSDP is usually layered on top of the data parallel axis to reclaim the memory the optimizer state would otherwise duplicate. Adding context or sequence parallelism gives a fourth axis, which Meta calls 4D parallelism in the Llama 3 paper.
The table below shows the configurations used for several landmark training runs.
| Model | Parameters | Hardware | Reported parallelism | Source |
|---|---|---|---|---|
| GPT-3 (OpenAI, 2020) | 175B | ~10,000 V100 (later replicated on 3,072 A100) | DP + TP=8 + PP=64 in NVIDIA's reproduction | Narayanan et al., 2021 |
| Megatron-Turing NLG | 530B | 2,240 A100 | DP=35 + TP=8 + PP=35 | Korthikanti 2022 |
| PaLM (Google, 2022) | 540B | 6,144 TPU v4 | 256-way FSDP, 12-way model parallelism per pod, 2-pod data parallel | Chowdhery et al., 2022 |
| Mixtral 8x7B | 47B (13B active) | Not disclosed | DP + Expert Parallelism + TP | Jiang et al., 2024 |
| Llama 3 405B | 405B | 16,384 H100 (up to 24,576) | DP + TP=8 + PP=16 + CP=16, FSDP ZeRO-1 | Meta AI, 2024 |
Every partitioning strategy reduces in the end to a small set of collective and point to point operations, almost always provided by NVIDIA's NCCL library on GPU clusters or by the XLA compiler on TPUs.
| Primitive | Use case |
|---|---|
| all-reduce | Gradient averaging in data parallelism; activation sync inside a tensor parallel block |
| all-gather | Materialize sharded parameters in ZeRO-3/FSDP; gather sharded activations |
| reduce-scatter | Distribute partial gradient sums in ZeRO-2/FSDP |
| all-to-all | Token routing in expert parallelism; transposing sharded layouts |
| send / recv | Stage-to-stage hand off in pipeline parallelism; KV rotation in ring attention |
| broadcast | Push initial parameters from rank 0 |
A classic identity worth memorizing is that an all-reduce is mathematically equal to a reduce-scatter followed by an all-gather. ZeRO-2 and ZeRO-3 exploit this to express the same total work as a normal data parallel all-reduce while keeping each rank's memory footprint sharded the rest of the time.
The partitioning strategy must match the cluster topology, because the cost of a collective scales with the slowest link it crosses.
Inside a node, NVIDIA NVLink and NVSwitch deliver 900 GB/s of bidirectional bandwidth per H100, which is roughly an order of magnitude faster than InfiniBand. Tensor parallelism, which issues two all-reduces per transformer block, is almost always confined to one node so that these collectives stay on NVLink. Between nodes, InfiniBand HDR or NDR and RDMA over Converged Ethernet (RoCE) provide 200 to 400 Gb/s per port. Pipeline parallel send and receive and FSDP all-gathers are scheduled here because they are bandwidth tolerant and easier to overlap with compute. Google's TPU v4 pods take a different approach: their 3D torus optical interconnect with reconfigurable optical circuit switches gives PaLM training the bandwidth profile needed to skip pipeline parallelism entirely and rely on 12-way model parallelism plus 256-way FSDP within a pod.
Getting these mappings wrong is one of the most common causes of poor scaling. A tensor parallel group that crosses a node boundary often performs worse than a smaller group that stays inside one DGX, even though the larger group has more aggregate memory.
| Framework | Strategies supported | Notes |
|---|---|---|
| PyTorch DDP | Data parallelism | Built into torch.nn.parallel; uses NCCL all-reduce |
| PyTorch FSDP | ZeRO-3 style sharded data parallelism | Native PyTorch; integrated with torch.compile |
| PyTorch DTensor and DeviceMesh | TP, PP, DP composition | Modern API used by torchtitan and Llama 3 training stack |
| DeepSpeed (Microsoft) | ZeRO-1/2/3, ZeRO-Offload, pipeline | Originated ZeRO; widely used with Hugging Face Accelerate |
| Megatron-LM (NVIDIA) | TP, PP, sequence parallelism, expert parallelism | Reference implementation for transformer training |
| Megatron-Core | Library form of Megatron-LM | Used by NeMo, Llama 3, and several open recipes |
| JAX jit + shard_map | Arbitrary SPMD partitioning via GSPMD | Powers Flax, MaxText, T5X, PaLM training |
| TensorFlow Mesh-TF and DTensor | TP, PP, DP | Used in earlier Google training stacks |
| Hugging Face Accelerate | DDP, FSDP, DeepSpeed wrapper | One-line config for the common cases |
| Colossal-AI | DP, TP, PP, ZeRO, sequence parallelism | Open source training stack from HPC-AI Tech |
| vLLM | TP, PP, expert and decode context parallelism | Inference oriented; shards KV cache too |
| TensorRT-LLM and Text Generation Inference (TGI) | TP, PP | Production inference servers |
No strategy is free. Each one trades along a different axis of the memory, compute, communication triangle.
Gradient checkpointing, also called activation recomputation, is an orthogonal technique that trades extra forward compute for lower activation memory. It is almost always combined with the strategies above. Mixed precision in BF16 or FP16 (and increasingly FP8 on Hopper class GPUs) halves or quarters the bytes moved by every collective and is essentially mandatory at scale. Parameter and optimizer offload to CPU memory or NVMe, supported in DeepSpeed ZeRO-Infinity, is a memory of last resort that lets very large models fit on tiny clusters at the cost of moving bytes over PCIe.
The same ideas apply to serving large models, with one twist: inference workloads care about latency and KV cache memory, not gradient throughput. Tensor parallelism is by far the most common strategy for serving multi GPU models in vLLM, TensorRT-LLM, and Text Generation Inference. As a side effect, sharding the weights across tp_size GPUs frees a large fraction of HBM for the KV cache, so the achievable token throughput typically rises faster than linearly with tensor parallel size. For very long contexts, vLLM and others now support context parallelism that shards the KV cache along the sequence axis, mirroring ring attention. Pipeline parallelism is used at inference for models too large to fit even after tensor parallelism, but it adds pipeline bubble like latency on the first token.
A practical rule of thumb for new training jobs:
This ladder reflects how Megatron-LM, DeepSpeed, and Llama 3's training stack are configured in practice.
In classical machine learning, the term partitioning strategy is sometimes used in a different sense to describe how a dataset is split into training, validation, and test subsets. The two main approaches in that sense are the holdout method, which carves the dataset into fixed train, validation, and test sets (often in 70/15/15 or 80/10/10 ratios), and k-fold cross-validation, which divides the data into k equal folds and rotates which fold is used for validation so that every example contributes to both training and evaluation. Stratified k-fold cross-validation preserves class proportions in each fold and is the standard choice for imbalanced classification tasks. These methods are about evaluation rigor rather than scaling and have no direct relationship to the parallelism strategies described above. They are covered in more detail at training set, validation set, and test set.
Imagine baking a giant cake that is too big for one oven. You can split the work in different ways. Several friends each bake a smaller cake from the same recipe and share the pictures, then average their notes (data parallelism). One friend bakes the bottom layer, another the middle, another the top, and they pass the cake along (pipeline parallelism). Or you let each friend bake one slice of every layer, then glue the slices back together (tensor parallelism). For really, really big cakes you do all three at once. The tricky part is that whenever the friends have to talk to each other, the cake stops baking, so you want to choose the splits that make them talk as little as possible.