3.2.1. The Training Process: Epochs, Batches, and Steps
💡 First Principle: An epoch is one complete pass through the training data, a batch is a subset processed together, and a step is one weight update. These three quantities determine how long training takes, how much memory you need, and how smoothly the model converges. Getting them wrong wastes compute or produces poor models.
| Concept | Definition | Effect of Increasing | Effect of Decreasing |
|---|---|---|---|
| Epoch | One full pass through training data | More learning, risk of overfitting | Less learning, may underfit |
| Batch size | Number of samples per weight update | Faster training (parallelism), less noisy gradients, more memory | Noisier gradients (can help escape local minima), less memory |
| Steps per epoch | dataset_size / batch_size | More frequent updates (smaller batches) | Fewer updates (larger batches) |
| Learning rate | How much weights change per update | Faster convergence, risk of divergence | More stable, slower convergence |
The relationship: total_steps = epochs × (dataset_size / batch_size). A training job with 100,000 samples, batch size 100, and 10 epochs performs 10,000 weight updates. Understanding this relationship helps you predict training time and diagnose convergence issues.
Learning Rate Scheduling: A fixed learning rate is rarely optimal. Common strategies include step decay (reduce by factor every N epochs), cosine annealing (smooth reduction following a cosine curve), and warm-up (start low, increase, then decay). SageMaker supports these through framework-level configuration in Script Mode.
⚠️ Exam Trap: Increasing batch size does NOT always speed up training proportionally. Larger batches require more GPU memory and can lead to poorer generalization (the "generalization gap"). If a question describes a model that trains fast but generalizes poorly, overly large batch size may be the cause—not insufficient epochs.
Reflection Question: A training job with batch_size=32 converges well but takes 12 hours. The team increases batch_size to 512 to speed it up. Training now takes 2 hours but validation accuracy drops 5%. What happened, and how would you fix it?