Your training loss hits NaN in epoch 2. Or it converges to 65% accuracy and stops improving. Or it hits 98% train and 55% val. Each has a different fix. You’re not just fighting for better metrics; you’re fighting for your sanity, your GPU budget, and the will to not rm -rf the entire project directory. The problem isn't that your model is broken—it's that you're reading the wrong signals. Let's translate the screams of your training log into an exact repair manual.
The Three Horsemen of the Training Apocalypse
Every training failure falls into one of three camps, and misdiagnosis wastes weeks.
NaN Loss: This isn't a bug; it's a scream. Your model's internal state has become a mathematical impossibility—division by zero, log of a negative, or an explosion of values that exceed float32 precision. It's a catastrophic, immediate failure. Underfitting: Your model is lazy. The training loss is high, the validation loss is high, and the curves look like two flat, sad lines holding hands. The model hasn't learned the signal, often because it can't (lack of capacity) or won't (excessive regularization, poor data). Overfitting: Your model is a cheater. It has memorized the training set's noise and quirks. The training loss plunges while validation loss stagnates or rises—the dreaded divergence. This is the most common failure mode, especially with modern architectures like Transformers, which underpin 94% of top-performing models on 15 major benchmarks (Papers with Code 2025). They have the capacity to memorize your entire dataset if you let them.
NaN Loss: Dissecting the Instant Crash
When your loss becomes nan, stop everything. Do not pass go. Do not collect 200 epochs. The debug loop is: inspect, hypothesize, fix with a tiny experiment.
First, rule out the data. A corrupted image file or a text sample with weird Unicode can cause silent havoc. Use a sanity-check script that iterates through your DataLoader and validates every batch. Next, check your loss function inputs. Are there any zeros where you're taking a log? For classification, ensure your model's final softmax isn't producing pure zeros due to extreme logits.
The most common culprits, however, are the learning rate and gradients. A learning rate that's too high causes the optimizer to take such a massive step that it catapults the parameters into a numerical abyss. This is especially lethal in the early, unstable phase of training. The fix is often a warmup. Don't just drop your LR; schedule it.
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
def get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5
):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
return LambdaLR(optimizer, lr_lambda)
# In your LightningModule
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr=2e-5, weight_decay=1e-4)
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=100, # 100 steps of linear warmup
num_training_steps=self.trainer.estimated_stepping_batches,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
If the learning rate isn't the villain, you likely have exploding gradients. This is classic in deep networks and vanilla RNNs. The gradients grow exponentially through successive layers, turning your weight update into a numerical grenade.
Your NaN Prevention Toolkit: Clipping and Initialization
When you suspect exploding gradients, your first line of defense is gradient clipping. This forcibly rescales the gradient vector if its norm exceeds a threshold, preventing the catastrophic update.
# In PyTorch Lightning, it's a one-liner in the Trainer
trainer = pl.Trainer(
gradient_clip_val=1.0, # Clip gradient norm to 1.0
gradient_clip_algorithm="norm"
)
# For a more manual approach in vanilla PyTorch
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
Pair this with proper weight initialization. A bad initialization can doom training from step zero. For ReLU networks, use He initialization (kaiming_normal_). For Transformers, the default initializations in libraries like Hugging Face are usually correct, but if you're building from scratch, pay attention.
Real Error Fix: Exploding loss in Transformer training — fix: reduce learning rate to 1e-5, add warmup steps (10% of total), clip gradient norm to 1.0.
Finally, normalize your inputs. For vision models, ensure pixel values are scaled reasonably (e.g., to [-1, 1] or [0, 1]). For language models, watch out for extreme token IDs. Batch normalization, cited in 88% of competitive CV architectures as of 2025 (arXiv survey), solves internal covariate shift, but it's not a silver bullet for bad input data.
Underfitting: When Your Model Just Won't Learn
Underfitting means high bias. Your model is too simple for the problem. The first check is model capacity. Are you using a tiny CNN for a 1000-class image task? Swap ResNet-50 (76.1% ImageNet top-1) for EfficientNet-B4 (82.9% top-1) despite the parameter difference (25M vs 19M). More capacity isn't always more params.
| Architecture | ImageNet Top-1 | Parameters | Inference Speed (V100) |
|---|---|---|---|
| ResNet-50 | 76.1% | 25.6M | ~1.2k img/sec |
| EfficientNet-B4 | 82.9% | 19.3M | ~850 img/sec |
| ViT-Base/16 | 81.8%* | 86.6M | ~650 img/sec |
*ViT requires large-scale pre-training to outperform CNNs.
If capacity seems fine, you might be over-regularizing. Are you using Dropout(0.8) on a small model? Is your weight decay (AdamW's weight_decay) set to an aggressive 1e-2? Dial it back. Start with 1e-4 for Transformers, 5e-4 for CNNs.
The most painful source of underfitting is bad or insufficient data. Transfer learning reduces required training data by 10–100x vs training from scratch (DeepMind survey 2025). If you're training from scratch on 10,000 images, you're likely underfitting. Use a pre-trained model from timm or Hugging Face.
Real Error Fix: Training plateaus after epoch 5 — fix: use CosineAnnealingLR with T_max=total_epochs, check if learning rate is too high or dataset has label noise.
Overfitting: The Art of Controlled Forgetting
Overfitting is high variance. Your model has too much freedom. Your toolkit here is about constraining that freedom intelligently.
- Data Augmentation: Artificially expand your dataset with label-preserving transformations. For images, use
Albumentations. For text, use synonym replacement or back-translation. - Dropout: The classic. It randomly zeros activations during training, forcing the network to learn redundant representations. The rate is critical: 0.3–0.5 for dense layers, often lower (0.1) for attention layers in Transformers.
- Weight Decay (L2 Regularization): This penalizes large weights, encouraging the model to find simpler solutions.
AdamWdecouples weight decay from the adaptive learning rate, making it the optimizer of choice.
The trade-off is a dance. More dropout/weight decay reduces overfitting but can cause underfitting. You must measure.
Real Error Fix: Overfitting with 98% train / 62% val accuracy — fix: add Dropout(0.3–0.5) after dense layers, use data augmentation, reduce model capacity or apply weight decay 1e-4.
TensorBoard: Your Training MRI Machine
Printing loss values is like checking a fever with your hand. You need precise diagnostics. Set up TensorBoard (or Weights & Biases) to log:
- Weight/Gradient Histograms: Are your layer weights all drifting to zero or blowing up to
1e10? Are gradients vanishing (all near zero) or exploding? - Gradient Norms per Layer: This pinpoints where in the network the explosion/vanishing happens. A great use for
Ctrl+Shift+Pin VS Code to launch the TensorBoard integration. - Activation Distributions: Watch for "dead ReLUs" where a large portion of neurons output zero.
Here’s a quick setup with PyTorch Lightning and TorchMetrics:
import torchmetrics
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = ...
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
# Log histograms of gradients (SLOWS TRAINING, use for debugging)
self.log_grad_norms = False
def on_after_backward(self):
if self.log_grad_norms and self.global_step % 50 == 0:
for name, param in self.model.named_parameters():
if param.grad is not None:
grad_norm = param.grad.data.norm(2).item()
self.logger.experiment.add_histogram(
f"grad_norm/{name}", grad_norm, self.global_step
)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
self.train_acc(logits, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", self.train_acc, prog_bar=True)
return loss
# Trainer will auto-log to TensorBoard
trainer = pl.Trainer(
max_epochs=10,
logger=pl.loggers.TensorBoardLogger("logs/"),
)
The Golden Rule: Overfit One Batch First
Before you scale to your 10TB dataset, run this systematic experiment:
- Isolate: Select a single, small batch (e.g., 8 samples).
- Overfit: Turn off all regularization (dropout=0, weight_decay=0). Use a standard optimizer like vanilla SGD with a moderate LR (0.01).
- Train: Run for 50-100 steps. Your training loss should go to near zero, and accuracy to 100%. If it can't overfit this tiny batch, your model has a fundamental architectural flaw, a broken loss function, or a data loading bug. This is the fastest possible debug cycle.
- Scale: Once it overfits, reintroduce regularization, switch to your preferred optimizer (AdamW), and scale to the full dataset. The 1-cycle LR schedule vs constant LR shows 15% faster convergence, 0.8% higher final accuracy (fastai study)—apply it now.
Next Steps: From Debugged to Deployed
You've slain the NaN, bridged the generalization gap, and have a model that learns. Now, operationalize your diagnostics. Automate the "overfit one batch" test as a CI check for new model code. Integrate gradient norm logging into your standard training template, even if only sampled. Remember, knowledge distillation achieves 95% of teacher model accuracy at 30% model size (average across 20 papers, 2025)—consider it as a final step to compress your now-well-trained model.
The goal isn't to avoid problems; it's to build a mental map that turns every weird training curve into a specific action. When your validation loss plateaus, you'll know to check label noise before blindly dropping the LR. When gradients explode, you'll reach for clipping and initialization checks. Your training loop becomes a dialogue, not a mystery. Now go fix it.