Exporting PyTorch Models to ONNX and Serving with 3x Lower Latency

Complete guide to PyTorch model export — ONNX export with dynamic axes, validating numerical equivalence, ONNX Runtime optimization for CPU/GPU, and deploying with FastAPI for production inference.

Your PyTorch model takes 45ms per inference. The same model in ONNX Runtime takes 14ms. Exporting is 20 lines of code.

Your model is trained, validated, and checkpointed. You’ve wrangled DataLoaders, survived OOM errors, and maybe even flirted with torch.compile. But in production, that elegant nn.Module is now a sluggish API endpoint, bottlenecking your entire service. The problem isn’t your model architecture—it’s the Python interpreter and PyTorch’s eager execution overhead. While PyTorch dominates 77% of ML research papers with code, research speed doesn't translate to production latency.

You could reach for TorchScript, but for standardized, hardware-accelerated inference, ONNX (Open Neural Network Exchange) is the exit ramp. This guide is about converting your PyTorch model into a portable, optimized computation graph and serving it with ONNX Runtime for a 3x latency cut. We’ll cover the export, validation, optimization, and serving—including the inevitable error messages and their exact fixes.

Exporting Your Model with Dynamic Axes for Real-World Batches

The first step is getting your model out of PyTorch. The torch.onnx.export function is your primary tool, but a naive export will bake in your example input’s batch size and sequence length. Real inference has variable sizes. You need dynamic axes.

Here’s a concrete example exporting a simple vision model, but the principle applies to any nn.Module:

import torch
import torch.nn as nn
import torch.onnx

class EfficientClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        self.head = nn.Linear(64, 10)

    def forward(self, pixel_values):
        features = self.backbone(pixel_values).flatten(1)
        return self.head(features)

model = EfficientClassifier().eval()
model.load_state_dict(torch.load('model.pt'))


dummy_input = torch.randn(1, 3, 224, 224)

# Define dynamic axes: batch and channel are dynamic for input; only batch for output
dynamic_axes = {
    'input': {0: 'batch_size', 1: 'channels'},  # 'input' must match forward arg name
    'output': {0: 'batch_size'}
}

torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes=dynamic_axes,
    opset_version=17,  # Use a recent, stable opset
    do_constant_folding=True  # Optimize constants
)

The dynamic_axes dictionary is critical. It tells ONNX that dimension 0 (batch) and dimension 1 (channels) of the input are variable. Now your model.onnx can handle batches of any size. Common gotcha: the keys in input_names and dynamic_axes must match. A mismatch here gives a silent failure later.

Validating the Export: Don’t Trust, Verify

An exported model that runs is not necessarily correct. Floating-point differences between PyTorch and ONNX Runtime (ORT) backends are expected, but they must be within tolerance. You must validate numerically.

import onnxruntime as ort
import numpy as np

# Run PyTorch inference
with torch.no_grad():
    torch_output = model(dummy_input).numpy()

# Create ONNX Runtime session
ort_session = ort.InferenceSession("model.onnx")
# Prepare input: convert to numpy, ensure correct type
ort_inputs = {ort_session.get_inputs()[0].name: dummy_input.numpy().astype(np.float32)}
ort_output = ort_session.run(None, ort_inputs)[0]

# Compare with relative (rtol) and absolute (atol) tolerance
np.testing.assert_allclose(torch_output, ort_output, rtol=1e-03, atol=1e-05)
print("✓ Export validated within rtol=1e-03, atol=1e-05")

Use rtol (relative tolerance) and atol (absolute tolerance). For FP32 models, rtol=1e-3 and atol=1e-5 are reasonable starting points. If this fails, your export is wrong. Common causes are mismatched opset versions or unsupported operations.

Choosing the Right Optimization Level in ONNX Runtime

ONNX Runtime isn’t a single runtime; it’s a stack of optimizers. The default session is unoptimized. You need to configure SessionOptions to unlock speed.

providers = ['CPUExecutionProvider']  # We'll cover GPU next
sess_options = ort.SessionOptions()

# Level 1: Basic graph optimizations (fuse nodes, eliminate redundancy)
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC

