How to Build Nyströmformer: Approximating Attention with Landmarks

Learn to build Nyströmformer for efficient attention approximation using landmarks. Reduce transformer complexity from O(n²) to O(n). Complete guide with code.

Imagine your transformer model taking forever to process long sequences. Your attention mechanism burns through computational resources like a gas-guzzling truck. Welcome to the O(n²) attention complexity nightmare that keeps ML engineers awake at night.

The Nyströmformer solves this problem by approximating attention using landmarks. This technique reduces computational complexity from O(n²) to O(n), making long sequence processing actually feasible.

This guide shows you how to build a complete Nyströmformer implementation. You'll learn the mathematical foundation, implement the core components, and optimize performance for real-world applications.

What is Nyströmformer and Why Use Landmarks?

Traditional self-attention computes relationships between every token pair. For a sequence of length n, this creates n² computations. A 1000-token sequence requires 1 million attention calculations.

Nyströmformer approximates the full attention matrix using landmark tokens. Instead of computing all pairwise relationships, it:

  • Selects a small subset of landmark tokens
  • Computes attention between landmarks and all tokens
  • Reconstructs the full attention matrix using the Nyström method

This approach maintains attention quality while dramatically reducing computational cost.

Key Benefits of Landmark-Based Attention

  • Linear complexity: O(n) instead of O(n²)
  • Memory efficiency: Significantly lower memory usage
  • Scalability: Handles sequences up to 64k tokens
  • Quality preservation: Maintains 95%+ of full attention performance

Mathematical Foundation: The Nyström Method

The Nyström method approximates a matrix using a low-rank decomposition. For attention matrix A, we select m landmark tokens where m << n.

# Mathematical representation
# A ≈ A[:, landmarks] @ pinv(A[landmarks, landmarks]) @ A[landmarks, :]

The approximation quality depends on landmark selection and the intrinsic rank of the attention matrix.

Landmark Selection Strategies

Three main approaches work for selecting landmarks:

  1. Uniform sampling: Select every k-th token
  2. Random sampling: Choose landmarks randomly
  3. Learned selection: Train a network to select optimal landmarks

Core Implementation: Building the Nyströmformer

Let's implement the essential components step by step.

Step 1: Landmark Attention Module

import torch
import torch.nn as nn
import torch.nn.functional as F

