Skip to content

Add chained sbatch submissions for runs that exceed cluster time limits#184

Open
j321m wants to merge 1 commit into
ctx_sclfrom
chain_jobs
Open

Add chained sbatch submissions for runs that exceed cluster time limits#184
j321m wants to merge 1 commit into
ctx_sclfrom
chain_jobs

Conversation

@j321m
Copy link
Copy Markdown
Collaborator

@j321m j321m commented May 1, 2026

Summary

Long training runs that won't fit in one SLURM time slot can now be submitted as N chained jobs that share one checkpoint dir and one wandb run. Just set chain.n_jobs: N in the experiment yaml — no code changes per experiment.

How it works

  • run_exp.py generates an EXPERIMENT_ID once per submission (the existing experiment_branch_name), exports it via sbatch --export=ALL,EXPERIMENT_ID=..., and submits N sbatch calls in the same tmux pane chained with --dependency=afterany:<prev>.
    • afterany (not aftercorr) is required so a SLURM time-limit kill (non-zero exit) doesn't break the chain.
  • checkpointing.get_full_checkpoint_path prefers EXPERIMENT_ID over SLURM_JOB_ID, so the checkpoint dir is stable across chain steps. Falls back to SLURM_JOB_ID for non-chained runs.
  • _resolve_load_path applies the same {EXPERIMENT_ID}/{ARRAY_TASK_ID} nesting that save uses, then picks the highest step_*. First chain step finds nothing → fresh start; subsequent steps find the latest checkpoint and resume.
  • main.run() exits early when next_step >= n_steps, so chain steps after training is complete are cheap no-ops.

Disk layout

{save.path}/
└── {EXPERIMENT_ID}/         ← stable across chain steps (was SLURM_JOB_ID)
    └── {ARRAY_TASK_ID}/     ← grid cell, isolates parallel sweep cells
        ├── step_20/
        ├── step_40/
        └── step_499/        ← chain step 2+ resumes from highest

Grid + chain composition

For a grid of M cells × chain length N, total = N sbatch arrays of M tasks each. Each grid cell K has its own consistent dir at {save.path}/{EXPERIMENT_ID}/K/step_* regardless of which chain step is running.

Caveats

  • afterany is whole-array level — chain step K+1 waits for ALL tasks of chain step K to finish. SLURM has no per-task afterany, so independent per-cell chains aren't possible without restructuring as M independent chains. For most uses this is fine since grid cells are similarly costly.
  • only_weights: false path is still flagged "not tested" in the existing checkpoint configs. Verified end-to-end on entropy with the smoke test.
  • Rewind cost compounds across chain steps — chain step K rewinds ~K × (steps-per-job) batches before training. With long chains this is real time. Out of scope for this PR.

Tested

End-to-end on entropy h100, n_steps=500, chain.n_jobs=2:

  • Chain step 1: trained from scratch, saved checkpoints at every step_20 through step_499. wandb run opened.
  • Chain step 2: dependency resolved after step 1 completed, found step_499 via auto-resolve, hit early-exit branch (next_step=500 >= n_steps=500), exited cleanly in 10s.
  • Both jobs COMPLETED 0:0.

Verified path resolution and EXPERIMENT_ID propagation produced the expected disk layout: chain_smoke/chain_smoke_2026-05-01_19-40-36/0/step_*.

🤖 Generated with Claude Code

A long training run that won't fit in one SLURM time slot can now be
submitted as N chained jobs that share one checkpoint dir and one wandb
run. Set `chain.n_jobs: N` in the experiment yaml.

How it works:
- run_exp.py generates an EXPERIMENT_ID once per submission, exports it via
  `sbatch --export=ALL,EXPERIMENT_ID=...`, and submits N sbatch calls in
  the same tmux pane chained with `--dependency=afterany:<prev>`. afterany
  (not aftercorr) is required so a SLURM time-limit kill (non-zero exit)
  doesn't break the chain.
- checkpointing.get_full_checkpoint_path prefers EXPERIMENT_ID over
  SLURM_JOB_ID, so the checkpoint dir is stable across chain steps.
- _resolve_load_path applies the same {EXPERIMENT_ID}/{ARRAY_TASK_ID}
  nesting that save uses, then picks the highest step_*. First chain step
  finds nothing → fresh start; subsequent steps find the latest
  checkpoint and resume.
- main.run() exits early when next_step >= n_steps, so chain steps after
  training is complete are cheap no-ops.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant