Waiting 30 seconds for your AI chatbot to respond? That's so 2022. Modern users expect instant, streaming responses that feel like real conversations. Here's how to build lightning-fast transformer inference with WebSocket streaming that delivers tokens as they're generated.
Why Real-time Transformer Inference Matters
Traditional HTTP-based inference creates a frustrating user experience. Users submit a prompt, stare at a loading spinner, then receive a wall of text all at once. WebSocket streaming transforms this into a natural conversation flow where responses appear word-by-word, just like ChatGPT.
Primary Benefits:
- Instant feedback: Users see responses immediately
- Better UX: Streaming text feels more interactive
- Scalable connections: WebSockets handle thousands of concurrent users
- Reduced perceived latency: Users engage with partial responses
Understanding WebSocket Streaming Architecture
WebSocket connections enable bidirectional communication between client and server. For transformer inference, this means:
- Client sends prompt via WebSocket
- Server generates tokens incrementally
- Each token streams back immediately
- Client displays tokens in real-time
Key Components for Implementation
- WebSocket Server: Handles connections and inference requests
- Transformer Model: Generates tokens with streaming capability
- Token Buffer: Manages partial responses
- Client Handler: Processes incoming tokens
- Error Management: Handles connection failures gracefully
Setting Up the WebSocket Server
Let's build a complete WebSocket server that streams transformer outputs:
import asyncio
import websockets
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import logging
class TransformerWebSocketServer:
def __init__(self, model_name="microsoft/DialoGPT-medium"):
"""Initialize transformer model and tokenizer for streaming inference"""
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Add padding token if missing
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
async def generate_streaming_response(self, prompt, websocket, max_length=100):
"""Generate tokens and stream them via WebSocket"""
try:
# Tokenize input prompt
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# Initialize generation parameters
generated = inputs
for _ in range(max_length):
# Generate next token
with torch.no_grad():
outputs = self.model(generated)
predictions = outputs.logits[0, -1, :]
next_token_id = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
# Decode token to text
token_text = self.tokenizer.decode(next_token_id, skip_special_tokens=True)
# Stream token immediately
await websocket.send(json.dumps({
"type": "token",
"content": token_text,
"finished": False
}))
# Add token to generated sequence
generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
# Check for end token
if next_token_id.item() == self.tokenizer.eos_token_id:
break
# Small delay to prevent overwhelming client
await asyncio.sleep(0.05)
# Send completion signal
await websocket.send(json.dumps({
"type": "complete",
"content": "",
"finished": True
}))
except Exception as e:
self.logger.error(f"Generation error: {str(e)}")
await websocket.send(json.dumps({
"type": "error",
"content": f"Generation failed: {str(e)}",
"finished": True
}))
async def handle_client(self, websocket, path):
"""Handle incoming WebSocket connections"""
self.logger.info(f"New client connected: {websocket.remote_address}")
try:
async for message in websocket:
data = json.loads(message)
if data.get("type") == "generate":
prompt = data.get("prompt", "")
max_length = data.get("max_length", 100)
# Send acknowledgment
await websocket.send(json.dumps({
"type": "started",
"content": "Generation started",
"finished": False
}))
# Start streaming generation
await self.generate_streaming_response(prompt, websocket, max_length)
except websockets.exceptions.ConnectionClosed:
self.logger.info(f"Client disconnected: {websocket.remote_address}")
except Exception as e:
self.logger.error(f"Connection error: {str(e)}")
def start_server(self, host="localhost", port=8765):
"""Start the WebSocket server"""
self.logger.info(f"Starting WebSocket server on {host}:{port}")
start_server = websockets.serve(self.handle_client, host, port)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
# Initialize and start server
if __name__ == "__main__":
server = TransformerWebSocketServer()
server.start_server()
Building the Client-Side Interface
Create a responsive web interface that handles streaming responses:
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Real-time Transformer Chat</title>
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
max-width: 800px;
margin: 0 auto;
padding: 20px;
background: #f5f5f5;
}
.chat-container {
background: white;
border-radius: 10px;
padding: 20px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.messages {
height: 400px;
overflow-y: auto;
border: 1px solid #ddd;
padding: 15px;
margin-bottom: 20px;
border-radius: 5px;
}
.message {
margin-bottom: 15px;
padding: 10px;
border-radius: 5px;
}
.user-message {
background: #007bff;
color: white;
margin-left: 20%;
}
.ai-message {
background: #f8f9fa;
border-left: 4px solid #28a745;
margin-right: 20%;
}
.input-area {
display: flex;
gap: 10px;
}
input[type="text"] {
flex: 1;
padding: 10px;
border: 1px solid #ddd;
border-radius: 5px;
font-size: 16px;
}
button {
padding: 10px 20px;
background: #007bff;
color: white;
border: none;
border-radius: 5px;
cursor: pointer;
font-size: 16px;
}
button:hover {
background: #0056b3;
}
button:disabled {
background: #6c757d;
cursor: not-allowed;
}
.status {
margin-top: 10px;
padding: 5px;
font-size: 14px;
color: #666;
}
.typing-indicator {
color: #28a745;
font-style: italic;
}
</style>
</head>
<body>
<div class="chat-container">
<h1>Real-time Transformer Chat</h1>
<div id="messages" class="messages"></div>
<div class="input-area">
<input type="text" id="prompt-input" placeholder="Enter your message..." />
<button id="send-button" onclick="sendMessage()">Send</button>
</div>
<div id="status" class="status">Ready to chat</div>
</div>
<script>
class TransformerWebSocketClient {
constructor() {
this.ws = null;
this.isConnected = false;
this.currentResponse = '';
this.messagesContainer = document.getElementById('messages');
this.statusElement = document.getElementById('status');
this.sendButton = document.getElementById('send-button');
this.promptInput = document.getElementById('prompt-input');
this.connect();
this.setupEventListeners();
}
connect() {
this.updateStatus('Connecting to server...');
this.ws = new WebSocket('ws://localhost:8765');
this.ws.onopen = () => {
this.isConnected = true;
this.updateStatus('Connected - Ready to chat');
this.sendButton.disabled = false;
};
this.ws.onmessage = (event) => {
const data = JSON.parse(event.data);
this.handleMessage(data);
};
this.ws.onclose = () => {
this.isConnected = false;
this.updateStatus('Disconnected - Attempting to reconnect...');
this.sendButton.disabled = true;
// Reconnect after 3 seconds
setTimeout(() => this.connect(), 3000);
};
this.ws.onerror = (error) => {
console.error('WebSocket error:', error);
this.updateStatus('Connection error');
};
}
handleMessage(data) {
switch (data.type) {
case 'started':
this.currentResponse = '';
this.addMessage('', 'ai-message', 'ai-current');
this.updateStatus('AI is thinking...');
break;
case 'token':
this.currentResponse += data.content;
this.updateCurrentMessage(this.currentResponse);
break;
case 'complete':
this.finalizeCurrentMessage();
this.updateStatus('Ready to chat');
this.sendButton.disabled = false;
break;
case 'error':
this.updateStatus(`Error: ${data.content}`);
this.sendButton.disabled = false;
break;
}
}
sendMessage() {
const prompt = this.promptInput.value.trim();
if (!prompt || !this.isConnected) return;
// Add user message
this.addMessage(prompt, 'user-message');
// Send to server
this.ws.send(JSON.stringify({
type: 'generate',
prompt: prompt,
max_length: 100
}));
// Clear input and disable button
this.promptInput.value = '';
this.sendButton.disabled = true;
this.updateStatus('Generating response...');
}
addMessage(content, className, id = null) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${className}`;
if (id) messageDiv.id = id;
messageDiv.textContent = content;
this.messagesContainer.appendChild(messageDiv);
this.scrollToBottom();
}
updateCurrentMessage(content) {
const currentMsg = document.getElementById('ai-current');
if (currentMsg) {
currentMsg.textContent = content;
this.scrollToBottom();
}
}
finalizeCurrentMessage() {
const currentMsg = document.getElementById('ai-current');
if (currentMsg) {
currentMsg.id = '';
}
}
updateStatus(message) {
this.statusElement.textContent = message;
}
scrollToBottom() {
this.messagesContainer.scrollTop = this.messagesContainer.scrollHeight;
}
setupEventListeners() {
this.promptInput.addEventListener('keypress', (e) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
this.sendMessage();
}
});
}
}
// Initialize client when page loads
let client;
window.onload = () => {
client = new TransformerWebSocketClient();
};
function sendMessage() {
client.sendMessage();
}
</script>
</body>
</html>
Advanced WebSocket Optimizations
Token Batching for Better Performance
Instead of sending every single token, batch them for smoother streaming:
async def generate_streaming_response_batched(self, prompt, websocket, max_length=100, batch_size=3):
"""Generate tokens in batches for optimized streaming"""
token_batch = []
try:
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
generated = inputs
for _ in range(max_length):
with torch.no_grad():
outputs = self.model(generated)
predictions = outputs.logits[0, -1, :]
next_token_id = torch.multinomial(torch.softmax(predictions, dim=-1), 1)
token_text = self.tokenizer.decode(next_token_id, skip_special_tokens=True)
token_batch.append(token_text)
# Send batch when full
if len(token_batch) >= batch_size:
await websocket.send(json.dumps({
"type": "tokens",
"content": "".join(token_batch),
"finished": False
}))
token_batch = []
generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
if next_token_id.item() == self.tokenizer.eos_token_id:
break
# Send remaining tokens
if token_batch:
await websocket.send(json.dumps({
"type": "tokens",
"content": "".join(token_batch),
"finished": False
}))
# Send completion
await websocket.send(json.dumps({
"type": "complete",
"content": "",
"finished": True
}))
except Exception as e:
await websocket.send(json.dumps({
"type": "error",
"content": f"Generation failed: {str(e)}",
"finished": True
}))
Connection Pool Management
Handle multiple concurrent connections efficiently:
class ConnectionManager:
def __init__(self):
self.active_connections = set()
self.connection_stats = {}
async def connect(self, websocket):
"""Register new WebSocket connection"""
self.active_connections.add(websocket)
self.connection_stats[websocket] = {
"connected_at": time.time(),
"messages_sent": 0,
"tokens_generated": 0
}
logging.info(f"New connection: {len(self.active_connections)} active")
async def disconnect(self, websocket):
"""Remove WebSocket connection"""
self.active_connections.discard(websocket)
if websocket in self.connection_stats:
del self.connection_stats[websocket]
logging.info(f"Connection closed: {len(self.active_connections)} active")
async def broadcast(self, message):
"""Send message to all connected clients"""
if self.active_connections:
await asyncio.gather(
*[ws.send(message) for ws in self.active_connections],
return_exceptions=True
)
def get_stats(self):
"""Return connection statistics"""
return {
"active_connections": len(self.active_connections),
"total_stats": self.connection_stats
}
Deployment Considerations
Production Server Configuration
Deploy your WebSocket server with proper process management:
# production_server.py
import uvloop
import asyncio
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
class ProductionTransformerServer(TransformerWebSocketServer):
def __init__(self, model_name, workers=None):
super().__init__(model_name)
self.workers = workers or multiprocessing.cpu_count()
self.executor = ThreadPoolExecutor(max_workers=self.workers)
async def generate_streaming_response(self, prompt, websocket, max_length=100):
"""Run inference in thread pool to prevent blocking"""
loop = asyncio.get_event_loop()
# Run heavy computation in thread pool
await loop.run_in_executor(
self.executor,
self._generate_tokens,
prompt,
websocket,
max_length
)
def start_production_server(self, host="0.0.0.0", port=8765):
"""Start production server with uvloop"""
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
start_server = websockets.serve(
self.handle_client,
host,
port,
max_size=10**7, # 10MB max message size
ping_interval=20, # Keep connection alive
ping_timeout=10,
close_timeout=10
)
loop = asyncio.get_event_loop()
loop.run_until_complete(start_server)
loop.run_forever()
if __name__ == "__main__":
server = ProductionTransformerServer("microsoft/DialoGPT-medium")
server.start_production_server()
Docker Deployment
Create a containerized deployment:
FROM python:3.9-slim
WORKDIR /app
# Install dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# Copy application code
COPY . .
# Expose WebSocket port
EXPOSE 8765
# Run production server
CMD ["python", "production_server.py"]
Load Balancing WebSocket Connections
Use nginx for WebSocket load balancing:
upstream websocket_backend {
server localhost:8765;
server localhost:8766;
server localhost:8767;
}
server {
listen 80;
location / {
proxy_pass http://websocket_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# WebSocket specific settings
proxy_read_timeout 86400;
proxy_send_timeout 86400;
}
}
Performance Monitoring and Analytics
Real-time Metrics Collection
Track streaming performance with custom metrics:
import time
from collections import defaultdict
import json
class StreamingMetrics:
def __init__(self):
self.metrics = defaultdict(list)
self.current_sessions = {}
def start_session(self, session_id):
"""Start tracking a new streaming session"""
self.current_sessions[session_id] = {
"start_time": time.time(),
"tokens_sent": 0,
"bytes_sent": 0,
"client_ip": None
}
def record_token(self, session_id, token_text):
"""Record a token being sent"""
if session_id in self.current_sessions:
session = self.current_sessions[session_id]
session["tokens_sent"] += 1
session["bytes_sent"] += len(token_text.encode('utf-8'))
def end_session(self, session_id):
"""Finish tracking a session"""
if session_id in self.current_sessions:
session = self.current_sessions[session_id]
session["duration"] = time.time() - session["start_time"]
session["tokens_per_second"] = session["tokens_sent"] / session["duration"]
# Store completed session
self.metrics["completed_sessions"].append(session)
del self.current_sessions[session_id]
def get_performance_report(self):
"""Generate performance analytics"""
completed = self.metrics["completed_sessions"]
if not completed:
return {"error": "No completed sessions"}
avg_duration = sum(s["duration"] for s in completed) / len(completed)
avg_tokens_per_sec = sum(s["tokens_per_second"] for s in completed) / len(completed)
total_tokens = sum(s["tokens_sent"] for s in completed)
return {
"total_sessions": len(completed),
"active_sessions": len(self.current_sessions),
"average_duration": round(avg_duration, 2),
"average_tokens_per_second": round(avg_tokens_per_sec, 2),
"total_tokens_generated": total_tokens,
"average_tokens_per_session": round(total_tokens / len(completed), 2)
}
Troubleshooting Common Issues
Connection Handling Problems
Problem: Clients disconnect unexpectedly during long generations.
Solution: Implement heartbeat pings and connection recovery:
async def handle_client_with_heartbeat(self, websocket, path):
"""Handle client with heartbeat monitoring"""
try:
# Start heartbeat task
heartbeat_task = asyncio.create_task(self.heartbeat(websocket))
async for message in websocket:
data = json.loads(message)
if data.get("type") == "pong":
# Client responded to ping
continue
elif data.get("type") == "generate":
# Handle generation request
await self.generate_streaming_response(
data.get("prompt"), websocket, data.get("max_length", 100)
)
except websockets.exceptions.ConnectionClosed:
logging.info("Client disconnected")
finally:
heartbeat_task.cancel()
async def heartbeat(self, websocket):
"""Send periodic ping messages"""
while True:
try:
await websocket.send(json.dumps({"type": "ping"}))
await asyncio.sleep(30) # Ping every 30 seconds
except websockets.exceptions.ConnectionClosed:
break
Memory Management Issues
Problem: Memory usage grows during long inference sessions.
Solution: Implement token buffer limits and cleanup:
def generate_with_memory_management(self, prompt, max_length=100, buffer_size=50):
"""Generate tokens with memory management"""
inputs = self.tokenizer.encode(prompt, return_tensors="pt")
generated = inputs
for step in range(max_length):
# Trim context if too long
if generated.size(1) > buffer_size:
generated = generated[:, -buffer_size:]
# Generate next token
with torch.no_grad():
outputs = self.model(generated)
next_token_id = torch.multinomial(
torch.softmax(outputs.logits[0, -1, :], dim=-1), 1
)
# Yield token immediately
yield self.tokenizer.decode(next_token_id, skip_special_tokens=True)
generated = torch.cat([generated, next_token_id.unsqueeze(0)], dim=-1)
if next_token_id.item() == self.tokenizer.eos_token_id:
break
# Clear GPU cache periodically
if step % 10 == 0:
torch.cuda.empty_cache()
Conclusion
Real-time transformer inference with WebSocket streaming transforms static AI interactions into dynamic conversations. This implementation provides instant feedback, better user experience, and scalable architecture for production deployment.
The combination of WebSocket connections, streaming token generation, and optimized client handling creates responsive applications that feel natural and engaging. Users receive immediate feedback, while servers efficiently manage multiple concurrent connections.
Deploy this streaming transformer solution to build modern AI applications that meet user expectations for real-time responsiveness. The architecture scales from development prototypes to production systems handling thousands of simultaneous conversations.