Use Ray Cluster for Distributed AI Training in Python

Scale PyTorch model training across multiple nodes with Ray Cluster. Set up workers, distribute data, and cut training time by 70%+.

Problem: Training AI Models Takes Too Long on One Machine

Your PyTorch model trains for hours — or days — on a single GPU. Adding more GPUs to one box has limits. You need to train across multiple machines without rewriting all your code.

You'll learn:

  • How to set up a Ray Cluster with head and worker nodes
  • How to distribute PyTorch training with ray.train
  • How to monitor jobs and handle failures gracefully

Time: 30 min | Level: Intermediate


Why This Happens

PyTorch's native DistributedDataParallel works across GPUs but requires manual process management and is painful to scale across machines. Ray wraps that complexity — you write normal Python, Ray handles worker coordination, data sharding, and fault tolerance.

Common symptoms without distributed training:

  • Single GPU memory errors on large models
  • Training loops that take 12+ hours for production datasets
  • No easy way to resume after node failure

Solution

Step 1: Install Ray and Set Up the Cluster

pip install "ray[train]" torch torchvision

On your head node (the machine you'll connect to):

ray start --head --port=6379

On each worker node:

# Replace with your head node's IP
ray start --address='192.168.1.10:6379'

Expected: Each worker prints Ray runtime started. and connects to the head.

If it fails:

  • "Connection refused": Check firewall — port 6379 must be open between nodes
  • "RuntimeError: address already in use": Run ray stop first, then retry

Ray dashboard showing connected nodes Ray dashboard at http://head-node-ip:8265 — all workers should show "ALIVE"


Step 2: Write a Trainable Function

Ray Train wraps your training loop. The key is using ray.train primitives instead of manual torch.distributed calls.

import ray
from ray import train
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

def train_loop_per_worker(config):
    # Ray handles process group init — no manual dist.init_process_group needed
    model = nn.Linear(config["input_size"], config["output_size"])
    
    # Wrap model for distributed training across workers
    model = train.torch.prepare_model(model)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    loss_fn = nn.MSELoss()

    # Prepare your DataLoader — Ray shards data across workers automatically
    X = torch.randn(1000, config["input_size"])
    y = torch.randn(1000, config["output_size"])
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=config["batch_size"])
    loader = train.torch.prepare_data_loader(loader)  # Distributes batches

    for epoch in range(config["epochs"]):
        total_loss = 0
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            pred = model(batch_X)
            loss = loss_fn(pred, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # Report metrics back to Ray — visible in dashboard
        train.report({"loss": total_loss / len(loader), "epoch": epoch})

Why prepare_model and prepare_data_loader: These handle device placement and DDP wrapping. Skip them and you'll get shape mismatch errors or all workers training on the same data.


Step 3: Launch the Distributed Job

ray.init(address="auto")  # Connects to your running cluster

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={
        "input_size": 128,
        "output_size": 10,
        "lr": 1e-3,
        "batch_size": 64,
        "epochs": 10,
    },
    scaling_config=ScalingConfig(
        num_workers=4,           # One per GPU across your cluster
        use_gpu=True,            # False if CPU-only nodes
        resources_per_worker={"GPU": 1},
    ),
)

result = trainer.fit()
print(f"Final loss: {result.metrics['loss']:.4f}")

Expected: Ray schedules workers across nodes and streams loss metrics to your Terminal.

Training metrics in Ray dashboard Metrics update in real time — each worker's loss converges together

If it fails:

  • "No GPUs available": Set use_gpu=False or verify CUDA drivers on worker nodes with nvidia-smi
  • "ObjectStoreFullError": Increase object_store_memory in ray.init() — e.g., ray.init(object_store_memory=10_000_000_000)

Step 4: Save and Load Checkpoints

Ray Train has built-in checkpointing so jobs survive node failures.

from ray.train import Checkpoint
import tempfile, os

def train_loop_per_worker(config):
    model = nn.Linear(config["input_size"], config["output_size"])
    model = train.torch.prepare_model(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
    
    # Resume from checkpoint if one exists
    checkpoint = train.get_checkpoint()
    start_epoch = 0
    if checkpoint:
        with checkpoint.as_directory() as ckpt_dir:
            state = torch.load(os.path.join(ckpt_dir, "checkpoint.pt"))
            model.load_state_dict(state["model"])
            optimizer.load_state_dict(state["optimizer"])
            start_epoch = state["epoch"] + 1

    for epoch in range(start_epoch, config["epochs"]):
        # ... training loop ...

        # Save checkpoint every epoch — only worker 0 saves to avoid conflicts
        if train.get_context().get_world_rank() == 0:
            with tempfile.TemporaryDirectory() as tmpdir:
                torch.save(
                    {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch},
                    os.path.join(tmpdir, "checkpoint.pt"),
                )
                train.report({"loss": 0.0}, checkpoint=Checkpoint.from_directory(tmpdir))

Why world_rank == 0: All workers run the same code. Without this guard, every worker tries to write the checkpoint simultaneously, causing corrupted files.


Verification

After training completes, check the best checkpoint:

best_checkpoint = result.best_checkpoints[0][0]
print(f"Best checkpoint path: {best_checkpoint}")
print(f"Metrics: {result.metrics_dataframe[['epoch','loss']].tail()}")

You should see: A path to the saved checkpoint and a table of loss values decreasing over epochs.

# Confirm all workers finished cleanly
ray status

You should see: All workers in IDLE state, no FAILED entries.


What You Learned

  • Ray wraps torch.distributed so you don't manage process groups manually
  • prepare_model and prepare_data_loader handle DDP and data sharding — both are required
  • Checkpoint with world_rank == 0 guard to avoid write conflicts
  • ScalingConfig is where you control worker count and GPU allocation

Limitation: Ray Train adds ~5-10% overhead vs raw DDP for very small models. It shines at scale — 4+ GPUs and large datasets where coordination savings outweigh overhead.

When NOT to use this: Single-GPU training, quick experiments, or models that fit in memory on one machine. Raw PyTorch is faster to iterate on for prototyping.


Tested on Ray 2.9, PyTorch 2.2, Python 3.11, Ubuntu 22.04