CHECKPOINTS

This guide shows safe, minimal checkpointing patterns that work with:

  • long training jobs
  • Slurm time limits
  • multi-process / shared-GPU setups
  1. What to save in a checkpoint (minimum set)

Always save:

{
  "epoch": int,
  "step": int,
  "model": model.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict(),  # optional
}

Optional but useful:

  • RNG state
  • AMP scaler state
  1. Where to store checkpoints

✔️ Use home directory during training
✔️ Sync to shared storage at job end

Example:

CKPT_DIR = "/home/username/checkpoints/"

os.makedirs(CKPT_DIR, exist_ok=True)
  1. Epoch-based checkpointing (simple & reliable)

Save every N epochs

SAVE_EVERY_EPOCHS = 5

for epoch in range(start_epoch, num_epochs):
    train_one_epoch()
    if (epoch + 1) % SAVE_EVERY_EPOCHS == 0:
        ckpt_path = f"{CKPT_DIR}/epoch_{epoch+1}.pt"
        torch.save({
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }, ckpt_path)

✔️ Best for stable jobs
✔️ Very low overhead

  1. Time-based checkpointing (Slurm-safe)

Save every N minutes

import time

SAVE_EVERY_SEC = 15 * 60  # 15 minutes
last_save_time = time.time()

Inside training loop:

now = time.time()

if now - last_save_time >= SAVE_EVERY_SEC:
    ckpt_path = f"{CKPT_DIR}/time_{int(now)}.pt"
    torch.save({
        "epoch": epoch,
        "step": global_step,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
     }, ckpt_path)
     last_save_time = now

✔️ Essential for long jobs
✔️ Survives node preemption
✔️ Works with unknown epoch length

  1. Combined: epoch + time (recommended)

Use both — whichever triggers first.

if (
    (epoch + 1) % SAVE_EVERY_EPOCHS == 0
    or time.time() - last_save_time >= SAVE_EVERY_SEC
):
    save_checkpoint()

This is the production default.

  1. Loading the latest checkpoint
ckpt = torch.load(path, map_location="cpu")

model.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])

start_epoch = ckpt["epoch"]
global_step = ckpt.get("step", 0)

⚠️ Load before moving model to GPU.

  1. AMP (mixed precision) checkpoint support

If using AMP:

torch.save({
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "scaler": scaler.state_dict(),
}, path)

On load:

scaler.load_state_dict(ckpt["scaler"])
  1. Checkpoint naming convention (recommended)

checkpoints/
├── epoch_010.pt
├── epoch_020.pt
├── time_1709811223.pt
└── last.pt   ← symlink or copy

Always keep a last.pt pointer.

  1. Slurm-aware final checkpoint

Catch termination signals:

import signal

def save_and_exit(signum, frame):
    save_checkpoint("last.pt")
    exit(0)

signal.signal(signal.SIGTERM, save_and_exit)
signal.signal(signal.SIGINT, save_and_exit)

✔️ Works with scancel
✔️ Works with time limit

  1. Best practices (short list)

✔️ Save state_dict, not full model
✔️ Save on CPU (map_location=”cpu”)
✔️ Keep checkpoints small & frequent
✔️ Always test restore path
✔️ Log checkpoint events

  1. One-sentence rule

Checkpoint by time for safety, by epoch for structure — and always save before Slurm kills you.

PyTorch Checkpoint Rotation & Cleanup

Checkpoint rotation prevents disks from filling up by keeping only the most recent N checkpoints.

  1. Rotation policy (recommended)

Use two layers:

Type Keep
Epoch-based last K epochs
Time-based last T checkpoints
Always last.pt

Example:

  • Keep last 5 epoch checkpoints
  • Keep last 3 time checkpoints
  1. Directory layout

checkpoints/
├── epoch_010.pt
├── epoch_015.pt
├── epoch_020.pt
├── time_1709811223.pt
├── time_1709812123.pt
├── last.pt

  1. Helper functions (drop-in)

List checkpoints by prefix

import os

def list_ckpts(ckpt_dir, prefix):
    files = [
        f for f in os.listdir(ckpt_dir)
        if f.startswith(prefix) and f.endswith(".pt")
    ]
    files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x)))
    return files

Cleanup old checkpoints

def cleanup_ckpts(ckpt_dir, prefix, keep_last):
    ckpts = list_ckpts(ckpt_dir, prefix)
    for ckpt in ckpts[:-keep_last]:
        path = os.path.join(ckpt_dir, ckpt)
        try:
            os.remove(path)
        except OSError:
            pass
  1. Unified save + rotate function (recommended)
import torch
import os
import time

def save_checkpoint(
    ckpt_dir,
    name,
    state,
    epoch_keep=5,
    time_keep=3,
):
    os.makedirs(ckpt_dir, exist_ok=True)

    path = os.path.join(ckpt_dir, name)
    torch.save(state, path)

# Update "last.pt"
    torch.save(state, os.path.join(ckpt_dir, "last.pt"))

# Rotate
    cleanup_ckpts(ckpt_dir, "epoch_", epoch_keep)
    cleanup_ckpts(ckpt_dir, "time_", time_keep)
  1. Usage examples

