# Partitioning strategy

> Source: https://aiwiki.ai/wiki/partitioning_strategy
> Updated: 2026-04-25
> Categories: MLOps, Training & Optimization
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

*See also: [Data parallelism](/wiki/data_parallelism), [Model parallelism](/wiki/model_parallelism), [Pipeline parallelism](/wiki/pipeline_parallelism), [Tensor parallelism](/wiki/tensor_parallelism), [Distributed computing](/wiki/distributed_computing)*

A **partitioning strategy** in distributed deep learning is the plan that decides how a model and its training data are split across multiple accelerators (typically [GPUs](/wiki/gpu_computing) or TPUs) so that the workload can be executed in parallel. The choice of strategy determines whether a network can fit in device memory at all, how fast each step runs, how much network traffic the cluster generates, and how well the system scales when more devices are added. For modern [large language models](/wiki/llm) with hundreds of billions of parameters, a single GPU cannot store even the parameters of one transformer block in full precision, so a carefully chosen partitioning strategy is the gating factor that decides whether training is feasible at all.

The term is also used in classical [machine learning](/wiki/machine_learning) to describe how a dataset is divided into training, validation, and test subsets. That second meaning is covered briefly at the end of this article. Most of the page is about the parallelism sense, since that is where the term is encountered most often in 2024 era systems work.

## Why partitioning is necessary

Three forces push researchers toward partitioning their training jobs:

1. Memory pressure. A 175 billion parameter model in 16 bit precision needs about 350 GB just for weights, plus a comparable amount for gradients and roughly four times as much again for the Adam optimizer state. An NVIDIA H100 has 80 GB of HBM, so even one copy of the parameters does not fit on a single device.
2. Compute throughput. Training a [GPT-3](/wiki/gpt-3) sized model on a single accelerator would take many human lifetimes. Splitting the work across thousands of devices brings wall clock time down to a few months.
3. Activation memory during the forward and backward pass. Even when weights fit, the intermediate tensors produced during backpropagation can be larger than the parameters themselves, especially for long context lengths.

A partitioning strategy decides which of these axes (parameters, gradients, optimizer state, activations, the batch, or the sequence) is split, where the splits are placed, and how the resulting fragments communicate during each training step.

## Major parallelism strategies

### Data parallelism

[Data parallelism](/wiki/data_parallelism) is the simplest and most widely used strategy. Each worker holds a complete copy of the model. The global mini batch is sliced across workers, every worker computes gradients on its shard, and an [all-reduce](/wiki/all_reduce) step averages the gradients before the optimizer updates the weights. Because every replica applies the same update, the replicas stay bitwise identical between steps. PyTorch exposes this through `torch.nn.parallel.DistributedDataParallel`, TensorFlow through `tf.distribute.MirroredStrategy`, and JAX through `jax.pmap` or `jax.jit` with sharded inputs. The downside is obvious: every worker must store the full model, so data parallelism on its own cannot scale to models that exceed device memory.

### Tensor parallelism

Tensor parallelism splits the weight matrices of individual layers across devices. The technique was popularized by Megatron-LM (Shoeybi et al., 2019), which showed how a transformer feedforward block could be sharded along its hidden dimension by column splitting the first linear layer and row splitting the second. The matrix multiplications are computed in parallel and stitched back together with two collective operations per block, an all-reduce on the forward pass and another on the backward pass. Attention is sharded similarly by partitioning the heads across devices. Tensor parallelism keeps activations small per device but adds heavy intra layer communication, so it is normally confined to GPUs that share a single node and can talk over NVLink. NVIDIA's reference Megatron-LM implementation typically uses tensor parallel groups of size 8, matching the eight GPUs in a DGX style node. See [tensor parallelism](/wiki/tensor_parallelism) and [Megatron-LM](/wiki/megatron-lm) for details.

### Pipeline parallelism

