Build Browser ML Apps with WebGPU in 30 Minutes

Run real machine learning models in the browser using WebGPU for GPU-accelerated inference without server costs.

Problem: Running ML Models Without Server Costs

You need to run image classification or text generation in your web app, but sending data to a server is slow, expensive, and raises privacy concerns.

You'll learn:

  • How WebGPU accelerates ML inference 10-100x vs CPU
  • Setting up ONNX Runtime Web with WebGPU backend
  • Running a real image classifier entirely in-browser

Time: 30 min | Level: Intermediate


Why This Happens

Traditional browser ML uses WebGL (2011 tech) or CPU-only JavaScript. WebGPU (standardized 2023, widely available 2024+) gives direct GPU access with compute shaders, making real-time ML feasible.

Common symptoms:

  • TensorFlow.js models run too slow for production
  • Can't afford inference API costs at scale
  • Privacy-sensitive data can't leave the browser
  • Mobile users have terrible ML performance

Reality check: WebGPU gives 20-50x speedup on simple models, but you still can't run LLaMA 70B in a browser. Think image classification, small vision transformers, or quantized language models under 1GB.


Solution

Step 1: Verify WebGPU Support

// Check browser compatibility
async function checkWebGPU() {
  if (!navigator.gpu) {
    console.error('WebGPU not supported');
    return false;
  }
  
  const adapter = await navigator.gpu.requestAdapter();
  if (!adapter) {
    console.error('No GPU adapter found');
    return false;
  }
  
  console.log('WebGPU ready:', adapter.info);
  return true;
}

checkWebGPU();

Expected: Console shows GPU adapter info (vendor, architecture)

If it fails:

  • Chrome < 113: Update browser (WebGPU stable since Chrome 113)
  • Firefox: Enable dom.webgpu.enabled in about:config (experimental)
  • Safari: Available in Safari 18+ (macOS Sonoma+)

Current support (Feb 2026): Chrome/Edge 113+, Safari 18+, Firefox behind flag


Step 2: Install ONNX Runtime Web

npm install onnxruntime-web

Why ONNX: Model format supported by PyTorch, TensorFlow, scikit-learn. ONNX Runtime Web has the best WebGPU integration as of 2026.

Alternatives:

  • TensorFlow.js (good WebGPU support, larger bundle size)
  • Transformers.js (great for NLP, limited vision models)

Step 3: Convert Your Model to ONNX

# Example: Export PyTorch ResNet to ONNX
import torch
import torchvision.models as models

model = models.resnet50(pretrained=True)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224)

