diff --git a/src/reflow/config.py b/src/reflow/config.py index ab20184..50bc774 100644 --- a/src/reflow/config.py +++ b/src/reflow/config.py @@ -55,6 +55,18 @@ # Python interpreter used for worker and dispatch jobs. # python = "/path/to/python" + # Cap on how many array tasks reflow submits at once (per wave). + # Set this to your scheduler's per-user job-submit limit + # (e.g. Slurm AssocMaxSubmitJobLimit). A large array task is then + # submitted in waves of at most this many indices; the remainder stay + # PENDING and are submitted as each wave drains, so the number of + # queued jobs never exceeds the cap. Leave a little headroom below + # the true limit for the dispatch job and any singleton tasks. + # Unset = no cap (submit the whole array at once). + # This is a reflow-internal scheduling cap, not a scheduler flag, so it + # is a plain [executor] key (not under [executor.submit_options]). + # max_submit_jobs = 900 + # Override scheduler command paths if they are not on $PATH. # Slurm: # sbatch = "/usr/bin/sbatch" @@ -313,6 +325,35 @@ def executor_python(self) -> str | None: def executor_mode(self) -> str | None: return self._get("executor", "mode", "REFLOW_MODE") + @property + def max_submit_jobs(self) -> int | None: + """Cap on array tasks submitted per wave. + + Set this to your scheduler's per-user job-submit limit (e.g. Slurm + ``AssocMaxSubmitJobLimit``). When set, a large array task is + submitted in waves of at most this many indices; the remainder stay + PENDING and are submitted by the dependency-triggered follow-up + dispatch as each wave drains, so the number of queued jobs never + exceeds the cap. ``None`` (the default) disables capping and submits + the whole array at once. Leave a little headroom below the true + limit for the follow-up dispatch job itself and any concurrent + singleton tasks. + + This is a reflow-internal scheduling cap, *not* a scheduler flag, so + it is a plain ``[executor]`` key rather than a ``submit_options`` + entry (those are rendered onto the sbatch command line). + + Config: ``[executor] max_submit_jobs``; env: ``REFLOW_MAX_SUBMIT_JOBS``. + """ + raw = self._get("executor", "max_submit_jobs", "REFLOW_MAX_SUBMIT_JOBS") + if raw is None: + return None + try: + value = int(raw) + except ValueError: + return None + return value if value > 0 else None + @property def executor_sbatch(self) -> str | None: return self._get_submit_option("sbatch") diff --git a/src/reflow/stores/__init__.py b/src/reflow/stores/__init__.py index 3e3ee80..ece818b 100644 --- a/src/reflow/stores/__init__.py +++ b/src/reflow/stores/__init__.py @@ -191,8 +191,16 @@ def update_task_submitted( run_id: str, task_name: str, job_id: str, + indices: list[int] | None = None, ) -> None: - """Mark pending/retrying instances as submitted.""" + """Mark pending/retrying instances as submitted. + + When ``indices`` is ``None`` all pending/retrying instances of the + task are marked (singleton tasks and full-array submits). When a + list of array indices is given, only those instances are marked, + so a large array can be submitted in capped waves, each wave with + its own job id. + """ @abc.abstractmethod def update_task_running(self, instance_id: int) -> None: @@ -224,6 +232,25 @@ def update_task_success( def update_task_failed(self, instance_id: int, error_text: str) -> None: """Mark one instance as failed.""" + @abc.abstractmethod + def fail_pending_tasks( + self, + run_id: str, + task_name: str, + error_text: str, + indices: list[int] | None = None, + ) -> int: + """Mark not-yet-running instances of a task as FAILED. + + Mirrors :meth:`update_task_submitted`: with ``indices=None`` every + not-yet-running instance of the task is failed; with a list of + array indices, only those are. Used when the scheduler cannot + place work (e.g. the batch system rejected the submission) so the + run finalises as FAILED instead of hanging with instances stuck in + PENDING/SUBMITTED. Only states that have not started executing are + affected. Returns the number of instances updated. + """ + @abc.abstractmethod def update_task_cancelled(self, instance_id: int) -> None: """Mark one instance as cancelled.""" diff --git a/src/reflow/stores/sqlite.py b/src/reflow/stores/sqlite.py index e4190d5..8014dc1 100644 --- a/src/reflow/stores/sqlite.py +++ b/src/reflow/stores/sqlite.py @@ -620,21 +620,76 @@ def update_task_submitted( run_id: str, task_name: str, job_id: str, + indices: list[int] | None = None, ) -> None: - self.conn.execute( + """Mark pending/retrying instances of a task as SUBMITTED. + + With ``indices=None`` every pending/retrying instance of the task + is marked (singleton tasks and full-array submits). With a list of + array indices, only those instances are marked, so a large array + can be submitted in capped waves, each wave carrying its own + ``job_id``. + """ + params: list[Any] = [ + TaskState.SUBMITTED.value, + job_id, + _utcnow(), + run_id, + task_name, + TaskState.PENDING.value, + TaskState.RETRYING.value, + ] + sql = ( "UPDATE task_instances SET state = ?, job_id = ?, updated_at = ? " - "WHERE run_id = ? AND task_name = ? AND state IN (?, ?)", - ( - TaskState.SUBMITTED.value, - job_id, - _utcnow(), - run_id, - task_name, - TaskState.PENDING.value, - TaskState.RETRYING.value, - ), + "WHERE run_id = ? AND task_name = ? AND state IN (?, ?)" + ) + if indices is not None: + idx = list(indices) + if not idx: + return + sql += " AND array_index IN ({})".format(",".join("?" * len(idx))) + params.extend(idx) + self.conn.execute(sql, params) + self.conn.commit() + + @_retry_on_locked + def fail_pending_tasks( + self, + run_id: str, + task_name: str, + error_text: str, + indices: list[int] | None = None, + ) -> int: + """Mark not-yet-running instances of *task_name* as FAILED. + + Used when the scheduler cannot place work (e.g. the batch system + rejected the submission), so the run finalises as FAILED instead + of hanging with tasks stuck in PENDING/SUBMITTED. Only states that + have not started executing are affected. Returns the row count. + """ + params: list[Any] = [ + TaskState.FAILED.value, + error_text, + _utcnow(), + run_id, + task_name, + TaskState.PENDING.value, + TaskState.RETRYING.value, + TaskState.SUBMITTED.value, + ] + sql = ( + "UPDATE task_instances SET state = ?, error_text = ?, updated_at = ? " + "WHERE run_id = ? AND task_name = ? AND state IN (?, ?, ?)" ) + if indices is not None: + idx = list(indices) + if not idx: + return 0 + sql += " AND array_index IN ({})".format(",".join("?" * len(idx))) + params.extend(idx) + cur = self.conn.execute(sql, params) self.conn.commit() + return cur.rowcount @_retry_on_locked def update_task_running(self, instance_id: int) -> None: diff --git a/src/reflow/workflow/_dispatch.py b/src/reflow/workflow/_dispatch.py index b424116..383267b 100644 --- a/src/reflow/workflow/_dispatch.py +++ b/src/reflow/workflow/_dispatch.py @@ -35,6 +35,27 @@ logger = logging.getLogger(__name__) +def _compress_indices(indices: list[int]) -> str: + """Compress a sorted index list into Slurm array syntax. + + Contiguous runs become ``lo-hi`` and isolated indices stay single, e.g. + ``[0,1,2,5,7,8] -> "0-2,5,7-8"``. This keeps the ``--array`` string short + for large contiguous waves while still expressing sparse leftovers. + """ + if not indices: + return "" + parts: list[str] = [] + start = prev = indices[0] + for value in indices[1:]: + if value == prev + 1: + prev = value + continue + parts.append(f"{start}-{prev}" if start != prev else f"{start}") + start = prev = value + parts.append(f"{start}-{prev}" if start != prev else f"{start}") + return ",".join(parts) + + class DispatchMixin: """Methods that implement the dispatch cycle. @@ -358,6 +379,38 @@ def _collect_upstream_output_hashes( # --- single / array dispatch ------------------------------------------- + def _submit_guarded( + self, + executor: Executor, + resources: Any, + command: list[str], + run_id: str, + spec: TaskSpec, + store: Store, + *, + indices: list[int] | None = None, + ) -> str | None: + """Submit work, converting a submission failure into FAILED tasks. + + If the batch system rejects the submission (e.g. the array exceeds + the cluster's job/array limit), the affected instances are marked + FAILED with the error text instead of being left stuck in PENDING, + so the run finalises as FAILED rather than hanging. Returns the job + id on success, or ``None`` on failure. + """ + try: + return executor.submit(resources, command) + except Exception as exc: # noqa: BLE001 - any rejection must not hang the run + error_text = f"Submission rejected for task {spec.name!r}: {exc}" + store.fail_pending_tasks(run_id, spec.name, error_text, indices=indices) + logger.error( + "Submission of task %s failed; marked affected instance(s) " + "FAILED so the run does not hang: %s", + spec.name, + exc, + ) + return None + def _dispatch_single( self, run_id: str, @@ -407,13 +460,71 @@ def _dispatch_single( state = TaskState(str(row["state"])) if state not in (TaskState.PENDING, TaskState.RETRYING): return None - job_id = executor.submit( + job_id = self._submit_guarded( + executor, self._single_resources(run_dir, spec), # type: ignore[attr-defined] self._worker_command(run_id, run_dir, spec.name, store), # type: ignore[attr-defined] + run_id, + spec, + store, ) + if job_id is None: + return None store.update_task_submitted(run_id, spec.name, job_id) return job_id + def _submit_wave( + self, + run_id: str, + run_dir: Path, + spec: TaskSpec, + store: Store, + executor: Executor, + candidate_indices: list[int], + ) -> str | None: + """Submit up to the configured cap of *candidate_indices* as one array. + + Honors ``Config.max_submit_jobs`` (the AssocMaxSubmitJobLimit-style + cap): only the first ``cap`` indices are submitted now; the rest stay + PENDING and are picked up by the dependency-triggered follow-up + dispatch once this wave drains, so queued jobs never exceed the cap. + Returns the job id, or ``None`` if nothing was submitted or the + submission was rejected (in which case the wave is marked FAILED by + :meth:`_submit_guarded`). + """ + if not candidate_indices: + return None + indices = sorted(candidate_indices) + cap = self.config.max_submit_jobs # type: ignore[attr-defined] + if cap is not None and len(indices) > cap: + wave = indices[:cap] + logger.info( + "Submitting %d of %d pending %s task(s) this wave (cap=%d); " + "the remainder follow as the wave drains.", + len(wave), + len(indices), + spec.name, + cap, + ) + else: + wave = indices + arr_str = _compress_indices(wave) + if spec.config.array_parallelism is not None: + arr_str = f"{arr_str}%{spec.config.array_parallelism}" + job_id = self._submit_guarded( + executor, + self._array_resources(run_dir, spec, arr_str), # type: ignore[attr-defined] + self._worker_command(run_id, run_dir, spec.name, store), # type: ignore[attr-defined] + run_id, + spec, + store, + indices=wave, + ) + if job_id is None: + return None + store.update_task_submitted(run_id, spec.name, job_id, indices=wave) + return job_id + def _dispatch_array( self, run_id: str, @@ -423,24 +534,43 @@ def _dispatch_array( executor: Executor, verify: bool = False, ) -> str | None: - """Dispatch an array task. Returns the job ID or None.""" + """Dispatch an array task. Returns the job ID or None. + + Large fan-outs are submitted in capped waves (see + :meth:`_submit_wave`): the first dispatch creates every instance and + submits the first wave; later dispatches (triggered by the follow-up + dependency) submit the next wave of still-PENDING indices until the + array is drained. + """ if not self._all_deps_satisfied(store, run_id, spec): return None existing = store.list_task_instances(run_id, task_name=spec.name) - retrying = [ - r for r in existing if TaskState(str(r["state"])) == TaskState.RETRYING + + # Resubmit any instances queued for retry first (also capped). + retry_indices = [ + int(r["array_index"]) + for r in existing + if TaskState(str(r["state"])) == TaskState.RETRYING ] - if retrying: - arr = ",".join(str(int(r["array_index"])) for r in retrying) - job_id = executor.submit( - self._array_resources(run_dir, spec, arr), # type: ignore[attr-defined] - self._worker_command(run_id, run_dir, spec.name, store), # type: ignore[attr-defined] + if retry_indices: + return self._submit_wave( + run_id, run_dir, spec, store, executor, retry_indices ) - store.update_task_submitted(run_id, spec.name, job_id) - return job_id - if store.count_task_instances(run_id, spec.name) > 0: + # Instances already exist (e.g. from a previous capped wave): submit + # the next wave of still-PENDING indices. The dependency-triggered + # follow-up dispatch keeps draining the remainder wave by wave. + if existing: + pending = [ + int(r["array_index"]) + for r in existing + if TaskState(str(r["state"])) == TaskState.PENDING + ] + if pending: + return self._submit_wave( + run_id, run_dir, spec, store, executor, pending + ) return None result_inputs = self._resolve_result_inputs(store, run_id, spec) @@ -512,18 +642,9 @@ def _dispatch_array( ) return None - if len(pending_indices) == len(fan_items): - arr_str = f"0-{len(fan_items) - 1}" - else: - arr_str = ",".join(str(i) for i in pending_indices) - if spec.config.array_parallelism is not None: - arr_str = f"{arr_str}%{spec.config.array_parallelism}" - job_id = executor.submit( - self._array_resources(run_dir, spec, arr_str), # type: ignore[attr-defined] - self._worker_command(run_id, run_dir, spec.name, store), # type: ignore[attr-defined] + return self._submit_wave( + run_id, run_dir, spec, store, executor, pending_indices ) - store.update_task_submitted(run_id, spec.name, job_id) - return job_id # --- finalisation ------------------------------------------------------ diff --git a/tests/test_cache.py b/tests/test_cache.py index 202293a..2461afd 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -4,12 +4,15 @@ from pathlib import Path -import pytest - from reflow import ( TaskState, ) -from reflow.cache import compute_input_hash, compute_identity, compute_output_hash, verify_cached_output +from reflow.cache import ( + compute_identity, + compute_input_hash, + compute_output_hash, + verify_cached_output, +) from reflow.stores.sqlite import SqliteStore # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/test_cli.py b/tests/test_cli.py index 72c5cc1..9de5c1f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -12,12 +12,11 @@ from reflow import ( Param, Result, - Run, RunDir, Workflow, ) -from reflow.stores.sqlite import SqliteStore from reflow.cli import build_parser, parse_args, run_command +from reflow.stores.sqlite import SqliteStore # ═══════════════════════════════════════════════════════════════════════════ # Coverage: _dispatch.py (dispatch loop, resolve, fan-out, finalize) @@ -48,7 +47,6 @@ def test_runs_command( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -65,7 +63,6 @@ def test_status_command( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -83,7 +80,6 @@ def test_cancel_command( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -99,7 +95,6 @@ def test_submit_command( capsys: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -129,7 +124,6 @@ def task_a( return wf def test_force_flag_parses(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -139,7 +133,6 @@ def test_force_flag_parses(self) -> None: assert args.force is True def test_force_tasks_flag_parses(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -150,7 +143,6 @@ def test_force_tasks_flag_parses(self) -> None: assert args.force_tasks == ["task_a"] def test_force_tasks_multiple(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -161,7 +153,6 @@ def test_force_tasks_multiple(self) -> None: assert args.force_tasks == ["task_a", "task_b"] def test_defaults_without_flags(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -176,7 +167,6 @@ def test_force_stored_in_parameters( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -202,7 +192,6 @@ def test_force_tasks_stored_in_parameters( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - from reflow.cli import parse_args, run_command monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() @@ -223,7 +212,6 @@ def test_force_tasks_stored_in_parameters( assert "__force__" not in params def test_force_flag_in_help(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() # parse_args with ["submit", "--help"] would SystemExit, @@ -246,8 +234,6 @@ def test_force_not_leaked_as_task_param( monkeypatch: pytest.MonkeyPatch, ) -> None: """--force and --force-tasks should not appear as task parameters.""" - from reflow.cli import parse_args, run_command - monkeypatch.setenv("REFLOW_MODE", "dry-run") wf = self._make_wf() store_path = str(tmp_path / "db.sqlite") @@ -292,12 +278,10 @@ def convert( return wf def test_parser_builds(self) -> None: - from reflow.cli import build_parser assert build_parser(self._make_wf()).prog == "cli_test" def test_submit_parses(self) -> None: - from reflow.cli import build_parser args = build_parser(self._make_wf()).parse_args( [ @@ -316,7 +300,6 @@ def test_submit_parses(self) -> None: assert args.model == "icon" def test_literal_choices_enforced(self) -> None: - from reflow.cli import build_parser with pytest.raises(SystemExit): build_parser(self._make_wf()).parse_args( @@ -334,7 +317,6 @@ def test_literal_choices_enforced(self) -> None: ) def test_hidden_commands(self) -> None: - from reflow.cli import build_parser, parse_args wf = self._make_wf() assert "dispatch" not in build_parser(wf).format_help() @@ -342,7 +324,6 @@ def test_hidden_commands(self) -> None: assert args._command == "dispatch" def test_describe_command(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args(wf, ["describe"]) @@ -364,7 +345,6 @@ def task_a( return wf def test_dag_command(self, capsys: pytest.CaptureFixture[str]) -> None: - from reflow.cli import parse_args, run_command wf = self._make_wf() args = parse_args(wf, ["dag"]) @@ -374,7 +354,6 @@ def test_dag_command(self, capsys: pytest.CaptureFixture[str]) -> None: assert "task_a" in out def test_describe_command(self, capsys: pytest.CaptureFixture[str]) -> None: - from reflow.cli import parse_args, run_command wf = self._make_wf() args = parse_args(wf, ["describe"]) @@ -385,7 +364,6 @@ def test_describe_command(self, capsys: pytest.CaptureFixture[str]) -> None: assert manifest["name"] == "cli_ext" def test_worker_parser_builds(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -404,7 +382,6 @@ def test_worker_parser_builds(self) -> None: assert args.task == "task_a" def test_worker_parser_with_index(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -424,7 +401,6 @@ def test_worker_parser_with_index(self) -> None: assert args.index == 3 def test_dispatch_parser_with_verify(self) -> None: - from reflow.cli import parse_args wf = self._make_wf() args = parse_args( @@ -447,9 +423,9 @@ def test_dispatch_parser_with_verify(self) -> None: class TestResolveRunDir: - def _make_wf(self) -> "Workflow": - from reflow import Workflow, Param - from typing import Annotated + def _make_wf(self) -> Workflow: + + from reflow import Workflow wf = Workflow("wf") @@ -460,18 +436,20 @@ def task(x: Annotated[str, Param(help="X")]) -> str: return wf def test_resolve_from_args_run_dir(self, tmp_path: Path) -> None: - from reflow.cli import _resolve_run_dir from unittest.mock import MagicMock + from reflow.cli import _resolve_run_dir + args = MagicMock() args.run_dir = str(tmp_path) result = _resolve_run_dir(args, None) assert result == tmp_path def test_resolve_from_store_params(self, tmp_path: Path) -> None: + from unittest.mock import MagicMock + from reflow.cli import _resolve_run_dir from reflow.stores.sqlite import SqliteStore - from unittest.mock import MagicMock st = SqliteStore(str(tmp_path / "db.sqlite")) st.init() @@ -485,9 +463,10 @@ def test_resolve_from_store_params(self, tmp_path: Path) -> None: st.close() def test_resolve_no_run_id_raises(self, tmp_path: Path) -> None: - from reflow.cli import _resolve_run_dir from unittest.mock import MagicMock + from reflow.cli import _resolve_run_dir + args = MagicMock() args.run_dir = None args.run_id = None @@ -495,9 +474,10 @@ def test_resolve_no_run_id_raises(self, tmp_path: Path) -> None: _resolve_run_dir(args, None) def test_resolve_missing_run_dir_in_store_raises(self, tmp_path: Path) -> None: + from unittest.mock import MagicMock + from reflow.cli import _resolve_run_dir from reflow.stores.sqlite import SqliteStore - from unittest.mock import MagicMock st = SqliteStore(str(tmp_path / "db.sqlite")) st.init() @@ -512,9 +492,9 @@ def test_resolve_missing_run_dir_in_store_raises(self, tmp_path: Path) -> None: class TestCLIRunCommands: - def _wf(self) -> "Workflow": - from reflow import Workflow, Param - from typing import Annotated + def _wf(self) -> Workflow: + + from reflow import Workflow wf = Workflow("wf") @@ -525,10 +505,9 @@ def task(x: Annotated[str, Param(help="X")]) -> str: return wf def test_retry_command(self, tmp_path: Path) -> None: - from unittest.mock import patch - from reflow.cli import run_command, parse_args - from reflow.stores.sqlite import SqliteStore + from reflow.cli import parse_args, run_command from reflow.executors.util import CommandResult + from reflow.stores.sqlite import SqliteStore wf = self._wf() store_path = str(tmp_path / "db.sqlite") @@ -546,7 +525,6 @@ def test_retry_command(self, tmp_path: Path) -> None: assert rc == 0 def test_dispatch_command(self, tmp_path: Path) -> None: - from reflow.cli import run_command, parse_args from reflow.stores.sqlite import SqliteStore wf = self._wf() @@ -565,7 +543,6 @@ def test_dispatch_command(self, tmp_path: Path) -> None: assert rc == 0 def test_worker_command(self, tmp_path: Path) -> None: - from reflow.cli import run_command, parse_args from reflow.stores.sqlite import SqliteStore wf = self._wf() @@ -587,7 +564,6 @@ def test_worker_command(self, tmp_path: Path) -> None: def test_runs_command_empty( self, tmp_path: Path, capsys: pytest.CaptureFixture[str] ) -> None: - from reflow.cli import run_command, parse_args wf = self._wf() store_path = str(tmp_path / "db.sqlite") @@ -601,10 +577,8 @@ def test_runs_command_empty( class TestCLITaskLocalParams: def test_task_local_params_stored(self, tmp_path: Path) -> None: """Task-local params (--task.param) are stored under __task_params__.""" - from reflow import Workflow, Param, Config - from reflow.cli import parse_args, run_command + from reflow import Config, Workflow from reflow.stores.sqlite import SqliteStore - from typing import Annotated # dry-run so submit_run does not call sbatch wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @@ -634,9 +608,8 @@ def convert( class TestCLIRichImportFallback: def test_arg_formatter_fallback(self) -> None: """When rich_argparse is unavailable the plain HelpFormatter is used.""" - import sys import importlib - from unittest.mock import patch + import sys with patch.dict(sys.modules, {"rich_argparse": None}): import reflow.cli as cli_mod @@ -649,10 +622,9 @@ class TestCLIStatusTaskFilter: def test_status_with_task_filter_shows_instances( self, tmp_path: Path, capsys: pytest.CaptureFixture[str] ) -> None: - from reflow import Workflow, Param, Config - from reflow.cli import parse_args, run_command + + from reflow import Workflow from reflow.stores.sqlite import SqliteStore - from typing import Annotated wf = Workflow("wf") @@ -681,10 +653,9 @@ def test_cancel_multi_run_eof_aborts( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] ) -> None: - from reflow import Workflow, Param - from reflow.cli import parse_args, run_command + + from reflow import Workflow from reflow.stores.sqlite import SqliteStore - from typing import Annotated wf = Workflow("wf") @@ -731,9 +702,8 @@ def test_main_no_args_prints_message( def test_parse_args_defaults_to_sys_argv( self, monkeypatch: pytest.MonkeyPatch ) -> None: - from reflow import Workflow, Param - from reflow.cli import parse_args - from typing import Annotated + + from reflow import Workflow wf = Workflow("wf") @wf.job() @@ -749,8 +719,8 @@ class TestCLIMethod: def test_workflow_cli_method( self, tmp_path: Path, capsys: pytest.CaptureFixture[str] ) -> None: - from reflow import Workflow, Param - from typing import Annotated + + from reflow import Workflow wf = Workflow("wf") @wf.job() diff --git a/tests/test_config.py b/tests/test_config.py index 0b3c759..dd852f8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,19 +3,16 @@ from __future__ import annotations from pathlib import Path -from unittest.mock import patch import pytest from reflow import ( Config, - Run, RunDir, Workflow, ensure_config_exists, ) -from reflow.config import write_example_config, config_path -from reflow.config import load_config +from reflow.config import config_path, load_config, write_example_config # ═══════════════════════════════════════════════════════════════════════════ # _types.py @@ -24,13 +21,11 @@ class TestConfig: def test_load_missing(self, tmp_path: Path) -> None: - from reflow.config import load_config cfg = load_config(tmp_path / "nonexistent.toml") assert cfg.executor_partition is None def test_env_fallback(self, monkeypatch: pytest.MonkeyPatch) -> None: - from reflow.config import load_config monkeypatch.setenv("REFLOW_PARTITION", "gpu") cfg = load_config(Path("/nonexistent")) @@ -208,7 +203,9 @@ def test_tomli_fallback_on_import_error( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """When tomllib is unavailable and tomli import fails, load returns {}.""" - import sys, importlib + import importlib + import sys + import reflow.config as cfg_mod cfg_file = tmp_path / "config.toml" @@ -287,7 +284,8 @@ def test_tomli_fallback_both_unavailable( self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: """When tomllib unavailable and tomli raises ImportError, returns {}.""" - import sys, importlib + import importlib + import sys cfg_file = tmp_path / "config.toml" cfg_file.write_text("[executor]\nmode = \"dry-run\"\n") diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index fea5220..3a3945d 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -13,6 +13,7 @@ from reflow import Config, Param, Result, RunDir, Workflow from reflow._types import TaskState +from reflow.executors.local import LocalExecutor from reflow.stores.sqlite import SqliteStore @@ -22,6 +23,39 @@ def _store(tmp_path: Path) -> SqliteStore: return st +class _RejectingExecutor(LocalExecutor): + """Executor whose submit always fails, mimicking sbatch hitting a QOS cap.""" + + def submit(self, resources, command): # type: ignore[override] + raise RuntimeError( + "AssocMaxSubmitJobLimit: Batch job submission failed: " + "Job violates accounting/QOS policy (job submit limit)" + ) + + +def _states(st: SqliteStore, run_id: str, task: str) -> dict: + out = {} + for r in st.list_task_instances(run_id, task_name=task): + idx = r["array_index"] + out[int(idx) if idx is not None else None] = r["state"] + return out + + +def _array_wf() -> Workflow: + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job() + def source() -> list[str]: + return ["a", "b", "c"] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + return wf + + # ═══════════════════════════════════════════════════════════════════════════ # Basic dispatch: singleton and array tasks # ═══════════════════════════════════════════════════════════════════════════ @@ -123,9 +157,7 @@ def downstream(x: Annotated[str, Result(step="upstream")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "upstream", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "upstream", None, TaskState.SUCCESS, {}) st.update_task_success(iid, "ok") assert wf._all_deps_satisfied(st, "r1", wf.tasks["downstream"]) @@ -212,9 +244,7 @@ def downstream(x: Annotated[str, Result(step="upstream")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "upstream", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "upstream", None, TaskState.SUCCESS, {}) st.update_task_success(iid, "ok") hashes = wf._collect_upstream_output_hashes( st, "r1", wf.tasks["downstream"] @@ -242,9 +272,7 @@ def sink(x: Annotated[str, Result(step="source")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) st.update_task_success(iid, "value") resolved = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) assert resolved["x"] == "value" @@ -263,9 +291,7 @@ def sink(item: Annotated[str, Result(step="source")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) st.update_task_success(iid, ["a", "b", "c"]) resolved = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) assert resolved["item"] == ["a", "b", "c"] @@ -289,9 +315,7 @@ def sink(items: Annotated[list[str], Result(step="source")]) -> str: with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) for i, val in enumerate(["A", "B"]): - iid = st.insert_task_instance( - "r1", "source", i, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", i, TaskState.SUCCESS, {}) st.update_task_success(iid, val) resolved = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) assert sorted(resolved["items"]) == ["A", "B"] @@ -390,9 +414,7 @@ def process(item: Annotated[str, Result(step="prep")]) -> str: instances = st.list_task_instances(run_id2) assert len(instances) >= 1 - def test_array_dispatch_no_result_inputs_returns_none( - self, tmp_path: Path - ) -> None: + def test_array_dispatch_no_result_inputs_returns_none(self, tmp_path: Path) -> None: """Array task with no upstream results should not be dispatched.""" wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @@ -486,9 +508,7 @@ def step() -> str: with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "step", None, TaskState.CANCELLED, {} - ) + iid = st.insert_task_instance("r1", "step", None, TaskState.CANCELLED, {}) wf._maybe_finalise_run("r1", st) # Cancelled is not "all ok" and not FAILED, status stays unchanged status = st.get_run("r1")["status"] @@ -521,9 +541,7 @@ def stage2(item: Annotated[str, Result(step="stage1")]) -> str: with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) for i, val in enumerate(["A", "B"]): - iid = st.insert_task_instance( - "r1", "stage1", i, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "stage1", i, TaskState.SUCCESS, {}) st.update_task_success(iid, val) resolved = wf._resolve_result_inputs(st, "r1", wf.tasks["stage2"]) # CHAIN: list of upstream array outputs @@ -590,6 +608,7 @@ def test_ingest_results_logs_when_n_gt_zero( ) -> None: """dispatch() logs an INFO message when worker results are ingested.""" import logging + from reflow.results import write_result wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @@ -599,9 +618,7 @@ def step() -> str: return "ok" with _store(tmp_path) as st: - run_id = wf.submit_run( - run_dir=tmp_path / "r", store=st, parameters={} - ) + run_id = wf.submit_run(run_dir=tmp_path / "r", store=st, parameters={}) row = st.get_task_instance(run_id, "step", None) iid = int(row["id"]) write_result( @@ -640,14 +657,13 @@ def downstream(x: Annotated[str, Result(step="upstream")]) -> str: with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) from reflow.executors.local import LocalExecutor + jid = wf._dispatch_single( "r1", tmp_path, wf.tasks["downstream"], st, LocalExecutor() ) assert jid is None - def test_dispatch_single_cache_hit_returns_none( - self, tmp_path: Path - ) -> None: + def test_dispatch_single_cache_hit_returns_none(self, tmp_path: Path) -> None: """Cache hit skips submission and returns None.""" wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) call_count = {"n": 0} @@ -670,9 +686,7 @@ def step(x: Annotated[str, Param(help="X")]) -> str: class TestDispatchArrayEdgeCases: - def test_array_parallelism_appended_to_arr_str( - self, tmp_path: Path - ) -> None: + def test_array_parallelism_appended_to_arr_str(self, tmp_path: Path) -> None: """array_parallelism config appends %N to the array string.""" wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @@ -746,9 +760,8 @@ def process( ) st.update_task_success(iid_src, ["a", "b"]) from reflow.executors.local import LocalExecutor - wf._dispatch_array( - "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() - ) + + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, LocalExecutor()) instances = st.list_task_instances("r1", task_name="process") assert len(instances) == 2 # Each element should have broadcast cfg @@ -759,6 +772,7 @@ def process( class TestDispatchTryCache: def test_try_cache_verify_stale_reruns(self, tmp_path: Path) -> None: from typing import Any + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) call_count = {"n": 0} verify_count = {"n": 0} @@ -797,9 +811,7 @@ def step(x: Annotated[str, Param(help="x")]) -> str: row = st.get_task_instance(run2.run_id, "step", None) assert row["state"] == "SUCCESS" - def test_try_cache_output_hash_computed_when_missing( - self, tmp_path: Path - ) -> None: + def test_try_cache_output_hash_computed_when_missing(self, tmp_path: Path) -> None: wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @wf.job(cache=True) @@ -832,9 +844,7 @@ def downstream(x: Annotated[str, Result(step="upstream")]) -> str: wf.validate() with _store(tmp_path) as st: - run_id = wf.submit_run( - run_dir=tmp_path / "r", store=st, parameters={} - ) + run_id = wf.submit_run(run_dir=tmp_path / "r", store=st, parameters={}) row = st.get_task_instance(run_id, "upstream", None) st.update_task_success(int(row["id"]), "value") wf.dispatch(run_id, st, tmp_path / "r") @@ -843,9 +853,7 @@ def downstream(x: Annotated[str, Result(step="upstream")]) -> str: class TestDispatchArrayFanItems: - def test_array_element_value_list_match_per_element( - self, tmp_path: Path - ) -> None: + def test_array_element_value_list_match_per_element(self, tmp_path: Path) -> None: wf = Workflow("wf") @wf.job() @@ -866,15 +874,12 @@ def process( wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid_l = st.insert_task_instance( - "r1", "labels", None, TaskState.SUCCESS, {} - ) + iid_l = st.insert_task_instance("r1", "labels", None, TaskState.SUCCESS, {}) st.update_task_success(iid_l, ["x", "y"]) - iid_i = st.insert_task_instance( - "r1", "items", None, TaskState.SUCCESS, {} - ) + iid_i = st.insert_task_instance("r1", "items", None, TaskState.SUCCESS, {}) st.update_task_success(iid_i, ["a", "b"]) from reflow.executors.local import LocalExecutor + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, LocalExecutor()) instances = st.list_task_instances("r1", task_name="process") assert len(instances) == 2 @@ -920,9 +925,7 @@ def process(item: Annotated[str, Result(step="source")]) -> str: return item with _store(tmp_path) as st: - run_id = wf.submit_run( - run_dir=tmp_path / "r", store=st, parameters={} - ) + run_id = wf.submit_run(run_dir=tmp_path / "r", store=st, parameters={}) row = st.get_task_instance(run_id, "source", None) st.update_task_success(int(row["id"]), ["a", "b", "c"]) wf.dispatch(run_id, st, tmp_path / "r") @@ -959,11 +962,12 @@ def process( wf.validate() import copy + spec = copy.copy(wf.tasks["process"]) from reflow.params import Result as R + spec.result_deps = { - k: R(step=v.steps[0], broadcast=True) - for k, v in spec.result_deps.items() + k: R(step=v.steps[0], broadcast=True) for k, v in spec.result_deps.items() } result = wf._find_fan_out_param(spec) assert result is None @@ -987,12 +991,12 @@ def step(x: Annotated[str, Param(help="x")]) -> str: assert row["state"] == "SUCCESS" assert row["output"] == "HI" - def test_dispatch_single_cache_miss_creates_pending( - self, tmp_path: Path - ) -> None: + def test_dispatch_single_cache_miss_creates_pending(self, tmp_path: Path) -> None: """When no cache hit, _dispatch_single creates a PENDING instance.""" from unittest.mock import patch + from reflow.executors.util import CommandResult + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @wf.job() @@ -1012,6 +1016,7 @@ def step(x: Annotated[str, Param(help="x")]) -> str: def test_dispatch_single_row_none_raises(self, tmp_path: Path) -> None: """If get_task_instance returns None after insert, RuntimeError is raised.""" from unittest.mock import patch + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @wf.job() @@ -1019,22 +1024,23 @@ def step() -> str: return "ok" with _store(tmp_path) as st: - run_id = wf.submit_run( - run_dir=tmp_path / "r", store=st, parameters={} - ) + run_id = wf.submit_run(run_dir=tmp_path / "r", store=st, parameters={}) with patch.object(st, "get_task_instance", return_value=None): with pytest.raises(RuntimeError, match="Could not create instance"): from reflow.executors.local import LocalExecutor + wf._dispatch_single( - run_id, tmp_path, wf.tasks["step"], st, + run_id, + tmp_path, + wf.tasks["step"], + st, LocalExecutor(), ) - def test_array_dispatch_no_fan_param_returns_none( - self, tmp_path: Path - ) -> None: + def test_array_dispatch_no_fan_param_returns_none(self, tmp_path: Path) -> None: """_dispatch_array returns None when _find_fan_out_param is None.""" from unittest.mock import patch + wf = Workflow("wf") @wf.job() @@ -1048,20 +1054,17 @@ def process(item: Annotated[str, Result(step="source")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) st.update_task_success(iid, ["a"]) from reflow.executors.local import LocalExecutor + with patch.object(wf, "_find_fan_out_param", return_value=None): result = wf._dispatch_array( "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() ) assert result is None - def test_array_dispatch_empty_fan_items_returns_none( - self, tmp_path: Path - ) -> None: + def test_array_dispatch_empty_fan_items_returns_none(self, tmp_path: Path) -> None: """_dispatch_array returns None when fan_items is empty list.""" wf = Workflow("wf") @@ -1076,19 +1079,16 @@ def process(item: Annotated[str, Result(step="source")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) st.update_task_success(iid, []) from reflow.executors.local import LocalExecutor + result = wf._dispatch_array( "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() ) assert result is None - def test_array_dispatch_fan_items_not_list_wraps( - self, tmp_path: Path - ) -> None: + def test_array_dispatch_fan_items_not_list_wraps(self, tmp_path: Path) -> None: """A scalar fan_items value is wrapped to a list before dispatch.""" wf = Workflow("wf") @@ -1103,14 +1103,11 @@ def process(item: Annotated[str, Result(step="source")]) -> str: wf.validate() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) st.update_task_success(iid, "single_scalar") from reflow.executors.local import LocalExecutor - wf._dispatch_array( - "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() - ) + + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, LocalExecutor()) instances = st.list_task_instances("r1", task_name="process") assert len(instances) == 1 @@ -1126,11 +1123,10 @@ def step() -> str: run = wf.run_local(run_dir=tmp_path / "r", store=st) assert st.get_run(run.run_id)["status"] == "SUCCESS" - def test_resolve_result_inputs_hints_exception( - self, tmp_path: Path - ) -> None: + def test_resolve_result_inputs_hints_exception(self, tmp_path: Path) -> None: """_resolve_result_inputs returns {} when get_type_hints raises.""" from unittest.mock import patch + wf = Workflow("wf") @wf.job() @@ -1154,403 +1150,106 @@ def sink(x: Annotated[str, Result(step="source")]) -> str: assert isinstance(result, dict) -# ═══════════════════════════════════════════════════════════════════════════ -# Targeted coverage for remaining _dispatch.py branches -# ═══════════════════════════════════════════════════════════════════════════ - - -class TestDispatchBranchCoverage: - def test_resolve_result_inputs_hints_exception(self, tmp_path: Path) -> None: - """_resolve_result_inputs: get_type_hints raising → hints = {} (145-146).""" - from unittest.mock import patch - wf = Workflow("wf") - - @wf.job() - def source() -> str: - return "v" - - @wf.job() - def sink(x: Annotated[str, Result(step="source")]) -> str: - return x +class TestArraySubmitFailure: + def _seed_source_done(self, st: SqliteStore) -> None: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) + st.update_task_success(iid, ["a", "b", "c"]) - wf.validate() + def test_rejection_marks_created_instances_failed(self, tmp_path: Path) -> None: + wf = _array_wf() with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} + self._seed_source_done(st) + jid = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, _RejectingExecutor() ) - st.update_task_success(iid, "v") - with patch( - "reflow.workflow._dispatch.get_type_hints", - side_effect=Exception("bad"), - ): - result = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) - # With hints unavailable the dep can't be wired; result is a dict - assert isinstance(result, dict) - - def test_find_fan_out_param_hints_exception_returns_none( - self, tmp_path: Path - ) -> None: - """_find_fan_out_param: get_type_hints raising → None (203-204).""" - from unittest.mock import patch - wf = Workflow("wf") - - @wf.job() - def source() -> list[str]: - return ["a"] - - @wf.array_job() - def process(item: Annotated[str, Result(step="source")]) -> str: - return item - - wf.validate() - with patch( - "reflow.workflow._dispatch.get_type_hints", - side_effect=Exception("bad"), - ): - assert wf._find_fan_out_param(wf.tasks["process"]) is None - - def test_find_fan_out_param_broadcast_skipped(self, tmp_path: Path) -> None: - """Broadcast dep is skipped in the first loop (211).""" - wf = Workflow("wf") - - @wf.job() - def cfg() -> list[str]: - return ["c"] - - @wf.job() - def source() -> list[str]: - return ["a", "b"] - - @wf.array_job() - def process( - item: Annotated[str, Result(step="source")], - extra: Annotated[list[str], Result(step="cfg", broadcast=True)], - ) -> str: - return item - - wf.validate() - # The fan-out param is "item"; "extra" (broadcast) is skipped. - assert wf._find_fan_out_param(wf.tasks["process"]) == "item" - - def test_find_fan_out_param_unknown_upstream_skipped( - self, tmp_path: Path - ) -> None: - """A dep whose upstream has no return type is skipped (218, 227-228).""" - from unittest.mock import patch - wf = Workflow("wf") - - @wf.job() - def source(): # no return annotation - return ["a", "b"] - - @wf.array_job() - def process(item: Annotated[str, Result(step="source")]) -> str: - return item - - wf.validate() - # upstream return_type is empty → first loop continues; falls back to - # the first non-broadcast dep, which is still "item". - assert wf._find_fan_out_param(wf.tasks["process"]) == "item" - - def test_find_fan_out_param_infer_type_error_skipped( - self, tmp_path: Path - ) -> None: - """infer_wire_mode raising TypeError skips the dep, fallback returns it (234).""" - from unittest.mock import patch - wf = Workflow("wf") - - @wf.job() - def source() -> list[str]: - return ["a", "b"] - - @wf.array_job() - def process(item: Annotated[str, Result(step="source")]) -> str: - return item - - wf.validate() - with patch( - "reflow.workflow._dispatch.infer_wire_mode", - side_effect=TypeError("bad"), - ): - # First loop continues past the TypeError; fallback returns "item". - assert wf._find_fan_out_param(wf.tasks["process"]) == "item" - - def test_try_cache_verify_stale(self, tmp_path: Path) -> None: - """_try_cache via dispatch: verify + stale callable → False (311-317).""" - from typing import Any - from reflow.executors.local import LocalExecutor - wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) - - def stale(output: Any) -> bool: - return False - - @wf.job(cache=True, verify=stale) - def step(x: Annotated[str, Param(help="x")]) -> str: - return x.upper() + assert jid is None + states = _states(st, "r1", "process") + assert states # instances were created + assert all(s == "FAILED" for s in states.values()) - wf.validate() + def test_rejection_records_error_text(self, tmp_path: Path) -> None: + wf = _array_wf() with _store(tmp_path) as st: - # Seed a cached record by completing the task once via dispatch. - st.insert_run("r1", "wf", "u", {"x": "v"}) - ex = LocalExecutor() - wf._dispatch_single("r1", tmp_path, wf.tasks["step"], st, ex) - row = st.get_task_instance("r1", "step", None) - st.update_task_success(int(row["id"]), "V", output_hash="abc") - - # New run: cache exists but verify=stale → _try_cache returns False. - st.insert_run("r2", "wf", "u", {"x": "v"}) - hit = wf._try_cache( - st, "r2", wf.tasks["step"], {}, None, [], verify=True + self._seed_source_done(st) + wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, _RejectingExecutor() ) - assert hit is False - - def test_try_cache_hit_recomputes_output_hash(self, tmp_path: Path) -> None: - """_try_cache via dispatch: missing output_hash recomputed (324-326).""" - from reflow.executors.local import LocalExecutor - wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) - - @wf.job(cache=True) - def step(x: Annotated[str, Param(help="x")]) -> str: - return x.upper() + rows = st.list_task_instances("r1", task_name="process") + assert any("submit limit" in (r.get("error_text") or "") for r in rows) - wf.validate() + def test_rejection_finalises_run_failed(self, tmp_path: Path) -> None: + wf = _array_wf() with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {"x": "v"}) - ex = LocalExecutor() - wf._dispatch_single("r1", tmp_path, wf.tasks["step"], st, ex) - row = st.get_task_instance("r1", "step", None) - iid = int(row["id"]) - st.update_task_success(iid, "V") - # Clear the output_hash so _try_cache must recompute it. - st.conn.execute( - "UPDATE task_instances SET output_hash = '' WHERE id = ?", (iid,) - ) - st.conn.commit() - - # _dispatch_single resolves payload via _resolve_result_inputs, which - # is {} for a Param-only task — so the cache identity is built from - # an empty payload. Match that here to hit the cached record. - st.insert_run("r2", "wf", "u", {"x": "v"}) - hit = wf._try_cache( - st, "r2", wf.tasks["step"], {}, None, [] + self._seed_source_done(st) + wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, _RejectingExecutor() ) - assert hit is True - assert st.get_task_instance("r2", "step", None)["state"] == "SUCCESS" - - def test_try_cache_no_cache_config_returns_false(self, tmp_path: Path) -> None: - """_try_cache returns False immediately when caching is off (295).""" - wf = Workflow("wf") - - @wf.job() # cache defaults to False - def step(x: Annotated[str, Param(help="x")]) -> str: - return x + wf._maybe_finalise_run("r1", st) + assert st.get_run("r1")["status"] == "FAILED" - wf.validate() + def test_success_path_returns_job_id(self, tmp_path: Path) -> None: + wf = _array_wf() with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {"x": "v"}) - hit = wf._try_cache( - st, "r1", wf.tasks["step"], {"x": "v"}, None, [] + self._seed_source_done(st) + jid = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() ) - assert hit is False - - def test_try_cache_miss_returns_false(self, tmp_path: Path) -> None: - """_try_cache returns False when nothing is cached (307-308).""" - wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) - - @wf.job(cache=True) - def step(x: Annotated[str, Param(help="x")]) -> str: - return x + assert jid is not None + # nothing left FAILED on the happy path + assert "FAILED" not in set(_states(st, "r1", "process").values()) - wf.validate() + def test_retrying_array_rejection_marks_failed(self, tmp_path: Path) -> None: + wf = _array_wf() with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {"x": "v"}) - hit = wf._try_cache( - st, "r1", wf.tasks["step"], {"x": "v"}, None, [] + self._seed_source_done(st) + # one already-done index and one queued for retry + iid0 = st.insert_task_instance( + "r1", "process", 0, TaskState.SUCCESS, {"item": "a"} ) - assert hit is False - - def test_dispatch_single_creates_pending(self, tmp_path: Path) -> None: - """_dispatch_single inserts a PENDING instance then submits (387+).""" - wf = Workflow("wf") - - @wf.job() - def step(x: Annotated[str, Param(help="x")]) -> str: - return x - - wf.validate() - with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {"x": "v"}) - from reflow.executors.local import LocalExecutor - wf._dispatch_single( - "r1", tmp_path, wf.tasks["step"], st, LocalExecutor() + st.update_task_success(iid0, "a") + st.insert_task_instance( + "r1", "process", 1, TaskState.RETRYING, {"item": "b"} ) - inst = st.get_task_instance("r1", "step", None) - assert inst is not None - - def test_dispatch_array_no_result_inputs(self, tmp_path: Path) -> None: - """_dispatch_array returns None when result_inputs empty (444).""" - wf = Workflow("wf") - - @wf.job() - def source() -> list[str]: - return ["a"] - - @wf.array_job() - def process(item: Annotated[str, Result(step="source")]) -> str: - return item - - wf.validate() - with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {}) - # source not done → no result inputs - from reflow.executors.local import LocalExecutor - result = wf._dispatch_array( - "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + jid = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, _RejectingExecutor() ) - assert result is None - - def test_dispatch_array_empty_fan_items(self, tmp_path: Path) -> None: - """_dispatch_array returns None when fan list empty (454).""" - wf = Workflow("wf") - - @wf.job() - def source() -> list[str]: - return [] - - @wf.array_job() - def process(item: Annotated[str, Result(step="source")]) -> str: - return item + assert jid is None + states = _states(st, "r1", "process") + assert states[1] == "FAILED" # the retrying one failed to submit + assert states[0] == "SUCCESS" # the done one is untouched - wf.validate() - with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid, []) - from reflow.executors.local import LocalExecutor - result = wf._dispatch_array( - "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() - ) - assert result is None - def test_dispatch_array_all_cached_returns_none(self, tmp_path: Path) -> None: - """All elements cached → 'all resolved from cache' log + None (508).""" - from reflow.executors.local import LocalExecutor +class TestSingletonSubmitFailure: + def _singleton_wf(self) -> Workflow: wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) @wf.job() - def source() -> list[str]: - return ["a", "b"] - - @wf.array_job(cache=True) - def process(item: Annotated[str, Result(step="source")]) -> str: - return item.upper() + def step(x: Annotated[str, Param(help="X")]) -> str: + return x wf.validate() - with _store(tmp_path) as st: - # Run 1: dispatch the array and complete both elements. - st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid, ["a", "b"]) - ex = LocalExecutor() - wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) - for inst in st.list_task_instances("r1", task_name="process"): - st.update_task_success(int(inst["id"]), inst["input"]["item"].upper()) - - # Run 2: both elements should resolve from cache → returns None. - st.insert_run("r2", "wf", "u", {}) - iid2 = st.insert_task_instance( - "r2", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid2, ["a", "b"]) - result = wf._dispatch_array( - "r2", tmp_path, wf.tasks["process"], st, ex - ) - assert result is None - - def test_dispatch_array_partial_cached(self, tmp_path: Path) -> None: - """Some cached, some pending → comma-joined arr_str (513).""" - from reflow.executors.local import LocalExecutor - wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) - - @wf.job() - def source() -> list[str]: - return ["a", "b", "c"] + return wf - @wf.array_job(cache=True) - def process(item: Annotated[str, Result(step="source")]) -> str: - return item.upper() - - wf.validate() + def test_singleton_rejection_marks_failed(self, tmp_path: Path) -> None: + wf = self._singleton_wf() with _store(tmp_path) as st: - # Run 1: complete only the first element so its identity is cached. st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid, ["a", "b", "c"]) - ex = LocalExecutor() - wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) - insts = sorted( - st.list_task_instances("r1", task_name="process"), - key=lambda i: i["array_index"], - ) - # Cache just element 0 ("a"). - st.update_task_success(int(insts[0]["id"]), "A") - - # Run 2: element 0 hits cache, elements 1,2 pending → partial arr_str. - st.insert_run("r2", "wf", "u", {}) - iid2 = st.insert_task_instance( - "r2", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid2, ["a", "b", "c"]) - job_id = wf._dispatch_array( - "r2", tmp_path, wf.tasks["process"], st, ex + st.insert_task_instance("r1", "step", None, TaskState.PENDING, {"x": "v"}) + jid = wf._dispatch_single( + "r1", tmp_path, wf.tasks["step"], st, _RejectingExecutor() ) - assert job_id is not None - # All three instances exist (one cached SUCCESS, two newly created). - assert len(st.list_task_instances("r2", task_name="process")) == 3 - - def test_dispatch_array_parallelism_suffix(self, tmp_path: Path) -> None: - """array_parallelism appends %N to arr_str (518).""" - wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) - - @wf.job() - def source() -> list[str]: - return ["a", "b", "c"] - - @wf.array_job(array_parallelism=2) - def process(item: Annotated[str, Result(step="source")]) -> str: - return item + assert jid is None + row = st.get_task_instance("r1", "step", None) + assert row is not None and row["state"] == "FAILED" + def test_singleton_rejection_finalises_run_failed(self, tmp_path: Path) -> None: + wf = self._singleton_wf() with _store(tmp_path) as st: st.insert_run("r1", "wf", "u", {}) - iid = st.insert_task_instance( - "r1", "source", None, TaskState.SUCCESS, {} - ) - st.update_task_success(iid, ["a", "b", "c"]) - from reflow.executors.local import LocalExecutor - job_id = wf._dispatch_array( - "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + st.insert_task_instance("r1", "step", None, TaskState.PENDING, {"x": "v"}) + wf._dispatch_single( + "r1", tmp_path, wf.tasks["step"], st, _RejectingExecutor() ) - assert job_id is not None - instances = st.list_task_instances("r1", task_name="process") - assert len(instances) == 3 - - def test_maybe_finalise_returns_on_empty_summary(self, tmp_path: Path) -> None: - """_maybe_finalise_run returns early on empty summary (533 region).""" - wf = Workflow("wf") - - @wf.job() - def step() -> str: - return "ok" - - with _store(tmp_path) as st: - st.insert_run("r1", "wf", "u", {}) - # No task instances → empty summary → early return wf._maybe_finalise_run("r1", st) - assert st.get_run("r1")["status"] in ("RUNNING", "PENDING", "SUBMITTED") + assert st.get_run("r1")["status"] == "FAILED" diff --git a/tests/test_dispatch_waves.py b/tests/test_dispatch_waves.py new file mode 100644 index 0000000..1192efa --- /dev/null +++ b/tests/test_dispatch_waves.py @@ -0,0 +1,153 @@ +"""test_dispatch_waves.py — capped-wave array submission. + +With ``Config.max_submit_jobs`` set, a large array task is submitted in +waves of at most ``cap`` indices; the remainder stay PENDING and are +submitted by later dispatches (which the real scheduler triggers via the +follow-up dependency). These tests drive ``_dispatch_array`` directly and +simulate each wave completing, asserting the ``--array`` string and the +PENDING/SUBMITTED bookkeeping. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Annotated + +from reflow import Config, Result, Workflow +from reflow._types import TaskState +from reflow.executors.local import LocalExecutor +from reflow.stores.sqlite import SqliteStore + + +def _store(tmp_path: Path) -> SqliteStore: + st = SqliteStore(str(tmp_path / "db.sqlite")) + st.init() + return st + + +class _RecordingExecutor(LocalExecutor): + """Records each submitted ``--array`` string and returns a fake job id.""" + + def __init__(self) -> None: + super().__init__() + self.arrays: list[str] = [] + self._n = 0 + + def submit(self, resources, command): # type: ignore[override] + self.arrays.append(resources.array) + self._n += 1 + return f"job{self._n}" + + +def _wf(cap: int, n_items: int) -> tuple[Workflow, int]: + wf = Workflow( + "wf", + config=Config( + {"executor": {"mode": "dry-run", "max_submit_jobs": cap}} + ), + ) + + @wf.job() + def source() -> list[str]: + return [f"x{i}" for i in range(n_items)] + + @wf.array_job(array_parallelism=4) + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + return wf, n_items + + +def _state_counts(st: SqliteStore) -> dict: + counts: dict[str, int] = {} + for r in st.list_task_instances("r1", task_name="process"): + counts[r["state"]] = counts.get(r["state"], 0) + 1 + return counts + + +def _complete(st: SqliteStore, indices: list[int]) -> None: + """Mark the given submitted process indices as SUCCESS (a wave finished).""" + for idx in indices: + row = st.get_task_instance("r1", "process", idx) + st.update_task_success(int(row["id"]), f"out{idx}") + + +class TestArrayWaveCapping: + def test_first_wave_respects_cap(self, tmp_path: Path) -> None: + wf, _ = _wf(cap=2, n_items=5) + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) + st.update_task_success(iid, [f"x{i}" for i in range(5)]) + ex = _RecordingExecutor() + + jid = wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) + + assert jid == "job1" + # all 5 instances created; only 2 submitted this wave, 3 still pending + assert _state_counts(st) == {"SUBMITTED": 2, "PENDING": 3} + # array string is the capped, parallelism-tagged range + assert ex.arrays == ["0-1%4"] + + def test_waves_drain_until_empty(self, tmp_path: Path) -> None: + wf, _ = _wf(cap=2, n_items=5) + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) + st.update_task_success(iid, [f"x{i}" for i in range(5)]) + ex = _RecordingExecutor() + spec = wf.tasks["process"] + + # wave 1: indices 0,1 + wf._dispatch_array("r1", tmp_path, spec, st, ex) + _complete(st, [0, 1]) + # wave 2: indices 2,3 + wf._dispatch_array("r1", tmp_path, spec, st, ex) + _complete(st, [2, 3]) + # wave 3: index 4 (single) + wf._dispatch_array("r1", tmp_path, spec, st, ex) + _complete(st, [4]) + # nothing left to submit + jid_final = wf._dispatch_array("r1", tmp_path, spec, st, ex) + + assert ex.arrays == ["0-1%4", "2-3%4", "4%4"] + assert jid_final is None + assert _state_counts(st) == {"SUCCESS": 5} + + def test_no_cap_submits_whole_array_at_once(self, tmp_path: Path) -> None: + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job() + def source() -> list[str]: + return [f"x{i}" for i in range(5)] + + @wf.array_job(array_parallelism=4) + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) + st.update_task_success(iid, [f"x{i}" for i in range(5)]) + ex = _RecordingExecutor() + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) + # cap unset -> single submission of the full range + assert ex.arrays == ["0-4%4"] + assert _state_counts(st) == {"SUBMITTED": 5} + + def test_wave_uses_local_executor_for_real(self, tmp_path: Path) -> None: + # sanity: the capped path also works with a real executor (LocalExecutor) + wf, _ = _wf(cap=3, n_items=4) + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance("r1", "source", None, TaskState.SUCCESS, {}) + st.update_task_success(iid, [f"x{i}" for i in range(4)]) + jid = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + ) + assert jid is not None + # 3 submitted (the cap), 1 still pending for the next wave + counts = _state_counts(st) + assert counts.get("PENDING", 0) == 1 diff --git a/tests/test_executor_live.py b/tests/test_executor_live.py index a1b6f0b..4e6b7da 100644 --- a/tests/test_executor_live.py +++ b/tests/test_executor_live.py @@ -8,9 +8,7 @@ from __future__ import annotations -from unittest.mock import patch, MagicMock - -import pytest +from unittest.mock import patch from reflow.executors import JobResources from reflow.executors.flux import FluxExecutor @@ -439,7 +437,7 @@ def test_submit_live_returns_job_id(self) -> None: assert jid == "f123abc456" def test_submit_live_multiline_takes_last(self) -> None: - """flux submit may print info before the job ID.""" + """Flux submit may print info before the job ID.""" exc = FluxExecutor(mode="flux") res = JobResources(job_name="test") with patch( diff --git a/tests/test_flow.py b/tests/test_flow.py index 5daf77e..64c2c59 100644 --- a/tests/test_flow.py +++ b/tests/test_flow.py @@ -2,16 +2,13 @@ from __future__ import annotations -from pathlib import Path from typing import Annotated import pytest from reflow import ( - Config, Flow, Result, - Run, RunDir, Workflow, ) diff --git a/tests/test_local_runner.py b/tests/test_local_runner.py index c1738d6..7db2f8e 100644 --- a/tests/test_local_runner.py +++ b/tests/test_local_runner.py @@ -10,7 +10,6 @@ from reflow import ( Param, Result, - Run, Workflow, ) from reflow.stores.sqlite import SqliteStore diff --git a/tests/test_local_runner_extended.py b/tests/test_local_runner_extended.py index 2715e56..e0c8e99 100644 --- a/tests/test_local_runner_extended.py +++ b/tests/test_local_runner_extended.py @@ -106,7 +106,8 @@ class TestParallelExecution: def test_parallel_closure_fails_elements(self, tmp_path: Path) -> None: """Closures defined in test methods can't be resolved by _run_task_func (qualname contains '' which has no getattr match). Each element - fails individually and the run is marked FAILED.""" + fails individually and the run is marked FAILED. + """ wf = Workflow("wf") @wf.job() @@ -173,7 +174,6 @@ def process(item: Annotated[str, Result(step="source")]) -> str: on_error="continue", ) # process should have been skipped / marked failed at the run level - from reflow._types import RunState with _store(tmp_path) as st2: row = st2.get_run(run.run_id) @@ -231,7 +231,6 @@ def ok_c() -> str: store=st, on_error="continue", ) - from reflow._types import TaskState with _store(tmp_path) as st2: a = st2.get_task_instance(run.run_id, "fail_a", None) @@ -475,7 +474,6 @@ def test_parallel_sequential_fallback_on_unpicklable( self, tmp_path: Path ) -> None: from unittest.mock import patch - from concurrent.futures import ProcessPoolExecutor wf = Workflow("wf") @@ -599,6 +597,7 @@ def test_local_resolve_inputs_hints_exception( ) -> None: """_local_resolve_inputs handles get_type_hints exception gracefully.""" from unittest.mock import patch + from reflow._types import TaskState wf = Workflow("wf") @@ -669,7 +668,8 @@ def test_local_find_fan_out_param_type_error_skipped( ) -> None: """_local_find_fan_out_param skips a dep when infer_wire_mode raises TypeError in the first loop, then falls back to the first - non-broadcast param in the second loop.""" + non-broadcast param in the second loop. + """ from unittest.mock import patch wf = Workflow("wf") diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 78e7f02..72bfa28 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -9,25 +9,21 @@ import pytest from reflow import ( - Param, - Run, RunState, TaskState, - Workflow, ) from reflow._types import TaskState as TS +from reflow.config import _rosetta_stone from reflow.manifest import ( DEFAULT_CODEC, CliParamDescription, - ManifestCodec, TaskDescription, WorkflowDescription, canonical_manifest_dumps, manifest_dumps, manifest_loads, ) -from reflow.config import _rosetta_stone -from reflow.stores.records import RunRecord, TaskInstanceRecord, TaskSpecRecord +from reflow.stores.records import RunRecord # ═══════════════════════════════════════════════════════════════════════════ # _types.py @@ -90,7 +86,6 @@ def test_enum_roundtrip(self) -> None: assert loaded == TaskState.SUCCESS def test_dataclass_roundtrip(self) -> None: - from reflow.stores.records import RunRecord r = RunRecord( run_id="r1", diff --git a/tests/test_misc.py b/tests/test_misc.py index 76c4b40..0c9d0b6 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -3,12 +3,10 @@ from __future__ import annotations import signal as sig -from pathlib import Path import pytest from reflow import ( - Run, RunDir, Workflow, ) @@ -69,7 +67,6 @@ def test_version_exists(self) -> None: assert hasattr(reflow, "__version__") def test_cli_version(self) -> None: - from reflow.cli import build_parser wf = Workflow("test") diff --git a/tests/test_new_features.py b/tests/test_new_features.py index 5bb8d38..3171115 100644 --- a/tests/test_new_features.py +++ b/tests/test_new_features.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from typing import Annotated, Any +from typing import Annotated from unittest.mock import patch import pytest @@ -20,7 +20,6 @@ ) from reflow.stores.sqlite import SqliteStore - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -237,7 +236,7 @@ def test_runs_default_last_20( self, tmp_path: Path, capsys: pytest.CaptureFixture[str] ) -> None: wf, store = self._setup(tmp_path) - from reflow.cli import build_parser, run_command, parse_args + from reflow.cli import parse_args, run_command args = parse_args(wf, ["runs", "--store-path", str(store.path)]) with patch( diff --git a/tests/test_params.py b/tests/test_params.py index 1f81ac3..96a6295 100644 --- a/tests/test_params.py +++ b/tests/test_params.py @@ -9,11 +9,9 @@ import pytest from reflow import ( - Config, Flow, Param, Result, - Run, RunDir, Workflow, ) @@ -383,9 +381,10 @@ def test_resolved_param_local_flag(self) -> None: class TestParamCliBuildExtended: def test_short_flag_included_in_names(self) -> None: """short_flag on a ResolvedParam adds a short flag to add_to_parser.""" - import argparse, inspect - from typing import Annotated - from reflow import Workflow, Param + import argparse + import inspect + + from reflow import Workflow from reflow.params import collect_cli_params wf = Workflow("wf") @@ -406,9 +405,10 @@ def task(x: Annotated[str, Param(help="X", short="-x")]) -> str: def test_list_type_gets_nargs_plus(self) -> None: """A list[str] param produces a ResolvedParam with is_list=True.""" - import argparse, inspect - from typing import Annotated - from reflow import Workflow, Param + import argparse + import inspect + + from reflow import Workflow from reflow.params import collect_cli_params wf = Workflow("wf") @@ -440,30 +440,33 @@ def test_bool_argparse_type_false(self) -> None: def test_bool_argparse_type_invalid(self) -> None: import argparse + from reflow.params import _parse_bool with pytest.raises(argparse.ArgumentTypeError, match="Invalid boolean"): _parse_bool("maybe") def test_datetime_argparse_type_valid(self) -> None: - from reflow.params import _parse_datetime from datetime import datetime + + from reflow.params import _parse_datetime result = _parse_datetime("2024-01-15T12:00:00") assert isinstance(result, datetime) assert result.year == 2024 def test_datetime_argparse_type_invalid(self) -> None: import argparse + from reflow.params import _parse_datetime with pytest.raises(argparse.ArgumentTypeError, match="Invalid datetime"): _parse_datetime("not-a-date") def test_collect_cli_params_dedup_required(self) -> None: """merge_resolved_params deduplicates global params across tasks.""" - from typing import Annotated - from reflow import Workflow, Param - from reflow.params import collect_cli_params, merge_resolved_params import inspect + from reflow import Workflow + from reflow.params import collect_cli_params, merge_resolved_params + wf = Workflow("wf") @wf.job() @@ -484,11 +487,11 @@ def task_b(x: Annotated[str, Param(help="X")]) -> str: def test_hints_exception_falls_back_gracefully(self) -> None: """collect_cli_params handles functions where get_type_hints raises.""" - from reflow.params import collect_cli_params - from unittest.mock import patch import inspect - from typing import Annotated - from reflow import Workflow, Param + from unittest.mock import patch + + from reflow import Workflow + from reflow.params import collect_cli_params wf = Workflow("wf") @@ -519,9 +522,9 @@ def test_is_run_dir_type_error_returns_false(self) -> None: def test_collect_cli_params_skips_result_deps(self) -> None: import inspect - from reflow import Workflow, Result + + from reflow import Workflow from reflow.params import collect_cli_params - from typing import Annotated wf = Workflow("wf") @@ -541,9 +544,9 @@ def sink(x: Annotated[str, Result(step="source")]) -> str: def test_collect_cli_params_hints_exception(self) -> None: import inspect from unittest.mock import patch - from reflow import Workflow, Param + + from reflow import Workflow from reflow.params import collect_cli_params - from typing import Annotated wf = Workflow("wf") diff --git a/tests/test_results_io.py b/tests/test_results_io.py index 6685daf..bb8d9b2 100644 --- a/tests/test_results_io.py +++ b/tests/test_results_io.py @@ -8,12 +8,11 @@ import pytest from reflow import ( - Result, TaskState, ) from reflow.manifest import DEFAULT_CODEC -from reflow.stores.sqlite import SqliteStore from reflow.results import _result_filename, _results_dir, ingest_results, write_result +from reflow.stores.sqlite import SqliteStore # --- helpers --------------------------------------------------------------- diff --git a/tests/test_store.py b/tests/test_store.py index 5d48b2f..5e24d4d 100644 --- a/tests/test_store.py +++ b/tests/test_store.py @@ -8,13 +8,35 @@ import pytest from reflow import ( - Run, RunState, TaskState, ) -from reflow.flow import TaskSpec -from reflow.stores.sqlite import SqliteStore from reflow.stores.records import RunRecord, TaskInstanceRecord, TaskSpecRecord +from reflow.stores.sqlite import SqliteStore + + +def _seed(tmp_path: Path, n: int, task: str = "conv") -> SqliteStore: + store = SqliteStore.for_run_dir(tmp_path) + store.init() + store.insert_run("r1", "g", "u", {}) + for idx in range(n): + store.insert_task_instance("r1", task, idx, TaskState.PENDING, {}) + return store + + +def _states(store: SqliteStore, task: str = "conv") -> dict: + return { + int(r["array_index"]): r["state"] + for r in store.list_task_instances("r1", task_name=task) + } + + +def _job_ids(store: SqliteStore, task: str = "conv") -> dict: + return { + int(r["array_index"]): r.get("job_id") + for r in store.list_task_instances("r1", task_name=task) + } + # ═══════════════════════════════════════════════════════════════════════════ # _types.py @@ -324,6 +346,7 @@ class TestRetryOnLocked: def test_non_lock_error_raises_immediately(self, tmp_path: Path) -> None: """OperationalError not about locking propagates without retry.""" import sqlite3 + from reflow.stores.sqlite import SqliteStore st = SqliteStore(str(tmp_path / "db.sqlite")) @@ -335,8 +358,9 @@ def test_retry_on_busy_eventually_succeeds(self, tmp_path: Path) -> None: """_retry_on_locked retries on lock errors and ultimately succeeds.""" import sqlite3 from unittest.mock import patch - from reflow.stores.sqlite import SqliteStore + from reflow._types import RunState + from reflow.stores.sqlite import SqliteStore st = SqliteStore(str(tmp_path / "db.sqlite")) st.init() @@ -414,8 +438,9 @@ def test_locked_warning_logged_and_retried( import logging import sqlite3 from unittest.mock import patch - from reflow.stores.sqlite import SqliteStore + from reflow._types import RunState + from reflow.stores.sqlite import SqliteStore st = SqliteStore(str(tmp_path / "db.sqlite")) st.init() @@ -445,13 +470,15 @@ def flaky(self, run_id, status): def test_wal_fallback_on_wal_pragma_error(self, tmp_path: Path) -> None: """Lines 137-141: WAL pragma fails, DELETE fallback is attempted.""" import sqlite3 - from unittest.mock import patch, MagicMock, call + from unittest.mock import MagicMock, patch + from reflow.stores.sqlite import SqliteStore # sqlite3.Connection.execute is a read-only C slot in Python 3.13. # Instead, return a MagicMock whose execute raises on the WAL pragma. - real_conn = sqlite3.connect(str(tmp_path / "real.sqlite"), - check_same_thread=False) + real_conn = sqlite3.connect( + str(tmp_path / "real.sqlite"), check_same_thread=False + ) real_conn.row_factory = sqlite3.Row executed_sqls: list[str] = [] @@ -487,11 +514,13 @@ def selective_execute(sql, *args): class TestSqliteRemainingGaps: def test_retry_raises_on_max_retries(self, tmp_path: Path) -> None: """_retry_on_locked re-raises after exhausting all retries and sleeps - (_MAX_RETRIES - 1) times in between.""" + (_MAX_RETRIES - 1) times in between. + """ import sqlite3 from unittest.mock import patch - from reflow.stores.sqlite import SqliteStore, _MAX_RETRIES + from reflow._types import RunState + from reflow.stores.sqlite import _MAX_RETRIES, SqliteStore # A Connection subclass whose execute can be overridden (the base # sqlite3.Connection.execute is a read-only C slot). @@ -516,10 +545,11 @@ def patched_connect(path, **kwargs): # Re-open the connection through the patched connect so that the # locked-execute subclass is used for update_run_status. st.close() - with patch("reflow.stores.sqlite.sqlite3.connect", - side_effect=patched_connect): - with patch("reflow.stores.sqlite.time.sleep", - side_effect=lambda d: sleep_calls.append(d)): + with patch("reflow.stores.sqlite.sqlite3.connect", side_effect=patched_connect): + with patch( + "reflow.stores.sqlite.time.sleep", + side_effect=lambda d: sleep_calls.append(d), + ): with pytest.raises(sqlite3.OperationalError, match="locked"): st.update_run_status("r1", RunState.SUCCESS) @@ -527,10 +557,7 @@ def patched_connect(path, **kwargs): assert len(sleep_calls) == _MAX_RETRIES - 1 st.close() - - def test_get_task_instance_missing_returns_none( - self, tmp_path: Path - ) -> None: + def test_get_task_instance_missing_returns_none(self, tmp_path: Path) -> None: """get_task_instance returns None for a non-existent task.""" from reflow.stores.sqlite import SqliteStore @@ -541,9 +568,7 @@ def test_get_task_instance_missing_returns_none( assert result is None st.close() - def test_get_singleton_output_missing_returns_none( - self, tmp_path: Path - ) -> None: + def test_get_singleton_output_missing_returns_none(self, tmp_path: Path) -> None: """get_singleton_output returns None when task has no record.""" from reflow.stores.sqlite import SqliteStore @@ -558,8 +583,8 @@ def test_dependency_is_satisfied_returns_false_with_failed( self, tmp_path: Path ) -> None: """dependency_is_satisfied returns False when task has FAILED instances.""" - from reflow.stores.sqlite import SqliteStore from reflow._types import TaskState + from reflow.stores.sqlite import SqliteStore st = SqliteStore(str(tmp_path / "db.sqlite")) st.init() @@ -567,3 +592,98 @@ def test_dependency_is_satisfied_returns_false_with_failed( st.insert_task_instance("r1", "task", None, TaskState.FAILED, {}) assert st.dependency_is_satisfied("r1", "task") is False st.close() + + +class TestUpdateTaskSubmittedIndices: + def test_none_marks_all_instances(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 3) + store.update_task_submitted("r1", "conv", "jobAll") + assert set(_states(store).values()) == {"SUBMITTED"} + assert set(_job_ids(store).values()) == {"jobAll"} + + def test_subset_marks_only_that_wave(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 4) + store.update_task_submitted("r1", "conv", "jobW1", indices=[0, 1]) + states = _states(store) + assert states[0] == states[1] == "SUBMITTED" + assert states[2] == states[3] == "PENDING" + job_ids = _job_ids(store) + assert job_ids[0] == "jobW1" and job_ids[2] is None + + def test_successive_waves_keep_their_own_job_ids(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 4) + store.update_task_submitted("r1", "conv", "jobW1", indices=[0, 1]) + store.update_task_submitted("r1", "conv", "jobW2", indices=[2, 3]) + job_ids = _job_ids(store) + assert job_ids[0] == "jobW1" and job_ids[1] == "jobW1" + assert job_ids[2] == "jobW2" and job_ids[3] == "jobW2" + + def test_empty_indices_is_noop(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 2) + store.update_task_submitted("r1", "conv", "x", indices=[]) + assert set(_states(store).values()) == {"PENDING"} + + def test_singleton_none_index(self, tmp_path: Path) -> None: + store = SqliteStore.for_run_dir(tmp_path) + store.init() + store.insert_run("r1", "g", "u", {}) + store.insert_task_instance("r1", "prep", None, TaskState.PENDING, {}) + store.update_task_submitted("r1", "prep", "jobS") + row = store.get_task_instance("r1", "prep", None) + assert row is not None and row["state"] == "SUBMITTED" + assert row.get("job_id") == "jobS" + + def test_retrying_is_also_submittable(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 2) + # flip index 1 to RETRYING; update_task_submitted should include it + store.mark_for_retry(int(store.get_task_instance("r1", "conv", 1)["id"])) + store.update_task_submitted("r1", "conv", "jobR") + assert set(_states(store).values()) == {"SUBMITTED"} + + +class TestFailPendingTasks: + def test_fail_all_pending(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 3) + n = store.fail_pending_tasks("r1", "conv", "rejected") + assert n == 3 + assert set(_states(store).values()) == {"FAILED"} + + def test_fail_only_subset(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 4) + n = store.fail_pending_tasks("r1", "conv", "rejected", indices=[0, 2]) + assert n == 2 + states = _states(store) + assert states[0] == states[2] == "FAILED" + assert states[1] == states[3] == "PENDING" + + def test_error_text_recorded(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 1) + store.fail_pending_tasks("r1", "conv", "sbatch: job submit limit") + row = store.list_task_instances("r1", task_name="conv")[0] + assert "submit limit" in (row.get("error_text") or "") + + def test_running_instance_is_protected(self, tmp_path: Path) -> None: + store = SqliteStore.for_run_dir(tmp_path) + store.init() + store.insert_run("r1", "g", "u", {}) + iid = store.insert_task_instance("r1", "conv", 0, TaskState.PENDING, {}) + store.update_task_running(iid) + store.insert_task_instance("r1", "conv", 1, TaskState.PENDING, {}) + n = store.fail_pending_tasks("r1", "conv", "rejected") + states = _states(store) + assert states[0] == "RUNNING" # not touched + assert states[1] == "FAILED" + assert n == 1 + + def test_submitted_instances_can_be_failed(self, tmp_path: Path) -> None: + # whole-submission rejection after instances were marked SUBMITTED + store = _seed(tmp_path, 2) + store.update_task_submitted("r1", "conv", "job1") + n = store.fail_pending_tasks("r1", "conv", "whole submission rejected") + assert n == 2 + assert set(_states(store).values()) == {"FAILED"} + + def test_empty_indices_returns_zero(self, tmp_path: Path) -> None: + store = _seed(tmp_path, 2) + assert store.fail_pending_tasks("r1", "conv", "x", indices=[]) == 0 + assert set(_states(store).values()) == {"PENDING"} diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 8beb917..3acf614 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -31,9 +31,13 @@ is_run_dir, unwrap_optional, ) -from reflow.stores.sqlite import SqliteStore -from reflow.workflow._helpers import default_executor, make_run_id, resolve_executor, resolve_index from reflow.results import ingest_results +from reflow.stores.sqlite import SqliteStore +from reflow.workflow._helpers import ( + make_run_id, + resolve_executor, + resolve_index, +) # ═══════════════════════════════════════════════════════════════════════════ # Coverage: _dispatch.py (dispatch loop, resolve, fan-out, finalize) @@ -629,7 +633,6 @@ def test_worker_success(self, tmp_path: Path) -> None: wf.worker(run_id, st, tmp_path, "task_ok") # Result file should have been written and can be ingested - from reflow.results import ingest_results n = ingest_results(run_id, st) assert n == 1 @@ -650,7 +653,6 @@ def test_worker_failure(self, tmp_path: Path) -> None: with pytest.raises(RuntimeError, match="worker boom"): wf.worker(run_id, st, tmp_path, "task_fail") - from reflow.results import ingest_results n = ingest_results(run_id, st) assert n == 1 @@ -817,8 +819,6 @@ def test_status_topo_value_error_fallback( capsys: pytest.CaptureFixture[str], ) -> None: """ValueError from _topological_order is caught; tasks sorted alphabetically.""" - from unittest.mock import patch - wf = Workflow("w") @wf.job() @@ -845,8 +845,6 @@ def test_worker_continues_if_update_running_fails( self, tmp_path: Path ) -> None: """update_task_running exception is swallowed; task still executes.""" - from unittest.mock import patch - wf = Workflow("w") @wf.job() @@ -873,9 +871,8 @@ def task(val: Annotated[str, Param(help="V")]) -> str: class TestValidateEdgeCases: def test_validate_skips_task_when_hints_raises(self) -> None: - from unittest.mock import patch - from reflow import Workflow, Result - from typing import Annotated + + from reflow import Workflow wf = Workflow("wf") @@ -903,8 +900,8 @@ def task() -> str: wf.validate() def test_validate_upstream_no_return_type_skipped(self) -> None: - from reflow import Workflow, Result - from typing import Annotated + + from reflow import Workflow wf = Workflow("wf") @@ -923,10 +920,9 @@ class TestCancelRunsJobId: def test_cancel_run_calls_executor_cancel(self, tmp_path: Path) -> None: """cancel_run calls executor.cancel for instances with a job_id.""" from unittest.mock import MagicMock - from reflow import Workflow, Param + + from reflow import Workflow from reflow.stores.sqlite import SqliteStore - from reflow._types import TaskState - from typing import Annotated wf = Workflow("wf")