Automate PyTorch Unit Tests with AI in 20 Minutes

Generate intelligent unit tests for PyTorch models using Claude API - catch bugs before production with automated test generation.

Problem: Writing PyTorch Tests Takes Hours

You built a neural network but writing comprehensive unit tests for tensor shapes, gradient flows, and edge cases eats up development time. Manual testing misses obscure bugs that break in production.

You'll learn:

  • Generate unit tests automatically using Claude API
  • Test tensor operations, gradients, and model outputs
  • Catch shape mismatches and NaN errors before deployment

Time: 20 min | Level: Intermediate


Why This Happens

PyTorch models have complex failure modes - gradient vanishing, shape broadcasting errors, device mismatches. Writing tests for every edge case manually is slow and error-prone. AI can analyze your model code and generate targeted tests automatically.

Common symptoms:

  • Model works in notebook but fails in production
  • Shape errors only appear with specific batch sizes
  • Gradient explosions in rare input conditions
  • No test coverage for edge cases

Solution

Step 1: Install Dependencies

# Install required packages
pip install torch pytest anthropic --break-system-packages

# Verify installation
python -c "import torch; import anthropic; print('Ready')"

Expected: Should print Ready with no errors

If it fails:

  • Error: "No module anthropic": Run pip install anthropic --break-system-packages --upgrade
  • torch not found: Install with pip install torch --index-url https://download.pytorch.org/whl/cpu

Step 2: Create Test Generator Script

# test_generator.py
import anthropic
import os
from pathlib import Path

def generate_tests(model_code: str, model_name: str) -> str:
    """
    Generate pytest tests for PyTorch model using Claude API.
    
    Args:
        model_code: Source code of the PyTorch model
        model_name: Name of the model class
    
    Returns:
        Generated test code as string
    """
    client = anthropic.Anthropic(
        api_key=os.environ.get("ANTHROPIC_API_KEY")
    )
    
    # This prompt is optimized for PyTorch-specific test generation
    prompt = f"""Generate comprehensive pytest unit tests for this PyTorch model.

MODEL CODE:
{model_code}

REQUIREMENTS:
1. Test tensor shape consistency across forward pass
2. Test gradient flow (no NaN, no vanishing)
3. Test with different batch sizes (1, 16, 32)
4. Test device compatibility (CPU/CUDA if available)
5. Test edge cases (zero input, negative values, extreme values)
6. Use fixtures for model instantiation
7. Include docstrings explaining what each test validates

MODEL CLASS: {model_name}

Return ONLY the pytest code, no explanations."""

    message = client.messages.create(
        model="claude-sonnet-4-20250514",
        max_tokens=4000,
        messages=[{"role": "user", "content": prompt}]
    )
    
    return message.content[0].text

def save_tests(test_code: str, output_path: str):
    """Save generated tests to file"""
    Path(output_path).write_text(test_code)
    print(f"✓ Tests saved to {output_path}")

if __name__ == "__main__":
    # Example usage
    sample_model = """
import torch
import torch.nn as nn

class SimpleClassifier(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
"""
    
    tests = generate_tests(sample_model, "SimpleClassifier")
    save_tests(tests, "test_classifier.py")

Why this works: Claude analyzes the model architecture and generates tests for common PyTorch failure modes. The structured prompt ensures tests cover shape checking, gradient validation, and edge cases.

Environment setup:

export ANTHROPIC_API_KEY="your-api-key-here"

Step 3: Generate Tests for Your Model

# your_model.py
import torch
import torch.nn as nn

class CustomCNN(nn.Module):
    """Example: Image classification CNN"""
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(64 * 8 * 8, 10)
    
    def forward(self, x):
        # x shape: (batch, 3, 32, 32)
        x = self.pool(torch.relu(self.conv1(x)))  # (batch, 32, 16, 16)
        x = self.pool(torch.relu(self.conv2(x)))  # (batch, 64, 8, 8)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        return x

Run the generator:

python -c "
from test_generator import generate_tests, save_tests
from pathlib import Path

model_code = Path('your_model.py').read_text()
tests = generate_tests(model_code, 'CustomCNN')
save_tests(tests, 'test_custom_cnn.py')
"

Expected output:

✓ Tests saved to test_custom_cnn.py

Step 4: Review Generated Tests

The AI generates tests like this:

# test_custom_cnn.py (AI-generated example)
import pytest
import torch
from your_model import CustomCNN

@pytest.fixture
def model():
    """Fixture providing fresh model instance for each test"""
    return CustomCNN()

def test_forward_pass_shape(model):
    """Verify output shape matches expected num_classes"""
    batch_size = 16
    x = torch.randn(batch_size, 3, 32, 32)
    output = model(x)
    
    assert output.shape == (batch_size, 10), \
        f"Expected (16, 10), got {output.shape}"

def test_gradient_flow(model):
    """Ensure gradients propagate without NaN or explosion"""
    x = torch.randn(4, 3, 32, 32, requires_grad=True)
    output = model(x)
    loss = output.sum()
    loss.backward()
    
    # Check no NaN gradients
    for name, param in model.named_parameters():
        assert not torch.isnan(param.grad).any(), \
            f"NaN gradient in {name}"
        assert not torch.isinf(param.grad).any(), \
            f"Inf gradient in {name}"

def test_batch_size_flexibility(model):
    """Model should handle different batch sizes"""
    for batch_size in [1, 16, 32]:
        x = torch.randn(batch_size, 3, 32, 32)
        output = model(x)
        assert output.shape[0] == batch_size

