PyTorch Dataset and DataLoader for Real-World Data: Augmentation, Caching, and Imbalanced Classes

Build production-quality PyTorch data pipelines — custom Dataset for on-disk data, efficient augmentation with Albumentations, in-memory caching for small datasets, WeightedRandomSampler for class imbalance.

PyTorch Dataset and DataLoader for Real-World Data: Augmentation, Caching, and Imbalanced Classes

Bad data pipeline is the most common reason PyTorch models perform worse than expected. Your model can only learn from what you feed it correctly. You can have the perfect architecture, the latest optimizer, and a GPU that costs more than your car, but if your Dataset is a mess of slow I/O, unreproducible transforms, and leaking memory, your training run is just an expensive random number generator. Let's fix that.

PyTorch Dataset Protocol: len and getitem That Actually Work

The torch.utils.data.Dataset protocol is deceptively simple: define __len__ and __getitem__. The trap is writing __getitem__ that does too much, too slowly, or on the wrong device. Here’s the first rule: keep __getitem__ lean and on the CPU. CUDA operations belong in your model's forward pass, not your data loading.

import torch
from torch.utils.data import Dataset
from pathlib import Path
import cv2  # Using OpenCV for example, but torchvision.io is also great

class ImageClassificationDataset(Dataset):
    """A Dataset that doesn't suck."""
    def __init__(self, image_dir, label_file, transform=None):
        self.image_paths = list(Path(image_dir).glob("*.jpg"))
        self.labels = self._load_labels(label_file)
        self.transform = transform

    def _load_labels(self, label_file):
        # Load once at init, not every __getitem__
        labels = {}
        with open(label_file, 'r') as f:
            for line in f:
                img_id, label = line.strip().split(',')
                labels[img_id] = int(label)
        return labels

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # 1. Get path
        img_path = self.image_paths[idx]
        # 2. Load data (CPU, blocking I/O)
        image = cv2.imread(str(img_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # OpenCV uses BGR
        # 3. Get label
        img_id = img_path.stem
        label = self.labels.get(img_id, -1)
        # 4. Apply transforms (still on CPU!)
        if self.transform:
            image = self.transform(image)
        # 5. Return tensors. Still on CPU. DataLoader will handle device transfer.
        return image, torch.tensor(label, dtype=torch.long)

The key is that __getitem__ returns CPU tensors. The DataLoader, especially with pin_memory=True, will asynchronously transfer batches to your GPU, overlapping computation with data loading. If you call .to('cuda') inside __getitem__, you'll block the main training thread and likely cause CUDA context errors in multi-worker mode.

Handling File I/O: Lazy Loading vs Pre-Loading vs Memory-Mapped Files

Your data strategy depends on size. Forget theory; here’s the decision tree:

  • Dataset fits in RAM (< 32GB on your machine): Pre-load everything in __init__. The upfront cost is trivial compared to the thousands of disk seeks during training.
  • Dataset is huge but files are large (e.g., videos, large arrays): Use memory-mapped files with libraries like numpy.memmap or PyTorch's torch.from_file. The OS caches frequently accessed parts.
  • Dataset is huge with many small files (e.g., ImageNet): You're I/O bound. Use lazy loading (as in the example above) but with a caching strategy.

Here’s a simple in-memory cache for expensive-to-load items:

from functools import lru_cache

class CachedImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, cache_size=1000):
        self.image_paths = list(Path(image_dir).glob("*.png"))
        self.transform = transform
        # Cache the last `cache_size` loaded images
        self.load_image = lru_cache(maxsize=cache_size)(self._load_image_uncached)

    def _load_image_uncached(self, img_path):
        return cv2.imread(str(img_path))

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = self.load_image(img_path)  # Returns from cache if possible
        # ... rest of processing

Augmentation Pipeline with Albumentations: 5x Faster Than torchvision Transforms

torchvision.transforms is fine, but it's slow for complex augmentations on high-resolution images. Albumentations is a library built for speed, used in top Kaggle solutions. It's optimized for numpy arrays and provides a richer set of transforms.

import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np


train_transform = A.Compose([
    A.RandomResizedCrop(224, 224, scale=(0.8, 1.0)),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=16, max_width=16, fill_value=0, p=0.3), # Cutout
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),  # Converts HWC numpy array to CxHxW torch.Tensor
])

