Real-time Inference with Transformers: WebSocket Implementation Guide

Build real-time transformer inference with WebSocket streaming. Complete implementation guide with code examples for instant AI responses.

Waiting 30 seconds for your AI chatbot to respond? That's so 2022. Modern users expect instant, streaming responses that feel like real conversations. Here's how to build lightning-fast transformer inference with WebSocket streaming that delivers tokens as they're generated.

Why Real-time Transformer Inference Matters

Traditional HTTP-based inference creates a frustrating user experience. Users submit a prompt, stare at a loading spinner, then receive a wall of text all at once. WebSocket streaming transforms this into a natural conversation flow where responses appear word-by-word, just like ChatGPT.

Primary Benefits:

  • Instant feedback: Users see responses immediately
  • Better UX: Streaming text feels more interactive
  • Scalable connections: WebSockets handle thousands of concurrent users
  • Reduced perceived latency: Users engage with partial responses

Understanding WebSocket Streaming Architecture

WebSocket connections enable bidirectional communication between client and server. For transformer inference, this means:

  1. Client sends prompt via WebSocket
  2. Server generates tokens incrementally
  3. Each token streams back immediately
  4. Client displays tokens in real-time

Key Components for Implementation

  • WebSocket Server: Handles connections and inference requests
  • Transformer Model: Generates tokens with streaming capability
  • Token Buffer: Manages partial responses
  • Client Handler: Processes incoming tokens
  • Error Management: Handles connection failures gracefully

Setting Up the WebSocket Server

Let's build a complete WebSocket server that streams transformer outputs:

import asyncio
import websockets
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging

class TransformerWebSocketServer:
    def __init__(self, model_name="microsoft/DialoGPT-medium"):
        """Initialize transformer model and tokenizer for streaming inference"""
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # Add padding token if missing
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    async def generate_streaming_response(self, prompt, websocket, max_length=100):
        """Generate tokens and stream them via WebSocket"""
        try:
            # Tokenize input prompt
            inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
            
            # Initialize generation parameters
            generated = inputs
            
            for _ in range(max_length):
                # Generate next token
                with torch.no_grad():
                    outputs = self.model(generated)
                    predictions = outputs.logits[0, -1, :]
                    next_token_id = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
                
                # Decode token to text
                token_text = self.tokenizer.decode(next_token_id, skip_special_tokens=True)
                
                # Stream token immediately
                await websocket.send(json.dumps({
                    "type": "token",
                    "content": token_text,
                    "finished": False
                }))
                
                # Add token to generated sequence
                generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
                
                # Check for end token
                if next_token_id.item() == self.tokenizer.eos_token_id:
                    break
                
                # Small delay to prevent overwhelming client
                await asyncio.sleep(0.05)
            
            # Send completion signal
            await websocket.send(json.dumps({
                "type": "complete",
                "content": "",
                "finished": True
            }))
            
        except Exception as e:
            self.logger.error(f"Generation error: {str(e)}")
            await websocket.send(json.dumps({
                "type": "error",
                "content": f"Generation failed: {str(e)}",
                "finished": True
            }))
    
    async def handle_client(self, websocket, path):
        """Handle incoming WebSocket connections"""
        self.logger.info(f"New client connected: {websocket.remote_address}")
        
        try:
            async for message in websocket:
                data = json.loads(message)
                
                if data.get("type") == "generate":
                    prompt = data.get("prompt", "")
                    max_length = data.get("max_length", 100)
                    
                    # Send acknowledgment
                    await websocket.send(json.dumps({
                        "type": "started",
                        "content": "Generation started",
                        "finished": False
                    }))
                    
                    # Start streaming generation
                    await self.generate_streaming_response(prompt, websocket, max_length)
                    
        except websockets.exceptions.ConnectionClosed:
            self.logger.info(f"Client disconnected: {websocket.remote_address}")
        except Exception as e:
            self.logger.error(f"Connection error: {str(e)}")
    
    def start_server(self, host="localhost", port=8765):
        """Start the WebSocket server"""
        self.logger.info(f"Starting WebSocket server on {host}:{port}")
        start_server = websockets.serve(self.handle_client, host, port)
        asyncio.get_event_loop().run_until_complete(start_server)
        asyncio.get_event_loop().run_forever()