Pipeline parallelism splits the model along its depth. Each device hosts a contiguous block of layers (a stage), and mini batches are chopped into smaller micro batches that flow through the stages in a pipeline. GPipe (Huang et al., 2018) introduced synchronous micro batched pipelining, in which all forward passes complete before all backward passes start. PipeDream (Narayanan et al., 2019) added the 1F1B (one forward, one backward) schedule that interleaves the two passes to keep all stages busy and shrink the pipeline bubble. The interleaved 1F1B schedule used in modern Megatron-LM splits each stage into multiple virtual stages to reduce bubble time further. See [pipeline parallelism](/wiki/pipeline_parallelism) for the schedule diagrams. Pipeline parallelism uses only point to point send and receive operations between adjacent stages, so it tolerates slower interconnects and is the natural way to scale across nodes.

### Sharded data parallelism (ZeRO and FSDP)

The Zero Redundancy Optimizer (Rajbhandari et al., 2020) attacks the memory waste in plain data parallelism, where every replica stores its own copy of the optimizer state, gradients, and parameters. [ZeRO](/wiki/zero) defines three progressively aggressive stages:

- ZeRO-1 partitions the optimizer state across data parallel ranks.
- ZeRO-2 also partitions the gradients.
- ZeRO-3 also partitions the parameters themselves, gathering them on demand for the forward and backward passes through all-gather, then discarding them after the layer finishes.

ZeRO-3 was first shipped in [DeepSpeed](/wiki/deepspeed) and then re-implemented in PyTorch as Fully Sharded Data Parallel ([FSDP](/wiki/fsdp)) by Zhao et al. in 2023. FSDP wraps a model so that each parameter is stored in a flat sharded form on one rank, then materialized just in time for the layer that needs it. The communication pattern of ZeRO-3 and FSDP looks like reduce-scatter on gradients followed by all-gather on parameters. Total bytes moved per step are similar to a normal data parallel all-reduce, but the memory savings are large enough that single node training of multi billion parameter models becomes possible.

### Expert parallelism

Mixture of Experts (MoE) models contain many feedforward blocks per layer, only a few of which are activated for any given token. Expert parallelism places different experts on different devices and routes each token to the device that owns its assigned expert. The Switch Transformer (Fedus, Zoph, and Shazeer, 2022) demonstrated trillion parameter models with this approach by routing each token to a single expert, which simplifies routing and lowers communication. Mixtral 8x7B (Jiang et al., 2024) uses eight experts per layer with top 2 routing, giving 47 billion total parameters but only about 13 billion active per token. The all-to-all collective is the dominant communication pattern in expert parallelism, since tokens have to be shuffled to their experts and the results shuffled back. See [expert parallelism](/wiki/expert_parallelism).

### Sequence parallelism

Sequence parallelism splits activations along the sequence (token) dimension within a tensor parallel group. Korthikanti et al. (2022) introduced it for the LayerNorm and dropout operations of a transformer, where standard tensor parallelism would otherwise replicate large activation tensors across the tensor parallel ranks. By sharding these regions across the same group, activation memory is reduced by roughly a factor equal to the tensor parallel size, with no extra communication beyond what tensor parallelism already pays. The technique combined with selective activation recomputation cut activation memory by 5x and the recomputation overhead by more than 90% on a 530 billion parameter run on 2,240 A100 GPUs.

### Context parallelism and ring attention

For very long contexts (tens of thousands or millions of tokens), even a sharded activation can blow past device memory. Ring Attention (Liu, Zaharia, and Abbeel, 2023) arranges devices in a ring, gives each device a chunk of the query and key/value tensors, and rotates the key/value chunks around the ring while computing blockwise attention. Communication is overlapped with compute, so the per device memory footprint scales down with the number of ring members rather than with sequence length. Meta's Llama 3 405B run used a ring based context parallel scheme with size 16 to support 128k token training without exhausting HBM (Llama 3 Herd of Models, 2024).

## Comparison of strategies