# Use it in your Dataset
def __getitem__(self, idx):
    image = cv2.imread(str(self.image_paths[idx]))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if self.transform:
        augmented = self.transform(image=image)  # Albumentations expects keyword arg 'image'
        image = augmented['image']
    return image, label

Why is it faster? Albumentations is written in optimized C++ under the hood and avoids unnecessary internal conversions. For a standard augmentation pipeline on 512x512 images, the speedup is significant.

WeightedRandomSampler: Solving Class Imbalance Without Oversampling

Your dataset has 1000 images of cats and 50 of dogs. If you sample uniformly, your model will become a cat expert and ignore dogs. The naive fix is to oversample the dog class, but this leads to overfitting on those 50 repeated images.

The correct PyTorch solution is WeightedRandomSampler. It assigns a probability weight to every sample in the dataset, not every class. You give dog samples a higher weight so they are sampled more frequently, but you're still drawing from the original dataset, not a duplicated one.

from torch.utils.data import DataLoader, WeightedRandomSampler

# Assume `dataset` is your Dataset object and `dataset.targets` is a list of labels
labels = [label for _, label in dataset]  # Extract all labels

# 1. Calculate class weights: weight for a sample = total_samples / (num_classes * freq_of_its_class)
class_counts = torch.bincount(torch.tensor(labels))
num_classes = len(class_counts)
total_samples = len(labels)
class_weights = total_samples / (num_classes * class_counts.float())

# 2. Assign a weight to each sample
sample_weights = [class_weights[label] for label in labels]

# 3. Create the sampler
sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),  # Typically you sample the full dataset per "epoch"
    replacement=True  # Required when using weights
)

# 4. Use it in your DataLoader. DO NOT use `shuffle=True` with a sampler.
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)

Now, each batch will have a balanced representation of classes, encouraging the model to learn all features equally.

DataLoader Determinism: Fixed Seeds for Reproducibility

You got a great result. Can you get it again? Not if your data pipeline is non-deterministic. To fix this:

  1. Set all random seeds: Python, NumPy, PyTorch.
  2. Use worker_init_fn in DataLoader: This ensures each subprocess (worker) also starts with a seeded state.
  3. Disable CUDA convolution benchmarking: torch.backends.cudnn.benchmark = False. Benchmarking finds the fastest algorithm, which can vary between runs.
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def get_deterministic_dataloader(dataset, batch_size=32):
    generator = torch.Generator()
    generator.manual_seed(42)  # Seed for the sampler's RNG

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,  # Shuffling is now seeded
        num_workers=4,
        worker_init_fn=seed_worker,
        generator=generator,  # Pass the generator to DataLoader
        pin_memory=True
    )
    return loader

# At the start of your script
import random
import numpy as np
import torch

def set_global_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True  # Slower, but reproducible
    torch.backends.cudnn.benchmark = False

set_global_seed(42)

Debugging Dataset Issues: Visualizing Batches Before Training Starts

Never start a 10-hour training job without looking at your data first. Write a 5-line sanity check function and run it.

import matplotlib.pyplot as plt

def visualize_batch(dataloader, class_names, num_images=8):
    images, labels = next(iter(dataloader))  # Get one batch

    fig, axes = plt.subplots(2, 4, figsize=(12, 6))
    axes = axes.ravel()

    for i in range(num_images):
        # images[i] is CxHxW. Convert to HxWxC for matplotlib.
        img = images[i].permute(1, 2, 0).numpy()
        # If normalized, denormalize for visualization
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        img = std * img + mean
        img = np.clip(img, 0, 1)

        axes[i].imshow(img)
        axes[i].set_title(f"Label: {class_names[labels[i].item()]}")
        axes[i].axis('off')
    plt.tight_layout()
    plt.show()

