diff --git a/main.py b/main.py index a69f9c16..15b043b7 100644 --- a/main.py +++ b/main.py @@ -413,6 +413,16 @@ def initialize_training_components(cfg: OmegaConf, metric_logger=None): def run(cfg: OmegaConf, metric_logger=None): setup_enviroment() + # Early-exit when a prior chain step already reached n_steps. + prior_state = load_training_state(cfg.trainer.checkpoint.load) + if prior_state["next_step"] >= cfg.trainer.n_steps: + logger.info( + f"Training already complete " + f"(next_step={prior_state['next_step']} >= n_steps={cfg.trainer.n_steps}). " + f"Exiting." + ) + return + if "distributed" in cfg.trainer and cfg.trainer.distributed is not None: distributed_setup() diff --git a/run_exp.py b/run_exp.py index 1fec94b3..272c1791 100755 --- a/run_exp.py +++ b/run_exp.py @@ -313,35 +313,30 @@ def get_experiment_components( return config_path, config_name -def wait_for_job_id(connection, tmux_pane, tries: int = 3): +def wait_for_job_id(connection, tmux_pane, seen=None, tries: int = 10): """ Wait for a SLURM job ID to appear in the output of a tmux pane. - Repeatedly checks the pane for a successful `sbatch` message and returns - the job ID. Raises RuntimeError if an error is found or if no job ID - appears after the given number of tries. + Returns the first ID not already in `seen`, so this works for chained + submissions where multiple `sbatch` calls share one pane. """ + seen = set(seen or []) while tries > 0: output = connection.run( f"tmux capture-pane -pt {tmux_pane}.0", hide=True ).stdout - match = re.search(r"Submitted batch job (\d+)", output) - if not match: - match_error = re.search(r"sbatch: error: (.*)\n", output) - if not match_error: - time.sleep(0.5) - tries -= 1 - if tries == 0: - raise RuntimeError("Failed to get job ID from sbatch output.") - continue - else: - err_msg = match_error.group(1) - raise RuntimeError(f"Error submitting job: {err_msg}") - else: - job_id = match.group(1) - break - return job_id + for match in re.finditer(r"Submitted batch job (\d+)", output): + if match.group(1) not in seen: + return match.group(1) + + match_error = re.search(r"sbatch: error: (.*)\n", output) + if match_error: + raise RuntimeError(f"Error submitting job: {match_error.group(1)}") + + time.sleep(0.5) + tries -= 1 + raise RuntimeError("Failed to get job ID from sbatch output.") @hydra.main(version_base=None, config_path="configs", config_name="exp") @@ -421,13 +416,34 @@ def submit_experiment( connection.run( f'tmux send -t {experiment_branch_name}.0 "cd {experiment_dir}" ENTER' ) + # EXPERIMENT_ID is forwarded so all chain steps share one + # checkpoint dir; defaults to SLURM_JOB_ID for non-chained runs. + experiment_id = experiment_branch_name + chain_n_jobs = int(cfg.get("chain", {}).get("n_jobs", 1) or 1) + job_ids = [] + for step in range(chain_n_jobs): + # afterany so chain continues even when a SLURM time-limit + # kill produces a non-zero exit (the production case). + dep = ( + f" --dependency=afterany:{job_ids[-1]}" if job_ids else "" + ) + sbatch_cmd = ( + f"sbatch --export=ALL,EXPERIMENT_ID={experiment_id}" + f"{dep} exp.job" + ) + connection.run( + f'tmux send -t {experiment_branch_name}.0 "{sbatch_cmd}" ENTER' + ) + job_id = wait_for_job_id( + connection, experiment_branch_name, seen=job_ids + ) + print( + f"Chain step {step + 1}/{chain_n_jobs}: job_id={job_id}" + f"{f' depends-on={job_ids[-1]}' if job_ids else ''}" + ) + job_ids.append(job_id) connection.run( - f'tmux send -t {experiment_branch_name}.0 "sbatch exp.job" ENTER' - ) - job_id = wait_for_job_id(connection, experiment_branch_name) - print(f"Job ID: {job_id}") - connection.run( - f'tmux send -t {experiment_branch_name}.0 "tail -f --retry slurm-{job_id}_0.out" ENTER' + f'tmux send -t {experiment_branch_name}.0 "tail -f --retry slurm-{job_ids[0]}_0.out" ENTER' ) LOGLEVEL = os.environ.get("LOGLEVEL", "WARNING").upper() if LOGLEVEL == "DEBUG": diff --git a/src/core/checkpointing.py b/src/core/checkpointing.py index 3cc44f60..3ca420d3 100644 --- a/src/core/checkpointing.py +++ b/src/core/checkpointing.py @@ -70,10 +70,12 @@ def save_training_state( def get_full_checkpoint_path(path): slurm_array_task_id = os.getenv("SLURM_ARRAY_TASK_ID") - slurm_job_id = os.getenv("SLURM_JOB_ID") + # EXPERIMENT_ID stays constant across chained sbatch submissions; falls back + # to SLURM_JOB_ID for non-chained runs. + job_id = os.getenv("EXPERIMENT_ID") or os.getenv("SLURM_JOB_ID") - if slurm_array_task_id and slurm_job_id: - return f"{path}/{slurm_job_id}/{slurm_array_task_id}" + if slurm_array_task_id and job_id: + return f"{path}/{job_id}/{slurm_array_task_id}" else: return f"{path}" @@ -98,6 +100,13 @@ def load_training_state(load_config): ) return training_start_config os.makedirs(load_path, exist_ok=True) + resolved = _resolve_load_path(load_path) + if resolved is None: + logger.info( + f"No prior step_* checkpoint under '{load_path}'. Starting training from scratch." + ) + return training_start_config + load_path = resolved training_state_path = f"{load_path}/{load_config.training_state_filename}" if os.path.isfile(training_state_path): @@ -112,17 +121,43 @@ def load_training_state(load_config): def _find_latest_checkpoint(path: str) -> str: - files = [os.path.join(path, f) for f in os.listdir(path)] - if not files: - logger.info(f"No checkpoints in '{path}'") - return - - return max(files, key=os.path.getmtime) + if not os.path.isdir(path): + return None + step_dirs = [] + for f in os.listdir(path): + if not f.startswith("step_"): + continue + try: + step_dirs.append((int(f[len("step_"):]), os.path.join(path, f))) + except ValueError: + continue + if not step_dirs: + logger.info(f"No step_* checkpoints in '{path}'") + return None + return max(step_dirs, key=lambda x: x[0])[1] + + +def _resolve_load_path(load_path: str) -> str: + # Manual mode: load_path explicitly names a step_X dir → use as-is. + # Otherwise apply the same {EXPERIMENT_ID}/{ARRAY_TASK_ID} nesting that + # save uses, then pick the highest step_*. Returns None if no checkpoint. + if load_path is None: + return None + if os.path.basename(load_path.rstrip("/")).startswith("step_"): + return load_path + nested = get_full_checkpoint_path(load_path) + latest = _find_latest_checkpoint(nested) + if latest is not None: + return latest + return _find_latest_checkpoint(load_path) def load_checkpoint_from_file(load_config, model, optimizer, scheduler): - checkpoint_path = load_config.path + checkpoint_path = _resolve_load_path(load_config.path) if checkpoint_path is None: + logger.info( + "No prior checkpoint to load — keeping freshly initialized weights." + ) return # reset_scheduler is used by run_decay.py to swap in a fresh LinearLR schedule.