Epoch-based

if (epoch + 1) % SAVE_EVERY_EPOCHS == 0:
    save_checkpoint(
        CKPT_DIR,
        f"epoch_{epoch+1:03d}.pt",
        state,
        epoch_keep=5,
        time_keep=3,
    )

Time-based

if time.time() - last_save_time >= SAVE_EVERY_SEC:
    save_checkpoint(
        CKPT_DIR,
        f"time_{int(time.time())}.pt",
        state,
        epoch_keep=5,
        time_keep=3,
    )
    last_save_time = time.time()
  1. Slurm-safe termination (with rotation)
import signal

def handle_exit(signum, frame):
    save_checkpoint(
        CKPT_DIR,
        "last.pt",
        state,
        epoch_keep=5,
        time_keep=3,
    )
    exit(0)

signal.signal(signal.SIGTERM, handle_exit)
signal.signal(signal.SIGINT, handle_exit)
  1. Advanced (optional): size-based cleanup

Remember, you have a home directory limit , so:

def cleanup_by_size(ckpt_dir, max_gb):
    files = sorted(
        (os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)),
        key=os.path.getmtime
    )

    def total_size():
        return sum(os.path.getsize(f) for f in files) / 1024**3

    while total_size() > max_gb and files:
        os.remove(files.pop(0))
  1. Best practices (short)

✔️ Always keep last.pt
✔️ Rotate after successful save
✔️ Never delete inside training step
✔️ Log checkpoint creation + deletion
✔️ Test resume path after cleanup

  1. One-line rule

Keep what you can resume from, delete everything else automatically.

Auto-Resume: Load the Newest Checkpoint Automatically

Auto-resume allows a training job to restart itself from the latest checkpoint without manually specifying a path.

Priority order (recommended):

  1. last.pt (if present)
  2. Newest epoch_*.pt
  3. Newest time_*.pt
  4. Start from scratch
  1. Checkpoint discovery logic

Helper: find newest checkpoint

import os

def find_latest_checkpoint(ckpt_dir):
    if not os.path.isdir(ckpt_dir):
        return None

    last_path = os.path.join(ckpt_dir, "last.pt")
    if os.path.isfile(last_path):
        return last_path

    candidates = []
    for f in os.listdir(ckpt_dir):
        if f.endswith(".pt") and (f.startswith("epoch_") or f.startswith("time_")):
            path = os.path.join(ckpt_dir, f)
            candidates.append(path)

    if not candidates:
        return None

    candidates.sort(key=os.path.getmtime)
    return candidates[-1]
  1. Resume function (model + optimizer + scheduler)
import torch

def auto_resume(
    ckpt_dir,
    model,
    optimizer=None,
    scheduler=None,
    scaler=None,
    device="cpu",
):
    ckpt_path = find_latest_checkpoint(ckpt_dir)
    if ckpt_path is None:
        print("No checkpoint found — starting from scratch")
        return 0, 0  # epoch, global_step

    print(f"Resuming from checkpoint: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location=device)

    model.load_state_dict(ckpt["model"])

    if optimizer and "optimizer" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer"])

    if scheduler and "scheduler" in ckpt:
        scheduler.load_state_dict(ckpt["scheduler"])

    if scaler and "scaler" in ckpt:
        scaler.load_state_dict(ckpt["scaler"])

    epoch = ckpt.get("epoch", 0)
    step = ckpt.get("step", 0)

    return epoch, step
  1. How to use it (minimal example)

Before training starts

start_epoch, global_step = auto_resume(
    CKPT_DIR,
    model,
    optimizer=optimizer,
    scheduler=scheduler,
    scaler=scaler,
    device="cpu",
)

⚠️ Load before moving model to GPU.

Training loop

for epoch in range(start_epoch, num_epochs):
    for batch in loader:
        loss = train_step(batch)
        global_step += 1
  1. Required checkpoint format (recap)

Your torch.save() must include:

{
  "epoch": epoch,
  "step": global_step,
  "model": model.state_dict(),
  "optimizer": optimizer.state_dict(),
  "scheduler": scheduler.state_dict(),  # optional
  "scaler": scaler.state_dict(),        # optional (AMP)
}
  1. Slurm-safe auto-resume workflow

Typical Slurm lifecycle:

Job starts
├─ checkpoint exists? → resume
├─ training runs
├─ SIGTERM (time limit)
└─ save last.pt → exit

On restart:

Job starts → finds last.pt → resumes automatically

No flags, no arguments, no manual steps.

  1. Common pitfalls (avoid these)

❌ Loading checkpoint after .cuda()
❌ Saving full model instead of state_dict
❌ Deleting last.pt during rotation
❌ Forgetting to save optimizer state
❌ Assuming epoch == step

  1. Optional: strict / non-strict resume

If model changed slightly:

model.load_state_dict(ckpt["model"], strict=False)

Useful for fine-tuning or architecture tweaks.

  1. One-line rule

If your job can be killed, it must be able to resume without asking questions.

 

Skip to content