# Initialize and start server
if __name__ == "__main__":
    server = TransformerWebSocketServer()
    server.start_server()

Building the Client-Side Interface

Create a responsive web interface that handles streaming responses:

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Real-time Transformer Chat</title>
    <style>
        body {
            font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
            max-width: 800px;
            margin: 0 auto;
            padding: 20px;
            background: #f5f5f5;
        }
        .chat-container {
            background: white;
            border-radius: 10px;
            padding: 20px;
            box-shadow: 0 2px 10px rgba(0,0,0,0.1);
        }
        .messages {
            height: 400px;
            overflow-y: auto;
            border: 1px solid #ddd;
            padding: 15px;
            margin-bottom: 20px;
            border-radius: 5px;
        }
        .message {
            margin-bottom: 15px;
            padding: 10px;
            border-radius: 5px;
        }
        .user-message {
            background: #007bff;
            color: white;
            margin-left: 20%;
        }
        .ai-message {
            background: #f8f9fa;
            border-left: 4px solid #28a745;
            margin-right: 20%;
        }
        .input-area {
            display: flex;
            gap: 10px;
        }
        input[type="text"] {
            flex: 1;
            padding: 10px;
            border: 1px solid #ddd;
            border-radius: 5px;
            font-size: 16px;
        }
        button {
            padding: 10px 20px;
            background: #007bff;
            color: white;
            border: none;
            border-radius: 5px;
            cursor: pointer;
            font-size: 16px;
        }
        button:hover {
            background: #0056b3;
        }
        button:disabled {
            background: #6c757d;
            cursor: not-allowed;
        }
        .status {
            margin-top: 10px;
            padding: 5px;
            font-size: 14px;
            color: #666;
        }
        .typing-indicator {
            color: #28a745;
            font-style: italic;
        }
    </style>