| Strategy | What is split | Per-step communication | Memory savings | Typical placement | Reference |
|---|---|---|---|---|---|
| [Data parallelism](/wiki/data_parallelism) | Batch | All-reduce on gradients | None on weights | Across nodes | Krizhevsky 2014 |
| [Tensor parallelism](/wiki/tensor_parallelism) | Weight matrices within a layer | All-reduce per matmul | Weights, activations | Within a node, NVLink | Shoeybi 2019 |
| [Pipeline parallelism](/wiki/pipeline_parallelism) | Layers (depth) | Send/recv between stages | Weights | Across nodes | Huang 2018, Narayanan 2019 |
| ZeRO-1 | Optimizer state | All-reduce | 4x | Within data parallel group | Rajbhandari 2020 |
| ZeRO-2 | Optimizer state, gradients | Reduce-scatter, all-reduce | 8x | Within data parallel group | Rajbhandari 2020 |
| ZeRO-3 / [FSDP](/wiki/fsdp) | Optimizer state, gradients, parameters | All-gather, reduce-scatter | Linear in rank | Within data parallel group | Rajbhandari 2020, Zhao 2023 |
| [Expert parallelism](/wiki/expert_parallelism) | MoE experts | All-to-all on tokens | Active params per device | Per MoE layer group | Fedus 2022 |
| Sequence parallelism | LayerNorm/dropout activations | Reuses tensor parallel collectives | Activations | Same group as TP | Korthikanti 2022 |
| Context parallelism (ring attention) | Query and KV along sequence | Ring of point-to-point sends | Attention activations | Across nodes | Liu 2023 |

## Hybrid 3D and 4D parallelism

Real world LLM training rarely uses one strategy alone. The standard recipe today is **3D parallelism**, which composes data, tensor, and pipeline parallelism into a single device mesh. The mesh is sized so that the slowest collective (tensor parallel all-reduce) lives on the fastest interconnect (NVLink within a node), pipeline send and receive crosses InfiniBand between nodes, and data parallel gradient synchronization runs in the background overlapped with compute. ZeRO-1 or FSDP is usually layered on top of the data parallel axis to reclaim the memory the optimizer state would otherwise duplicate. Adding context or sequence parallelism gives a fourth axis, which Meta calls 4D parallelism in the Llama 3 paper.

The table below shows the configurations used for several landmark training runs.

| Model | Parameters | Hardware | Reported parallelism | Source |
|---|---|---|---|---|
| GPT-3 (OpenAI, 2020) | 175B | ~10,000 V100 (later replicated on 3,072 A100) | DP + TP=8 + PP=64 in NVIDIA's reproduction | Narayanan et al., 2021 |
| Megatron-Turing NLG | 530B | 2,240 A100 | DP=35 + TP=8 + PP=35 | Korthikanti 2022 |
| PaLM (Google, 2022) | 540B | 6,144 TPU v4 | 256-way FSDP, 12-way model parallelism per pod, 2-pod data parallel | Chowdhery et al., 2022 |
| Mixtral 8x7B | 47B (13B active) | Not disclosed | DP + Expert Parallelism + TP | Jiang et al., 2024 |
| Llama 3 405B | 405B | 16,384 H100 (up to 24,576) | DP + TP=8 + PP=16 + CP=16, FSDP ZeRO-1 | Meta AI, 2024 |

## Communication primitives

Every partitioning strategy reduces in the end to a small set of collective and point to point operations, almost always provided by NVIDIA's NCCL library on GPU clusters or by the XLA compiler on TPUs.

| Primitive | Use case |
|---|---|
| all-reduce | Gradient averaging in [data parallelism](/wiki/data_parallelism); activation sync inside a tensor parallel block |
| all-gather | Materialize sharded parameters in ZeRO-3/FSDP; gather sharded activations |
| reduce-scatter | Distribute partial gradient sums in ZeRO-2/FSDP |
| all-to-all | Token routing in expert parallelism; transposing sharded layouts |
| send / recv | Stage-to-stage hand off in pipeline parallelism; KV rotation in ring attention |
| broadcast | Push initial parameters from rank 0 |

A classic identity worth memorizing is that an all-reduce is mathematically equal to a reduce-scatter followed by an all-gather. ZeRO-2 and ZeRO-3 exploit this to express the same total work as a normal data parallel all-reduce while keeping each rank's memory footprint sharded the rest of the time.

## Hardware considerations

The partitioning strategy must match the cluster topology, because the cost of a collective scales with the slowest link it crosses.

Inside a node, NVIDIA NVLink and NVSwitch deliver 900 GB/s of bidirectional bandwidth per H100, which is roughly an order of magnitude faster than InfiniBand. Tensor parallelism, which issues two all-reduces per transformer block, is almost always confined to one node so that these collectives stay on NVLink. Between nodes, InfiniBand HDR or NDR and RDMA over Converged Ethernet (RoCE) provide 200 to 400 Gb/s per port. Pipeline parallel send and receive and FSDP all-gathers are scheduled here because they are bandwidth tolerant and easier to overlap with compute. Google's TPU v4 pods take a different approach: their 3D torus optical interconnect with reconfigurable optical circuit switches gives PaLM training the bandwidth profile needed to skip pipeline parallelism entirely and rely on 12-way model parallelism plus 256-way FSDP within a pod.

