Problem: Writing PyTorch Tests Takes Hours
You built a neural network but writing comprehensive unit tests for tensor shapes, gradient flows, and edge cases eats up development time. Manual testing misses obscure bugs that break in production.
You'll learn:
- Generate unit tests automatically using Claude API
- Test tensor operations, gradients, and model outputs
- Catch shape mismatches and NaN errors before deployment
Time: 20 min | Level: Intermediate
Why This Happens
PyTorch models have complex failure modes - gradient vanishing, shape broadcasting errors, device mismatches. Writing tests for every edge case manually is slow and error-prone. AI can analyze your model code and generate targeted tests automatically.
Common symptoms:
- Model works in notebook but fails in production
- Shape errors only appear with specific batch sizes
- Gradient explosions in rare input conditions
- No test coverage for edge cases
Solution
Step 1: Install Dependencies
# Install required packages
pip install torch pytest anthropic --break-system-packages
# Verify installation
python -c "import torch; import anthropic; print('Ready')"
Expected: Should print Ready with no errors
If it fails:
- Error: "No module anthropic": Run
pip install anthropic --break-system-packages --upgrade - torch not found: Install with
pip install torch --index-url https://download.pytorch.org/whl/cpu
Step 2: Create Test Generator Script
# test_generator.py
import anthropic
import os
from pathlib import Path
def generate_tests(model_code: str, model_name: str) -> str:
"""
Generate pytest tests for PyTorch model using Claude API.
Args:
model_code: Source code of the PyTorch model
model_name: Name of the model class
Returns:
Generated test code as string
"""
client = anthropic.Anthropic(
api_key=os.environ.get("ANTHROPIC_API_KEY")
)
# This prompt is optimized for PyTorch-specific test generation
prompt = f"""Generate comprehensive pytest unit tests for this PyTorch model.
MODEL CODE:
{model_code}
REQUIREMENTS:
1. Test tensor shape consistency across forward pass
2. Test gradient flow (no NaN, no vanishing)
3. Test with different batch sizes (1, 16, 32)
4. Test device compatibility (CPU/CUDA if available)
5. Test edge cases (zero input, negative values, extreme values)
6. Use fixtures for model instantiation
7. Include docstrings explaining what each test validates
MODEL CLASS: {model_name}
Return ONLY the pytest code, no explanations."""
message = client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=4000,
messages=[{"role": "user", "content": prompt}]
)
return message.content[0].text
def save_tests(test_code: str, output_path: str):
"""Save generated tests to file"""
Path(output_path).write_text(test_code)
print(f"✓ Tests saved to {output_path}")
if __name__ == "__main__":
# Example usage
sample_model = """
import torch
import torch.nn as nn
class SimpleClassifier(nn.Module):
def __init__(self, input_dim=784, hidden_dim=128, num_classes=10):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
"""
tests = generate_tests(sample_model, "SimpleClassifier")
save_tests(tests, "test_classifier.py")
Why this works: Claude analyzes the model architecture and generates tests for common PyTorch failure modes. The structured prompt ensures tests cover shape checking, gradient validation, and edge cases.
Environment setup:
export ANTHROPIC_API_KEY="your-api-key-here"
Step 3: Generate Tests for Your Model
# your_model.py
import torch
import torch.nn as nn
class CustomCNN(nn.Module):
"""Example: Image classification CNN"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc = nn.Linear(64 * 8 * 8, 10)
def forward(self, x):
# x shape: (batch, 3, 32, 32)
x = self.pool(torch.relu(self.conv1(x))) # (batch, 32, 16, 16)
x = self.pool(torch.relu(self.conv2(x))) # (batch, 64, 8, 8)
x = x.view(x.size(0), -1) # Flatten
x = self.fc(x)
return x
Run the generator:
python -c "
from test_generator import generate_tests, save_tests
from pathlib import Path
model_code = Path('your_model.py').read_text()
tests = generate_tests(model_code, 'CustomCNN')
save_tests(tests, 'test_custom_cnn.py')
"
Expected output:
✓ Tests saved to test_custom_cnn.py
Step 4: Review Generated Tests
The AI generates tests like this:
# test_custom_cnn.py (AI-generated example)
import pytest
import torch
from your_model import CustomCNN
@pytest.fixture
def model():
"""Fixture providing fresh model instance for each test"""
return CustomCNN()
def test_forward_pass_shape(model):
"""Verify output shape matches expected num_classes"""
batch_size = 16
x = torch.randn(batch_size, 3, 32, 32)
output = model(x)
assert output.shape == (batch_size, 10), \
f"Expected (16, 10), got {output.shape}"
def test_gradient_flow(model):
"""Ensure gradients propagate without NaN or explosion"""
x = torch.randn(4, 3, 32, 32, requires_grad=True)
output = model(x)
loss = output.sum()
loss.backward()
# Check no NaN gradients
for name, param in model.named_parameters():
assert not torch.isnan(param.grad).any(), \
f"NaN gradient in {name}"
assert not torch.isinf(param.grad).any(), \
f"Inf gradient in {name}"
def test_batch_size_flexibility(model):
"""Model should handle different batch sizes"""
for batch_size in [1, 16, 32]:
x = torch.randn(batch_size, 3, 32, 32)
output = model(x)
assert output.shape[0] == batch_size
def test_zero_input_handling(model):
"""Model should process zero tensors without errors"""
x = torch.zeros(8, 3, 32, 32)
output = model(x)
assert not torch.isnan(output).any()
Manual review checklist:
- Tests match your model's actual input/output dimensions
- Edge cases are relevant to your use case
- Test names clearly describe what they validate
Common fixes needed:
# If AI assumes wrong input shape, update the test:
# Change this:
x = torch.randn(batch_size, 3, 32, 32)
# To your actual input shape:
x = torch.randn(batch_size, 1, 28, 28) # MNIST example
Step 5: Run the Tests
# Run all generated tests
pytest test_custom_cnn.py -v
# Run with coverage report
pytest test_custom_cnn.py --cov=your_model --cov-report=term-missing
Expected output:
test_custom_cnn.py::test_forward_pass_shape PASSED
test_custom_cnn.py::test_gradient_flow PASSED
test_custom_cnn.py::test_batch_size_flexibility PASSED
test_custom_cnn.py::test_zero_input_handling PASSED
======== 4 passed in 2.31s ========
If tests fail:
- Shape mismatch: Update your model or fix the test's input dimensions
- Gradient NaN: Check for division by zero or unstable activations in your model
- CUDA errors: Tests auto-detect GPU availability, ensure
torch.cuda.is_available()works
Verification
Test the complete workflow:
# 1. Generate tests for a new model
python test_generator.py
# 2. Run generated tests
pytest test_classifier.py -v
# 3. Check coverage
pytest --cov=. --cov-report=html
open htmlcov/index.html # View coverage report
You should see:
- All tests passing
- Coverage report showing tested code paths
- No warnings about tensor shape mismatches
Advanced: CI/CD Integration
Add to .github/workflows/test.yml:
name: Auto-Generated Tests
on: [push, pull_request]
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Install dependencies
run: |
pip install torch pytest anthropic --break-system-packages
- name: Generate tests
env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
run: python test_generator.py
- name: Run tests
run: pytest -v --cov=. --cov-report=xml
- name: Upload coverage
uses: codecov/codecov-action@v4
Why this works: Every commit auto-generates fresh tests based on current model code, catching regressions immediately.
What You Learned
- AI can generate PyTorch-specific tests (shapes, gradients, edge cases)
- Claude API analyzes model architecture to create targeted tests
- Automated testing catches device mismatches and shape errors before production
Limitations:
- Generated tests need manual review for domain-specific edge cases
- AI can't test business logic (e.g., "output should be probability distribution")
- Complex models (transformers, GANs) may need custom test templates
When NOT to use this:
- Simple models with <10 lines of code (manual tests are faster)
- Models with unusual architectures AI hasn't seen in training data
- Projects requiring specific test frameworks beyond pytest
Production Tips
1. Test Template Customization
Create reusable templates for your team:
# test_templates.py
TEMPLATE_PROMPT = """Generate pytest tests for this PyTorch model.
CRITICAL REQUIREMENTS:
- Test on our standard input shape: (batch, 512, 768)
- Verify output is probability distribution (sum to 1)
- Test with our custom loss function: FocalLoss
- Include tests for mixed precision (fp16)
MODEL CODE:
{model_code}
Return pytest code only."""
2. Batch Test Generation
Generate tests for entire model directory:
# batch_generate.py
from pathlib import Path
from test_generator import generate_tests, save_tests
models_dir = Path("models/")
for model_file in models_dir.glob("*.py"):
code = model_file.read_text()
tests = generate_tests(code, model_file.stem)
save_tests(tests, f"tests/test_{model_file.stem}.py")
3. Cost Optimization
Claude API costs ~$0.01 per test file. Reduce costs:
# Only regenerate tests when model changes
import hashlib
def should_regenerate(model_path, test_path):
"""Check if model changed since last test generation"""
model_hash = hashlib.sha256(Path(model_path).read_bytes()).hexdigest()
if not Path(test_path).exists():
return True
# Store hash in test file comment
test_content = Path(test_path).read_text()
if f"# model_hash: {model_hash}" in test_content:
return False
return True
Troubleshooting
Issue: Tests generate but fail immediately
# Debug: Check if model imports correctly
python -c "from your_model import CustomCNN; print('OK')"
Solution: Ensure model file has no import errors, all dependencies installed
Issue: AI generates wrong input shapes
Solution: Make input shape explicit in docstrings:
class CustomCNN(nn.Module):
"""
Image classifier for 32x32 RGB images.
Input shape: (batch_size, 3, 32, 32)
Output shape: (batch_size, 10)
"""
Issue: Rate limit errors from Anthropic API
# Add retry logic
from anthropic import RateLimitError
import time
def generate_with_retry(model_code, model_name, max_retries=3):
for attempt in range(max_retries):
try:
return generate_tests(model_code, model_name)
except RateLimitError:
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
else:
raise
Tested on PyTorch 2.2.0, Python 3.11, Claude API 2026-02-01, Ubuntu 24.04 & macOS Sonoma