torch.onnx.export(
    model,
    dummy_input,
    "resnet50.onnx",
    opset_version=17,  # Use latest opset for WebGPU ops
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

Why opset 17: WebGPU backend supports ops from ONNX opset 13+, but 17 includes optimizations for attention mechanisms

Model size warning: Keep under 100MB for reasonable load times. Use quantization for larger models.


Step 4: Load and Run Inference

import * as ort from 'onnxruntime-web';

// Enable WebGPU backend (falls back to WASM if unavailable)
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/';

async function runInference(imageData) {
  // Load model once, reuse for multiple inferences
  const session = await ort.InferenceSession.create('./resnet50.onnx', {
    executionProviders: ['webgpu', 'wasm'],  // Fallback order
    graphOptimizationLevel: 'all'
  });
  
  // Preprocess image to tensor
  const tensor = preprocessImage(imageData);
  
  // Run inference
  const feeds = { input: tensor };
  const results = await session.run(feeds);
  
  // Get predictions
  const output = results.output.data;  // Float32Array of logits
  return getTopPredictions(output, 5);
}

function preprocessImage(imageData) {
  // Resize to 224x224, normalize to ImageNet stats
  const canvas = document.createElement('canvas');
  canvas.width = 224;
  canvas.height = 224;
  
  const ctx = canvas.getContext('2d');
  ctx.drawImage(imageData, 0, 0, 224, 224);
  
  const pixels = ctx.getImageData(0, 0, 224, 224).data;
  const float32Data = new Float32Array(3 * 224 * 224);
  
  // Convert RGBA to RGB, normalize
  for (let i = 0; i < 224 * 224; i++) {
    float32Data[i] = (pixels[i * 4] / 255 - 0.485) / 0.229;          // R
    float32Data[224 * 224 + i] = (pixels[i * 4 + 1] / 255 - 0.456) / 0.224;  // G
    float32Data[2 * 224 * 224 + i] = (pixels[i * 4 + 2] / 255 - 0.406) / 0.225;  // B
  }
  
  return new ort.Tensor('float32', float32Data, [1, 3, 224, 224]);
}

function getTopPredictions(logits, k = 5) {
  // Apply softmax and get top-k
  const probabilities = softmax(logits);
  const indices = Array.from(probabilities.keys())
    .sort((a, b) => probabilities[b] - probabilities[a])
    .slice(0, k);
  
  return indices.map(i => ({
    class: IMAGENET_CLASSES[i],
    probability: probabilities[i]
  }));
}

function softmax(arr) {
  const max = Math.max(...arr);
  const exp = arr.map(x => Math.exp(x - max));
  const sum = exp.reduce((a, b) => a + b);
  return exp.map(x => x / sum);
}

Why this works: ONNX Runtime detects WebGPU, compiles compute shaders for each operation, runs on GPU. Falls back to WebAssembly CPU if WebGPU unavailable.

Performance: ResNet-50 inference ~15-30ms on modern GPUs (M1/M2, RTX 30-series) vs 300-500ms on CPU


Step 5: Optimize for Production

// Cache session globally, load once at app start
let cachedSession = null;

async function initModel() {
  if (cachedSession) return cachedSession;
  
  cachedSession = await ort.InferenceSession.create('./resnet50.onnx', {
    executionProviders: ['webgpu', 'wasm'],
    graphOptimizationLevel: 'all',
    executionMode: 'parallel',  // Use multiple threads
    logSeverityLevel: 3  // Only errors in production
  });
  
  return cachedSession;
}

// Warm up model (first inference is slower due to shader compilation)
async function warmup() {
  const session = await initModel();
  const dummyTensor = new ort.Tensor('float32', new Float32Array(1 * 3 * 224 * 224), [1, 3, 224, 224]);
  await session.run({ input: dummyTensor });
  console.log('Model warmed up');
}

// Call on page load
warmup();

Why warm up: First inference compiles GPU shaders (50-200ms). Subsequent runs use cached shaders.

Bundle optimization:

// Use CDN for ONNX Runtime WASM files (saves ~8MB from bundle)
ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';

Verification

Test WebGPU Performance

async function benchmark() {
  const session = await initModel();
  const testTensor = new ort.Tensor('float32', new Float32Array(1 * 3 * 224 * 224), [1, 3, 224, 224]);
  
  const runs = 100;
  const start = performance.now();
  
  for (let i = 0; i < runs; i++) {
    await session.run({ input: testTensor });
  }
  
  const elapsed = performance.now() - start;
  console.log(`Average inference time: ${elapsed / runs}ms`);
}

benchmark();

You should see:

  • WebGPU: 15-30ms per inference (desktop), 30-60ms (mobile)
  • CPU fallback: 300-500ms per inference

If slow:

  • Check DevTools → Performance → GPU usage during inference
  • Ensure model loaded from cache (Network tab shows 0ms)
  • Try smaller model (MobileNetV3 is 5x faster, 2% less accurate)

Real-World Example: Image Classifier UI

<!DOCTYPE html>
<html>
<head>
  <title>WebGPU Image Classifier</title>
</head>
<body>
  <input type="file" id="imageInput" accept="image/*">
  <canvas id="preview" width="224" height="224"></canvas>
  <div id="results"></div>
  
  <script type="module">
    import * as ort from 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/esm/ort.min.js';
    
    let session = null;
    
    // Initialize model
    async function init() {
      ort.env.wasm.wasmPaths = 'https://cdn.jsdelivr.net/npm/onnxruntime-web@1.17.0/dist/';
      session = await ort.InferenceSession.create('./resnet50.onnx', {
        executionProviders: ['webgpu', 'wasm']
      });
      console.log('Ready');
    }
    
    // Handle image upload
    document.getElementById('imageInput').onchange = async (e) => {
      const file = e.target.files[0];
      const img = await loadImage(file);
      
      // Show preview
      const canvas = document.getElementById('preview');
      const ctx = canvas.getContext('2d');
      ctx.drawImage(img, 0, 0, 224, 224);
      
      // Classify
      const tensor = preprocessImage(img);
      const results = await session.run({ input: tensor });
      const predictions = getTopPredictions(results.output.data, 5);
      
      // Display results
      document.getElementById('results').innerHTML = predictions
        .map(p => `<div>${p.class}: ${(p.probability * 100).toFixed(1)}%</div>`)
        .join('');
    };
    
    function loadImage(file) {
      return new Promise((resolve) => {
        const img = new Image();
        img.onload = () => resolve(img);
        img.src = URL.createObjectURL(file);
      });
    }
    
    // (Include preprocessImage and getTopPredictions from Step 4)
    
    init();
  </script>
</body>
</html>

Test it: Upload a photo, should see predictions in <50ms (WebGPU) or <500ms (CPU fallback)


What You Learned

  • WebGPU enables real-time ML inference in browsers (20-50x faster than CPU)
  • ONNX Runtime Web supports most PyTorch/TensorFlow models with minimal conversion
  • Always provide CPU fallback (WebGPU support still growing)
  • First inference is slow (shader compilation), cache sessions

Limitations:

  • Model size: Keep under 100MB for reasonable UX
  • Browser support: Chrome/Safari only, Firefox experimental
  • Memory: Large models can crash mobile browsers
  • Not for training: Inference only, use PyTorch/JAX for training

When NOT to use this:

  • LLMs over 1-2B parameters (use server or quantization)
  • Models requiring custom ops not in ONNX
  • Legacy browser support required

Production Checklist

  • Add loading spinner during model initialization
  • Show CPU fallback warning if WebGPU unavailable
  • Implement progressive loading for large models
  • Add error handling for OOM crashes
  • Test on mobile Safari and Chrome
  • Monitor bundle size (<500KB for ONNX Runtime)
  • Add Web Worker for non-blocking inference
  • Implement model caching with Service Worker

Resources

Official Docs:

Model Hubs:

Performance:

  • Quantize to INT8: 4x smaller, 2-3x faster, <1% accuracy loss
  • Use MobileNet/EfficientNet for mobile
  • Split large models into chunks for progressive loading

Tested on Chrome 121, Safari 18, ONNX Runtime Web 1.17, macOS M2 & Windows RTX 4070