diff --git a/docs/cli-reference.md b/docs/cli-reference.md index 01d6272..2434563 100644 --- a/docs/cli-reference.md +++ b/docs/cli-reference.md @@ -98,8 +98,74 @@ Print the task dependency graph: ```console $ python pipeline.py dag + gather_sources + prepare_shared <- gather_sources + download_source [array] <- gather_sources + convert_source [array] <- download_source, prepare_shared + gather_temp_levels <- convert_source + finalize [array] <- gather_temp_levels ``` +Array tasks are marked `[array]`; the `<-` arrow lists each task's +dependencies. + +Use `--format` to choose a different rendering. The `text`, `mermaid`, +and `dot` formats have no extra dependencies: + +```console +$ python pipeline.py dag --format mermaid +flowchart TD + gather_sources[gather_sources] + download_source[[download_source]] + convert_source[[convert_source]] + gather_sources --> download_source + download_source --> convert_source + prepare_shared --> convert_source +``` + +`mermaid` output renders directly in GitHub and in mkdocs-material, so it +is handy for embedding a workflow diagram in documentation. `dot` emits +Graphviz source you can pipe to the `dot` binary: + +```console +$ python pipeline.py dag --format dot | dot -Tpng -o dag.png +``` + +In both `mermaid` and `dot`, array tasks are drawn with a distinct shape +(a subroutine box `[[...]]` in Mermaid, a doubled border in Graphviz) so +they stand out from singleton tasks. + +The `phart` format draws a pretty Unicode diagram directly in the +terminal, with the multi-parent fan-in rendered without duplicating +shared nodes: + +```console +$ python pipeline.py dag --format phart + [gather_sources] + ↓ + <>───────+───────────→[prepare_shared] + ↓ ↓ + +────────→<>──────────+ + ↓ + →[gather_temp_levels] + ↓ + <> +``` + +Array tasks appear as `<>` and singletons as `[name]`. This format +needs an optional dependency: + +```console +$ pip install 'reflow-hpc[pretty]' +``` + +If the extra is not installed, `--format phart` prints the plain `text` +diagram to stdout and a one-line install hint to stderr, then exits +successfully — so redirecting the output (for example +`dag --format phart > graph.txt`) still produces a clean diagram. Pass +`--ascii` to force 7-bit ASCII instead of Unicode box characters, which +is useful for terminals or logs that do not handle Unicode well. + ### `describe` Print the full workflow manifest as JSON: @@ -165,6 +231,17 @@ Only applies to `runs`. Skip the confirmation prompt when cancelling multiple runs. +### `--format` + +Output format for `dag`: `text` (default), `mermaid`, `dot`, or +`phart`. Only applies to `dag`. See the [`dag`](#dag) command above for +details and examples. + +### `--ascii` + +For `dag --format phart`, force 7-bit ASCII output instead of Unicode +box-drawing characters. Only applies to `dag`. + ### `--version` Print the reflow version. diff --git a/docs/getting-started.md b/docs/getting-started.md index 78d013c..427cac5 100644 --- a/docs/getting-started.md +++ b/docs/getting-started.md @@ -8,6 +8,13 @@ Build a complete workflow from scratch in five minutes. pip install reflow-hpc ``` +For a prettier terminal diagram from the [`dag`](cli-reference.md#dag) +command, install the optional `pretty` extra: + +```console +pip install 'reflow-hpc[pretty]' +``` + ## Step 1 — create a workflow ```python diff --git a/pyproject.toml b/pyproject.toml index 8614e1c..ca20b55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ dev = [ "ruff", "mypy", ] -pretty = ["rich-argparse"] +pretty = ["rich-argparse", "phart", "networkx"] docs = [ "mkdocs>=1.5", "mkdocs-material>=9.5", diff --git a/src/reflow/_dag_render.py b/src/reflow/_dag_render.py new file mode 100644 index 0000000..bc1d747 --- /dev/null +++ b/src/reflow/_dag_render.py @@ -0,0 +1,134 @@ +"""DAG rendering for the ``dag`` CLI command. + +Renders a workflow's task graph in one of several formats: + +- ``text`` : indented adjacency list (default, zero dependencies) +- ``mermaid`` : Mermaid ``flowchart`` source (zero dependencies) +- ``dot`` : Graphviz DOT source (zero dependencies) +- ``phart`` : pretty ASCII/Unicode rendered in the terminal, requires the + optional ``reflow[pretty]`` extra (phart + networkx) + +The text, mermaid, and dot renderers are pure string emission. The phart +renderer imports lazily and the caller is responsible for falling back to +text when the import is unavailable. + +Array tasks are marked consistently across formats: a ``[array]`` suffix in +text, a distinct node shape in mermaid/dot, and angle-bracket decorators in +phart. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .workflow import Workflow + +FORMATS = ("text", "mermaid", "dot", "phart") + + +def _edges(wf: Workflow) -> list[tuple[str, str]]: + """Return (dependency, task) edges in topological order.""" + edges: list[tuple[str, str]] = [] + for tname in wf._topological_order(): + spec = wf.tasks[tname] + for dep in wf._effective_dependencies(spec): + edges.append((dep, tname)) + return edges + + +def _array_tasks(wf: Workflow) -> set[str]: + return {name for name, spec in wf.tasks.items() if spec.config.array} + + +def render_text(wf: Workflow) -> str: + """Indented adjacency list (the original format).""" + lines: list[str] = [] + for tname in wf._topological_order(): + spec = wf.tasks[tname] + deps = wf._effective_dependencies(spec) + dep_str = f" <- {', '.join(deps)}" if deps else "" + tag = " [array]" if spec.config.array else "" + lines.append(f" {tname}{tag}{dep_str}") + return "\n".join(lines) + + +def render_mermaid(wf: Workflow) -> str: + """Mermaid ``flowchart TD`` source. + + Array tasks use the subroutine shape ``[[name]]``; singletons use the + default box ``[name]``. Renders natively in mkdocs-material and GitHub. + """ + array = _array_tasks(wf) + lines = ["flowchart TD"] + # Declare nodes first so isolated tasks (no edges) still appear. + for tname in wf._topological_order(): + if tname in array: + lines.append(f" {tname}[[{tname}]]") + else: + lines.append(f" {tname}[{tname}]") + for dep, tname in _edges(wf): + lines.append(f" {dep} --> {tname}") + return "\n".join(lines) + + +def render_dot(wf: Workflow) -> str: + """Graphviz DOT source. + + Array tasks are drawn as boxes with doubled borders (``peripheries=2``) + to distinguish them from singleton tasks. + """ + array = _array_tasks(wf) + lines = ["digraph reflow {", " rankdir=TB;", " node [shape=box];"] + for tname in wf._topological_order(): + if tname in array: + lines.append(f' "{tname}" [peripheries=2];') + else: + lines.append(f' "{tname}";') + for dep, tname in _edges(wf): + lines.append(f' "{dep}" -> "{tname}";') + lines.append("}") + return "\n".join(lines) + + +def render_phart(wf: Workflow, *, use_ascii: bool = False) -> str: + """Pretty ASCII/Unicode DAG via the optional phart + networkx extra. + + Array tasks are decorated with angle brackets ``<>``; singletons + with square brackets ``[name]``. Raises ImportError if the extra is + not installed; the caller should catch this and fall back to text. + """ + import networkx as nx # noqa: PLC0415 - optional dependency + from phart import ASCIIRenderer, NodeStyle # noqa: PLC0415 + + array = _array_tasks(wf) + g = nx.DiGraph() + # Add all nodes so isolated tasks still render. + for tname in wf._topological_order(): + g.add_node(tname) + g.add_edges_from(_edges(wf)) + + decorators = { + name: (("<<", ">>") if name in array else ("[", "]")) for name in g.nodes + } + renderer = ASCIIRenderer( + g, + node_style=NodeStyle.CUSTOM, + custom_decorators=decorators, + use_ascii=use_ascii, + ) + result: str = renderer.render() + return result.rstrip("\n") + + +def render(wf: Workflow, fmt: str, *, use_ascii: bool = False) -> str: + """Render the DAG in *fmt*. phart import errors propagate to the caller.""" + if fmt == "text": + return render_text(wf) + if fmt == "mermaid": + return render_mermaid(wf) + if fmt == "dot": + return render_dot(wf) + if fmt == "phart": + return render_phart(wf, use_ascii=use_ascii) + raise ValueError(f"Unknown DAG format: {fmt!r}") diff --git a/src/reflow/cli.py b/src/reflow/cli.py index dc226a5..e0ba5a7 100644 --- a/src/reflow/cli.py +++ b/src/reflow/cli.py @@ -217,10 +217,27 @@ def _add_runs_parser(sp: Any) -> None: def _add_dag_parser(sp: Any) -> None: + from ._dag_render import FORMATS + p = sp.add_parser( "dag", help="Print the task DAG.", ) + p.add_argument( + "--format", + choices=FORMATS, + default="text", + help=( + "Output format. 'text' is the default adjacency list; 'mermaid' " + "and 'dot' emit graph source for external renderers; 'phart' " + "draws a pretty terminal diagram (needs the 'pretty' extra)." + ), + ) + p.add_argument( + "--ascii", + action="store_true", + help="For --format phart, force 7-bit ASCII instead of Unicode.", + ) p.set_defaults(_command="dag") @@ -552,13 +569,21 @@ def _cmd_runs(wf: Any, args: argparse.Namespace) -> int: def _cmd_dag(wf: Any, args: argparse.Namespace) -> int: - order = wf._topological_order() - for tname in order: - spec = wf.tasks[tname] - deps = wf._effective_dependencies(spec) - dep_str = f" <- {', '.join(deps)}" if deps else "" - tag = " [array]" if spec.config.array else "" - print(f" {tname}{tag}{dep_str}") + from . import _dag_render + + fmt = getattr(args, "format", "text") + if fmt == "phart": + try: + print(_dag_render.render_phart(wf, use_ascii=getattr(args, "ascii", False))) + return 0 + except ImportError: + print( + "phart not installed; showing plain text. " + "For a prettier diagram: pip install 'reflow[pretty]'", + file=sys.stderr, + ) + fmt = "text" + print(_dag_render.render(wf, fmt, use_ascii=getattr(args, "ascii", False))) return 0 diff --git a/tests/test_cli.py b/tests/test_cli.py index 37b905c..216f38c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -761,3 +761,209 @@ def task(x: Annotated[str, Param(help="X")]) -> str: out = capsys.readouterr().out assert rc == 0 assert "task" in out or "wf" in out + + +# ═══════════════════════════════════════════════════════════════════════════ +# dag command: --format {text,mermaid,dot,phart} +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestCLIDagFormats: + @staticmethod + def _make_wf() -> Workflow: + """A workflow with an array task and a multi-parent node, mirroring + the shape of a real conversion pipeline.""" + wf = Workflow("dagfmt") + + @wf.job() + def gather_sources() -> list[str]: + return ["a", "b"] + + @wf.job() + def prepare_shared( + s: Annotated[list[str], Result(step="gather_sources")], + ) -> str: + return "shared" + + @wf.array_job() + def download_source( + item: Annotated[str, Result(step="gather_sources")], + ) -> str: + return item + + @wf.array_job() + def convert_source( + d: Annotated[str, Result(step="download_source")], + p: Annotated[str, Result(step="prepare_shared", broadcast=True)], + ) -> str: + return d + + return wf + + def test_dag_default_is_text( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + args = parse_args(wf, ["dag"]) + assert args.format == "text" + rc = run_command(wf, args) + out = capsys.readouterr().out + assert rc == 0 + # Array tasks carry the [array] suffix; multi-parent shows both deps + assert "download_source [array]" in out + assert "convert_source [array] <- download_source, prepare_shared" in out + + def test_dag_format_text_explicit( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + rc = run_command(wf, parse_args(wf, ["dag", "--format", "text"])) + out = capsys.readouterr().out + assert rc == 0 + assert "gather_sources" in out + + def test_dag_format_mermaid( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + rc = run_command(wf, parse_args(wf, ["dag", "--format", "mermaid"])) + out = capsys.readouterr().out + assert rc == 0 + assert out.startswith("flowchart TD") + # Array tasks use the subroutine shape [[...]] + assert "download_source[[download_source]]" in out + assert "convert_source[[convert_source]]" in out + # Singletons use the plain box + assert "gather_sources[gather_sources]" in out + # Edges render with arrows + assert "gather_sources --> download_source" in out + assert "prepare_shared --> convert_source" in out + + def test_dag_format_dot( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + rc = run_command(wf, parse_args(wf, ["dag", "--format", "dot"])) + out = capsys.readouterr().out + assert rc == 0 + assert out.startswith("digraph reflow {") + assert out.rstrip().endswith("}") + # Array tasks get doubled borders + assert '"download_source" [peripheries=2];' in out + assert '"convert_source" [peripheries=2];' in out + # Singletons have no peripheries attribute + assert '"gather_sources";' in out + # Edges are quoted + assert '"gather_sources" -> "download_source";' in out + + def test_dag_invalid_format_rejected(self) -> None: + from reflow.cli import parse_args + + wf = self._make_wf() + with pytest.raises(SystemExit): + parse_args(wf, ["dag", "--format", "nonsense"]) + + def test_dag_format_phart_when_available( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """If phart + networkx are importable, phart output is produced.""" + pytest.importorskip("phart") + pytest.importorskip("networkx") + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + rc = run_command(wf, parse_args(wf, ["dag", "--format", "phart"])) + out = capsys.readouterr().out + assert rc == 0 + # All task names appear in the rendered diagram + for name in ( + "gather_sources", + "prepare_shared", + "download_source", + "convert_source", + ): + assert name in out + # Array tasks are decorated with angle brackets, singletons with [] + assert "<>" in out + assert "[gather_sources]" in out + + def test_dag_format_phart_ascii_flag( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """--ascii forces 7-bit output (no Unicode arrows).""" + pytest.importorskip("phart") + pytest.importorskip("networkx") + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + rc = run_command( + wf, parse_args(wf, ["dag", "--format", "phart", "--ascii"]) + ) + out = capsys.readouterr().out + assert rc == 0 + # No Unicode downward arrow in ASCII mode + assert "\u2193" not in out + assert "gather_sources" in out + + def test_dag_phart_missing_falls_back_to_text( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """When phart can't be imported, dag prints text to stdout and a hint + to stderr, still returning 0.""" + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + + real_import = __builtins__["__import__"] if isinstance( + __builtins__, dict + ) else __builtins__.__import__ + + def blocked_import(name, *a, **k): + if name == "phart" or name.startswith("phart."): + raise ImportError("No module named 'phart'") + return real_import(name, *a, **k) + + with patch("builtins.__import__", side_effect=blocked_import): + rc = run_command(wf, parse_args(wf, ["dag", "--format", "phart"])) + + captured = capsys.readouterr() + assert rc == 0 + # Text diagram on stdout + assert "download_source [array]" in captured.out + # Actionable hint on stderr, not stdout + assert "pip install" in captured.err + assert "pretty" in captured.err + assert "phart" not in captured.out + + def test_dag_phart_hint_goes_to_stderr_only( + self, capsys: pytest.CaptureFixture[str] + ) -> None: + """The fallback hint must not pollute stdout (so redirects stay clean).""" + from reflow.cli import parse_args, run_command + + wf = self._make_wf() + + real_import = __builtins__["__import__"] if isinstance( + __builtins__, dict + ) else __builtins__.__import__ + + def blocked_import(name, *a, **k): + if name == "phart" or name.startswith("phart."): + raise ImportError("blocked") + return real_import(name, *a, **k) + + with patch("builtins.__import__", side_effect=blocked_import): + run_command(wf, parse_args(wf, ["dag", "--format", "phart"])) + + captured = capsys.readouterr() + # stdout is a valid text diagram with no warning text mixed in + assert "not installed" not in captured.out + assert "not installed" in captured.err diff --git a/tests/test_dag_render.py b/tests/test_dag_render.py new file mode 100644 index 0000000..473df04 --- /dev/null +++ b/tests/test_dag_render.py @@ -0,0 +1,218 @@ +"""test_dag_render.py - unit tests for the DAG rendering module. + +These drive reflow._dag_render directly (rather than through the CLI) so the +render() dispatcher, the format guard, and each renderer are exercised in +isolation. The CLI-level behaviour (--format flag, phart fallback to text, +stderr hint) is covered separately in test_cli.py. +""" + +from __future__ import annotations + +from typing import Annotated + +import pytest + +from reflow import Result, Workflow +from reflow import _dag_render + + +def _make_wf() -> Workflow: + """Workflow with an array task and a multi-parent node.""" + wf = Workflow("dagrender") + + @wf.job() + def gather_sources() -> list[str]: + return ["a", "b"] + + @wf.job() + def prepare_shared( + s: Annotated[list[str], Result(step="gather_sources")], + ) -> str: + return "shared" + + @wf.array_job() + def download_source( + item: Annotated[str, Result(step="gather_sources")], + ) -> str: + return item + + @wf.array_job() + def convert_source( + d: Annotated[str, Result(step="download_source")], + p: Annotated[str, Result(step="prepare_shared", broadcast=True)], + ) -> str: + return d + + return wf + + +# ═══════════════════════════════════════════════════════════════════════════ +# render() dispatcher +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestRenderDispatcher: + def test_render_text(self) -> None: + out = _dag_render.render(_make_wf(), "text") + assert "gather_sources" in out + assert "download_source [array]" in out + + def test_render_mermaid(self) -> None: + out = _dag_render.render(_make_wf(), "mermaid") + assert out.startswith("flowchart TD") + + def test_render_dot(self) -> None: + out = _dag_render.render(_make_wf(), "dot") + assert out.startswith("digraph reflow {") + + def test_render_phart_via_dispatcher(self) -> None: + """The phart branch of render() (line 132-133) is reached.""" + pytest.importorskip("phart") + pytest.importorskip("networkx") + out = _dag_render.render(_make_wf(), "phart") + assert "gather_sources" in out + assert "<>" in out + + def test_render_phart_via_dispatcher_ascii(self) -> None: + pytest.importorskip("phart") + pytest.importorskip("networkx") + out = _dag_render.render(_make_wf(), "phart", use_ascii=True) + assert "\u2193" not in out # no Unicode arrow + assert "gather_sources" in out + + def test_render_unknown_format_raises(self) -> None: + """The format guard (line 134) raises ValueError.""" + with pytest.raises(ValueError, match="Unknown DAG format"): + _dag_render.render(_make_wf(), "nonsense") + + def test_formats_constant(self) -> None: + assert _dag_render.FORMATS == ("text", "mermaid", "dot", "phart") + + +# ═══════════════════════════════════════════════════════════════════════════ +# Individual renderers +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestRenderText: + def test_dependencies_arrow(self) -> None: + out = _dag_render.render_text(_make_wf()) + assert ( + "convert_source [array] <- download_source, prepare_shared" in out + ) + + def test_root_has_no_arrow(self) -> None: + out = _dag_render.render_text(_make_wf()) + lines = out.splitlines() + root = next(l for l in lines if "gather_sources" in l) + assert "<-" not in root + + +class TestRenderMermaid: + def test_array_subroutine_shape(self) -> None: + out = _dag_render.render_mermaid(_make_wf()) + assert "download_source[[download_source]]" in out + assert "convert_source[[convert_source]]" in out + + def test_singleton_box_shape(self) -> None: + out = _dag_render.render_mermaid(_make_wf()) + assert "gather_sources[gather_sources]" in out + + def test_edges_present(self) -> None: + out = _dag_render.render_mermaid(_make_wf()) + assert "gather_sources --> download_source" in out + assert "prepare_shared --> convert_source" in out + + def test_isolated_node_still_declared(self) -> None: + """A task with no edges still appears as a declared node.""" + wf = Workflow("solo") + + @wf.job() + def lonely() -> str: + return "x" + + out = _dag_render.render_mermaid(wf) + assert "lonely[lonely]" in out + + +class TestRenderDot: + def test_array_doubled_border(self) -> None: + out = _dag_render.render_dot(_make_wf()) + assert '"download_source" [peripheries=2];' in out + assert '"convert_source" [peripheries=2];' in out + + def test_singleton_no_peripheries(self) -> None: + out = _dag_render.render_dot(_make_wf()) + assert '"gather_sources";' in out + + def test_edges_quoted(self) -> None: + out = _dag_render.render_dot(_make_wf()) + assert '"gather_sources" -> "download_source";' in out + + def test_well_formed(self) -> None: + out = _dag_render.render_dot(_make_wf()) + assert out.startswith("digraph reflow {") + assert out.rstrip().endswith("}") + assert "rankdir=TB;" in out + + +class TestRenderPhart: + def test_renders_all_nodes(self) -> None: + pytest.importorskip("phart") + pytest.importorskip("networkx") + out = _dag_render.render_phart(_make_wf()) + for name in ( + "gather_sources", + "prepare_shared", + "download_source", + "convert_source", + ): + assert name in out + + def test_array_and_singleton_decorators(self) -> None: + pytest.importorskip("phart") + pytest.importorskip("networkx") + out = _dag_render.render_phart(_make_wf()) + assert "<>" in out + assert "[gather_sources]" in out + + def test_ascii_mode_no_unicode(self) -> None: + pytest.importorskip("phart") + pytest.importorskip("networkx") + out = _dag_render.render_phart(_make_wf(), use_ascii=True) + assert "\u2193" not in out + assert "\u2192" not in out + + def test_missing_dependency_raises_importerror(self) -> None: + """render_phart propagates ImportError when phart is unavailable.""" + import builtins + + real_import = builtins.__import__ + + def blocked(name, *a, **k): + if name == "phart" or name.startswith("phart."): + raise ImportError("No module named 'phart'") + return real_import(name, *a, **k) + + with pytest.MonkeyPatch.context() as mp: + mp.setattr(builtins, "__import__", blocked) + with pytest.raises(ImportError): + _dag_render.render_phart(_make_wf()) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Edge helpers +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestEdgeHelpers: + def test_edges_topological(self) -> None: + edges = _dag_render._edges(_make_wf()) + # Every edge is (dependency, task); gather_sources has no incoming edge + assert ("gather_sources", "download_source") in edges + assert ("download_source", "convert_source") in edges + assert ("prepare_shared", "convert_source") in edges + + def test_array_tasks_detected(self) -> None: + arrays = _dag_render._array_tasks(_make_wf()) + assert arrays == {"download_source", "convert_source"} diff --git a/tests/test_dispatch.py b/tests/test_dispatch.py index 81a529d..fea5220 100644 --- a/tests/test_dispatch.py +++ b/tests/test_dispatch.py @@ -1152,3 +1152,405 @@ def sink(x: Annotated[str, Result(step="source")]) -> str: with patch("typing.get_type_hints", side_effect=Exception("bad")): result = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) assert isinstance(result, dict) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Targeted coverage for remaining _dispatch.py branches +# ═══════════════════════════════════════════════════════════════════════════ + + +class TestDispatchBranchCoverage: + def test_resolve_result_inputs_hints_exception(self, tmp_path: Path) -> None: + """_resolve_result_inputs: get_type_hints raising → hints = {} (145-146).""" + from unittest.mock import patch + wf = Workflow("wf") + + @wf.job() + def source() -> str: + return "v" + + @wf.job() + def sink(x: Annotated[str, Result(step="source")]) -> str: + return x + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance( + "r1", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid, "v") + with patch( + "reflow.workflow._dispatch.get_type_hints", + side_effect=Exception("bad"), + ): + result = wf._resolve_result_inputs(st, "r1", wf.tasks["sink"]) + # With hints unavailable the dep can't be wired; result is a dict + assert isinstance(result, dict) + + def test_find_fan_out_param_hints_exception_returns_none( + self, tmp_path: Path + ) -> None: + """_find_fan_out_param: get_type_hints raising → None (203-204).""" + from unittest.mock import patch + wf = Workflow("wf") + + @wf.job() + def source() -> list[str]: + return ["a"] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + with patch( + "reflow.workflow._dispatch.get_type_hints", + side_effect=Exception("bad"), + ): + assert wf._find_fan_out_param(wf.tasks["process"]) is None + + def test_find_fan_out_param_broadcast_skipped(self, tmp_path: Path) -> None: + """Broadcast dep is skipped in the first loop (211).""" + wf = Workflow("wf") + + @wf.job() + def cfg() -> list[str]: + return ["c"] + + @wf.job() + def source() -> list[str]: + return ["a", "b"] + + @wf.array_job() + def process( + item: Annotated[str, Result(step="source")], + extra: Annotated[list[str], Result(step="cfg", broadcast=True)], + ) -> str: + return item + + wf.validate() + # The fan-out param is "item"; "extra" (broadcast) is skipped. + assert wf._find_fan_out_param(wf.tasks["process"]) == "item" + + def test_find_fan_out_param_unknown_upstream_skipped( + self, tmp_path: Path + ) -> None: + """A dep whose upstream has no return type is skipped (218, 227-228).""" + from unittest.mock import patch + wf = Workflow("wf") + + @wf.job() + def source(): # no return annotation + return ["a", "b"] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + # upstream return_type is empty → first loop continues; falls back to + # the first non-broadcast dep, which is still "item". + assert wf._find_fan_out_param(wf.tasks["process"]) == "item" + + def test_find_fan_out_param_infer_type_error_skipped( + self, tmp_path: Path + ) -> None: + """infer_wire_mode raising TypeError skips the dep, fallback returns it (234).""" + from unittest.mock import patch + wf = Workflow("wf") + + @wf.job() + def source() -> list[str]: + return ["a", "b"] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + with patch( + "reflow.workflow._dispatch.infer_wire_mode", + side_effect=TypeError("bad"), + ): + # First loop continues past the TypeError; fallback returns "item". + assert wf._find_fan_out_param(wf.tasks["process"]) == "item" + + def test_try_cache_verify_stale(self, tmp_path: Path) -> None: + """_try_cache via dispatch: verify + stale callable → False (311-317).""" + from typing import Any + from reflow.executors.local import LocalExecutor + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + def stale(output: Any) -> bool: + return False + + @wf.job(cache=True, verify=stale) + def step(x: Annotated[str, Param(help="x")]) -> str: + return x.upper() + + wf.validate() + with _store(tmp_path) as st: + # Seed a cached record by completing the task once via dispatch. + st.insert_run("r1", "wf", "u", {"x": "v"}) + ex = LocalExecutor() + wf._dispatch_single("r1", tmp_path, wf.tasks["step"], st, ex) + row = st.get_task_instance("r1", "step", None) + st.update_task_success(int(row["id"]), "V", output_hash="abc") + + # New run: cache exists but verify=stale → _try_cache returns False. + st.insert_run("r2", "wf", "u", {"x": "v"}) + hit = wf._try_cache( + st, "r2", wf.tasks["step"], {}, None, [], verify=True + ) + assert hit is False + + def test_try_cache_hit_recomputes_output_hash(self, tmp_path: Path) -> None: + """_try_cache via dispatch: missing output_hash recomputed (324-326).""" + from reflow.executors.local import LocalExecutor + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job(cache=True) + def step(x: Annotated[str, Param(help="x")]) -> str: + return x.upper() + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {"x": "v"}) + ex = LocalExecutor() + wf._dispatch_single("r1", tmp_path, wf.tasks["step"], st, ex) + row = st.get_task_instance("r1", "step", None) + iid = int(row["id"]) + st.update_task_success(iid, "V") + # Clear the output_hash so _try_cache must recompute it. + st.conn.execute( + "UPDATE task_instances SET output_hash = '' WHERE id = ?", (iid,) + ) + st.conn.commit() + + # _dispatch_single resolves payload via _resolve_result_inputs, which + # is {} for a Param-only task — so the cache identity is built from + # an empty payload. Match that here to hit the cached record. + st.insert_run("r2", "wf", "u", {"x": "v"}) + hit = wf._try_cache( + st, "r2", wf.tasks["step"], {}, None, [] + ) + assert hit is True + assert st.get_task_instance("r2", "step", None)["state"] == "SUCCESS" + + def test_try_cache_no_cache_config_returns_false(self, tmp_path: Path) -> None: + """_try_cache returns False immediately when caching is off (295).""" + wf = Workflow("wf") + + @wf.job() # cache defaults to False + def step(x: Annotated[str, Param(help="x")]) -> str: + return x + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {"x": "v"}) + hit = wf._try_cache( + st, "r1", wf.tasks["step"], {"x": "v"}, None, [] + ) + assert hit is False + + def test_try_cache_miss_returns_false(self, tmp_path: Path) -> None: + """_try_cache returns False when nothing is cached (307-308).""" + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job(cache=True) + def step(x: Annotated[str, Param(help="x")]) -> str: + return x + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {"x": "v"}) + hit = wf._try_cache( + st, "r1", wf.tasks["step"], {"x": "v"}, None, [] + ) + assert hit is False + + def test_dispatch_single_creates_pending(self, tmp_path: Path) -> None: + """_dispatch_single inserts a PENDING instance then submits (387+).""" + wf = Workflow("wf") + + @wf.job() + def step(x: Annotated[str, Param(help="x")]) -> str: + return x + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {"x": "v"}) + from reflow.executors.local import LocalExecutor + wf._dispatch_single( + "r1", tmp_path, wf.tasks["step"], st, LocalExecutor() + ) + inst = st.get_task_instance("r1", "step", None) + assert inst is not None + + def test_dispatch_array_no_result_inputs(self, tmp_path: Path) -> None: + """_dispatch_array returns None when result_inputs empty (444).""" + wf = Workflow("wf") + + @wf.job() + def source() -> list[str]: + return ["a"] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + # source not done → no result inputs + from reflow.executors.local import LocalExecutor + result = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + ) + assert result is None + + def test_dispatch_array_empty_fan_items(self, tmp_path: Path) -> None: + """_dispatch_array returns None when fan list empty (454).""" + wf = Workflow("wf") + + @wf.job() + def source() -> list[str]: + return [] + + @wf.array_job() + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + wf.validate() + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance( + "r1", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid, []) + from reflow.executors.local import LocalExecutor + result = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + ) + assert result is None + + def test_dispatch_array_all_cached_returns_none(self, tmp_path: Path) -> None: + """All elements cached → 'all resolved from cache' log + None (508).""" + from reflow.executors.local import LocalExecutor + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job() + def source() -> list[str]: + return ["a", "b"] + + @wf.array_job(cache=True) + def process(item: Annotated[str, Result(step="source")]) -> str: + return item.upper() + + wf.validate() + with _store(tmp_path) as st: + # Run 1: dispatch the array and complete both elements. + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance( + "r1", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid, ["a", "b"]) + ex = LocalExecutor() + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) + for inst in st.list_task_instances("r1", task_name="process"): + st.update_task_success(int(inst["id"]), inst["input"]["item"].upper()) + + # Run 2: both elements should resolve from cache → returns None. + st.insert_run("r2", "wf", "u", {}) + iid2 = st.insert_task_instance( + "r2", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid2, ["a", "b"]) + result = wf._dispatch_array( + "r2", tmp_path, wf.tasks["process"], st, ex + ) + assert result is None + + def test_dispatch_array_partial_cached(self, tmp_path: Path) -> None: + """Some cached, some pending → comma-joined arr_str (513).""" + from reflow.executors.local import LocalExecutor + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job() + def source() -> list[str]: + return ["a", "b", "c"] + + @wf.array_job(cache=True) + def process(item: Annotated[str, Result(step="source")]) -> str: + return item.upper() + + wf.validate() + with _store(tmp_path) as st: + # Run 1: complete only the first element so its identity is cached. + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance( + "r1", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid, ["a", "b", "c"]) + ex = LocalExecutor() + wf._dispatch_array("r1", tmp_path, wf.tasks["process"], st, ex) + insts = sorted( + st.list_task_instances("r1", task_name="process"), + key=lambda i: i["array_index"], + ) + # Cache just element 0 ("a"). + st.update_task_success(int(insts[0]["id"]), "A") + + # Run 2: element 0 hits cache, elements 1,2 pending → partial arr_str. + st.insert_run("r2", "wf", "u", {}) + iid2 = st.insert_task_instance( + "r2", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid2, ["a", "b", "c"]) + job_id = wf._dispatch_array( + "r2", tmp_path, wf.tasks["process"], st, ex + ) + assert job_id is not None + # All three instances exist (one cached SUCCESS, two newly created). + assert len(st.list_task_instances("r2", task_name="process")) == 3 + + def test_dispatch_array_parallelism_suffix(self, tmp_path: Path) -> None: + """array_parallelism appends %N to arr_str (518).""" + wf = Workflow("wf", config=Config({"executor": {"mode": "dry-run"}})) + + @wf.job() + def source() -> list[str]: + return ["a", "b", "c"] + + @wf.array_job(array_parallelism=2) + def process(item: Annotated[str, Result(step="source")]) -> str: + return item + + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + iid = st.insert_task_instance( + "r1", "source", None, TaskState.SUCCESS, {} + ) + st.update_task_success(iid, ["a", "b", "c"]) + from reflow.executors.local import LocalExecutor + job_id = wf._dispatch_array( + "r1", tmp_path, wf.tasks["process"], st, LocalExecutor() + ) + assert job_id is not None + instances = st.list_task_instances("r1", task_name="process") + assert len(instances) == 3 + + def test_maybe_finalise_returns_on_empty_summary(self, tmp_path: Path) -> None: + """_maybe_finalise_run returns early on empty summary (533 region).""" + wf = Workflow("wf") + + @wf.job() + def step() -> str: + return "ok" + + with _store(tmp_path) as st: + st.insert_run("r1", "wf", "u", {}) + # No task instances → empty summary → early return + wf._maybe_finalise_run("r1", st) + assert st.get_run("r1")["status"] in ("RUNNING", "PENDING", "SUBMITTED")