# Use it before training
train_loader = get_deterministic_dataloader(train_dataset)
visualize_batch(train_loader, class_names=['cat', 'dog'])

This catches 90% of data bugs: misaligned labels, broken augmentations (e.g., black images), incorrect normalization, or tensor shape errors.

Custom Collate Function: Handling Variable-Length Sequences

The default collate_fn in DataLoader stacks tensors into a batch. It fails if your samples have variable dimensions (e.g., sentences of different lengths). You must write a custom one.

Real Error Message: RuntimeError: stack expects each tensor to be equal size, but got [10] at entry 0 and [15] at entry 1

Fix: Create a collate_fn that pads sequences to the longest in the batch.

from torch.nn.utils.rnn import pad_sequence

def collate_variable_length(batch):
    """Collate function for sequences (data, label) where data is a 1D tensor of variable length."""
    sequences, labels = zip(*batch)  # batch is a list of (sequence, label) tuples

    # Pad sequences to the max length in this batch
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)

    # Stack labels into a tensor
    labels = torch.stack(labels)

    return sequences_padded, labels

# Use it
dataloader = DataLoader(
    variable_length_dataset,
    batch_size=32,
    collate_fn=collate_variable_length,
    shuffle=True,
    num_workers=4
)

For more complex structures (e.g., graphs, nested dictionaries), your collate_fn should return a dictionary or a custom Batch object that your model's forward method knows how to unpack.

Performance Benchmarks: Why Your Choices Matter

Choosing the right tools isn't guesswork. Here’s what the benchmarks say about critical pipeline choices:

Component & ConfigurationPerformance ImpactSource / Context
DataLoader: num_workers=8 vs 03.7x training throughputImageNet loading, I/O-bound workload
Training: torch.compile on ResNet-501.8x speedup on A100, 1.4x on RTX 4090PyTorch 2.3 benchmark
Memory: AMP (FP16/BF16) vs FP322.1x training throughput, 40-50% VRAM reduction, ~0.3% accuracy dropA100 benchmark
Inference: torch.jit.script for LSTM1.3x faster than eager modePyTorch 2.3 (Note: 1.0x for Transformers)

These aren't minor gains. A slow DataLoader (num_workers=0) means your $10,000 GPU spends most of its time waiting for your CPU to load the next batch. Always set num_workers to 2-4x your CPU core count for I/O-heavy tasks, and use pin_memory=True. The combination allows data to be asynchronously transferred to GPU-ready page-locked memory.

Another Real Error Message: RuntimeError: DataLoader worker (pid(s) XXXX) exited unexpectedly with a CUDA error.

Fix: This often happens because you have CUDA operations in your Dataset.__getitem__. Move all .to(device) calls and model operations out of the dataset. If you must debug, set num_workers=0 first.

Next Steps: From Working Pipeline to Production Pipeline

You now have a correct, fast, and debuggable data pipeline. Where next?

  1. Scale with PyTorch Lightning: The stats show 68% of PyTorch training projects use PyTorch Lightning (JetBrains 2025) to eliminate boilerplate. Its LightningDataModule formalizes the setup of your datasets, transforms, and dataloaders, making your code reusable and shareable.
  2. Profile with Weights & Biases: Use W&B's system metrics to see if you're truly GPU-bound or still data-bound. Their charts will show you if your GPU utilization dips while waiting for data.
  3. Compile for Speed: Remember, torch.compile (TorchDynamo) speeds up training 1.5–2x on typical model architectures. Once your pipeline is solid, wrap your model in torch.compile for a free speed boost.
  4. Push to Hugging Face Datasets: For collaboration, use the datasets library. It handles streaming, caching, and versioning for massive datasets, integrating seamlessly with PyTorch DataLoader.

The goal is to make your data pipeline invisible—a silent, efficient conveyor belt feeding perfect tensors to your model. When it works, you stop thinking about it and start thinking about what actually matters: your model's architecture and loss function. Get the data right first. Everything else is built on top of it.