</head>
<body>
    <div class="chat-container">
        <h1>Real-time Transformer Chat</h1>
        <div id="messages" class="messages"></div>
        <div class="input-area">
            <input type="text" id="prompt-input" placeholder="Enter your message..." />
            <button id="send-button" onclick="sendMessage()">Send</button>
        </div>
        <div id="status" class="status">Ready to chat</div>
    </div>

    <script>
        class TransformerWebSocketClient {
            constructor() {
                this.ws = null;
                this.isConnected = false;
                this.currentResponse = '';
                this.messagesContainer = document.getElementById('messages');
                this.statusElement = document.getElementById('status');
                this.sendButton = document.getElementById('send-button');
                this.promptInput = document.getElementById('prompt-input');
                
                this.connect();
                this.setupEventListeners();
            }
            
            connect() {
                this.updateStatus('Connecting to server...');
                this.ws = new WebSocket('ws://localhost:8765');
                
                this.ws.onopen = () => {
                    this.isConnected = true;
                    this.updateStatus('Connected - Ready to chat');
                    this.sendButton.disabled = false;
                };
                
                this.ws.onmessage = (event) => {
                    const data = JSON.parse(event.data);
                    this.handleMessage(data);
                };
                
                this.ws.onclose = () => {
                    this.isConnected = false;
                    this.updateStatus('Disconnected - Attempting to reconnect...');
                    this.sendButton.disabled = true;
                    // Reconnect after 3 seconds
                    setTimeout(() => this.connect(), 3000);
                };
                
                this.ws.onerror = (error) => {
                    console.error('WebSocket error:', error);
                    this.updateStatus('Connection error');
                };
            }
            
            handleMessage(data) {
                switch (data.type) {
                    case 'started':
                        this.currentResponse = '';
                        this.addMessage('', 'ai-message', 'ai-current');
                        this.updateStatus('AI is thinking...');
                        break;
                        
                    case 'token':
                        this.currentResponse += data.content;
                        this.updateCurrentMessage(this.currentResponse);
                        break;
                        
                    case 'complete':
                        this.finalizeCurrentMessage();
                        this.updateStatus('Ready to chat');
                        this.sendButton.disabled = false;
                        break;
                        
                    case 'error':
                        this.updateStatus(`Error: ${data.content}`);
                        this.sendButton.disabled = false;
                        break;
                }
            }
            
            sendMessage() {
                const prompt = this.promptInput.value.trim();
                if (!prompt || !this.isConnected) return;
                
                // Add user message
                this.addMessage(prompt, 'user-message');
                
                // Send to server
                this.ws.send(JSON.stringify({
                    type: 'generate',
                    prompt: prompt,
                    max_length: 100
                }));
                
                // Clear input and disable button
                this.promptInput.value = '';
                this.sendButton.disabled = true;
                this.updateStatus('Generating response...');
            }
            
            addMessage(content, className, id = null) {
                const messageDiv = document.createElement('div');
                messageDiv.className = `message ${className}`;
                if (id) messageDiv.id = id;
                messageDiv.textContent = content;
                this.messagesContainer.appendChild(messageDiv);
                this.scrollToBottom();
            }
            
            updateCurrentMessage(content) {
                const currentMsg = document.getElementById('ai-current');
                if (currentMsg) {
                    currentMsg.textContent = content;
                    this.scrollToBottom();
                }
            }
            
            finalizeCurrentMessage() {
                const currentMsg = document.getElementById('ai-current');
                if (currentMsg) {
                    currentMsg.id = '';
                }
            }
            
            updateStatus(message) {
                this.statusElement.textContent = message;
            }
            
            scrollToBottom() {
                this.messagesContainer.scrollTop = this.messagesContainer.scrollHeight;
            }
            
            setupEventListeners() {
                this.promptInput.addEventListener('keypress', (e) => {
                    if (e.key === 'Enter' && !e.shiftKey) {
                        e.preventDefault();
                        this.sendMessage();
                    }
                });
            }
        }
        
        // Initialize client when page loads
        let client;
        window.onload = () => {
            client = new TransformerWebSocketClient();
        };
        
        function sendMessage() {
            client.sendMessage();
        }
    </script>
</body>
</html>

Advanced WebSocket Optimizations

Token Batching for Better Performance

Instead of sending every single token, batch them for smoother streaming:

async def generate_streaming_response_batched(self, prompt, websocket, max_length=100, batch_size=3):
    """Generate tokens in batches for optimized streaming"""
    token_batch = []
    
    try:
        inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        generated = inputs
        
        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(generated)
                predictions = outputs.logits[0, -1, :]
                next_token_id = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
            
            token_text = self.tokenizer.decode(next_token_id, skip_special_tokens=True)
            token_batch.append(token_text)
            
            # Send batch when full
            if len(token_batch) >= batch_size:
                await websocket.send(json.dumps({
                    "type": "tokens",
                    "content": "".join(token_batch),
                    "finished": False
                }))
                token_batch = []
            
            generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
            
            if next_token_id.item() == self.tokenizer.eos_token_id:
                break
        
        # Send remaining tokens
        if token_batch:
            await websocket.send(json.dumps({
                "type": "tokens",
                "content": "".join(token_batch),
                "finished": False
            }))
        
        # Send completion
        await websocket.send(json.dumps({
            "type": "complete",
            "content": "",
            "finished": True
        }))
        
    except Exception as e:
        await websocket.send(json.dumps({
            "type": "error",
            "content": f"Generation failed: {str(e)}",
            "finished": True
        }))

Connection Pool Management

Handle multiple concurrent connections efficiently:

