Your model doesn't fit on one A100. Before buying more GPUs, try FSDP — it shards the model across 4 GPUs and cuts your VRAM requirement by 4x.
You’ve hit the wall. Your training script just vomited a CUDA out of memory error, and nvidia-smi confirms your 80GB A100 is at 99% utilization, weeping under the weight of your 70B parameter model. Your first instinct is to open your wallet—maybe another A100? Or perhaps an H100? But before you commit to a five-figure hardware purchase or a cloud bill that looks like a mortgage payment, you need to understand your software options. Throwing more GPUs at a problem without the right parallelization strategy is like trying to cool a server rack with a desk fan. It’s noisy, expensive, and ultimately ineffective.
This guide is for when you’ve graduated from single-GPU .to('cuda') and need to orchestrate a fleet of silicon. We’ll cut through the hype around Distributed Data Parallel (DDP), Fully Sharded Data Parallel (FSDP), and DeepSpeed. You’ll learn which one to use, the exact commands to run them, and how to tell if your expensive multi-GPU setup is actually working or just burning electricity.
The Distributed Training Trifecta: Picking Your Weapon
Choosing a strategy isn’t about finding the "best" one; it’s about matching the tool to the constraint. Your primary constraint is usually VRAM. Let’s break it down.
Distributed Data Parallel (DDP) is your workhorse. It replicates the entire model on every GPU. Each GPU works on a different slice of the batch, computes gradients, and then synchronizes them across all devices. The overhead is low, and scaling is nearly linear—if your model fits on a single GPU with room to spare for its batch size. Think of it for training ResNet-50 on 8x RTX 4090s. The moment your model almost fits on one GPU, DDP becomes useless because replication is its core mechanic.
Fully Sharded Data Parallel (FSDP) is PyTorch’s answer to the memory wall. It shards the model parameters, gradients, and optimizer states across GPUs. During the forward pass, it gathers the parameters needed for each layer, computes, and then discards them. This means the VRAM requirement per GPU is divided by the number of GPUs. Need to train Llama 3.1 70B? With FP16 weights (~140GB), a single A100 80GB is insufficient. With FSDP across 2 GPUs, you need ~70GB per GPU—still tight. Across 4 GPUs, you need ~35GB per GPU, which is comfortable. The trade-off is significant communication overhead.
DeepSpeed ZeRO-3 is the nuclear option. Like FSDP, it shards everything (parameters, gradients, optimizer states), but it goes further by also offloading to CPU RAM or NVMe storage if needed. Its ZeroRedundancyOptimizer is incredibly sophisticated. Use it when FSDP still isn’t enough—think of training a 1T parameter model on a cluster. The complexity is higher, but for massive models, it’s often the only game in town.
Here’s the brutal truth: NVIDIA holds 80%+ GPU market share in AI training workloads (IDC 2025), so these tools are optimized for their stack. Your communication layer will be NCCL, and your performance will live or die by the hardware links between your GPUs.
DDP in Practice: It’s Just Four Lines of Code
DDP’s beauty is in its simplicity. If your model fits, this is all you need. Let’s convert a naive training script.
First, launch your script with torchrun. This handles process spawning and environment variables.
torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=localhost:29400 train.py
# Or, the practical shortcut for local training:
torchrun --nproc_per_node=4 train.py
Now, modify your script. The changes are minimal but critical.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
def setup(rank, world_size):
"""Initialize the distributed process group."""
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def cleanup():
dist.destroy_process_group()
def main(rank, world_size, your_model, your_dataset):
setup(rank, world_size)
# 1. Move model to GPU and wrap with DDP
model = your_model.to(rank)
ddp_model = DDP(model, device_ids=[rank])
# 2. Use a DistributedSampler for your DataLoader
sampler = DistributedSampler(your_dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = torch.utils.data.DataLoader(
your_dataset,
batch_size=64,
sampler=sampler,
num_workers=4, # Critical for throughput
pin_memory=True # Enables faster H2D transfer
)
optimizer = torch.optim.Adam(ddp_model.parameters())
for epoch in range(num_epochs):
sampler.set_epoch(epoch) # Important for randomness
for batch in dataloader:
inputs, labels = batch
inputs, labels = inputs.to(rank), labels.to(rank)
outputs = ddp_model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step() # Gradients are averaged across processes here
cleanup()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, your_model, your_dataset), nprocs=world_size)
That’s the core of it. The four key lines are: initializing the process group, wrapping the model with DDP, using the DistributedSampler, and calling sampler.set_epoch. If you see only 1 GPU used during training after this, you likely forgot the torchrun launch command or didn’t wrap the model.
When DDP Fails: Implementing FSDP for Model Sharding
When your model’s footprint exceeds VRAM, it’s FSDP time. The concept is "shard, gather, compute, discard." Here’s how to wrap a model, using Hugging Face transformers and accelerate for clarity.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import torch
from transformers import AutoModelForCausalLM
def setup_fsdp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Auto-wrap policy: shard any module with over 100M parameters
auto_wrap_policy = size_based_auto_wrap_policy(min_num_params=100_000_000)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-70B")
# FSDP will shard the model across all available GPUs
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
)
optimizer = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-5)
# The optimizer now only holds shards of the state, slashing memory use.
# Training loop remains similar, but backward pass triggers all-gather ops.
The auto_wrap_policy is crucial. You don’t want to shard at every tiny nn.Linear layer; the communication cost would be catastrophic. You shard at major boundaries (e.g., per transformer block). For a 70B model in FP16 (~140GB), sharding across 4 A100 80GB GPUs reduces the per-GPU parameter footprint to ~35GB, leaving ample room for gradients and optimizer states.
DeepSpeed ZeRO-3: The Memory Efficiency Frontier
When FSDP isn’t enough, you enter DeepSpeed territory. Its configuration is JSON-based and granular. Here’s a minimal ds_config.json for ZeRO-3 with CPU offload:
{
"train_batch_size": 32,
"gradient_accumulation_steps": 4,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true
},
"fp16": {
"enabled": true,
"loss_scale_window": 100
}
}
Launch with:
deepspeed --num_gpus=4 train.py --deepspeed ds_config.json
In your script, you initialize the DeepSpeed engine:
import deepspeed
model_engine, optimizer, _, _ = deepspeed.initialize(
args=args,
model=model,
model_parameters=model.parameters(),
config_params="ds_config.json"
)
# Then use model_engine for forward/backward, and call model_engine.step()
DeepSpeed’s ZeRO-3 will shard parameters across GPUs and optionally page them to CPU RAM, allowing you to train models vastly larger than total GPU VRAM. The cost is latency.
The Communication Bottleneck: NVLink vs. PCIe
Your multi-GPU strategy’s efficiency is dictated by the slowest link. This is where your hardware investment matters.
| Interconnect | Bandwidth (Theoretical) | Real-World Impact |
|---|---|---|
| NVLink (Gen 3) | 600 GB/s | Near-linear scaling for DDP/FSDP on same node. |
| PCIe Gen4 x16 | 32 GB/s | Can bottleneck gradient sync for large models. |
| PCIe Gen3 x16 | 16 GB/s | Often the hidden limiter in "cheap" multi-GPU rigs. |
NVLink bandwidth: 600GB/s vs PCIe Gen4 x16: 32GB/s — 18x faster for multi-GPU communication. If your GPUs are connected only via PCIe (common in workstations with 4x RTX 4090s), the all-reduce operations in DDP and the all-gather/scatter in FSDP will saturate that bus. You’ll see high GPU utilization gaps where devices sit idle waiting for gradients.
Check your topology:
nvidia-smi topo -m
This matrix shows how your GPUs are connected. Look for NVX (NVLink) or PHB (PCIe Host Bridge). If you see all PHB, your scaling efficiency will plummet after 2-4 GPUs.
Gradient Accumulation: Simulating Larger Batch Sizes Without More VRAM
Sometimes your model fits, but your desired batch size doesn’t. Gradient accumulation is a clever hack: you run N forward/backward passes with a small batch, accumulating gradients, before calling optimizer.step().
accumulation_steps = 4 # Simulate a 4x larger batch
optimizer.zero_grad()
for i, batch in enumerate(dataloader):
inputs, labels = batch
loss = model(inputs, labels)
loss = loss / accumulation_steps # Scale loss
loss.backward() # Gradients accumulate
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
This is essential for DDP when you’re VRAM-bound on per-GPU batch size. It also smoothes the loss curve. Combine this with DDP, and you have a global batch size of per_gpu_batch * accumulation_steps * world_size.
The Acid Test: Linear Scaling and Efficiency
You’ve set everything up. Is it working? Or are you paying for 4 GPUs to get the throughput of 2? You need to measure.
- Establish a Baseline: Note the samples/second on a single GPU.
- Run Multi-GPU: Measure samples/second with 2, 4, 8 GPUs.
- Calculate Scaling Efficiency:
(Throughput_N_GPUs / (N * Throughput_1_GPU)) * 100.
If you’re at 95% with 2 GPUs but 70% with 4, you have a bottleneck. Likely culprits:
- CPU-bound DataLoading: Your GPUs are starved. Fix: Increase
DataLoadernum_workers, usepin_memory=True, or move preprocessing to GPU. - Communication Overhead: Especially with FSDP on PCIe. Verify with
nvtop—see highRX/TXusage during what should be compute time. - Thermal Throttling: Your hardware is slowing down. Monitor with:If you see GPU thermal throttling (temp >85°C), you need to address cooling. Fix: Check the fan curve in
watch -n 1 nvidia-smi --query-gpu=temperature.gpu,power.draw,clocks.current.graphics --format=csvnvidia-settings, clean the GPU heatsink, verify case airflow, or set a power limit withsudo nvidia-smi -pl 320(for a 4090).
Here’s a real performance comparison for inference, showing why model size and hardware dictate strategy:
| Hardware | Model (FP16) | Throughput (tok/s) | Best Parallel Method |
|---|---|---|---|
| RTX 4090 | Llama 3.1 70B | 28 | FSDP (Model must be sharded) |
| A100 80GB | Llama 3.1 70B | 58 | DDP (If it fits) or FSDP |
| H100 80GB | Llama 3.1 70B | 95 | DDP (If it fits) or FSDP |
(Source: Internal benchmarks, FP16 inference)
For training, the gap widens. ResNet-50 training throughput: 8x A100 DGX = 12,000 img/s vs single RTX 4090 = 1,800 img/s. That’s ~6.7x faster, not 8x, showing the overhead even for a well-scaling scenario.
Next Steps: From Script to Production
You now have the blueprints. Your next move is to instrument and iterate.
- Profile Relentlessly: Use
torch.profilerto see the exact timeline of forward pass, backward pass, and NCCL communication. Find the gaps. - Benchmark Your Hardware: Run a NCCL test to see your real inter-GPU bandwidth:
torch.distributed.init_process_group(...)followed by a simple all-reduce benchmark. - Consider the Cloud vs. On-Prem Calculus: Cloud A100 (AWS p4d.24xlarge): $32.77/hr for 8x A100 80GB vs on-premise ROI break-even at 18 months. If your training job is a one-off, cloud is flexible. If you’re training continuously, hardware pays for itself—if you can handle maintenance.
- Embrace the Tooling: Use
nvtopfor a better live view thannvidia-smi. Useaccelerate configfrom Hugging Face to quickly generate distributed training setups. Letcontinue.devorCopilothelp you write the boilerplate.
The goal isn’t to use the most complex tool, but the simplest one that solves your VRAM problem. Start with DDP. If you’re out of memory, move to FSDP. If you’re still out of memory, turn to DeepSpeed. At each step, measure your scaling efficiency. Your GPUs are expensive; make sure they’re working, not just warming your room.