Understand how federated learning enables collaborative AI training while keeping sensitive data on-device—and build a working example.
Problem: Training AI Requires Centralizing Sensitive Data
Your hospital, bank, or mobile app needs to improve AI models, but regulations prevent sending user data to central servers. Traditional machine learning demands aggregating all training data in one place—creating privacy risks and compliance headaches.
You'll learn:
- How federated learning trains models across distributed devices
- Why this solves privacy issues that centralized training can't
- How to implement a basic federated learning system
Time: 15 min | Level: Intermediate
Why This Matters
Traditional machine learning collects data from thousands of sources into a central database, trains a model, then deploys it. This creates problems:
Privacy risks:
- Medical records, financial data, or personal messages exposed in transit
- Single point of failure for data breaches
- Compliance violations (GDPR, HIPAA, CCPA)
Federated learning flips this: The model travels to the data instead. Each device trains on local data, sends only model updates (not raw data) to a central server, and the server aggregates these updates into a global model.
How Federated Learning Works
Think of it like a study group where students learn from different textbooks at home, then share only their notes—not the actual books.
The process:
- Server sends model → Devices receive current global model weights
- Local training → Each device trains on its private data for a few epochs
- Send gradients → Devices send model weight updates (not data) back to server
- Aggregation → Server averages updates using federated averaging (FedAvg)
- Repeat → Updated global model sent back to devices for next round
What's actually transmitted:
- Model parameters (floating point numbers representing neural network weights)
- Gradient updates (changes to those weights)
- Never raw data (no images, text, or personal information)
Real-World Examples
Google Gboard (2017): Your phone's keyboard learns your typing patterns locally. Model updates improve predictions for everyone without Google seeing what you type.
Hospitals collaborating on diagnosis: Five hospitals train a tumor detection model without sharing patient X-rays. Each hospital keeps data on-premises, shares only weight updates.
Financial fraud detection: Banks improve fraud models by learning from each other's transaction patterns without revealing customer data or proprietary detection rules.
Building a Federated Learning System
This example uses Flower (flwr), a modern federated learning framework that works with PyTorch and TensorFlow.
Step 1: Install Dependencies
pip install flwr torch torchvision --break-system-packages
Why Flower: Production-ready framework used by researchers and companies, simpler than raw implementations.
Step 2: Create the Model and Data
# model.py - Simple CNN for image classification
import torch
import torch.nn as nn
class SimpleCNN(nn.Module):
def __init__(self):
super().__init__()
# Lightweight model suitable for edge devices
self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10)
self.pool = nn.MaxPool2d(2, 2)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = x.view(-1, 64 * 7 * 7)
x = torch.relu(self.fc1(x))
return self.fc2(x)
def load_partition(client_id, num_clients=3):
"""Simulate data partitioning across clients"""
from torchvision import datasets, transforms
# Each client gets a subset of MNIST
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
dataset = datasets.MNIST('./data', train=True, download=True,
transform=transform)
# Partition data - each client gets different samples
partition_size = len(dataset) // num_clients
start = client_id * partition_size
end = start + partition_size
return torch.utils.data.Subset(dataset, range(start, end))
Why this structure: Data stays partitioned—each client only accesses their slice, simulating real-world scenarios where devices never share data.
Step 3: Implement Federated Client
# client.py - Runs on each device/hospital/user
import flwr as fl
from model import SimpleCNN, load_partition
import torch
class FederatedClient(fl.client.NumPyClient):
def __init__(self, client_id):
self.model = SimpleCNN()
self.client_id = client_id
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Load this client's private data partition
self.trainloader = torch.utils.data.DataLoader(
load_partition(client_id), batch_size=32, shuffle=True
)
def get_parameters(self, config):
"""Send current model weights to server"""
return [val.cpu().numpy() for val in self.model.state_dict().values()]
def set_parameters(self, parameters):
"""Receive updated global model from server"""
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = {k: torch.tensor(v) for k, v in params_dict}
self.model.load_state_dict(state_dict, strict=True)
def fit(self, parameters, config):
"""Train on local data - THIS DATA NEVER LEAVES THE DEVICE"""
self.set_parameters(parameters)
optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
self.model.train()
for epoch in range(1): # Just 1 epoch per round
for images, labels in self.trainloader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Return updated weights (not data) and training metrics
return self.get_parameters(config={}), len(self.trainloader.dataset), {}
def evaluate(self, parameters, config):
"""Evaluate model on local test set"""
self.set_parameters(parameters)
# Evaluation logic here
return 0.0, len(self.trainloader.dataset), {"accuracy": 0.0}
if __name__ == "__main__":
import sys
client_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
# Connect to federated learning server
fl.client.start_numpy_client(
server_address="localhost:8080",
client=FederatedClient(client_id)
)
Key security feature: The fit() method trains on local data but only returns model parameters—raw data never enters the return statement.
Step 4: Create Aggregation Server
# server.py - Coordinates training, aggregates updates
import flwr as fl
def weighted_average(metrics):
"""Aggregate client metrics weighted by dataset size"""
accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
examples = [num_examples for num_examples, _ in metrics]
return {"accuracy": sum(accuracies) / sum(examples)}
# Configure federated learning strategy
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Use 100% of available clients each round
fraction_evaluate=1.0,
min_fit_clients=3, # Need at least 3 clients
min_available_clients=3,
evaluate_metrics_aggregation_fn=weighted_average,
)
# Start server
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=5), # 5 federated learning rounds
strategy=strategy,
)
What FedAvg does: Takes weight updates from all clients, averages them proportionally to each client's dataset size, produces a new global model.
Step 5: Run the Federated System
Open four terminals:
# Terminal 1: Start server
python server.py
# Terminal 2-4: Start 3 clients (each with different data)
python client.py 0
python client.py 1
python client.py 2
Expected output:
Server: Round 1/5 starting...
Client 0: Training on 20000 samples
Client 1: Training on 20000 samples
Client 2: Training on 20000 samples
Server: Round 1/5 complete, aggregated model accuracy: 87.3%
If it fails:
- "Address already in use": Server port 8080 blocked, change to 8081
- "Connection refused": Start server before clients
- Out of memory: Reduce batch size in
DataLoaderfrom 32 to 16
Verification
After 5 rounds, the global model should reach 90%+ accuracy on MNIST—comparable to centralized training—without any client sharing raw data.
Test the final model:
# test_global.py
from model import SimpleCNN
import torch
from torchvision import datasets, transforms
model = SimpleCNN()
# Load weights from server's final checkpoint
model.load_state_dict(torch.load('global_model.pt'))
testset = datasets.MNIST('./data', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=1000)
model.eval()
correct = 0
with torch.no_grad():
for images, labels in testloader:
outputs = model(images)
correct += (outputs.argmax(1) == labels).sum().item()
print(f"Global model accuracy: {100 * correct / len(testset):.2f}%")
What You Learned
Core concept: Model travels to data, not the other way around. Only parameter updates cross the network.
Privacy guarantee: Clients never send raw data. Servers never see individual training samples.
Trade-offs to know:
- Slower than centralized training (network latency per round)
- Requires more rounds to converge (clients may have non-IID data)
- Communication costs can exceed computation costs on slow networks
When NOT to use this:
- Data is already centralized and non-sensitive
- Real-time training needed (federated adds latency)
- Devices too resource-constrained for local training
Production Considerations
Differential privacy: Add noise to gradients before sending to prevent membership inference attacks:
# In client.py fit() method
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, trainloader = privacy_engine.make_private(
module=model,
optimizer=optimizer,
data_loader=trainloader,
noise_multiplier=1.1, # Privacy budget
max_grad_norm=1.0,
)
Secure aggregation: Encrypt updates so server can't see individual client contributions—only the aggregate. Use libraries like pysyft or TensorFlow Federated's secure aggregation.
Client selection: Don't wait for all clients each round (some devices may be offline). FedAvg's fraction_fit parameter handles this.
Modern Alternatives and Tools
TensorFlow Federated (2024): Google's framework with built-in differential privacy and secure aggregation. Better for mobile deployments.
PySyft: Adds encrypted computation and secure multi-party computation on top of PyTorch.
NVIDIA FLARE (2023): Enterprise-grade framework for healthcare and financial services with HIPAA compliance features.
Apple/Google on-device learning: iOS 17+ and Android 14+ have native federated learning APIs for app developers.