def test_zero_input_handling(model):
    """Model should process zero tensors without errors"""
    x = torch.zeros(8, 3, 32, 32)
    output = model(x)
    assert not torch.isnan(output).any()

Manual review checklist:

  • Tests match your model's actual input/output dimensions
  • Edge cases are relevant to your use case
  • Test names clearly describe what they validate

Common fixes needed:

# If AI assumes wrong input shape, update the test:
# Change this:
x = torch.randn(batch_size, 3, 32, 32)
# To your actual input shape:
x = torch.randn(batch_size, 1, 28, 28)  # MNIST example

Step 5: Run the Tests

# Run all generated tests
pytest test_custom_cnn.py -v

# Run with coverage report
pytest test_custom_cnn.py --cov=your_model --cov-report=term-missing

Expected output:

test_custom_cnn.py::test_forward_pass_shape PASSED
test_custom_cnn.py::test_gradient_flow PASSED
test_custom_cnn.py::test_batch_size_flexibility PASSED
test_custom_cnn.py::test_zero_input_handling PASSED

======== 4 passed in 2.31s ========

If tests fail:

  • Shape mismatch: Update your model or fix the test's input dimensions
  • Gradient NaN: Check for division by zero or unstable activations in your model
  • CUDA errors: Tests auto-detect GPU availability, ensure torch.cuda.is_available() works

Verification

Test the complete workflow:

# 1. Generate tests for a new model
python test_generator.py

# 2. Run generated tests
pytest test_classifier.py -v

# 3. Check coverage
pytest --cov=. --cov-report=html
open htmlcov/index.html  # View coverage report

You should see:

  • All tests passing
  • Coverage report showing tested code paths
  • No warnings about tensor shape mismatches

Advanced: CI/CD Integration

Add to .github/workflows/test.yml:

name: Auto-Generated Tests

on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v4
      
      - name: Set up Python
        uses: actions/setup-python@v5
        with:
          python-version: '3.11'
      
      - name: Install dependencies
        run: |
          pip install torch pytest anthropic --break-system-packages
      
      - name: Generate tests
        env:
          ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
        run: python test_generator.py
      
      - name: Run tests
        run: pytest -v --cov=. --cov-report=xml
      
      - name: Upload coverage
        uses: codecov/codecov-action@v4

Why this works: Every commit auto-generates fresh tests based on current model code, catching regressions immediately.


What You Learned

  • AI can generate PyTorch-specific tests (shapes, gradients, edge cases)
  • Claude API analyzes model architecture to create targeted tests
  • Automated testing catches device mismatches and shape errors before production

Limitations:

  • Generated tests need manual review for domain-specific edge cases
  • AI can't test business logic (e.g., "output should be probability distribution")
  • Complex models (transformers, GANs) may need custom test templates

When NOT to use this:

  • Simple models with <10 lines of code (manual tests are faster)
  • Models with unusual architectures AI hasn't seen in training data
  • Projects requiring specific test frameworks beyond pytest

Production Tips

1. Test Template Customization

Create reusable templates for your team:

# test_templates.py
TEMPLATE_PROMPT = """Generate pytest tests for this PyTorch model.

CRITICAL REQUIREMENTS:
- Test on our standard input shape: (batch, 512, 768)
- Verify output is probability distribution (sum to 1)
- Test with our custom loss function: FocalLoss
- Include tests for mixed precision (fp16)

MODEL CODE:
{model_code}

Return pytest code only."""

2. Batch Test Generation

Generate tests for entire model directory:

# batch_generate.py
from pathlib import Path
from test_generator import generate_tests, save_tests

models_dir = Path("models/")
for model_file in models_dir.glob("*.py"):
    code = model_file.read_text()
    tests = generate_tests(code, model_file.stem)
    save_tests(tests, f"tests/test_{model_file.stem}.py")

3. Cost Optimization

Claude API costs ~$0.01 per test file. Reduce costs:

# Only regenerate tests when model changes
import hashlib

def should_regenerate(model_path, test_path):
    """Check if model changed since last test generation"""
    model_hash = hashlib.sha256(Path(model_path).read_bytes()).hexdigest()
    
    if not Path(test_path).exists():
        return True
    
    # Store hash in test file comment
    test_content = Path(test_path).read_text()
    if f"# model_hash: {model_hash}" in test_content:
        return False
    return True

Troubleshooting

Issue: Tests generate but fail immediately

# Debug: Check if model imports correctly
python -c "from your_model import CustomCNN; print('OK')"

Solution: Ensure model file has no import errors, all dependencies installed


Issue: AI generates wrong input shapes

Solution: Make input shape explicit in docstrings:

class CustomCNN(nn.Module):
    """
    Image classifier for 32x32 RGB images.
    
    Input shape: (batch_size, 3, 32, 32)
    Output shape: (batch_size, 10)
    """

Issue: Rate limit errors from Anthropic API

# Add retry logic
from anthropic import RateLimitError
import time

def generate_with_retry(model_code, model_name, max_retries=3):
    for attempt in range(max_retries):
        try:
            return generate_tests(model_code, model_name)
        except RateLimitError:
            if attempt < max_retries - 1:
                time.sleep(2 ** attempt)  # Exponential backoff
            else:
                raise

Tested on PyTorch 2.2.0, Python 3.11, Claude API 2026-02-01, Ubuntu 24.04 & macOS Sonoma