# Level 2: Extended optimizations (layout changes, more fusion)
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED

# Level 3: All available optimizations, including device-specific
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

# For reproducible benchmarking, disable them
# sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL

optimized_session = ort.InferenceSession("model.onnx", sess_options=sess_options, providers=providers)

ORT_ENABLE_ALL is typically what you want for production. It applies transformations like kernel fusion, which can dramatically reduce overhead. The performance gain varies by model architecture.

Optimization LevelResNet-50 Inference Latency (CPU)Notes
ORT_DISABLE_ALL120 msBaseline, no optimizations
ORT_ENABLE_BASIC95 ms~1.26x speedup
ORT_ENABLE_ALL78 ms~1.54x speedup vs. disabled

Table: Example inference latency improvement from ONNX Runtime graph optimizations on an Intel Xeon CPU. Your gains depend on model complexity and operator fusion opportunities.

Pushing Speed Further: INT8 Quantization for CPU Inference

If you’re serving on CPUs, FP32 is wasteful. INT8 quantization reduces precision to shrink the model and speed up computation. ONNX Runtime provides a quantization API that requires a calibration dataset.

from onnxruntime.quantization import quantize_dynamic, QuantType

# Dynamic quantization: weights are quantized to INT8, activations are quantized on-the-fly
quantized_model_path = "model.quant.onnx"
quantize_dynamic(
    "model.onnx",
    quantized_model_path,
    weight_type=QuantType.QInt8  # Use QInt8 for x86, QUInt8 for ARM
)

# Load and run the quantized model
quant_session = ort.InferenceSession(quantized_model_path, providers=['CPUExecutionProvider'])
ort_inputs = {quant_session.get_inputs()[0].name: dummy_input.numpy().astype(np.float32)}
quant_output = quant_session.run(None, ort_inputs)[0]

# Validate quantization didn't break everything
np.testing.assert_allclose(torch_output, quant_output, rtol=1e-01, atol=1e-02)  # Looser tolerance
print("✓ Quantized model validated. Expect ~4x speedup on CPU.")

Dynamic quantization is fast and requires no retraining. The accuracy drop is often negligible for many models, but always validate on your task. The tolerance here is looser (rtol=1e-1) because quantization introduces more significant numerical divergence.

GPU Inference: Hooking into the CUDAExecutionProvider

For real latency cuts, you need GPU inference. ONNX Runtime supports this via execution providers. The CUDAExecutionProvider is key.

import onnxruntime as ort

# Check available providers
print(ort.get_available_providers())
# Should output: ['CPUExecutionProvider', 'CUDAExecutionProvider'] if CUDA is available

# Create a session explicitly using CUDA
cuda_providers = [
    ('CUDAExecutionProvider', {
        'device_id': 0,  # Use GPU 0
        'arena_extend_strategy': 'kNextPowerOfTwo',
        'gpu_mem_limit': 4 * 1024 * 1024 * 1024,  # Limit to 4GB
        'cudnn_conv_algo_search': 'EXHAUSTIVE',  # Heuristic for conv algorithms
        'do_copy_in_default_stream': True,
    }),
    'CPUExecutionProvider'  # Fallback provider
]

sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Enable CUDA graph capture for even lower latency on repetitive inputs
sess_options.enable_cuda_graph = True

ort_session_gpu = ort.InferenceSession(
    "model.onnx",
    sess_options=sess_options,
    providers=cuda_providers
)

# Ensure input is on CPU, ORT handles GPU transfer
ort_inputs = {ort_session_gpu.get_inputs()[0].name: dummy_input.numpy()}
ort_output_gpu = ort_session_gpu.run(None, ort_inputs)[0]

Setting enable_cuda_graph=True is a pro move. It captures kernel sequences after the first run, eliminating launch overhead for subsequent identical runs—perfect for fixed-shape inference. The performance leap over PyTorch eager mode is significant, often hitting that 3x goal.

Serving with FastAPI: Async Endpoints with Smart Batching

A fast model is useless behind a slow server. FastAPI with async endpoints lets you handle concurrent requests efficiently. The trick is to use a singleton session and batch requests where possible.

