This guide shows safe, minimal checkpointing patterns that work with:
- long training jobs
- Slurm time limits
- multi-process / shared-GPU setups
- What to save in a checkpoint (minimum set)
Always save:
{
"epoch": int,
"step": int,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(), # optional
}
Optional but useful:
- RNG state
- AMP scaler state
- Where to store checkpoints
✔️ Use home directory during training
✔️ Sync to shared storage at job end
Example:
CKPT_DIR = "/home/username/checkpoints/" os.makedirs(CKPT_DIR, exist_ok=True)
- Epoch-based checkpointing (simple & reliable)
Save every N epochs
SAVE_EVERY_EPOCHS = 5
for epoch in range(start_epoch, num_epochs):
train_one_epoch()
if (epoch + 1) % SAVE_EVERY_EPOCHS == 0:
ckpt_path = f"{CKPT_DIR}/epoch_{epoch+1}.pt"
torch.save({
"epoch": epoch + 1,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}, ckpt_path)
✔️ Best for stable jobs
✔️ Very low overhead
- Time-based checkpointing (Slurm-safe)
Save every N minutes
import time SAVE_EVERY_SEC = 15 * 60 # 15 minutes last_save_time = time.time()
Inside training loop:
now = time.time()
if now - last_save_time >= SAVE_EVERY_SEC:
ckpt_path = f"{CKPT_DIR}/time_{int(now)}.pt"
torch.save({
"epoch": epoch,
"step": global_step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
}, ckpt_path)
last_save_time = now
✔️ Essential for long jobs
✔️ Survives node preemption
✔️ Works with unknown epoch length
- Combined: epoch + time (recommended)
Use both — whichever triggers first.
if (
(epoch + 1) % SAVE_EVERY_EPOCHS == 0
or time.time() - last_save_time >= SAVE_EVERY_SEC
):
save_checkpoint()
This is the production default.
- Loading the latest checkpoint
ckpt = torch.load(path, map_location="cpu")
model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"]
global_step = ckpt.get("step", 0)
⚠️ Load before moving model to GPU.
- AMP (mixed precision) checkpoint support
If using AMP:
torch.save({
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scaler": scaler.state_dict(),
}, path)
On load:
scaler.load_state_dict(ckpt["scaler"])
- Checkpoint naming convention (recommended)
checkpoints/
├── epoch_010.pt
├── epoch_020.pt
├── time_1709811223.pt
└── last.pt ← symlink or copy
Always keep a last.pt pointer.
- Slurm-aware final checkpoint
Catch termination signals:
import signal
def save_and_exit(signum, frame):
save_checkpoint("last.pt")
exit(0)
signal.signal(signal.SIGTERM, save_and_exit)
signal.signal(signal.SIGINT, save_and_exit)
✔️ Works with scancel
✔️ Works with time limit
- Best practices (short list)
✔️ Save state_dict, not full model
✔️ Save on CPU (map_location=”cpu”)
✔️ Keep checkpoints small & frequent
✔️ Always test restore path
✔️ Log checkpoint events
- One-sentence rule
Checkpoint by time for safety, by epoch for structure — and always save before Slurm kills you.
PyTorch Checkpoint Rotation & Cleanup
Checkpoint rotation prevents disks from filling up by keeping only the most recent N checkpoints.
- Rotation policy (recommended)
Use two layers:
| Type | Keep |
| Epoch-based | last K epochs |
| Time-based | last T checkpoints |
| Always | last.pt |
Example:
- Keep last 5 epoch checkpoints
- Keep last 3 time checkpoints
- Directory layout
checkpoints/
├── epoch_010.pt
├── epoch_015.pt
├── epoch_020.pt
├── time_1709811223.pt
├── time_1709812123.pt
├── last.pt
- Helper functions (drop-in)
List checkpoints by prefix
import os
def list_ckpts(ckpt_dir, prefix):
files = [
f for f in os.listdir(ckpt_dir)
if f.startswith(prefix) and f.endswith(".pt")
]
files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
return files
Cleanup old checkpoints
def cleanup_ckpts(ckpt_dir, prefix, keep_last):
ckpts = list_ckpts(ckpt_dir, prefix)
for ckpt in ckpts[:-keep_last]:
path = os.path.join(ckpt_dir, ckpt)
try:
os.remove(path)
except OSError:
pass
- Unified save + rotate function (recommended)
import torch
import os
import time
def save_checkpoint(
ckpt_dir,
name,
state,
epoch_keep=5,
time_keep=3,
):
os.makedirs(ckpt_dir, exist_ok=True)
path = os.path.join(ckpt_dir, name)
torch.save(state, path)
# Update "last.pt"
torch.save(state, os.path.join(ckpt_dir, "last.pt"))
# Rotate
cleanup_ckpts(ckpt_dir, "epoch_", epoch_keep)
cleanup_ckpts(ckpt_dir, "time_", time_keep)
- Usage examples
Epoch-based
if (epoch + 1) % SAVE_EVERY_EPOCHS == 0:
save_checkpoint(
CKPT_DIR,
f"epoch_{epoch+1:03d}.pt",
state,
epoch_keep=5,
time_keep=3,
)
Time-based
if time.time() - last_save_time >= SAVE_EVERY_SEC:
save_checkpoint(
CKPT_DIR,
f"time_{int(time.time())}.pt",
state,
epoch_keep=5,
time_keep=3,
)
last_save_time = time.time()
- Slurm-safe termination (with rotation)
import signal
def handle_exit(signum, frame):
save_checkpoint(
CKPT_DIR,
"last.pt",
state,
epoch_keep=5,
time_keep=3,
)
exit(0)
signal.signal(signal.SIGTERM, handle_exit)
signal.signal(signal.SIGINT, handle_exit)
- Advanced (optional): size-based cleanup
Remember, you have a home directory limit , so:
def cleanup_by_size(ckpt_dir, max_gb):
files = sorted(
(os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)),
key=os.path.getmtime
)
def total_size():
return sum(os.path.getsize(f) for f in files) / 1024**3
while total_size() > max_gb and files:
os.remove(files.pop(0))
- Best practices (short)
✔️ Always keep last.pt
✔️ Rotate after successful save
✔️ Never delete inside training step
✔️ Log checkpoint creation + deletion
✔️ Test resume path after cleanup
- One-line rule
Keep what you can resume from, delete everything else automatically.
Auto-Resume: Load the Newest Checkpoint Automatically
Auto-resume allows a training job to restart itself from the latest checkpoint without manually specifying a path.
Priority order (recommended):
- last.pt (if present)
- Newest epoch_*.pt
- Newest time_*.pt
- Start from scratch
- Checkpoint discovery logic
Helper: find newest checkpoint
import os
def find_latest_checkpoint(ckpt_dir):
if not os.path.isdir(ckpt_dir):
return None
last_path = os.path.join(ckpt_dir, "last.pt")
if os.path.isfile(last_path):
return last_path
candidates = []
for f in os.listdir(ckpt_dir):
if f.endswith(".pt") and (f.startswith("epoch_") or f.startswith("time_")):
path = os.path.join(ckpt_dir, f)
candidates.append(path)
if not candidates:
return None
candidates.sort(key=os.path.getmtime)
return candidates[-1]
- Resume function (model + optimizer + scheduler)
import torch
def auto_resume(
ckpt_dir,
model,
optimizer=None,
scheduler=None,
scaler=None,
device="cpu",
):
ckpt_path = find_latest_checkpoint(ckpt_dir)
if ckpt_path is None:
print("No checkpoint found — starting from scratch")
return 0, 0 # epoch, global_step
print(f"Resuming from checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model"])
if optimizer and "optimizer" in ckpt:
optimizer.load_state_dict(ckpt["optimizer"])
if scheduler and "scheduler" in ckpt:
scheduler.load_state_dict(ckpt["scheduler"])
if scaler and "scaler" in ckpt:
scaler.load_state_dict(ckpt["scaler"])
epoch = ckpt.get("epoch", 0)
step = ckpt.get("step", 0)
return epoch, step
- How to use it (minimal example)
Before training starts
start_epoch, global_step = auto_resume( CKPT_DIR, model, optimizer=optimizer, scheduler=scheduler, scaler=scaler, device="cpu", )
⚠️ Load before moving model to GPU.
Training loop
for epoch in range(start_epoch, num_epochs): for batch in loader: loss = train_step(batch) global_step += 1
- Required checkpoint format (recap)
Your torch.save() must include:
{
"epoch": epoch,
"step": global_step,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(), # optional
"scaler": scaler.state_dict(), # optional (AMP)
}
- Slurm-safe auto-resume workflow
Typical Slurm lifecycle:
Job starts
├─ checkpoint exists? → resume
├─ training runs
├─ SIGTERM (time limit)
└─ save last.pt → exit
On restart:
Job starts → finds last.pt → resumes automatically
No flags, no arguments, no manual steps.
- Common pitfalls (avoid these)
❌ Loading checkpoint after .cuda()
❌ Saving full model instead of state_dict
❌ Deleting last.pt during rotation
❌ Forgetting to save optimizer state
❌ Assuming epoch == step
- Optional: strict / non-strict resume
If model changed slightly:
model.load_state_dict(ckpt["model"], strict=False)
Useful for fine-tuning or architecture tweaks.
- One-line rule
If your job can be killed, it must be able to resume without asking questions.
