Train Federated AI Models Without Sharing Data in 15 Minutes

Build collaborative machine learning systems that keep sensitive data on-device using federated learning with Python and Flower framework.

Understand how federated learning enables collaborative AI training while keeping sensitive data on-device—and build a working example.


Problem: Training AI Requires Centralizing Sensitive Data

Your hospital, bank, or mobile app needs to improve AI models, but regulations prevent sending user data to central servers. Traditional machine learning demands aggregating all training data in one place—creating privacy risks and compliance headaches.

You'll learn:

  • How federated learning trains models across distributed devices
  • Why this solves privacy issues that centralized training can't
  • How to implement a basic federated learning system

Time: 15 min | Level: Intermediate


Why This Matters

Traditional machine learning collects data from thousands of sources into a central database, trains a model, then deploys it. This creates problems:

Privacy risks:

  • Medical records, financial data, or personal messages exposed in transit
  • Single point of failure for data breaches
  • Compliance violations (GDPR, HIPAA, CCPA)

Federated learning flips this: The model travels to the data instead. Each device trains on local data, sends only model updates (not raw data) to a central server, and the server aggregates these updates into a global model.


How Federated Learning Works

Think of it like a study group where students learn from different textbooks at home, then share only their notes—not the actual books.

The process:

  1. Server sends model → Devices receive current global model weights
  2. Local training → Each device trains on its private data for a few epochs
  3. Send gradients → Devices send model weight updates (not data) back to server
  4. Aggregation → Server averages updates using federated averaging (FedAvg)
  5. Repeat → Updated global model sent back to devices for next round

What's actually transmitted:

  • Model parameters (floating point numbers representing neural network weights)
  • Gradient updates (changes to those weights)
  • Never raw data (no images, text, or personal information)

Real-World Examples

Google Gboard (2017): Your phone's keyboard learns your typing patterns locally. Model updates improve predictions for everyone without Google seeing what you type.

Hospitals collaborating on diagnosis: Five hospitals train a tumor detection model without sharing patient X-rays. Each hospital keeps data on-premises, shares only weight updates.

Financial fraud detection: Banks improve fraud models by learning from each other's transaction patterns without revealing customer data or proprietary detection rules.


Building a Federated Learning System

This example uses Flower (flwr), a modern federated learning framework that works with PyTorch and TensorFlow.

Step 1: Install Dependencies

pip install flwr torch torchvision --break-system-packages

Why Flower: Production-ready framework used by researchers and companies, simpler than raw implementations.


Step 2: Create the Model and Data

# model.py - Simple CNN for image classification
import torch
import torch.nn as nn

class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Lightweight model suitable for edge devices
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

def load_partition(client_id, num_clients=3):
    """Simulate data partitioning across clients"""
    from torchvision import datasets, transforms
    
    # Each client gets a subset of MNIST
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    dataset = datasets.MNIST('./data', train=True, download=True, 
                            transform=transform)
    
    # Partition data - each client gets different samples
    partition_size = len(dataset) // num_clients
    start = client_id * partition_size
    end = start + partition_size
    
    return torch.utils.data.Subset(dataset, range(start, end))

Why this structure: Data stays partitioned—each client only accesses their slice, simulating real-world scenarios where devices never share data.


Step 3: Implement Federated Client

# client.py - Runs on each device/hospital/user
import flwr as fl
from model import SimpleCNN, load_partition
import torch

class FederatedClient(fl.client.NumPyClient):
    def __init__(self, client_id):
        self.model = SimpleCNN()
        self.client_id = client_id
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        
        # Load this client's private data partition
        self.trainloader = torch.utils.data.DataLoader(
            load_partition(client_id), batch_size=32, shuffle=True
        )
    
    def get_parameters(self, config):
        """Send current model weights to server"""
        return [val.cpu().numpy() for val in self.model.state_dict().values()]
    
    def set_parameters(self, parameters):
        """Receive updated global model from server"""
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v) for k, v in params_dict}
        self.model.load_state_dict(state_dict, strict=True)
    
    def fit(self, parameters, config):
        """Train on local data - THIS DATA NEVER LEAVES THE DEVICE"""
        self.set_parameters(parameters)
        
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01)
        criterion = torch.nn.CrossEntropyLoss()
        
        self.model.train()
        for epoch in range(1):  # Just 1 epoch per round
            for images, labels in self.trainloader:
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
        
        # Return updated weights (not data) and training metrics
        return self.get_parameters(config={}), len(self.trainloader.dataset), {}
    
    def evaluate(self, parameters, config):
        """Evaluate model on local test set"""
        self.set_parameters(parameters)
        # Evaluation logic here
        return 0.0, len(self.trainloader.dataset), {"accuracy": 0.0}