Getting these mappings wrong is one of the most common causes of poor scaling. A tensor parallel group that crosses a node boundary often performs worse than a smaller group that stays inside one DGX, even though the larger group has more aggregate memory.

## Frameworks and libraries

| Framework | Strategies supported | Notes |
|---|---|---|
| [PyTorch](/wiki/pytorch) DDP | Data parallelism | Built into `torch.nn.parallel`; uses NCCL all-reduce |
| PyTorch [FSDP](/wiki/fsdp) | ZeRO-3 style sharded data parallelism | Native PyTorch; integrated with torch.compile |
| PyTorch DTensor and DeviceMesh | TP, PP, DP composition | Modern API used by torchtitan and Llama 3 training stack |
| [DeepSpeed](/wiki/deepspeed) (Microsoft) | ZeRO-1/2/3, ZeRO-Offload, pipeline | Originated ZeRO; widely used with Hugging Face Accelerate |
| [Megatron-LM](/wiki/megatron-lm) (NVIDIA) | TP, PP, sequence parallelism, expert parallelism | Reference implementation for transformer training |
| Megatron-Core | Library form of Megatron-LM | Used by NeMo, Llama 3, and several open recipes |
| [JAX](/wiki/jax) jit + shard_map | Arbitrary SPMD partitioning via GSPMD | Powers Flax, MaxText, T5X, PaLM training |
| [TensorFlow](/wiki/tensorflow) Mesh-TF and DTensor | TP, PP, DP | Used in earlier Google training stacks |
| Hugging Face Accelerate | DDP, FSDP, DeepSpeed wrapper | One-line config for the common cases |
| Colossal-AI | DP, TP, PP, ZeRO, sequence parallelism | Open source training stack from HPC-AI Tech |
| vLLM | TP, PP, expert and decode context parallelism | Inference oriented; shards KV cache too |
| TensorRT-LLM and Text Generation Inference (TGI) | TP, PP | Production inference servers |

## Tradeoffs

No strategy is free. Each one trades along a different axis of the memory, compute, communication triangle.

- Data parallelism scales best in throughput but does nothing for the memory wall.
- Tensor parallelism slashes per device memory and activation footprint but pays heavy intra layer communication, which only NVLink class fabrics make tolerable.
- Pipeline parallelism shifts memory cost onto activation stashes and introduces the pipeline bubble. Larger micro batch counts and the 1F1B and interleaved schedules shrink the bubble but do not eliminate it.
- ZeRO-3 and FSDP cut memory in proportion to the data parallel size, but every layer pays an all-gather, so step time can rise if the gathers are not overlapped with compute.
- Expert parallelism lets total parameter count grow far beyond active parameter count, but the all-to-all router becomes the bottleneck once experts are scattered across many nodes.
- Sequence and context parallelism unlock long context training, but only after the simpler memory wins from FSDP and tensor parallelism are already exhausted.

Gradient checkpointing, also called activation recomputation, is an orthogonal technique that trades extra forward compute for lower activation memory. It is almost always combined with the strategies above. Mixed precision in BF16 or FP16 (and increasingly FP8 on Hopper class GPUs) halves or quarters the bytes moved by every collective and is essentially mandatory at scale. Parameter and optimizer offload to CPU memory or NVMe, supported in DeepSpeed ZeRO-Infinity, is a memory of last resort that lets very large models fit on tiny clusters at the cost of moving bytes over PCIe.

## Inference time partitioning

The same ideas apply to serving large models, with one twist: inference workloads care about latency and KV cache memory, not gradient throughput. Tensor parallelism is by far the most common strategy for serving multi GPU models in [vLLM](/wiki/vllm), TensorRT-LLM, and Text Generation Inference. As a side effect, sharding the weights across `tp_size` GPUs frees a large fraction of HBM for the KV cache, so the achievable token throughput typically rises faster than linearly with tensor parallel size. For very long contexts, vLLM and others now support context parallelism that shards the KV cache along the sequence axis, mirroring ring attention. Pipeline parallelism is used at inference for models too large to fit even after tensor parallelism, but it adds pipeline bubble like latency on the first token.