class LandmarkAttention(nn.Module):
    def __init__(self, dim, num_landmarks=64, heads=8):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.num_landmarks = num_landmarks
        self.head_dim = dim // heads
        
        # Query, Key, Value projections
        self.q_proj = nn.Linear(dim, dim, bias=False)
        self.k_proj = nn.Linear(dim, dim, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim)
        
        # Landmark selection (learnable)
        self.landmark_proj = nn.Linear(dim, 1)
        
    def select_landmarks(self, x):
        """Select landmark tokens using learned selection"""
        # x shape: (batch, seq_len, dim)
        batch_size, seq_len = x.shape[:2]
        
        # Compute landmark scores
        scores = self.landmark_proj(x).squeeze(-1)  # (batch, seq_len)
        
        # Select top-k landmarks
        _, indices = torch.topk(scores, self.num_landmarks, dim=-1)
        indices = indices.sort(dim=-1)[0]  # Sort for stability
        
        return indices
    
    def forward(self, x, mask=None):
        batch_size, seq_len, dim = x.shape
        
        # Project to Q, K, V
        q = self.q_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
        
        # Select landmarks
        landmark_indices = self.select_landmarks(x)  # (batch, num_landmarks)
        
        # Extract landmark tokens
        landmark_k = torch.gather(
            k, 1, landmark_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.heads, self.head_dim)
        )  # (batch, num_landmarks, heads, head_dim)
        
        landmark_v = torch.gather(
            v, 1, landmark_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.heads, self.head_dim)
        )  # (batch, num_landmarks, heads, head_dim)
        
        # Reshape for attention computation
        q = q.transpose(1, 2)  # (batch, heads, seq_len, head_dim)
        landmark_k = landmark_k.transpose(1, 2)  # (batch, heads, num_landmarks, head_dim)
        landmark_v = landmark_v.transpose(1, 2)  # (batch, heads, num_landmarks, head_dim)
        
        # Compute attention components
        return self.nystrom_attention(q, landmark_k, landmark_v, landmark_indices)
    
    def nystrom_attention(self, q, landmark_k, landmark_v, landmark_indices):
        """Compute Nyström approximation of attention"""
        batch_size, heads, seq_len, head_dim = q.shape
        num_landmarks = landmark_k.shape[2]
        
        # Scale factor for attention
        scale = (head_dim ** -0.5)
        
        # A = Q @ K_landmarks^T (batch, heads, seq_len, num_landmarks)
        A = torch.einsum('bhid,bhjd->bhij', q, landmark_k) * scale
        
        # B = Q_landmarks @ K_landmarks^T (batch, heads, num_landmarks, num_landmarks)
        landmark_q = torch.gather(
            q, 2, landmark_indices.unsqueeze(1).unsqueeze(-1).expand(-1, heads, -1, head_dim)
        )
        B = torch.einsum('bhid,bhjd->bhij', landmark_q, landmark_k) * scale
        
        # Compute pseudo-inverse of B
        B_pinv = torch.pinverse(B)  # (batch, heads, num_landmarks, num_landmarks)
        
        # Apply softmax to A
        A_softmax = F.softmax(A, dim=-1)
        
        # Nyström approximation: A_softmax @ B_pinv @ A^T
        attention_weights = torch.einsum('bhij,bhjk,bhkl->bhil', A_softmax, B_pinv, A.transpose(-1, -2))
        
        # Apply attention to values
        output = torch.einsum('bhij,bhjd->bhid', attention_weights, landmark_v)
        
        # Reshape output
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
        
        return self.out_proj(output)

Step 2: Complete Nyströmformer Block

class NystromformerBlock(nn.Module):
    def __init__(self, dim, num_landmarks=64, heads=8, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.attention = LandmarkAttention(dim, num_landmarks, heads)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        
        # MLP block
        mlp_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, mask=None):
        # Attention with residual connection
        attn_out = self.attention(self.norm1(x), mask)
        x = x + self.dropout(attn_out)
        
        # MLP with residual connection
        mlp_out = self.mlp(self.norm2(x))
        x = x + mlp_out
        
        return x

Step 3: Full Nyströmformer Model

class Nystromformer(nn.Module):
    def __init__(self, vocab_size, dim=512, depth=6, num_landmarks=64, 
                 heads=8, max_seq_len=4096, num_classes=None):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        
        # Token and position embeddings
        self.token_emb = nn.Embedding(vocab_size, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            NystromformerBlock(dim, num_landmarks, heads)
            for _ in range(depth)
        ])
        
        # Output layers
        self.norm = nn.LayerNorm(dim)
        
        if num_classes:
            self.classifier = nn.Linear(dim, num_classes)
        else:
            self.classifier = None
    
    def forward(self, input_ids, attention_mask=None):
        batch_size, seq_len = input_ids.shape
        
        # Create embeddings
        tokens = self.token_emb(input_ids)
        positions = self.pos_emb(torch.arange(seq_len, device=input_ids.device))
        x = tokens + positions
        
        # Apply transformer blocks
        for block in self.blocks:
            x = block(x, attention_mask)
        
        x = self.norm(x)
        
        # Classification head (if provided)
        if self.classifier:
            # Use [CLS] token (first token) for classification
            x = self.classifier(x[:, 0])
        
        return x

Advanced Optimization Techniques

Adaptive Landmark Selection

Improve landmark selection by making it context-aware:

class AdaptiveLandmarkSelector(nn.Module):
    def __init__(self, dim, num_landmarks=64):
        super().__init__()
        self.num_landmarks = num_landmarks
        self.query_proj = nn.Linear(dim, dim // 4)
        self.key_proj = nn.Linear(dim, dim // 4)
        self.score_proj = nn.Linear(dim // 4, 1)
        
    def forward(self, x):
        # Compute context-aware scores
        q = self.query_proj(x)  # (batch, seq_len, dim//4)
        k = self.key_proj(x)    # (batch, seq_len, dim//4)
        
        # Attention-based scoring
        scores = torch.einsum('bid,bjd->bij', q, k)  # (batch, seq_len, seq_len)
        context_scores = scores.mean(dim=-1)  # (batch, seq_len)
        
        # Combine with learned features
        feature_scores = self.score_proj(x).squeeze(-1)  # (batch, seq_len)
        final_scores = context_scores + feature_scores
        
        # Select landmarks
        _, indices = torch.topk(final_scores, self.num_landmarks, dim=-1)
        return indices.sort(dim=-1)[0]

Memory-Efficient Implementation

Optimize memory usage for very long sequences:

def efficient_nystrom_attention(q, k, v, num_landmarks=64, chunk_size=1024):
    """Memory-efficient Nyström attention for long sequences"""
    batch_size, heads, seq_len, head_dim = q.shape
    
    if seq_len <= chunk_size:
        # Use standard implementation for short sequences
        return standard_nystrom_attention(q, k, v, num_landmarks)
    
    # Process in chunks
    output = torch.zeros_like(q)
    
    for i in range(0, seq_len, chunk_size):
        end_idx = min(i + chunk_size, seq_len)
        chunk_q = q[:, :, i:end_idx]
        
        # Compute attention for this chunk
        chunk_output = standard_nystrom_attention(
            chunk_q, k, v, num_landmarks
        )
        output[:, :, i:end_idx] = chunk_output
    
    return output

Performance Comparison and Benchmarks

Computational Complexity Analysis

Compare Nyströmformer with standard attention:

def analyze_complexity():
    """Benchmark computational complexity"""
    import time
    
    sequence_lengths = [512, 1024, 2048, 4096, 8192]
    dim = 512
    num_landmarks = 64
    
    results = []
    
    for seq_len in sequence_lengths:
        # Standard attention
        x = torch.randn(1, seq_len, dim)
        
        start_time = time.time()
        # Simulate O(n²) attention
        attn_matrix = torch.randn(seq_len, seq_len)
        output = torch.matmul(attn_matrix, x)
        standard_time = time.time() - start_time
        
        # Nyström attention
        start_time = time.time()
        # Simulate O(n) Nyström attention
        landmarks = torch.randn(1, num_landmarks, dim)
        nystrom_output = approximate_attention(x, landmarks)
        nystrom_time = time.time() - start_time
        
        speedup = standard_time / nystrom_time
        results.append({
            'seq_len': seq_len,
            'standard_time': standard_time,
            'nystrom_time': nystrom_time,
            'speedup': speedup
        })
    
    return results

# Expected results:
# seq_len=512:  speedup ~2x
# seq_len=1024: speedup ~4x  
# seq_len=2048: speedup ~8x
# seq_len=4096: speedup ~16x

Quality Preservation Metrics

Measure how well Nyströmformer preserves attention quality:

def evaluate_attention_quality(model_standard, model_nystrom, test_data):
    """Compare attention quality between models"""
    
    quality_metrics = {
        'cosine_similarity': [],
        'mse_loss': [],
        'attention_entropy': []
    }
    
    for batch in test_data:
        # Get attention weights from both models
        with torch.no_grad():
            attn_standard = model_standard.get_attention_weights(batch)
            attn_nystrom = model_nystrom.get_attention_weights(batch)
            
            # Cosine similarity
            cos_sim = F.cosine_similarity(
                attn_standard.flatten(), 
                attn_nystrom.flatten(), 
                dim=0
            )
            quality_metrics['cosine_similarity'].append(cos_sim.item())
            
            # MSE loss
            mse = F.mse_loss(attn_standard, attn_nystrom)
            quality_metrics['mse_loss'].append(mse.item())
            
            # Attention entropy (diversity measure)
            entropy_std = -torch.sum(attn_standard * torch.log(attn_standard + 1e-9), dim=-1).mean()
            entropy_nys = -torch.sum(attn_nystrom * torch.log(attn_nystrom + 1e-9), dim=-1).mean()
            quality_metrics['attention_entropy'].append(abs(entropy_std - entropy_nys).item())
    
    return {k: np.mean(v) for k, v in quality_metrics.items()}

Training and Fine-tuning Best Practices

Learning Rate Scheduling

Use different learning rates for attention components:

def create_optimizer(model, base_lr=1e-4):
    """Create optimizer with component-specific learning rates"""
    
    param_groups = [
        {
            'params': [p for n, p in model.named_parameters() if 'landmark' in n],
            'lr': base_lr * 0.1,  # Lower LR for landmark selection
            'weight_decay': 1e-4
        },
        {
            'params': [p for n, p in model.named_parameters() if 'attention' in n and 'landmark' not in n],
            'lr': base_lr,
            'weight_decay': 1e-5
        },
        {
            'params': [p for n, p in model.named_parameters() if 'attention' not in n],
            'lr': base_lr * 2,  # Higher LR for other components
            'weight_decay': 1e-4
        }
    ]
    
    return torch.optim.AdamW(param_groups)

Gradual Sequence Length Scaling

Train with progressively longer sequences:

class ProgressiveTrainer:
    def __init__(self, model, start_len=512, max_len=4096, steps_per_stage=1000):
        self.model = model
        self.start_len = start_len
        self.max_len = max_len
        self.steps_per_stage = steps_per_stage
        self.current_step = 0
        
    def get_current_seq_len(self):
        """Calculate current sequence length based on training progress"""
        stage = self.current_step // self.steps_per_stage
        growth_factor = min(stage * 0.5, 3.0)  # Cap growth
        current_len = int(self.start_len * (2 ** growth_factor))
        return min(current_len, self.max_len)
    
    def train_step(self, batch):
        # Truncate batch to current sequence length
        seq_len = self.get_current_seq_len()
        truncated_batch = {k: v[:, :seq_len] for k, v in batch.items()}
        
        # Regular training step
        loss = self.model.training_step(truncated_batch)
        self.current_step += 1
        
        return loss

Real-World Applications and Use Cases

Document Processing Pipeline

Implement Nyströmformer for long document analysis:

class DocumentProcessor:
    def __init__(self, model_path, max_doc_length=16384):
        self.model = Nystromformer.load_pretrained(model_path)
        self.max_doc_length = max_doc_length
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
    
    def process_document(self, document_text):
        """Process long documents with overlapping windows"""
        
        # Tokenize document
        tokens = self.tokenizer.encode(document_text, return_tensors='pt')
        
        if tokens.shape[1] <= self.max_doc_length:
            # Single pass for short documents
            return self.model(tokens)
        
        # Sliding window approach for long documents
        window_size = self.max_doc_length
        overlap = window_size // 4
        stride = window_size - overlap
        
        outputs = []
        for start in range(0, tokens.shape[1], stride):
            end = min(start + window_size, tokens.shape[1])
            window_tokens = tokens[:, start:end]
            
            with torch.no_grad():
                window_output = self.model(window_tokens)
                outputs.append(window_output)
        
        # Merge overlapping outputs
        return self.merge_outputs(outputs, overlap)
    
    def merge_outputs(self, outputs, overlap):
        """Intelligently merge overlapping window outputs"""
        if len(outputs) == 1:
            return outputs[0]
        
        # Weighted averaging in overlap regions
        merged = outputs[0]
        for i, output in enumerate(outputs[1:], 1):
            # Implementation depends on output format
            merged = self.weighted_merge(merged, output, overlap)
        
        return merged

Code Analysis System

Use Nyströmformer for analyzing long source code files:

class CodeAnalyzer:
    def __init__(self, nystrom_model):
        self.model = nystrom_model
        self.code_tokenizer = CodeTokenizer()
    
    def analyze_repository(self, repo_path):
        """Analyze entire code repository"""
        
        results = {}
        for file_path in self.get_code_files(repo_path):
            with open(file_path, 'r') as f:
                code_content = f.read()
            
            # Tokenize code preserving structure
            tokens = self.code_tokenizer.tokenize(code_content)
            
            # Analyze with Nyströmformer
            analysis = self.model.analyze_code(tokens)
            results[file_path] = {
                'complexity_score': analysis.complexity,
                'function_boundaries': analysis.functions,
                'dependencies': analysis.imports,
                'potential_bugs': analysis.issues
            }
        
        return results

Troubleshooting Common Issues

Landmark Selection Problems

Fix poor landmark selection:

def debug_landmark_selection(model, input_data):
    """Debug and visualize landmark selection"""
    
    with torch.no_grad():
        # Get landmark indices
        landmark_indices = model.attention.select_landmarks(input_data)
        
        # Analyze distribution
        print(f"Landmark distribution:")
        print(f"Min index: {landmark_indices.min().item()}")
        print(f"Max index: {landmark_indices.max().item()}")
        print(f"Unique landmarks: {len(torch.unique(landmark_indices))}")
        
        # Check for clustering
        sorted_indices = torch.sort(landmark_indices)[0]
        gaps = sorted_indices[:, 1:] - sorted_indices[:, :-1]
        print(f"Average gap: {gaps.float().mean().item():.2f}")
        print(f"Max gap: {gaps.max().item()}")
        
        # Visualize selection (if sequence is short enough)
        if input_data.shape[1] <= 1000:
            import matplotlib.pyplot as plt
            plt.figure(figsize=(12, 4))
            plt.scatter(range(input_data.shape[1]), [0] * input_data.shape[1], alpha=0.3, label='All tokens')
            plt.scatter(landmark_indices[0].cpu(), [0] * len(landmark_indices[0]), 
                       color='red', s=50, label='Landmarks')
            plt.legend()
            plt.title('Landmark Selection Pattern')
            plt.show()

Memory Issues

Handle out-of-memory errors:

def memory_efficient_training(model, dataloader, max_memory_mb=8000):
    """Train with memory monitoring and gradient accumulation"""
    
    accumulation_steps = 1
    
    for batch_idx, batch in enumerate(dataloader):
        # Monitor memory usage
        if torch.cuda.is_available():
            memory_used = torch.cuda.memory_allocated() / 1024 / 1024  # MB
            
            if memory_used > max_memory_mb:
                # Increase gradient accumulation
                accumulation_steps = min(accumulation_steps * 2, 16)
                torch.cuda.empty_cache()
                print(f"Increased accumulation steps to {accumulation_steps}")
        
        # Forward pass with gradient accumulation
        loss = model(batch) / accumulation_steps
        loss.backward()
        
        if (batch_idx + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

Conclusion

The Nyströmformer transforms how we handle long sequences in transformer models. By approximating attention with landmarks, you achieve linear complexity while preserving attention quality.

Key implementation points:

  • Landmark selection determines approximation quality
  • Memory optimization enables processing very long sequences
  • Progressive training improves convergence on long sequences
  • Quality metrics help validate attention preservation

This attention approximation technique makes transformer models practical for applications requiring long context understanding. The landmarks approach scales efficiently while maintaining the representational power that makes transformers effective.

Start with the basic implementation and gradually add optimizations based on your specific use case. The computational savings make Nyströmformer essential for production systems handling long sequences.

Ready to implement Nyströmformer in your project? Begin with sequences around 2048 tokens and scale up as you optimize the landmark selection strategy.