if __name__ == "__main__":
    import sys
    client_id = int(sys.argv[1]) if len(sys.argv) > 1 else 0
    
    # Connect to federated learning server
    fl.client.start_numpy_client(
        server_address="localhost:8080",
        client=FederatedClient(client_id)
    )

Key security feature: The fit() method trains on local data but only returns model parameters—raw data never enters the return statement.


Step 4: Create Aggregation Server

# server.py - Coordinates training, aggregates updates
import flwr as fl

def weighted_average(metrics):
    """Aggregate client metrics weighted by dataset size"""
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]
    return {"accuracy": sum(accuracies) / sum(examples)}

# Configure federated learning strategy
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,  # Use 100% of available clients each round
    fraction_evaluate=1.0,
    min_fit_clients=3,  # Need at least 3 clients
    min_available_clients=3,
    evaluate_metrics_aggregation_fn=weighted_average,
)

# Start server
fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=5),  # 5 federated learning rounds
    strategy=strategy,
)

What FedAvg does: Takes weight updates from all clients, averages them proportionally to each client's dataset size, produces a new global model.


Step 5: Run the Federated System

Open four terminals:

# Terminal 1: Start server
python server.py

# Terminal 2-4: Start 3 clients (each with different data)
python client.py 0
python client.py 1
python client.py 2

Expected output:

Server: Round 1/5 starting...
Client 0: Training on 20000 samples
Client 1: Training on 20000 samples  
Client 2: Training on 20000 samples
Server: Round 1/5 complete, aggregated model accuracy: 87.3%

If it fails:

  • "Address already in use": Server port 8080 blocked, change to 8081
  • "Connection refused": Start server before clients
  • Out of memory: Reduce batch size in DataLoader from 32 to 16

Verification

After 5 rounds, the global model should reach 90%+ accuracy on MNIST—comparable to centralized training—without any client sharing raw data.

Test the final model:

# test_global.py
from model import SimpleCNN
import torch
from torchvision import datasets, transforms

model = SimpleCNN()
# Load weights from server's final checkpoint
model.load_state_dict(torch.load('global_model.pt'))

testset = datasets.MNIST('./data', train=False, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ]))
testloader = torch.utils.data.DataLoader(testset, batch_size=1000)

model.eval()
correct = 0
with torch.no_grad():
    for images, labels in testloader:
        outputs = model(images)
        correct += (outputs.argmax(1) == labels).sum().item()

print(f"Global model accuracy: {100 * correct / len(testset):.2f}%")

What You Learned

Core concept: Model travels to data, not the other way around. Only parameter updates cross the network.

Privacy guarantee: Clients never send raw data. Servers never see individual training samples.

Trade-offs to know:

  • Slower than centralized training (network latency per round)
  • Requires more rounds to converge (clients may have non-IID data)
  • Communication costs can exceed computation costs on slow networks

When NOT to use this:

  • Data is already centralized and non-sensitive
  • Real-time training needed (federated adds latency)
  • Devices too resource-constrained for local training

Production Considerations

Differential privacy: Add noise to gradients before sending to prevent membership inference attacks:

# In client.py fit() method
from opacus import PrivacyEngine

privacy_engine = PrivacyEngine()
model, optimizer, trainloader = privacy_engine.make_private(
    module=model,
    optimizer=optimizer,
    data_loader=trainloader,
    noise_multiplier=1.1,  # Privacy budget
    max_grad_norm=1.0,
)

Secure aggregation: Encrypt updates so server can't see individual client contributions—only the aggregate. Use libraries like pysyft or TensorFlow Federated's secure aggregation.

Client selection: Don't wait for all clients each round (some devices may be offline). FedAvg's fraction_fit parameter handles this.


Modern Alternatives and Tools

TensorFlow Federated (2024): Google's framework with built-in differential privacy and secure aggregation. Better for mobile deployments.

PySyft: Adds encrypted computation and secure multi-party computation on top of PyTorch.

NVIDIA FLARE (2023): Enterprise-grade framework for healthcare and financial services with HIPAA compliance features.

Apple/Google on-device learning: iOS 17+ and Android 14+ have native federated learning APIs for app developers.