class ConnectionManager:
    def __init__(self):
        self.active_connections = set()
        self.connection_stats = {}
    
    async def connect(self, websocket):
        """Register new WebSocket connection"""
        self.active_connections.add(websocket)
        self.connection_stats[websocket] = {
            "connected_at": time.time(),
            "messages_sent": 0,
            "tokens_generated": 0
        }
        logging.info(f"New connection: {len(self.active_connections)} active")
    
    async def disconnect(self, websocket):
        """Remove WebSocket connection"""
        self.active_connections.discard(websocket)
        if websocket in self.connection_stats:
            del self.connection_stats[websocket]
        logging.info(f"Connection closed: {len(self.active_connections)} active")
    
    async def broadcast(self, message):
        """Send message to all connected clients"""
        if self.active_connections:
            await asyncio.gather(
                *[ws.send(message) for ws in self.active_connections],
                return_exceptions=True
            )
    
    def get_stats(self):
        """Return connection statistics"""
        return {
            "active_connections": len(self.active_connections),
            "total_stats": self.connection_stats
        }

Deployment Considerations

Production Server Configuration

Deploy your WebSocket server with proper process management:

# production_server.py
import uvloop
import asyncio
from concurrent.futures import ThreadPoolExecutor
import multiprocessing

class ProductionTransformerServer(TransformerWebSocketServer):
    def __init__(self, model_name, workers=None):
        super().__init__(model_name)
        self.workers = workers or multiprocessing.cpu_count()
        self.executor = ThreadPoolExecutor(max_workers=self.workers)
    
    async def generate_streaming_response(self, prompt, websocket, max_length=100):
        """Run inference in thread pool to prevent blocking"""
        loop = asyncio.get_event_loop()
        
        # Run heavy computation in thread pool
        await loop.run_in_executor(
            self.executor, 
            self._generate_tokens, 
            prompt, 
            websocket, 
            max_length
        )
    
    def start_production_server(self, host="0.0.0.0", port=8765):
        """Start production server with uvloop"""
        asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
        
        start_server = websockets.serve(
            self.handle_client, 
            host, 
            port,
            max_size=10**7,  # 10MB max message size
            ping_interval=20,  # Keep connection alive
            ping_timeout=10,
            close_timeout=10
        )
        
        loop = asyncio.get_event_loop()
        loop.run_until_complete(start_server)
        loop.run_forever()

if __name__ == "__main__":
    server = ProductionTransformerServer("microsoft/DialoGPT-medium")
    server.start_production_server()

Docker Deployment

Create a containerized deployment:

FROM python:3.9-slim

WORKDIR /app

# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# Copy application code
COPY . .

# Expose WebSocket port
EXPOSE 8765

# Run production server
CMD ["python", "production_server.py"]

Load Balancing WebSocket Connections

Use nginx for WebSocket load balancing:

upstream websocket_backend {
    server localhost:8765;
    server localhost:8766;
    server localhost:8767;
}

server {
    listen 80;
    
    location / {
        proxy_pass http://websocket_backend;
        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
        
        # WebSocket specific settings
        proxy_read_timeout 86400;
        proxy_send_timeout 86400;
    }
}

Performance Monitoring and Analytics

Real-time Metrics Collection

Track streaming performance with custom metrics:

import time
from collections import defaultdict
import json