from fastapi import FastAPI, BackgroundTasks
from contextlib import asynccontextmanager
import numpy as np
import onnxruntime as ort
from loguru import logger
import asyncio
from typing import List

# Lifespan management: load model once on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup: load the ONNX session
    providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
    sess_options = ort.SessionOptions()
    sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    app.state.model_session = ort.InferenceSession("model.quant.onnx", sess_options=sess_options, providers=providers)
    logger.info("Model loaded and ready.")
    yield
    # Shutdown: clean up
    del app.state.model_session
    logger.info("Model unloaded.")

app = FastAPI(lifespan=lifespan)

@app.post("/predict")
async def predict_batch(request_data: List[List[float]]):
    """Accepts a batch of inputs for prediction."""
    session = app.state.model_session
    input_name = session.get_inputs()[0].name

    # Convert to numpy array (this is a sync operation, keep batches small)
    np_batch = np.array(request_data, dtype=np.float32)

    # Reshape if needed (e.g., for vision models: (batch, 3, 224, 224))
    # np_batch = np_batch.reshape(-1, 3, 224, 224)

    ort_inputs = {input_name: np_batch}
    try:
        outputs = session.run(None, ort_inputs)
        predictions = outputs[0].tolist()
        return {"predictions": predictions}
    except Exception as e:
        logger.error(f"Inference failed: {e}")
        return {"error": "Inference failed"}, 500

# For true async batching, you'd use a queue system, but this pattern works for many workloads.

This setup avoids reloading the model per request. For higher throughput, implement a batching queue that collects requests over a short window (e.g., 10ms) and runs them as a single batch, amortizing the GPU kernel launch cost.

Debugging the Inevitable Export Failures

Export will fail. Here are the two most common errors and their exact fixes.

1. RuntimeError: Expected all tensors to be on the same device This happens during torch.onnx.export if your dummy input is on CPU but the model has parameters on GPU, or vice-versa.

# FIX: Explicitly manage device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dummy_input = dummy_input.to(device)  # CRITICAL: same device as model
torch.onnx.export(model, dummy_input, ...)

2. Training loss is NaN after making changes for export You might be tweaking your model or data pipeline for export and suddenly training diverges.

# FIX: Systematic debugging.
# 1. Check learning rate. Try 10x smaller.
# 2. Add gradient clipping to prevent explosion.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 3. Inspect inputs for NaN before the forward pass.
if torch.isnan(dummy_input).any():
    raise ValueError("NaN in input tensor")

For ONNX-specific failures like unsupported ops, you’ll see a clear error (e.g., Exporting the operator 'aten::my_custom_op' to ONNX opset version 17 is not supported). Solutions include:

  • Simplify the Model: Replace the unsupported PyTorch op with a composition of supported ones.
  • Use a Different Opset: Try a higher opset_version in torch.onnx.export.
  • Register a Custom Op: For advanced use, you can implement the op for ONNX Runtime, but this is a deep rabbit hole.

Next Steps: From Optimized Inference to Full Pipeline

You’ve exported, validated, optimized, and served your model. The latency should be significantly lower. Where next?

  1. Benchmark Rigorously: Use a tool like locust or wrk to stress-test your FastAPI endpoint under concurrent load. Measure p50, p95, and p99 latencies, not just averages.
  2. Explore torch.compile for Training: If you’re still in the training phase, remember that torch.compile (TorchDynamo) speeds up training 1.5–2x on typical model architectures. It’s a separate optimization from inference but crucial for development speed.
  3. Consider the Full Stack: For maximum throughput, look beyond ONNX Runtime. TorchServe is a dedicated serving framework from PyTorch that can also use ONNX models. For massive models, integrate DeepSpeed or FSDP (Fully Sharded Data Parallel) during training, though their inference stories are different.
  4. Monitor and Iterate: Deploy your optimized model with monitoring (e.g., Weights & Biases for tracking performance drift). Model optimization isn’t a one-time event.

The goal isn’t just a faster model, but a predictable, scalable service. ONNX Runtime gets you the raw speed; a robust serving architecture built around it ensures that speed translates to a better product. Stop letting Python interpreter overhead be your bottleneck. Export, optimize, and serve.