Imagine your transformer model taking forever to process long sequences. Your attention mechanism burns through computational resources like a gas-guzzling truck. Welcome to the O(n²) attention complexity nightmare that keeps ML engineers awake at night.
The Nyströmformer solves this problem by approximating attention using landmarks. This technique reduces computational complexity from O(n²) to O(n), making long sequence processing actually feasible.
This guide shows you how to build a complete Nyströmformer implementation. You'll learn the mathematical foundation, implement the core components, and optimize performance for real-world applications.
What is Nyströmformer and Why Use Landmarks?
Traditional self-attention computes relationships between every token pair. For a sequence of length n, this creates n² computations. A 1000-token sequence requires 1 million attention calculations.
Nyströmformer approximates the full attention matrix using landmark tokens. Instead of computing all pairwise relationships, it:
- Selects a small subset of landmark tokens
- Computes attention between landmarks and all tokens
- Reconstructs the full attention matrix using the Nyström method
This approach maintains attention quality while dramatically reducing computational cost.
Key Benefits of Landmark-Based Attention
- Linear complexity: O(n) instead of O(n²)
- Memory efficiency: Significantly lower memory usage
- Scalability: Handles sequences up to 64k tokens
- Quality preservation: Maintains 95%+ of full attention performance
Mathematical Foundation: The Nyström Method
The Nyström method approximates a matrix using a low-rank decomposition. For attention matrix A, we select m landmark tokens where m << n.
# Mathematical representation
# A ≈ A[:, landmarks] @ pinv(A[landmarks, landmarks]) @ A[landmarks, :]
The approximation quality depends on landmark selection and the intrinsic rank of the attention matrix.
Landmark Selection Strategies
Three main approaches work for selecting landmarks:
- Uniform sampling: Select every k-th token
- Random sampling: Choose landmarks randomly
- Learned selection: Train a network to select optimal landmarks
Core Implementation: Building the Nyströmformer
Let's implement the essential components step by step.
Step 1: Landmark Attention Module
import torch
import torch.nn as nn
import torch.nn.functional as F
class LandmarkAttention(nn.Module):
def __init__(self, dim, num_landmarks=64, heads=8):
super().__init__()
self.dim = dim
self.heads = heads
self.num_landmarks = num_landmarks
self.head_dim = dim // heads
# Query, Key, Value projections
self.q_proj = nn.Linear(dim, dim, bias=False)
self.k_proj = nn.Linear(dim, dim, bias=False)
self.v_proj = nn.Linear(dim, dim, bias=False)
self.out_proj = nn.Linear(dim, dim)
# Landmark selection (learnable)
self.landmark_proj = nn.Linear(dim, 1)
def select_landmarks(self, x):
"""Select landmark tokens using learned selection"""
# x shape: (batch, seq_len, dim)
batch_size, seq_len = x.shape[:2]
# Compute landmark scores
scores = self.landmark_proj(x).squeeze(-1) # (batch, seq_len)
# Select top-k landmarks
_, indices = torch.topk(scores, self.num_landmarks, dim=-1)
indices = indices.sort(dim=-1)[0] # Sort for stability
return indices
def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape
# Project to Q, K, V
q = self.q_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.heads, self.head_dim)
# Select landmarks
landmark_indices = self.select_landmarks(x) # (batch, num_landmarks)
# Extract landmark tokens
landmark_k = torch.gather(
k, 1, landmark_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.heads, self.head_dim)
) # (batch, num_landmarks, heads, head_dim)
landmark_v = torch.gather(
v, 1, landmark_indices.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, self.heads, self.head_dim)
) # (batch, num_landmarks, heads, head_dim)
# Reshape for attention computation
q = q.transpose(1, 2) # (batch, heads, seq_len, head_dim)
landmark_k = landmark_k.transpose(1, 2) # (batch, heads, num_landmarks, head_dim)
landmark_v = landmark_v.transpose(1, 2) # (batch, heads, num_landmarks, head_dim)
# Compute attention components
return self.nystrom_attention(q, landmark_k, landmark_v, landmark_indices)
def nystrom_attention(self, q, landmark_k, landmark_v, landmark_indices):
"""Compute Nyström approximation of attention"""
batch_size, heads, seq_len, head_dim = q.shape
num_landmarks = landmark_k.shape[2]
# Scale factor for attention
scale = (head_dim ** -0.5)
# A = Q @ K_landmarks^T (batch, heads, seq_len, num_landmarks)
A = torch.einsum('bhid,bhjd->bhij', q, landmark_k) * scale
# B = Q_landmarks @ K_landmarks^T (batch, heads, num_landmarks, num_landmarks)
landmark_q = torch.gather(
q, 2, landmark_indices.unsqueeze(1).unsqueeze(-1).expand(-1, heads, -1, head_dim)
)
B = torch.einsum('bhid,bhjd->bhij', landmark_q, landmark_k) * scale
# Compute pseudo-inverse of B
B_pinv = torch.pinverse(B) # (batch, heads, num_landmarks, num_landmarks)
# Apply softmax to A
A_softmax = F.softmax(A, dim=-1)
# Nyström approximation: A_softmax @ B_pinv @ A^T
attention_weights = torch.einsum('bhij,bhjk,bhkl->bhil', A_softmax, B_pinv, A.transpose(-1, -2))
# Apply attention to values
output = torch.einsum('bhij,bhjd->bhid', attention_weights, landmark_v)
# Reshape output
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.dim)
return self.out_proj(output)
Step 2: Complete Nyströmformer Block
class NystromformerBlock(nn.Module):
def __init__(self, dim, num_landmarks=64, heads=8, mlp_ratio=4, dropout=0.1):
super().__init__()
self.attention = LandmarkAttention(dim, num_landmarks, heads)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
# MLP block
mlp_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x, mask=None):
# Attention with residual connection
attn_out = self.attention(self.norm1(x), mask)
x = x + self.dropout(attn_out)
# MLP with residual connection
mlp_out = self.mlp(self.norm2(x))
x = x + mlp_out
return x
Step 3: Full Nyströmformer Model
class Nystromformer(nn.Module):
def __init__(self, vocab_size, dim=512, depth=6, num_landmarks=64,
heads=8, max_seq_len=4096, num_classes=None):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
# Token and position embeddings
self.token_emb = nn.Embedding(vocab_size, dim)
self.pos_emb = nn.Embedding(max_seq_len, dim)
# Transformer blocks
self.blocks = nn.ModuleList([
NystromformerBlock(dim, num_landmarks, heads)
for _ in range(depth)
])
# Output layers
self.norm = nn.LayerNorm(dim)
if num_classes:
self.classifier = nn.Linear(dim, num_classes)
else:
self.classifier = None
def forward(self, input_ids, attention_mask=None):
batch_size, seq_len = input_ids.shape
# Create embeddings
tokens = self.token_emb(input_ids)
positions = self.pos_emb(torch.arange(seq_len, device=input_ids.device))
x = tokens + positions
# Apply transformer blocks
for block in self.blocks:
x = block(x, attention_mask)
x = self.norm(x)
# Classification head (if provided)
if self.classifier:
# Use [CLS] token (first token) for classification
x = self.classifier(x[:, 0])
return x
Advanced Optimization Techniques
Adaptive Landmark Selection
Improve landmark selection by making it context-aware:
class AdaptiveLandmarkSelector(nn.Module):
def __init__(self, dim, num_landmarks=64):
super().__init__()
self.num_landmarks = num_landmarks
self.query_proj = nn.Linear(dim, dim // 4)
self.key_proj = nn.Linear(dim, dim // 4)
self.score_proj = nn.Linear(dim // 4, 1)
def forward(self, x):
# Compute context-aware scores
q = self.query_proj(x) # (batch, seq_len, dim//4)
k = self.key_proj(x) # (batch, seq_len, dim//4)
# Attention-based scoring
scores = torch.einsum('bid,bjd->bij', q, k) # (batch, seq_len, seq_len)
context_scores = scores.mean(dim=-1) # (batch, seq_len)
# Combine with learned features
feature_scores = self.score_proj(x).squeeze(-1) # (batch, seq_len)
final_scores = context_scores + feature_scores
# Select landmarks
_, indices = torch.topk(final_scores, self.num_landmarks, dim=-1)
return indices.sort(dim=-1)[0]
Memory-Efficient Implementation
Optimize memory usage for very long sequences:
def efficient_nystrom_attention(q, k, v, num_landmarks=64, chunk_size=1024):
"""Memory-efficient Nyström attention for long sequences"""
batch_size, heads, seq_len, head_dim = q.shape
if seq_len <= chunk_size:
# Use standard implementation for short sequences
return standard_nystrom_attention(q, k, v, num_landmarks)
# Process in chunks
output = torch.zeros_like(q)
for i in range(0, seq_len, chunk_size):
end_idx = min(i + chunk_size, seq_len)
chunk_q = q[:, :, i:end_idx]
# Compute attention for this chunk
chunk_output = standard_nystrom_attention(
chunk_q, k, v, num_landmarks
)
output[:, :, i:end_idx] = chunk_output
return output
Performance Comparison and Benchmarks
Computational Complexity Analysis
Compare Nyströmformer with standard attention:
def analyze_complexity():
"""Benchmark computational complexity"""
import time
sequence_lengths = [512, 1024, 2048, 4096, 8192]
dim = 512
num_landmarks = 64
results = []
for seq_len in sequence_lengths:
# Standard attention
x = torch.randn(1, seq_len, dim)
start_time = time.time()
# Simulate O(n²) attention
attn_matrix = torch.randn(seq_len, seq_len)
output = torch.matmul(attn_matrix, x)
standard_time = time.time() - start_time
# Nyström attention
start_time = time.time()
# Simulate O(n) Nyström attention
landmarks = torch.randn(1, num_landmarks, dim)
nystrom_output = approximate_attention(x, landmarks)
nystrom_time = time.time() - start_time
speedup = standard_time / nystrom_time
results.append({
'seq_len': seq_len,
'standard_time': standard_time,
'nystrom_time': nystrom_time,
'speedup': speedup
})
return results
# Expected results:
# seq_len=512: speedup ~2x
# seq_len=1024: speedup ~4x
# seq_len=2048: speedup ~8x
# seq_len=4096: speedup ~16x
Quality Preservation Metrics
Measure how well Nyströmformer preserves attention quality:
def evaluate_attention_quality(model_standard, model_nystrom, test_data):
"""Compare attention quality between models"""
quality_metrics = {
'cosine_similarity': [],
'mse_loss': [],
'attention_entropy': []
}
for batch in test_data:
# Get attention weights from both models
with torch.no_grad():
attn_standard = model_standard.get_attention_weights(batch)
attn_nystrom = model_nystrom.get_attention_weights(batch)
# Cosine similarity
cos_sim = F.cosine_similarity(
attn_standard.flatten(),
attn_nystrom.flatten(),
dim=0
)
quality_metrics['cosine_similarity'].append(cos_sim.item())
# MSE loss
mse = F.mse_loss(attn_standard, attn_nystrom)
quality_metrics['mse_loss'].append(mse.item())
# Attention entropy (diversity measure)
entropy_std = -torch.sum(attn_standard * torch.log(attn_standard + 1e-9), dim=-1).mean()
entropy_nys = -torch.sum(attn_nystrom * torch.log(attn_nystrom + 1e-9), dim=-1).mean()
quality_metrics['attention_entropy'].append(abs(entropy_std - entropy_nys).item())
return {k: np.mean(v) for k, v in quality_metrics.items()}
Training and Fine-tuning Best Practices
Learning Rate Scheduling
Use different learning rates for attention components:
def create_optimizer(model, base_lr=1e-4):
"""Create optimizer with component-specific learning rates"""
param_groups = [
{
'params': [p for n, p in model.named_parameters() if 'landmark' in n],
'lr': base_lr * 0.1, # Lower LR for landmark selection
'weight_decay': 1e-4
},
{
'params': [p for n, p in model.named_parameters() if 'attention' in n and 'landmark' not in n],
'lr': base_lr,
'weight_decay': 1e-5
},
{
'params': [p for n, p in model.named_parameters() if 'attention' not in n],
'lr': base_lr * 2, # Higher LR for other components
'weight_decay': 1e-4
}
]
return torch.optim.AdamW(param_groups)
Gradual Sequence Length Scaling
Train with progressively longer sequences:
class ProgressiveTrainer:
def __init__(self, model, start_len=512, max_len=4096, steps_per_stage=1000):
self.model = model
self.start_len = start_len
self.max_len = max_len
self.steps_per_stage = steps_per_stage
self.current_step = 0
def get_current_seq_len(self):
"""Calculate current sequence length based on training progress"""
stage = self.current_step // self.steps_per_stage
growth_factor = min(stage * 0.5, 3.0) # Cap growth
current_len = int(self.start_len * (2 ** growth_factor))
return min(current_len, self.max_len)
def train_step(self, batch):
# Truncate batch to current sequence length
seq_len = self.get_current_seq_len()
truncated_batch = {k: v[:, :seq_len] for k, v in batch.items()}
# Regular training step
loss = self.model.training_step(truncated_batch)
self.current_step += 1
return loss
Real-World Applications and Use Cases
Document Processing Pipeline
Implement Nyströmformer for long document analysis:
class DocumentProcessor:
def __init__(self, model_path, max_doc_length=16384):
self.model = Nystromformer.load_pretrained(model_path)
self.max_doc_length = max_doc_length
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
def process_document(self, document_text):
"""Process long documents with overlapping windows"""
# Tokenize document
tokens = self.tokenizer.encode(document_text, return_tensors='pt')
if tokens.shape[1] <= self.max_doc_length:
# Single pass for short documents
return self.model(tokens)
# Sliding window approach for long documents
window_size = self.max_doc_length
overlap = window_size // 4
stride = window_size - overlap
outputs = []
for start in range(0, tokens.shape[1], stride):
end = min(start + window_size, tokens.shape[1])
window_tokens = tokens[:, start:end]
with torch.no_grad():
window_output = self.model(window_tokens)
outputs.append(window_output)
# Merge overlapping outputs
return self.merge_outputs(outputs, overlap)
def merge_outputs(self, outputs, overlap):
"""Intelligently merge overlapping window outputs"""
if len(outputs) == 1:
return outputs[0]
# Weighted averaging in overlap regions
merged = outputs[0]
for i, output in enumerate(outputs[1:], 1):
# Implementation depends on output format
merged = self.weighted_merge(merged, output, overlap)
return merged
Code Analysis System
Use Nyströmformer for analyzing long source code files:
class CodeAnalyzer:
def __init__(self, nystrom_model):
self.model = nystrom_model
self.code_tokenizer = CodeTokenizer()
def analyze_repository(self, repo_path):
"""Analyze entire code repository"""
results = {}
for file_path in self.get_code_files(repo_path):
with open(file_path, 'r') as f:
code_content = f.read()
# Tokenize code preserving structure
tokens = self.code_tokenizer.tokenize(code_content)
# Analyze with Nyströmformer
analysis = self.model.analyze_code(tokens)
results[file_path] = {
'complexity_score': analysis.complexity,
'function_boundaries': analysis.functions,
'dependencies': analysis.imports,
'potential_bugs': analysis.issues
}
return results
Troubleshooting Common Issues
Landmark Selection Problems
Fix poor landmark selection:
def debug_landmark_selection(model, input_data):
"""Debug and visualize landmark selection"""
with torch.no_grad():
# Get landmark indices
landmark_indices = model.attention.select_landmarks(input_data)
# Analyze distribution
print(f"Landmark distribution:")
print(f"Min index: {landmark_indices.min().item()}")
print(f"Max index: {landmark_indices.max().item()}")
print(f"Unique landmarks: {len(torch.unique(landmark_indices))}")
# Check for clustering
sorted_indices = torch.sort(landmark_indices)[0]
gaps = sorted_indices[:, 1:] - sorted_indices[:, :-1]
print(f"Average gap: {gaps.float().mean().item():.2f}")
print(f"Max gap: {gaps.max().item()}")
# Visualize selection (if sequence is short enough)
if input_data.shape[1] <= 1000:
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.scatter(range(input_data.shape[1]), [0] * input_data.shape[1], alpha=0.3, label='All tokens')
plt.scatter(landmark_indices[0].cpu(), [0] * len(landmark_indices[0]),
color='red', s=50, label='Landmarks')
plt.legend()
plt.title('Landmark Selection Pattern')
plt.show()
Memory Issues
Handle out-of-memory errors:
def memory_efficient_training(model, dataloader, max_memory_mb=8000):
"""Train with memory monitoring and gradient accumulation"""
accumulation_steps = 1
for batch_idx, batch in enumerate(dataloader):
# Monitor memory usage
if torch.cuda.is_available():
memory_used = torch.cuda.memory_allocated() / 1024 / 1024 # MB
if memory_used > max_memory_mb:
# Increase gradient accumulation
accumulation_steps = min(accumulation_steps * 2, 16)
torch.cuda.empty_cache()
print(f"Increased accumulation steps to {accumulation_steps}")
# Forward pass with gradient accumulation
loss = model(batch) / accumulation_steps
loss.backward()
if (batch_idx + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
Conclusion
The Nyströmformer transforms how we handle long sequences in transformer models. By approximating attention with landmarks, you achieve linear complexity while preserving attention quality.
Key implementation points:
- Landmark selection determines approximation quality
- Memory optimization enables processing very long sequences
- Progressive training improves convergence on long sequences
- Quality metrics help validate attention preservation
This attention approximation technique makes transformer models practical for applications requiring long context understanding. The landmarks approach scales efficiently while maintaining the representational power that makes transformers effective.
Start with the basic implementation and gradually add optimizations based on your specific use case. The computational savings make Nyströmformer essential for production systems handling long sequences.
Ready to implement Nyströmformer in your project? Begin with sequences around 2048 tokens and scale up as you optimize the landmark selection strategy.