class StreamingMetrics:
    def __init__(self):
        self.metrics = defaultdict(list)
        self.current_sessions = {}
    
    def start_session(self, session_id):
        """Start tracking a new streaming session"""
        self.current_sessions[session_id] = {
            "start_time": time.time(),
            "tokens_sent": 0,
            "bytes_sent": 0,
            "client_ip": None
        }
    
    def record_token(self, session_id, token_text):
        """Record a token being sent"""
        if session_id in self.current_sessions:
            session = self.current_sessions[session_id]
            session["tokens_sent"] += 1
            session["bytes_sent"] += len(token_text.encode('utf-8'))
    
    def end_session(self, session_id):
        """Finish tracking a session"""
        if session_id in self.current_sessions:
            session = self.current_sessions[session_id]
            session["duration"] = time.time() - session["start_time"]
            session["tokens_per_second"] = session["tokens_sent"] / session["duration"]
            
            # Store completed session
            self.metrics["completed_sessions"].append(session)
            del self.current_sessions[session_id]
    
    def get_performance_report(self):
        """Generate performance analytics"""
        completed = self.metrics["completed_sessions"]
        if not completed:
            return {"error": "No completed sessions"}
        
        avg_duration = sum(s["duration"] for s in completed) / len(completed)
        avg_tokens_per_sec = sum(s["tokens_per_second"] for s in completed) / len(completed)
        total_tokens = sum(s["tokens_sent"] for s in completed)
        
        return {
            "total_sessions": len(completed),
            "active_sessions": len(self.current_sessions),
            "average_duration": round(avg_duration, 2),
            "average_tokens_per_second": round(avg_tokens_per_sec, 2),
            "total_tokens_generated": total_tokens,
            "average_tokens_per_session": round(total_tokens / len(completed), 2)
        }

Troubleshooting Common Issues

Connection Handling Problems

Problem: Clients disconnect unexpectedly during long generations.

Solution: Implement heartbeat pings and connection recovery:

async def handle_client_with_heartbeat(self, websocket, path):
    """Handle client with heartbeat monitoring"""
    try:
        # Start heartbeat task
        heartbeat_task = asyncio.create_task(self.heartbeat(websocket))
        
        async for message in websocket:
            data = json.loads(message)
            
            if data.get("type") == "pong":
                # Client responded to ping
                continue
            elif data.get("type") == "generate":
                # Handle generation request
                await self.generate_streaming_response(
                    data.get("prompt"), websocket, data.get("max_length", 100)
                )
    
    except websockets.exceptions.ConnectionClosed:
        logging.info("Client disconnected")
    finally:
        heartbeat_task.cancel()

async def heartbeat(self, websocket):
    """Send periodic ping messages"""
    while True:
        try:
            await websocket.send(json.dumps({"type": "ping"}))
            await asyncio.sleep(30)  # Ping every 30 seconds
        except websockets.exceptions.ConnectionClosed:
            break

Memory Management Issues

Problem: Memory usage grows during long inference sessions.

Solution: Implement token buffer limits and cleanup:

def generate_with_memory_management(self, prompt, max_length=100, buffer_size=50):
    """Generate tokens with memory management"""
    inputs = self.tokenizer.encode(prompt, return_tensors="pt")
    generated = inputs
    
    for step in range(max_length):
        # Trim context if too long
        if generated.size(1) > buffer_size:
            generated = generated[:, -buffer_size:]
        
        # Generate next token
        with torch.no_grad():
            outputs = self.model(generated)
            next_token_id = torch.multinomial(
                torch.softmax(outputs.logits[0, -1, :], dim=-1), 1
            )
        
        # Yield token immediately
        yield self.tokenizer.decode(next_token_id, skip_special_tokens=True)
        
        generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
        
        if next_token_id.item() == self.tokenizer.eos_token_id:
            break
        
        # Clear GPU cache periodically
        if step % 10 == 0:
            torch.cuda.empty_cache()

Conclusion

Real-time transformer inference with WebSocket streaming transforms static AI interactions into dynamic conversations. This implementation provides instant feedback, better user experience, and scalable architecture for production deployment.

The combination of WebSocket connections, streaming token generation, and optimized client handling creates responsive applications that feel natural and engaging. Users receive immediate feedback, while servers efficiently manage multiple concurrent connections.

Deploy this streaming transformer solution to build modern AI applications that meet user expectations for real-time responsiveness. The architecture scales from development prototypes to production systems handling thousands of simultaneous conversations.