## Choosing a strategy

A practical rule of thumb for new training jobs:

1. If the model fits comfortably on one device, use plain data parallelism (DDP).
2. If the model fits but the optimizer state does not, switch to FSDP or DeepSpeed ZeRO.
3. If a single layer is too large for one device, add tensor parallelism inside the node (TP up to the number of NVLink connected GPUs).
4. If the whole model still does not fit across one node, add pipeline parallelism across nodes.
5. If the active parameter count is the bottleneck rather than total memory, consider an MoE design with expert parallelism.
6. If sequences are very long, layer in sequence parallelism, then ring/context parallelism.

This ladder reflects how Megatron-LM, DeepSpeed, and Llama 3's training stack are configured in practice.

## The dataset partitioning meaning

In classical machine learning, the term partitioning strategy is sometimes used in a different sense to describe how a dataset is split into training, validation, and test subsets. The two main approaches in that sense are the **holdout method**, which carves the dataset into fixed train, validation, and test sets (often in 70/15/15 or 80/10/10 ratios), and **k-fold cross-validation**, which divides the data into k equal folds and rotates which fold is used for validation so that every example contributes to both training and evaluation. Stratified k-fold cross-validation preserves class proportions in each fold and is the standard choice for imbalanced classification tasks. These methods are about evaluation rigor rather than scaling and have no direct relationship to the parallelism strategies described above. They are covered in more detail at [training set](/wiki/training_set), [validation set](/wiki/validation_set), and [test set](/wiki/test_set).

## Explain like I'm 5

Imagine baking a giant cake that is too big for one oven. You can split the work in different ways. Several friends each bake a smaller cake from the same recipe and share the pictures, then average their notes (data parallelism). One friend bakes the bottom layer, another the middle, another the top, and they pass the cake along (pipeline parallelism). Or you let each friend bake one slice of every layer, then glue the slices back together (tensor parallelism). For really, really big cakes you do all three at once. The tricky part is that whenever the friends have to talk to each other, the cake stops baking, so you want to choose the splits that make them talk as little as possible.

## References

1. Shoeybi, M. et al. (2019). Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism. arXiv:1909.08053.
2. Huang, Y. et al. (2018). GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism. arXiv:1811.06965 (NeurIPS 2019).
3. Narayanan, D. et al. (2019). PipeDream: Generalized Pipeline Parallelism for DNN Training. SOSP 2019.
4. Rajbhandari, S., Rasley, J., Ruwase, O., He, Y. (2020). ZeRO: Memory Optimizations Toward Training Trillion Parameter Models. SC20. arXiv:1910.02054.
5. Zhao, Y. et al. (2023). PyTorch FSDP: Experiences on Scaling Fully Sharded Data Parallel. PVLDB 16(12). arXiv:2304.11277.
6. Korthikanti, V. et al. (2022). Reducing Activation Recomputation in Large Transformer Models. arXiv:2205.05198.
7. Fedus, W., Zoph, B., Shazeer, N. (2022). Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity. JMLR 23. arXiv:2101.03961.
8. Liu, H., Zaharia, M., Abbeel, P. (2023). Ring Attention with Blockwise Transformers for Near-Infinite Context. arXiv:2310.01889.
9. Chowdhery, A. et al. (2022). PaLM: Scaling Language Modeling with Pathways. arXiv:2204.02311.
10. Jiang, A. et al. (2024). Mixtral of Experts. arXiv:2401.04088.
11. Meta AI (2024). The Llama 3 Herd of Models. arXiv:2407.21783.
12. Narayanan, D. et al. (2021). Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM. SC21. arXiv:2104.04473.
13. NVIDIA. NCCL User Guide: Collective Operations. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html
14. vLLM project. Parallelism and Scaling. https://docs.vllm.ai/en/stable/serving/parallelism_scaling/
15. Rasley, J. et al. (2020). DeepSpeed: System Optimizations Enable Training Deep Learning Models with Over 100 Billion Parameters. KDD 2020.

