diff --git a/docs/superpowers/plans/2026-04-07-eng379-function-node-refactor.md b/docs/superpowers/plans/2026-04-07-eng379-function-node-refactor.md new file mode 100644 index 00000000..9bc13d42 --- /dev/null +++ b/docs/superpowers/plans/2026-04-07-eng379-function-node-refactor.md @@ -0,0 +1,1630 @@ +# ENG-379 FunctionNode Refactor Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Refactor `FunctionNode` so `iter_packets()` is strictly read-only, computation only triggers via `run()` / `execute()` / `async_execute()`, and all duplicated DB join logic is consolidated into one helper `_load_cached_entries()`. + +**Architecture:** A new internal `_load_cached_entries()` method replaces four copies of the pipeline-DB + result-DB join. `_cached_output_packets` switches from `dict[int, tuple]` (position-keyed) to `dict[str, tuple]` (entry_id-keyed). `iter_packets()` becomes a pure yield from the in-memory store with a one-shot hot-load from DB on empty cache; `run()` becomes a load-status guard that delegates to `execute()`. + +**Tech Stack:** Python, PyArrow, Polars (for DB joins), pytest, asyncio + +--- + +## File Map + +| Action | Path | +|---|---| +| **Modify** | `src/orcapod/core/nodes/function_node.py` | +| **Create** | `tests/test_core/nodes/test_function_node_iteration.py` | +| **Modify** | `tests/test_core/function_pod/test_function_pod_node_stream.py` | +| **Modify** | `tests/test_core/function_pod/test_function_pod_node.py` | +| **Modify** | `tests/test_core/function_pod/test_function_node_attach_db.py` | +| **Modify** | `tests/test_data/test_polars_nullability/test_function_node_nullability.py` | + +`tests/test_core/nodes/test_function_node_get_cached.py` and `tests/test_core/nodes/test_node_execute.py` need minor verification but no code changes — the key type change (`int → str`) in `_cached_output_packets` still satisfies `len(node._cached_output_packets) == N` assertions. + +`tests/test_core/test_regression_fixes.py::TestConcurrentFallbackInRunningLoop` tests `FunctionPodStream._iter_packets_concurrent()` in `function_pod.py` — a **different class** that is not changed. Leave it as-is and add equivalent coverage in the new test file. + +--- + +## Concepts for implementers + +**Entry ID**: `compute_pipeline_entry_id(tag, packet)` — a hash over `(tag columns + system_tags + input_packet_hash + node_content_hash)`. Uniquely identifies one (input, node) pair. Already computed in `execute()` for every upstream packet. + +**`_cached_output_packets`**: Session-level result store. After the refactor it is `dict[str, tuple[TagProtocol, PacketProtocol | None]]` keyed by entry_id. Populated by `_process_packet_internal()` (computation) and by `_load_cached_entries()` (DB hot-load). Never cleared except by `clear_cache()` or `attach_databases()`. Overwriting an existing key is safe because in-memory and DB results for the same entry_id are always semantically equivalent. + +**`_load_cached_entries(entry_ids=None)`**: The new single DB join helper. Returns `dict[str, tuple[Tag, Packet]]`. Does NOT mutate `_cached_output_packets`; callers do `self._cached_output_packets.update(loaded)`. + +**`iter_packets()` after the refactor**: Strictly read-only. Never calls `_process_packet_internal()`. On first call with empty `_cached_output_packets` and a DB attached, hot-loads via `_load_cached_entries()`. Otherwise yields from `_cached_output_packets.values()`. On a fresh node with no prior `run()` and empty DB, yields nothing — this is correct. + +**`run()`**: Guards on `load_status` (UNAVAILABLE → RuntimeError, CACHE_ONLY → no-op), then calls `execute(self._input_stream)`. + +**Tests that called `iter_packets()` to trigger computation**: Must be updated to call `node.run()` or `node.execute(node._input_stream)` first. The new `iter_packets()` is read-only. + +**`as_table()` is affected**: It calls `iter_packets()` internally. Without a prior `run()`, it returns an empty table. Tests that previously relied on `as_table()` triggering computation must be updated to call `run()` first. + +--- + +## Task 1: Write failing tests for new iteration semantics + +**Files:** +- Create: `tests/test_core/nodes/test_function_node_iteration.py` + +- [ ] **Step 1: Write the test file** + +```python +"""Tests for the refactored FunctionNode iteration semantics. + +After ENG-379: +- iter_packets() is strictly read-only — never triggers computation +- Computation only via run() / execute() / async_execute() +- execute() is always sequential; async/concurrent path is only in async_execute() +""" +from __future__ import annotations + +import asyncio +from unittest.mock import patch + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.executors import LocalExecutor + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + }, + schema=pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=False), + ] + ), + ) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _make_node(n: int = 3, db: InMemoryArrowDatabase | None = None) -> FunctionNode: + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + pipeline_db = db if db is not None else InMemoryArrowDatabase() + return FunctionNode(pod, _make_source(n=n), pipeline_database=pipeline_db) + + +class TestIterPacketsReadOnly: + def test_fresh_node_no_db_yields_nothing(self): + """iter_packets() on a fresh node with no run() and empty DB yields nothing.""" + node = _make_node() + assert list(node.iter_packets()) == [] + + def test_iter_does_not_call_process_packet_internal(self): + """iter_packets() never calls _process_packet_internal under any non-compute path.""" + node = _make_node() + with patch.object(node, "_process_packet_internal") as mock_proc: + list(node.iter_packets()) + mock_proc.assert_not_called() + + def test_iter_after_db_populated_hot_loads_without_compute(self): + """iter_packets() on a node with DB records hot-loads without _process_packet_internal.""" + db = InMemoryArrowDatabase() + node1 = _make_node(n=3, db=db) + node1.run() # populate DB + + node2 = _make_node(n=3, db=db) + with patch.object(node2, "_process_packet_internal") as mock_proc: + results = list(node2.iter_packets()) + mock_proc.assert_not_called() + assert len(results) == 3 + + def test_after_run_iter_yields_from_cache_no_db_query(self): + """After run(), iter_packets() yields from _cached_output_packets without DB query.""" + node = _make_node() + node.run() + initial_count = len(node._cached_output_packets) + assert initial_count == 3 + + with patch.object(node, "_load_cached_entries") as mock_load: + results = list(node.iter_packets()) + mock_load.assert_not_called() + assert len(results) == 3 + + def test_iter_twice_same_order_db_queried_once(self): + """Two successive iter_packets() calls return same order; DB queried at most once.""" + db = InMemoryArrowDatabase() + node1 = _make_node(n=3, db=db) + node1.run() + + node2 = _make_node(n=3, db=db) + with patch.object(node2, "_load_cached_entries", wraps=node2._load_cached_entries) as mock_load: + first = [(t["id"], p["result"]) for t, p in node2.iter_packets()] + second = [(t["id"], p["result"]) for t, p in node2.iter_packets()] + assert mock_load.call_count <= 1 # at most one DB query + assert first == second + + def test_cached_output_packets_keyed_by_entry_id_strings(self): + """After run(), _cached_output_packets keys are entry_id strings, not ints.""" + node = _make_node() + node.run() + assert len(node._cached_output_packets) == 3 + for key in node._cached_output_packets: + assert isinstance(key, str), f"Expected str key, got {type(key)}: {key!r}" + + def test_as_table_fresh_node_returns_empty_no_compute(self): + """as_table() on a fresh node with no run() and empty DB returns empty table.""" + node = _make_node() + with patch.object(node, "_process_packet_internal") as mock_proc: + table = node.as_table() + mock_proc.assert_not_called() + assert isinstance(table, pa.Table) + assert len(table) == 0 + + def test_run_cache_only_is_noop(self): + """run() on a CACHE_ONLY node returns without error and without computation.""" + from orcapod.pipeline.serialization import LoadStatus + + node = _make_node() + node._load_status = LoadStatus.CACHE_ONLY + node._input_stream = None # simulate no upstream + + with patch.object(node, "execute") as mock_exec: + node.run() + mock_exec.assert_not_called() + + def test_run_unavailable_raises(self): + """run() on an UNAVAILABLE node raises RuntimeError.""" + from orcapod.pipeline.serialization import LoadStatus + + node = _make_node() + node._load_status = LoadStatus.UNAVAILABLE + with pytest.raises(RuntimeError, match="unavailable"): + node.run() + + def test_execute_error_policy_continue_skips_failures(self): + """execute() sequential path: on_packet_crash fires per failing packet with error_policy='continue'.""" + errors = [] + + def sometimes_fail(x: int) -> int: + if x == 1: + raise ValueError("intentional failure") + return x * 2 + + pf = PythonPacketFunction(sometimes_fail, output_keys="result") + pf.executor = LocalExecutor() # non-concurrent; tests sequential execute() path + pod = FunctionPod(pf) + db = InMemoryArrowDatabase() + node = FunctionNode(pod, _make_source(n=3), pipeline_database=db) + + from orcapod.pipeline.observer import NoOpObserver + + class CapturingObserver(NoOpObserver): + def on_packet_crash(self, node_label, tag, packet, exc): + errors.append(exc) + + results = node.execute(node._input_stream, observer=CapturingObserver(), error_policy="continue") + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + # Two non-failing packets should succeed + assert len(results) == 2 +``` + +- [ ] **Step 2: Run the tests to confirm they all fail** + +```bash +cd /home/kurouto/kurouto-jobs/d6d5fe6c-886b-4517-b09e-1371bda77ecd/orcapod-python +python -m pytest tests/test_core/nodes/test_function_node_iteration.py -v 2>&1 | head -60 +``` + +Expected: Multiple failures — `_load_cached_entries` does not exist, `iter_packets()` currently computes, `run()` currently calls `iter_packets()`, etc. + +--- + +## Task 2: Implement `_load_cached_entries()` — the single DB join helper + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` + +Insert a new method `_load_cached_entries()` in the `# Cache-only helpers` section (after `_require_pipeline_database()`, before `_load_all_cached_records()` — around line 1051). Place it just before `_load_all_cached_records`. + +- [ ] **Step 1: Add `_load_cached_entries()` to `function_node.py`** + +Insert after line 249 (`_filter_by_content_hash` end) and before line 251 (`# from_descriptor`), adding it as the last private helper in the internal helpers block. Alternatively, place it just before `_load_all_cached_records` (around line 1051 — the `# Cache-only helpers` section). The latter location is preferred for grouping. + +```python +def _load_cached_entries( + self, + entry_ids: list[str] | None = None, +) -> "dict[str, tuple[TagProtocol, PacketProtocol]]": + """Load (tag, packet) pairs from pipeline DB + result DB. + + Args: + entry_ids: If provided, load only these specific entry IDs. + If ``None``, load all records for this node. + + Returns: + dict mapping entry_id → (tag, packet). Empty dict when either + database is None, records are empty, or no rows match. + + Does NOT mutate ``_cached_output_packets``. + Callers merge via ``self._cached_output_packets.update(loaded)``. + """ + if self._cached_function_pod is None or self._pipeline_database is None: + return {} + + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + + taginfo = self._pipeline_database.get_all_records( + self.node_identity_path, + record_id_column=PIPELINE_ENTRY_ID_COL, + ) + results = self._cached_function_pod._result_database.get_all_records( + self._cached_function_pod.record_path, + record_id_column=constants.PACKET_RECORD_ID, + ) + + if taginfo is None or results is None: + return {} + + taginfo = self._filter_by_content_hash(taginfo) + taginfo_schema = taginfo.schema + results_schema = results.schema + + joined_df = pl.DataFrame(taginfo).join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + if entry_ids is not None: + joined_df = joined_df.filter( + pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids) + ) + joined = joined_df.to_arrow() + joined = arrow_utils.restore_schema_nullability( + joined, taginfo_schema, results_schema + ) + + if joined.num_rows == 0: + return {} + + # Derive tag keys: prefer input_stream when available; fall back to + # taginfo column exclusion for CACHE_ONLY / deserialized nodes. + if self._input_stream is not None: + tag_keys = self._input_stream.keys()[0] + else: + tag_keys = tuple( + c + for c in taginfo.column_names + if not c.startswith(constants.META_PREFIX) + and not c.startswith(constants.SOURCE_PREFIX) + and not c.startswith(constants.SYSTEM_TAG_PREFIX) + and c != PIPELINE_ENTRY_ID_COL + and c != constants.NODE_CONTENT_HASH_COL + ) + + # Drop internal columns (SOURCE_PREFIX is kept — ArrowTableStream needs it) + drop_cols = [ + c + for c in joined.column_names + if c.startswith(constants.META_PREFIX) + or c == PIPELINE_ENTRY_ID_COL + or c == constants.NODE_CONTENT_HASH_COL + ] + data_table = joined.drop([c for c in drop_cols if c in joined.column_names]) + + entry_ids_col = joined.column(PIPELINE_ENTRY_ID_COL).to_pylist() + stream = ArrowTableStream(data_table, tag_columns=tag_keys) + + loaded: dict[str, tuple[TagProtocol, PacketProtocol]] = {} + for eid, (tag, packet) in zip(entry_ids_col, stream.iter_packets()): + loaded[eid] = (tag, packet) + return loaded +``` + +- [ ] **Step 2: Run existing test suite to verify no regressions from just adding the method** + +```bash +python -m pytest tests/test_core/nodes/ tests/test_core/function_pod/ -x -q 2>&1 | tail -20 +``` + +Expected: Same pass/fail ratio as before (method is new; nothing calls it yet). + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "feat(ENG-379): add _load_cached_entries() single DB join helper" +``` + +--- + +## Task 3: Simplify `_process_packet_internal()` and `_async_process_packet_internal()` + +Remove the `cache_index: int | None` parameter from both methods. Store results by `entry_id` string instead of by integer position. Remove the lines that reset `_cached_input_iterator` and `_needs_iterator` (those fields are going away in Task 6). + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 652–852 + +- [ ] **Step 1: Update `_process_packet_internal()`** + +Replace the entire method (lines 652–708) with: + +```python +def _process_packet_internal( + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: PacketExecutionLoggerProtocol | None = None, +) -> tuple[TagProtocol, PacketProtocol | None]: + """Core compute + persist + cache. + + Used by ``execute_packet`` and ``execute``. + Stores result in ``_cached_output_packets`` keyed by entry_id. + Exceptions propagate to the caller — no error handling here. + + Returns: + A ``(tag, output_packet)`` 2-tuple. + """ + if self._cached_function_pod is not None: + tag_out, output_packet = self._cached_function_pod.process_packet( + tag, packet, logger=logger + ) + + if output_packet is not None: + result_computed = bool( + output_packet.get_meta_value( + self._cached_function_pod.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + else: + tag_out, output_packet = self._function_pod.process_packet( + tag, packet, logger=logger + ) + + # Store by entry_id and invalidate derived caches + entry_id = self.compute_pipeline_entry_id(tag, packet) + self._cached_output_packets[entry_id] = (tag_out, output_packet) + self._cached_output_table = None + self._cached_content_hash_column = None + + return tag_out, output_packet +``` + +- [ ] **Step 2: Update `_async_process_packet_internal()`** + +Replace lines 793–852 with: + +```python +async def _async_process_packet_internal( + self, + tag: TagProtocol, + packet: PacketProtocol, + *, + logger: PacketExecutionLoggerProtocol | None = None, +) -> tuple[TagProtocol, PacketProtocol | None]: + """Async counterpart of ``_process_packet_internal``. + + Computes via async path, writes pipeline provenance, caches by entry_id. + Exceptions propagate. + + Returns: + A ``(tag, output_packet)`` 2-tuple. + """ + if self._cached_function_pod is not None: + tag_out, output_packet = ( + await self._cached_function_pod.async_process_packet( + tag, packet, logger=logger + ) + ) + + if output_packet is not None: + result_computed = bool( + output_packet.get_meta_value( + self._cached_function_pod.RESULT_COMPUTED_FLAG, False + ) + ) + self.add_pipeline_record( + tag, + packet, + packet_record_id=output_packet.datagram_id, + computed=result_computed, + ) + else: + tag_out, output_packet = ( + await self._function_pod.async_process_packet( + tag, packet, logger=logger + ) + ) + + # Store by entry_id and invalidate derived caches + entry_id = self.compute_pipeline_entry_id(tag, packet) + self._cached_output_packets[entry_id] = (tag_out, output_packet) + self._cached_output_table = None + self._cached_content_hash_column = None + + return tag_out, output_packet +``` + +- [ ] **Step 3: Run tests to verify `execute_packet()` still works** + +`execute_packet()` delegates to `_process_packet_internal()`. The `TestFunctionNodeExecutePacket` tests verify it stores results and writes DB records. + +```bash +python -m pytest tests/test_core/nodes/test_node_execute.py::TestFunctionNodeExecutePacket -v 2>&1 +``` + +Expected: PASS (the key type change from int to str means `len(_cached_output_packets) == 1` still holds). + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): store _cached_output_packets by entry_id, remove cache_index" +``` + +--- + +## Task 4: Rewrite `run()` — load_status guard + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 1347–1350 + +- [ ] **Step 1: Replace `run()` (lines 1347–1350)** + +```python +def run(self) -> None: + """Eagerly compute all input packets, filling pipeline and result databases. + + Raises: + RuntimeError: If ``load_status`` is UNAVAILABLE (no pod, no DB). + """ + from orcapod.pipeline.serialization import LoadStatus + + if self._load_status == LoadStatus.UNAVAILABLE: + raise RuntimeError( + f"FunctionNode {self.label!r} is unavailable: " + "no function pod and no database attached." + ) + if self._load_status == LoadStatus.CACHE_ONLY: + # Upstream is unavailable; computation requires a live input stream. + # Callers should use iter_packets() to serve existing DB results. + return + self.execute(self._input_stream) +``` + +- [ ] **Step 2: Verify new `run()` tests pass** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_iteration.py::TestIterPacketsReadOnly::test_run_cache_only_is_noop tests/test_core/nodes/test_function_node_iteration.py::TestIterPacketsReadOnly::test_run_unavailable_raises -v 2>&1 +``` + +Expected: PASS. + +- [ ] **Step 3: Verify existing `test_run_fills_database` still passes** + +```bash +python -m pytest tests/test_core/function_pod/test_function_pod_node.py::TestFunctionNodeStreamInterface::test_run_fills_database -v 2>&1 +``` + +Expected: PASS (`run()` now calls `execute(self._input_stream)` which does the same computation). + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): rewrite run() with load_status guard delegating to execute()" +``` + +--- + +## Task 5: Rewrite `iter_packets()` — strictly read-only + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 1172–1285 + +- [ ] **Step 1: Replace `iter_packets()` (lines 1172–1285) with the read-only version** + +```python +def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Yield all computed (tag, packet) pairs for this node. + + Strictly read-only — never triggers computation. Callers must call + ``run()`` or ``execute()`` first if they want results computed. + + On the first call with an empty in-memory store and a DB attached, + hot-loads all existing records from the DB (one-shot, no recompute). + + Raises: + RuntimeError: If ``load_status`` is UNAVAILABLE. + """ + from orcapod.pipeline.serialization import LoadStatus + + status = self.load_status + if status == LoadStatus.UNAVAILABLE: + raise RuntimeError( + f"FunctionNode {self.label!r} is unavailable: " + "no function pod and no database attached." + ) + + if status == LoadStatus.CACHE_ONLY: + # Upstream unavailable; serve entirely from DB. + if not self._cached_output_packets: + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + yield from ( + (tag, pkt) + for tag, pkt in self._cached_output_packets.values() + if pkt is not None + ) + return + + # FULL / READ_ONLY — in-memory store may be populated from computation + # (via execute/run) or hot-loaded from DB. + if self.is_stale: + self.clear_cache() + + if not self._cached_output_packets and self._cached_function_pod is not None: + # Hot-load from DB on the first call when store is empty. + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + yield from ( + (tag, pkt) + for tag, pkt in self._cached_output_packets.values() + if pkt is not None + ) +``` + +- [ ] **Step 2: Run new iteration tests** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_iteration.py -v 2>&1 | grep -E "PASS|FAIL|ERROR" +``` + +Expected: Most of the 10 tests pass now. `test_execute_error_policy_continue_skips_failures` may still fail (Task 9). + +- [ ] **Step 3: Quick smoke-test on execute path** + +```bash +python -m pytest tests/test_core/nodes/test_node_execute.py -v 2>&1 +``` + +Expected: PASS (execute is unchanged; iter_packets refactor doesn't break execute). + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): rewrite iter_packets() as strictly read-only with DB hot-load" +``` + +--- + +## Task 6: Remove iterator state from `__init__`, `from_descriptor()`, `clear_cache()` + +Remove `_cached_input_iterator` and `_needs_iterator` — they were only used by the now-deleted computation path in `iter_packets()`. + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` + +- [ ] **Step 1: Remove from `__init__` (lines 121–124)** + +Delete these two lines: +```python + self._cached_input_iterator: ( + Iterator[tuple[TagProtocol, PacketProtocol]] | None + ) = None + self._needs_iterator = True +``` + +Also remove the `Iterator` import from the top-level imports if it is no longer used anywhere else. Check with: +```bash +grep -n "Iterator" src/orcapod/core/nodes/function_node.py +``` +If `Iterator` still appears in type annotations on other methods, keep the import. + +- [ ] **Step 2: Remove from `from_descriptor()` (lines 358–359)** + +Delete these two lines in the read-only mode `__new__` block: +```python + node._cached_input_iterator = None + node._needs_iterator = True +``` + +- [ ] **Step 3: Remove from `clear_cache()` (lines 543–545)** + +Change `clear_cache()` from: +```python + def clear_cache(self) -> None: + self._cached_input_iterator = None + self._needs_iterator = True + self._cached_output_packets.clear() + self._cached_output_table = None + self._cached_content_hash_column = None + self._node_identity_path_cache = None + self._update_modified_time() +``` + +To: +```python + def clear_cache(self) -> None: + self._cached_output_packets.clear() + self._cached_output_table = None + self._cached_content_hash_column = None + self._node_identity_path_cache = None + self._update_modified_time() +``` + +- [ ] **Step 4: Remove `_ensure_iterator()` (lines 536–541)** + +Delete the entire method: +```python + def _ensure_iterator(self) -> None: + """Lazily acquire the upstream iterator on first use.""" + if self._needs_iterator: + self._cached_input_iterator = self._input_stream.iter_packets() + self._needs_iterator = False + self._update_modified_time() +``` + +- [ ] **Step 5: Update `_cached_output_packets` type annotation in `__init__` (lines 125–127)** + +Change the type annotation from: +```python + self._cached_output_packets: dict[ + int, tuple[TagProtocol, PacketProtocol | None] + ] = {} +``` +To: +```python + self._cached_output_packets: dict[ + str, tuple[TagProtocol, PacketProtocol | None] + ] = {} +``` + +- [ ] **Step 6: Run core tests** + +```bash +python -m pytest tests/test_core/nodes/ tests/test_core/function_pod/test_function_node_caching.py -v -q 2>&1 | tail -20 +``` + +Expected: The existing tests that don't call `iter_packets()` for computation pass. Some tests in `test_function_pod_node_stream.py` will now fail (they used `iter_packets()` as a computation trigger) — those are fixed in Task 12. + +- [ ] **Step 7: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): remove _cached_input_iterator/_needs_iterator iterator state" +``` + +--- + +## Task 7: Simplify `get_cached_results()` — delegate to `_load_cached_entries()` + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 710–791 + +- [ ] **Step 1: Replace `get_cached_results()` (lines 710–791) with the simplified version** + +```python +def get_cached_results( + self, entry_ids: list[str] +) -> dict[str, tuple[TagProtocol, PacketProtocol]]: + """Retrieve cached results for specific pipeline entry IDs. + + Checks in-memory cache first. Loads only truly missing entries from DB. + Add-only semantics: existing in-memory entries are never cleared or + overwritten (overwrite is safe since in-memory and DB entries for the + same entry_id are always semantically equivalent). + + Args: + entry_ids: Pipeline entry IDs to look up. + + Returns: + Mapping from entry_id to ``(tag, output_packet)`` for found entries. + Empty dict if no DB is attached or no matches found. + """ + if self._cached_function_pod is None or not entry_ids: + return {} + + missing = [eid for eid in entry_ids if eid not in self._cached_output_packets] + if missing: + loaded = self._load_cached_entries(missing) + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + return { + eid: self._cached_output_packets[eid] + for eid in entry_ids + if eid in self._cached_output_packets + } +``` + +- [ ] **Step 2: Run `test_function_node_get_cached.py`** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_get_cached.py -v 2>&1 +``` + +Expected: All 5 tests pass. The `test_get_cached_results_populates_internal_cache` test manually calls `node._cached_output_packets.clear()` before calling `get_cached_results()` — this is compatible with the new add-only semantics. After clearing, `get_cached_results()` loads missing entries from DB and adds them. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): simplify get_cached_results() using _load_cached_entries(), add-only semantics" +``` + +--- + +## Task 8: Simplify `_async_execute_cache_only()` + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 1136–1166 + +- [ ] **Step 1: Replace `_async_execute_cache_only()` (lines 1136–1166)** + +```python +async def _async_execute_cache_only( + self, + output: "WritableChannel[tuple[TagProtocol, PacketProtocol]]", + *, + observer: Any | None = None, +) -> None: + """Send all DB-cached (tag, packet) pairs to *output*. + + Used in ``CACHE_ONLY`` mode when the upstream is unavailable. + Does not access ``_input_stream``. + """ + from orcapod.pipeline.observer import NoOpObserver + + obs = observer if observer is not None else NoOpObserver() + node_label = self.label + node_hash = self.content_hash().to_string() + ctx_obs = obs.contextualize(*self.node_identity_path) + + ctx_obs.on_node_start(node_label, node_hash, tag_schema=None) + try: + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + for tag, packet in self._cached_output_packets.values(): + if packet is not None: + ctx_obs.on_packet_start(node_label, tag, packet) + ctx_obs.on_packet_end(node_label, tag, packet, packet, cached=True) + await output.send((tag, packet)) + ctx_obs.on_node_end(node_label, node_hash) + finally: + await output.close() +``` + +- [ ] **Step 2: Run async-related tests** + +```bash +python -m pytest tests/ -k "async_execute or cache_only" -v -q 2>&1 | tail -20 +``` + +Expected: No regressions on async execution paths. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): simplify _async_execute_cache_only() using _load_cached_entries()" +``` + +--- + +## Task 9: Refactor `execute()` — selective DB reload (always sequential) + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 577–650 + +This is the most complex change. The goals: +1. After collecting all upstream entry_ids, load only the **missing** ones from DB (not all). +2. After DB load, compute any entries still missing (truly new computation). +3. When `_executor_supports_concurrent()` returns True, dispatch `_async_process_packet_internal` concurrently via `asyncio.gather(..., return_exceptions=True)`. +4. Observer hooks (`on_packet_start`, `on_packet_end`, `on_packet_crash`) fire correctly in both paths. + +- [ ] **Step 1: Add `import concurrent.futures` at the top of `execute()`** + +In the body of `execute()`, add `import concurrent.futures` alongside the existing `from orcapod.pipeline.observer import NoOpObserver`. + +- [ ] **Step 2: Replace `execute()` (lines 577–650)** + +```python +def execute( + self, + input_stream: StreamProtocol, + *, + observer: ExecutionObserverProtocol | None = None, + error_policy: Literal["continue", "fail_fast"] = "continue", +) -> list[tuple[TagProtocol, PacketProtocol]]: + """Execute all packets from a stream: compute, persist, and cache. + + Computation order: + 1. Collect all (tag, packet, entry_id) from input_stream. + 2. Load only missing entry_ids from DB (selective reload). + 3. Compute truly missing entries — concurrent if executor supports it, + sequential otherwise. + 4. Build output list, firing observer hooks for every packet. + + Args: + input_stream: The input stream to process. + observer: Optional execution observer for hooks. + error_policy: ``"continue"`` skips failed packets; + ``"fail_fast"`` re-raises on the first failure. + + Returns: + Materialized list of (tag, output_packet) pairs, excluding + ``None`` outputs and failed packets. + """ + import concurrent.futures + from orcapod.pipeline.observer import NoOpObserver + + node_label = self.label + node_hash = self.content_hash().to_string() + + obs = observer if observer is not None else NoOpObserver() + ctx_obs = obs.contextualize(*self.node_identity_path) + + tag_schema = input_stream.output_schema(columns={"system_tags": True})[0] + ctx_obs.on_node_start(node_label, node_hash, tag_schema=tag_schema) + + # --- Step 1: Collect upstream entries --- + upstream_entries: list[tuple[TagProtocol, PacketProtocol, str]] = [ + (tag, packet, self.compute_pipeline_entry_id(tag, packet)) + for tag, packet in input_stream.iter_packets() + ] + + # --- Step 2: Selective DB reload for missing entries --- + if self._cached_function_pod is not None: + missing_eids = [ + eid + for _, _, eid in upstream_entries + if eid not in self._cached_output_packets + ] + if missing_eids: + loaded = self._load_cached_entries(missing_eids) + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + # --- Step 3: Compute truly missing entries --- + to_compute = [ + (tag, pkt, eid) + for tag, pkt, eid in upstream_entries + if eid not in self._cached_output_packets + ] + + if to_compute and _executor_supports_concurrent(self._packet_function): + # Concurrent path: dispatch all missing packets via asyncio.gather + async def _gather(): + return await asyncio.gather( + *[ + self._async_process_packet_internal(tag, pkt) + for tag, pkt, _ in to_compute + ], + return_exceptions=True, + ) + + try: + asyncio.get_running_loop() + # Already in event loop — run in a separate thread + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + gather_results = pool.submit(asyncio.run, _gather()).result() + except RuntimeError: + gather_results = asyncio.run(_gather()) + + for (tag, pkt, eid), result in zip(to_compute, gather_results): + if isinstance(result, BaseException): + logger.warning( + "Packet execution failed in %s: %s", + node_label, + result, + exc_info=result, + ) + ctx_obs.on_packet_crash(node_label, tag, pkt, result) + if error_policy == "fail_fast": + ctx_obs.on_node_end(node_label, node_hash) + raise result + else: + # Sequential path + for tag, pkt, eid in to_compute: + pkt_logger = ctx_obs.create_packet_logger(tag, pkt) + try: + self._process_packet_internal(tag, pkt, logger=pkt_logger) + except Exception as exc: + logger.warning( + "Packet execution failed in %s: %s", + node_label, + exc, + exc_info=True, + ) + ctx_obs.on_packet_crash(node_label, tag, pkt, exc) + if error_policy == "fail_fast": + ctx_obs.on_node_end(node_label, node_hash) + raise + + # --- Step 4: Build output, fire observer hooks for all packets --- + to_compute_eids = {eid for _, _, eid in to_compute} + output: list[tuple[TagProtocol, PacketProtocol]] = [] + for tag, pkt, eid in upstream_entries: + ctx_obs.on_packet_start(node_label, tag, pkt) + if eid in self._cached_output_packets: + tag_out, result = self._cached_output_packets[eid] + if result is not None: + ctx_obs.on_packet_end( + node_label, + tag, + pkt, + result, + cached=(eid not in to_compute_eids), + ) + output.append((tag_out, result)) + # Packets that failed are absent from _cached_output_packets — silently skipped + + ctx_obs.on_node_end(node_label, node_hash) + return output +``` + +- [ ] **Step 3: Run `test_node_execute.py` tests to verify execute still works** + +```bash +python -m pytest tests/test_core/nodes/test_node_execute.py::TestFunctionNodeExecute -v 2>&1 +``` + +Expected: PASS. + +- [ ] **Step 4: Run the concurrent execute test** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_iteration.py::TestIterPacketsReadOnly::test_execute_error_policy_continue_skips_failures -v 2>&1 +``` + +Expected: PASS. + +- [ ] **Step 5: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): refactor execute() with selective DB reload (always sequential)" +``` + +--- + +## Task 10: Simplify `async_execute()` Phase 1 + +Replace the inline DB join in `async_execute()` Phase 1 (lines 1519–1571) with a call to `_load_cached_entries()`. + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — lines 1519–1572 + +- [ ] **Step 1: Replace Phase 1 of the DB-backed branch in `async_execute()`** + +Locate the block starting with: +```python + if self._cached_function_pod is not None: + # DB-backed async execution: + # Phase 1: build cache lookup from pipeline DB + PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" + cached_by_entry_id: dict[...] = {} + + taginfo = self._pipeline_database.get_all_records(...) + ... +``` + +Replace all of Phase 1 (the inline join from `taginfo = ...` through `cached_by_entry_id[eid] = (tag_out, pkt_out)`) with: + +```python + if self._cached_function_pod is not None: + # Phase 1: build cache lookup from pipeline DB + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + cached_by_entry_id: dict[str, tuple[TagProtocol, PacketProtocol]] = dict(loaded) +``` + +Phase 2 (the `_process_one_db` async closure + TaskGroup) remains unchanged except that `cached_by_entry_id` now comes from `_load_cached_entries()`. + +- [ ] **Step 2: Run async pipeline integration test if available, otherwise smoke test** + +```bash +python -m pytest tests/ -k "async" -v -q 2>&1 | tail -30 +``` + +Expected: No regressions on async paths. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): simplify async_execute() Phase 1 using _load_cached_entries()" +``` + +--- + +## Task 11: Remove dead methods + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` + +The following five methods are no longer used and should be deleted entirely: + +| Method | Location | Replaced by | +|---|---|---| +| `_iter_packets_sequential()` | ~line 1287 | `execute()` sequential path | +| `_iter_packets_concurrent()` | ~line 1301 | removed — concurrent path did not move to `execute()`; async execution is in `async_execute()` only | +| `_iter_all_from_database()` | ~line 1119 | `_load_cached_entries()` | +| `_load_all_cached_records()` | ~line 1051 | `_load_cached_entries()` | + +`_ensure_iterator()` was already removed in Task 6. + +- [ ] **Step 1: Delete all four dead methods from `function_node.py`** + +Delete: +1. `_iter_packets_sequential()` (~lines 1287–1299) +2. `_iter_packets_concurrent()` (~lines 1301–1345) +3. `_iter_all_from_database()` (~lines 1119–1134) +4. `_load_all_cached_records()` (~lines 1051–1117) + +Verify nothing else calls these methods: +```bash +grep -n "_iter_packets_sequential\|_iter_packets_concurrent\|_iter_all_from_database\|_load_all_cached_records\|_ensure_iterator" src/orcapod/core/nodes/function_node.py +``` +Expected: no results. + +- [ ] **Step 2: Run full test suite** + +```bash +python -m pytest tests/ -x -q 2>&1 | tail -30 +``` + +Expected: Some tests fail in `test_function_pod_node_stream.py`, `test_function_pod_node.py`, `test_function_node_attach_db.py`, `test_function_node_nullability.py` — those are fixed in Tasks 12–13. + +- [ ] **Step 3: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "refactor(ENG-379): remove dead methods (iter_sequential/concurrent, iter_all_from_database, load_all_cached_records)" +``` + +--- + +## Task 12: Fix `test_function_pod_node_stream.py` + +All tests below failed because they called `iter_packets()` or `as_table()` to trigger computation. Fix: add `node.run()` before the assertion (or where computation is expected). + +**Files:** +- Modify: `tests/test_core/function_pod/test_function_pod_node_stream.py` + +- [ ] **Step 1: Fix `TestFunctionNodeStreamBasic` fixture — add `node.run()` in the fixture** + +Change the fixture from: +```python + @pytest.fixture + def node(self, double_pf) -> FunctionNode: + db = InMemoryArrowDatabase() + return FunctionNode( + function_pod=FunctionPod(packet_function=double_pf), + input_stream=make_int_stream(n=3), + pipeline_database=db, + ) +``` +To: +```python + @pytest.fixture + def node(self, double_pf) -> FunctionNode: + db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=FunctionPod(packet_function=double_pf), + input_stream=make_int_stream(n=3), + pipeline_database=db, + ) + node.run() + return node +``` + +This fixes: `test_iter_packets_yields_correct_count`, `test_iter_packets_correct_values`, `test_iter_is_repeatable`, `test_dunder_iter_delegates_to_iter_packets`, `test_as_table_returns_pyarrow_table`, `test_as_table_has_correct_row_count`, `test_as_table_contains_tag_columns`, `test_as_table_contains_packet_columns`. + +- [ ] **Step 2: Fix `TestFunctionNodeColumnConfig` tests** + +Add `node.run()` before `as_table()` in both tests: + +`test_as_table_content_hash_column`: +```python + def test_as_table_content_hash_column(self, double_pf): + node = _make_node(double_pf, n=3) + node.run() + table = node.as_table(columns={"content_hash": True}) + assert "_content_hash" in table.column_names + assert len(table.column("_content_hash")) == 3 +``` + +`test_as_table_sort_by_tags`: +```python + def test_as_table_sort_by_tags(self, double_pf): + db = InMemoryArrowDatabase() + reversed_table = pa.table(...) + input_stream = ArrowTableStream(reversed_table, tag_columns=["id"]) + node = FunctionNode( + function_pod=FunctionPod(packet_function=double_pf), + input_stream=input_stream, + pipeline_database=db, + ) + node.run() + result = node.as_table(columns={"sort_by_tags": True}) + ids: list[int] = result.column("id").to_pylist() + assert ids == sorted(ids) +``` + +- [ ] **Step 3: Fix `TestFunctionNodeInactive::test_as_table_returns_cached_results_when_packet_function_inactive`** + +Change from calling `node1.as_table()` to calling `node1.run()` to populate DB: +```python + def test_as_table_returns_cached_results_when_packet_function_inactive(self, double_pf): + n = 3 + db = InMemoryArrowDatabase() + node1 = _make_node(double_pf, n=n, db=db) + node1.run() # populate DB + table1 = node1.as_table() # now hot-loads from cache + assert len(table1) == n + + double_pf.set_active(False) + + node2 = _make_node(double_pf, n=n, db=db) + table2 = node2.as_table() # hot-loads from DB (function inactive but DB has results) + + assert isinstance(table2, pa.Table) + assert len(table2) == n + assert table2.column("result").to_pylist() == table1.column("result").to_pylist() +``` + +- [ ] **Step 4: Fix `TestIterPacketsDbPhase::test_db_served_results_have_correct_values`** + +The test calls `_make_node(...).as_table()` for both nodes. Add `run()` to the first node (which populates DB) — the second node should then hot-load: + +```python + def test_db_served_results_have_correct_values(self, double_pf): + n = 4 + db = InMemoryArrowDatabase() + + node1 = _make_node(double_pf, n=n, db=db) + node1.run() + table1 = node1.as_table() + table2 = _make_node(double_pf, n=n, db=db).as_table() # hot-loads from DB + + assert sorted(table1.column("result").to_pylist()) == sorted( + table2.column("result").to_pylist() + ) +``` + +- [ ] **Step 5: Fix `TestIterPacketsMissingEntriesOnly::test_partial_fill_total_row_count_correct`** + +After the refactor, `iter_packets()` on a fresh node with 2 DB entries and 4 upstream inputs will hot-load and yield 2 (not 4). Add `run()` before `iter_packets()`: + +```python + def test_partial_fill_total_row_count_correct(self, double_pf): + n = 4 + db = InMemoryArrowDatabase() + _fill_node(_make_node(double_pf, n=2, db=db)) + node = _make_node(double_pf, n=n, db=db) + node.run() # computes the 2 missing entries + packets = list(node.iter_packets()) + assert len(packets) == n +``` + +- [ ] **Step 6: Fix `TestIterPacketsMissingEntriesOnly::test_partial_fill_all_values_correct`** + +```python + def test_partial_fill_all_values_correct(self, double_pf): + n = 4 + db = InMemoryArrowDatabase() + _fill_node(_make_node(double_pf, n=2, db=db)) + node = _make_node(double_pf, n=n, db=db) + node.run() # computes the 2 missing entries + table = node.as_table() + assert sorted(table.column("result").to_pylist()) == [0, 2, 4, 6] +``` + +- [ ] **Step 7: Fix `TestFunctionNodeStaleness::test_clear_cache_resets_output_packets`** + +Change `list(node.iter_packets())` to `node.run()`: +```python + def test_clear_cache_resets_output_packets(self, double_pf): + node = _make_node(double_pf, n=3) + node.run() # populate _cached_output_packets + assert len(node._cached_output_packets) == 3 + + node.clear_cache() + assert len(node._cached_output_packets) == 0 + assert node._cached_output_table is None +``` + +- [ ] **Step 8: Fix `TestFunctionNodeStaleness::test_clear_cache_produces_same_results_on_re_iteration`** + +Add `node.run()` before first `as_table()`: +```python + def test_clear_cache_produces_same_results_on_re_iteration(self, double_pf): + node = _make_node(double_pf, n=3) + node.run() + table_before = node.as_table() + + node.clear_cache() + table_after = node.as_table() # hot-loads from DB after clear + + assert sorted(table_before.column("result").to_pylist()) == sorted( + table_after.column("result").to_pylist() + ) +``` + +- [ ] **Step 9: Fix `TestFunctionNodeStaleness::test_iter_packets_auto_detects_stale_and_repopulates`** + +Add `node.run()` to populate DB before first `iter_packets()`: +```python + def test_iter_packets_auto_detects_stale_and_repopulates(self, double_pf): + import time + + db = InMemoryArrowDatabase() + input_stream = make_int_stream(n=3) + node = FunctionNode( + function_pod=FunctionPod(packet_function=double_pf), + input_stream=input_stream, + pipeline_database=db, + ) + node.run() # populate DB + first = list(node.iter_packets()) # hot-loads or serves from cache + + time.sleep(0.01) + input_stream._update_modified_time() + assert node.is_stale + + second = list(node.iter_packets()) # detects stale, clears, hot-loads from DB + assert len(second) == len(first) + assert [p["result"] for _, p in second] == [p["result"] for _, p in first] +``` + +- [ ] **Step 10: Fix `TestFunctionNodeStaleness::test_as_table_auto_detects_stale_and_repopulates`** + +Add `node.run()` before first `as_table()`. Read the full test first (around line 444) and add `node.run()` before `table_before = node.as_table()`. + +- [ ] **Step 11: Run the fixed file** + +```bash +python -m pytest tests/test_core/function_pod/test_function_pod_node_stream.py -v 2>&1 | tail -30 +``` + +Expected: All tests pass. + +- [ ] **Step 12: Commit** + +```bash +git add tests/test_core/function_pod/test_function_pod_node_stream.py +git commit -m "test(ENG-379): update test_function_pod_node_stream.py — add run() before compute-dependent assertions" +``` + +--- + +## Task 13: Fix remaining test files + +**Files:** +- Modify: `tests/test_core/function_pod/test_function_pod_node.py` +- Modify: `tests/test_core/function_pod/test_function_node_attach_db.py` +- Modify: `tests/test_data/test_polars_nullability/test_function_node_nullability.py` + +- [ ] **Step 1: Fix `test_function_pod_node.py::TestFunctionNodeStreamInterface`** + +Two tests use `iter_packets()` without prior `run()`: + +```python +class TestFunctionNodeStreamInterface: + @pytest.fixture + def node(self, double_pf) -> FunctionNode: + db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=FunctionPod(packet_function=double_pf), + input_stream=make_int_stream(n=3), + pipeline_database=db, + ) + node.run() # ← add this + return node + + def test_iter_packets_correct_values(self, node): + assert [packet["result"] for _, packet in node.iter_packets()] == [0, 2, 4] + + def test_node_is_stream_protocol(self, node): + assert isinstance(node, StreamProtocol) + + def test_dunder_iter_delegates_to_iter_packets(self, node): + assert len(list(node)) == len(list(node.iter_packets())) + + def test_run_fills_database(self, node): + # node.run() already called in fixture; just verify DB + records = node.get_all_records() + assert records is not None + assert records.num_rows == 3 +``` + +- [ ] **Step 2: Fix `test_function_node_attach_db.py` — four tests** + +**`test_iter_packets_without_database`** (line 46): Add `node.run()`: +```python + def test_iter_packets_without_database(self): + node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=3)) + node.run() + results = list(node.iter_packets()) + assert len(results) == 3 + assert results[0][1]["result"] == 0 +``` + +**`test_attach_databases_clears_caches`** (line 77): Use `node.run()` to populate cache: +```python + def test_attach_databases_clears_caches(self): + node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) + node.run() # populate _cached_output_packets + assert len(node._cached_output_packets) > 0 + db = InMemoryArrowDatabase() + node.attach_databases(pipeline_database=db, result_database=db) + assert len(node._cached_output_packets) == 0 +``` + +**`test_iter_packets_after_attach_works`** (line 106): Add `node.run()`: +```python + def test_iter_packets_after_attach_works(self): + node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=2)) + db = InMemoryArrowDatabase() + node.attach_databases(pipeline_database=db, result_database=db) + node.run() + results = list(node.iter_packets()) + assert len(results) == 2 +``` + +**`test_iter_packets_with_database`** (line 135): Add `node.run()`: +```python + def test_iter_packets_with_database(self): + db = InMemoryArrowDatabase() + node = FunctionNode( + function_pod=_make_pod(), + input_stream=_make_stream(n=3), + pipeline_database=db, + result_database=db, + ) + node.run() + results = list(node.iter_packets()) + assert len(results) == 3 +``` + +- [ ] **Step 3: Fix `test_function_node_nullability.py::TestFunctionNodeIterPacketsNullability`** + +The test calls `fn_node._iter_all_from_database()` which is removed. Replace with `_load_cached_entries()`: + +```python + def test_iter_packets_from_database_preserves_non_nullable_output(self): + """Packets loaded from DB via _load_cached_entries carry non-nullable output schema.""" + database = InMemoryArrowDatabase() + source = op.sources.DictSource( + [{"id": 1, "x": 7}], + tag_columns=["id"], + ) + + @op.function_pod(output_keys=["result"]) + def add_one(x: int) -> int: + return x + 1 + + pipeline = op.Pipeline("test_iter_packets_nullable", database) + with pipeline: + add_one.pod(source) + + pipeline.run() + + fn_nodes = _get_function_nodes(pipeline) + fn_node = fn_nodes[0] + + # Load from DB via the new helper + loaded = fn_node._load_cached_entries() + packets_seen = list(loaded.values()) + assert len(packets_seen) == 1, "Expected one packet from the database" + + _tag, packet = packets_seen[0] + packet_schema = packet.arrow_schema() + + result_field = packet_schema.field("result") + assert result_field.nullable is False, ( + f"Packet 'result' field should be non-nullable (int return type), " + f"but got nullable={result_field.nullable}. " + "Arrow→Polars→Arrow round-trip in _load_cached_entries dropped nullability." + ) +``` + +Update the docstring at the class level to reflect the method name change. + +- [ ] **Step 4: Run all three fixed test files** + +```bash +python -m pytest \ + tests/test_core/function_pod/test_function_pod_node.py \ + tests/test_core/function_pod/test_function_node_attach_db.py \ + tests/test_data/test_polars_nullability/test_function_node_nullability.py \ + -v 2>&1 | tail -30 +``` + +Expected: All pass. + +- [ ] **Step 5: Commit** + +```bash +git add \ + tests/test_core/function_pod/test_function_pod_node.py \ + tests/test_core/function_pod/test_function_node_attach_db.py \ + tests/test_data/test_polars_nullability/test_function_node_nullability.py +git commit -m "test(ENG-379): update remaining test files for read-only iter_packets semantics" +``` + +--- + +## Task 14: Fix `as_table()` for empty result set + +**Context:** After the `iter_packets()` refactor, `as_table()` on a fresh node with no prior `run()` and an empty DB will enter the `for tag, packet in self.iter_packets():` loop and exit immediately (zero iterations). `self._cached_output_table` is never set, so the `assert self._cached_output_table is not None` line (line ~1389) raises `AssertionError`. New test 7 (`test_as_table_fresh_node_returns_empty_no_compute`) requires `as_table()` to return a valid empty `pa.Table`, not raise. + +**Files:** +- Modify: `src/orcapod/core/nodes/function_node.py` — `as_table()` method (~line 1356) + +- [ ] **Step 1: Patch `as_table()` to handle the empty case** + +Locate the block inside `as_table()` that starts with: +```python + if self._cached_output_table is None: + all_tags = [] + all_packets = [] + tag_schema, packet_schema = None, None + for tag, packet in self.iter_packets(): + ... + + converter = self.data_context.type_converter + ... + self._cached_output_table = arrow_utils.hstack_tables( + all_tags_as_tables, all_packets_as_tables + ) + assert self._cached_output_table is not None, ( + "_cached_output_table should not be None here." + ) +``` + +After the `for` loop but before building `_cached_output_table`, add a short-circuit for the empty case: + +```python + if not all_tags: + # No packets — return an empty table with the node's output schema. + # Build column-typed empty arrays from the output schema so the + # table is usable even with 0 rows. + tag_schema_out, packet_schema_out = self.output_schema() + empty_tag = pa.table( + {k: pa.array([], type=pa.from_numpy_dtype(object)) + for k in tag_schema_out} + ) if tag_schema_out else pa.table({}) + empty_pkt = pa.table( + {k: pa.array([], type=pa.from_numpy_dtype(object)) + for k in packet_schema_out} + ) if packet_schema_out else pa.table({}) + self._cached_output_table = arrow_utils.hstack_tables( + empty_tag, empty_pkt + ) +``` + +Then convert the `assert` to a conditional: +```python + if self._cached_output_table is None: + # This should not happen in practice, but guard defensively. + self._cached_output_table = pa.table({}) +``` + +> **Note:** The goal is simply that `as_table()` returns a valid `pa.Table` with 0 rows rather than raising. The exact column schema of an empty table is less critical — tests only assert `len(table) == 0` and `isinstance(table, pa.Table)`. A minimal empty table (`pa.table({})`) satisfies both. Use the simplest approach that works. + +**Simplest acceptable implementation:** +```python + if not all_tags: + self._cached_output_table = pa.table({}) +``` + +Replace the `assert` with: +```python + if self._cached_output_table is None: + self._cached_output_table = pa.table({}) +``` + +- [ ] **Step 2: Run test 7 to verify it passes** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_iteration.py::TestIterPacketsReadOnly::test_as_table_fresh_node_returns_empty_no_compute -v 2>&1 +``` + +Expected: PASS. + +- [ ] **Step 3: Verify existing `as_table()` tests still pass** + +```bash +python -m pytest tests/test_core/function_pod/test_function_pod_node_stream.py -k "as_table" -v 2>&1 +``` + +Expected: All pass (tests that required computed results have `run()` added in Task 12). + +- [ ] **Step 4: Commit** + +```bash +git add src/orcapod/core/nodes/function_node.py +git commit -m "fix(ENG-379): as_table() returns empty pa.Table when iter_packets() yields nothing" +``` + +--- + +## Task 15: Final full test suite run + +- [ ] **Step 1: Run the full test suite** + +```bash +python -m pytest tests/ -v -q 2>&1 | tail -40 +``` + +Expected: All tests pass. If any failures remain: + +- `test_function_pod_node_stream.py` failures → check if `node.run()` was added in Task 12 +- `as_table()` `AssertionError` → check Task 14 fix +- `test_regression_fixes.py::TestConcurrentFallbackInRunningLoop` → tests `FunctionPodStream` (not `FunctionNode`) — should pass untouched +- Import errors for `_iter_all_from_database` or `_load_all_cached_records` → check Task 11 deletion and Task 13 test fixes +- `_cached_output_packets` key type errors → ensure all accesses use str keys; any code reading `self._cached_output_packets[i]` with an int `i` must be updated + +- [ ] **Step 2: Run only the new iteration tests to confirm all 10 pass** + +```bash +python -m pytest tests/test_core/nodes/test_function_node_iteration.py -v 2>&1 +``` + +Expected: 10/10 PASS. + +- [ ] **Step 3: Final commit** + +```bash +git add tests/test_core/nodes/test_function_node_iteration.py +git commit -m "test(ENG-379): add test_function_node_iteration.py — 10 tests for read-only iteration semantics + +Co-Authored-By: Claude Sonnet 4.6 " +``` + +--- + +## Out of Scope + +- **ENG-381**: Concurrency control flag/argument for `execute()` — tracked separately +- `get_all_records()` DB join — different return type / responsibility; not refactored here +- `OperatorNode`, `SourceNode` — not touched +- `FunctionPodStream` in `function_pod.py` — has its own `iter_packets()` with concurrent path; not changed diff --git a/docs/superpowers/specs/2026-04-07-function-node-refactor-design.md b/docs/superpowers/specs/2026-04-07-function-node-refactor-design.md new file mode 100644 index 00000000..3dce7e00 --- /dev/null +++ b/docs/superpowers/specs/2026-04-07-function-node-refactor-design.md @@ -0,0 +1,315 @@ +# FunctionNode Refactor Design + +**Issue:** ENG-379 +**Date:** 2026-04-07 +**Status:** Approved for implementation + +--- + +## Problem Statement + +`FunctionNode` has two intertwined problems: + +1. **Redundant logic** — the DB join (taginfo + results → join on `PACKET_RECORD_ID` → filter by content hash → reconstruct stream) is duplicated in at least four places: `get_cached_results()`, `iter_packets()` Phase 1, `async_execute()` Phase 1, and `_load_all_cached_records()`. + +2. **Incorrect iteration semantics** — `iter_packets()` currently triggers computation for any input packet not yet in the DB (Phase 2). Iteration should be a strictly read-only operation that yields only already-computed results. + +--- + +## Goals + +- `iter_packets()` is strictly read-only: never triggers computation +- Computation is only triggered by an explicit `run()` call (or the orchestrator calling `execute()` / `async_execute()`) +- If no results exist and no DB records are present, iteration yields nothing +- All duplicated DB join logic is consolidated into a single internal helper +- Every internal method has a clear, non-overlapping purpose +- Tests verify that iteration alone does not cause computation side effects + +--- + +## Design + +### 1. In-Memory Result Store + +`_cached_output_packets` changes key type from `dict[int, tuple]` to `dict[str, tuple[TagProtocol, PacketProtocol]]` keyed by **entry_id** (the output of `compute_pipeline_entry_id(tag, packet)`). + +**Properties:** +- **Add-only between `clear_cache()` calls**: entries are inserted and existing entries are never removed until `clear_cache()` is called. When `_load_cached_entries()` returns entries for `entry_id`s already in the store, `dict.update()` will technically overwrite them — but this is safe and intentional, because any in-memory result and its DB-serialised counterpart for the same `entry_id` are always semantically equivalent. No pre-update existence check is needed. +- **Cleared only** by `clear_cache()` (triggered on staleness detection) or `attach_databases()` +- O(1) membership test: `entry_id in self._cached_output_packets` +- Preserves insertion order (Python 3.7+ dict), so iteration order is stable between calls in the same session + +The derived caches `_cached_output_table` and `_cached_content_hash_column` are invalidated (set to `None`) whenever `_cached_output_packets` is modified. + +### 2. New Internal Helper: `_load_cached_entries()` + +Single replacement for all four duplicated DB join sites: + +```python +def _load_cached_entries( + self, + entry_ids: list[str] | None = None, +) -> dict[str, tuple[TagProtocol, PacketProtocol]]: + """Load entries from the pipeline DB + result DB. + + entry_ids=None — loads all records for this node + entry_ids=[...] — loads only those specific entry IDs + + Returns a dict of entry_id → (tag, packet). + Does NOT mutate _cached_output_packets; the caller merges via update(). + Returns {} if either _cached_function_pod or _pipeline_database is None. + """ +``` + +All callers merge results: `self._cached_output_packets.update(loaded)` — **no `clear()` before `update()`**. + +**Tag key derivation inside `_load_cached_entries()`:** + +Reconstructing `(tag, packet)` pairs from the joined Arrow table requires knowing which columns are tag columns. Two modes: + +- **When `_input_stream` is available** (FULL / READ_ONLY load status): use `self._input_stream.keys()[0]` — the authoritative list of user-facing tag column names. +- **When `_input_stream` is not used** (CACHE_ONLY / deserialized node): derive from `taginfo.column_names` by excluding all known internal prefixes and sentinels: `META_PREFIX`, `SOURCE_PREFIX`, `SYSTEM_TAG_PREFIX`, `PIPELINE_ENTRY_ID_COL` (the local sentinel `"__pipeline_entry_id"`), and `NODE_CONTENT_HASH_COL`. This mirrors the existing logic in `_load_all_cached_records()`. + +**Columns dropped from the joined table before reconstructing packets:** + +Drop: columns starting with `META_PREFIX`, the `PIPELINE_ENTRY_ID_COL` sentinel, and `NODE_CONTENT_HASH_COL`. Source columns (`SOURCE_PREFIX`) are **not** dropped — `ArrowTableStream` requires them for packet identity. This matches the drop lists in the existing `get_cached_results()` (line 764) and `iter_packets()` Phase 1 (line 1233). + +**Filtering to `entry_ids`:** when `entry_ids` is provided, apply a polars filter on `PIPELINE_ENTRY_ID_COL` in the joined table before reconstruction. + +This encapsulates: +- Fetching taginfo from `_pipeline_database` with `PIPELINE_ENTRY_ID_COL` +- Fetching results from `_cached_function_pod._result_database` +- Joining on `PACKET_RECORD_ID` via polars +- Applying `_filter_by_content_hash()` +- Restoring schema nullability +- Deriving tag keys using the strategy above +- Dropping internal columns (per list above) from the joined table +- Reconstructing `(tag, packet)` pairs via `ArrowTableStream` +- Filtering to `entry_ids` when provided + +**`get_all_records()`** also contains an inline DB join but returns a raw `pa.Table` with configurable column filtering. It is **out of scope** for this refactor — its join logic serves a different responsibility and will be addressed separately. + +### 3. `iter_packets()` — Strictly Read-Only + +``` +UNAVAILABLE → raise RuntimeError (no pod, no DB) + +CACHE_ONLY → if _cached_output_packets is empty: + loaded = _load_cached_entries() # load all from DB + _cached_output_packets.update(loaded) + yield from _cached_output_packets.values() + +FULL / READ_ONLY → + if is_stale: clear_cache() + if _cached_output_packets is empty AND DB is attached: + loaded = _load_cached_entries() # hot-load from DB + _cached_output_packets.update(loaded) + yield from _cached_output_packets.values() + # No computation. No input stream traversal. No _process_packet_internal(). +``` + +**Contract:** calling `iter_packets()` on a node where `run()` has never been called and the DB has no records yields nothing. This is correct and intentional. + +**Iteration order stability:** the order of `_cached_output_packets.values()` reflects insertion order. When loaded from DB via `_load_cached_entries()`, the order follows the Polars join output order. When populated by `_process_packet_internal()`, it follows the order packets were processed in `execute()`. Two successive calls to `iter_packets()` in the same session yield the same order because the second call serves from the already-populated in-memory store without re-querying. + +### 4. `run()` — User-Facing Computation Entry Point + +`run()` checks `load_status` before delegating, mirroring the guard already present in `iter_packets()`: + +```python +def run(self) -> None: + """Eagerly compute all input packets and persist results.""" + from orcapod.pipeline.serialization import LoadStatus + if self._load_status == LoadStatus.UNAVAILABLE: + raise RuntimeError( + f"FunctionNode {self.label!r} is unavailable: " + "no function pod and no database attached." + ) + if self._load_status == LoadStatus.CACHE_ONLY: + # Upstream is unavailable (CACHE_ONLY); computation requires a live + # input stream. Iteration will serve existing results from DB. + return + self.execute(self._input_stream) +``` + +### 5. `execute()` — Sync Computation (Orchestrator + `run()`) + +`execute()` is **always sequential** — there is no concurrent path. Any async/concurrent +execution is handled exclusively by `async_execute()` (used by the async orchestrator). + +``` +1. Observer: on_node_start (with tag_schema from input_stream) +2. Collect all (tag, packet, entry_id) from input_stream.iter_packets() +3. Hot-load missing entries from DB via get_cached_results() (side-effect only): + get_cached_results(entry_ids) # populates _cached_output_packets +4. Per-packet loop — fire on_packet_start before each packet: + for tag, pkt, eid in upstream: + on_packet_start(node_label, tag, pkt) + if eid in _cached_output_packets: # includes None-output entries + tag_out, result = _cached_output_packets[eid] + on_packet_end(..., cached=True) + if result is not None: output.append((tag_out, result)) + else: + try: + _process_packet_internal(tag, pkt) # stores result in _cached_output_packets + except Exception as exc: + log warning; on_packet_crash; re-raise if error_policy=="fail_fast" + else: + on_packet_end(..., cached=False) + if result is not None: output.append((tag_out, result)) +5. on_node_end +6. _update_modified_time() # prevents is_stale from clearing cache on next iter_packets() +7. return output +``` + +**Observer hooks:** `on_packet_start` fires **before** each packet's processing (including +cache hits). `on_packet_end` fires for successes; `on_packet_crash` fires for failures. +`cached=True` for entries already in `_cached_output_packets`; `cached=False` for freshly +computed ones. + +**Note on `_cached_output_packets` membership:** The cache-hit check uses +`_cached_output_packets` directly (not the filtered return value of `get_cached_results()`) +so that entries with `None` outputs (function returned None) are treated as cache hits and +are not recomputed on subsequent `execute()` calls. + +### 6. `async_execute()` — Async Computation (Async Orchestrator) + +Phase 1 (cache pre-load) is simplified to a single `_load_cached_entries()` call: + +```python +cached_by_entry_id: dict[str, tuple] = {} +if self._cached_function_pod is not None: + loaded = self._load_cached_entries() # load all existing records + self._cached_output_packets.update(loaded) + cached_by_entry_id = loaded +``` + +Phase 2 (per-packet: cached or compute) is unchanged in logic: uses `_async_execute_one_packet()` for cache misses. The `supports_concurrent_execution` flag is not consulted in `async_execute()` — the TaskGroup already provides per-packet concurrency. + +`_async_execute_cache_only()` is simplified: replace the `_load_all_cached_records()` call with `_load_cached_entries()`, drop the `(tag_keys, data_table)` unpacking, and iterate over the returned dict values directly. + +### 7. `_process_packet_internal()` — Simplified + +- `cache_index: int | None` parameter is **removed** +- Computes `entry_id` internally and stores by it: + ```python + entry_id = self.compute_pipeline_entry_id(tag, packet) + self._cached_output_packets[entry_id] = (tag_out, output_packet) + self._cached_output_table = None + self._cached_content_hash_column = None + ``` +- Lines clearing `_cached_input_iterator` and `_needs_iterator` are **removed** +- All other logic (routing through `_cached_function_pod` vs `_function_pod`, calling `add_pipeline_record()`) is unchanged + +`execute_packet()` delegates to `_process_packet_internal()` without `cache_index`. No change to `execute_packet()`'s signature or callers. After the call, `_cached_output_packets` contains the result keyed by entry_id. Tests asserting `len(node._cached_output_packets) == N` continue to hold. + +`_async_process_packet_internal()` is updated identically. + +### 8. `get_cached_results()` — Simplified + +Kept as a public method with its existing signature. Rewritten to delegate to `_load_cached_entries()` with add-only semantics: + +```python +def get_cached_results( + self, entry_ids: list[str] +) -> dict[str, tuple[TagProtocol, PacketProtocol]]: + if not entry_ids or self._cached_function_pod is None: + return {} + missing = [eid for eid in entry_ids if eid not in self._cached_output_packets] + if missing: + loaded = self._load_cached_entries(missing) + self._cached_output_packets.update(loaded) # add-only, no clear() + return { + eid: self._cached_output_packets[eid] + for eid in entry_ids + if eid in self._cached_output_packets + } +``` + +**Behavior change vs current implementation:** the existing implementation calls `self._cached_output_packets.clear()` before repopulating. The new version does not clear. This is intentional: existing in-memory results are preserved. The existing test `test_get_cached_results_populates_internal_cache` manually clears `_cached_output_packets` before calling `get_cached_results()`, so it continues to pass. + +--- + +## Method Inventory: Before → After + +| Method | Verdict | Notes | +|---|---|---| +| `_ensure_iterator()` | **REMOVE** | Only served the old compute-while-iterate model | +| `_iter_packets_sequential()` | **REMOVE** | Computation leaves `iter_packets()`; dissolves | +| `_iter_packets_concurrent()` | **REMOVE** | Concurrent logic moves inline into `execute()` | +| `_iter_all_from_database()` | **REMOVE** | `iter_packets()` calls `_load_cached_entries()` directly | +| `_load_all_cached_records()` | **REMOVE** | Absorbed by `_load_cached_entries()` | +| `_load_cached_entries()` | **ADD** | Single DB join helper replacing 4 duplicates | +| `_executor_supports_concurrent()` | **KEEP / RELOCATE** | Module-level helper; moves from serving `iter_packets()` to `execute()` | +| `clear_cache()` | **SIMPLIFY** | Remove `_cached_input_iterator` and `_needs_iterator` clearing | +| `_process_packet_internal()` | **SIMPLIFY** | Remove `cache_index`; store by entry_id; remove iterator field resets | +| `_async_process_packet_internal()` | **SIMPLIFY** | Same as above | +| `_async_execute_cache_only()` | **SIMPLIFY** | Use `_load_cached_entries()` instead of `_load_all_cached_records()` | +| `get_cached_results()` | **SIMPLIFY** | Delegate to `_load_cached_entries()`; drop `clear()` | +| `execute()` | **REFACTOR** | Selective reload + concurrent path; per-packet observer hooks | +| `async_execute()` | **REFACTOR** | Phase 1 via `_load_cached_entries()` | +| `run()` | **REWRITE** | Load-status guard + `execute(self._input_stream)` | +| `iter_packets()` | **REWRITE** | Strictly read-only | +| `get_all_records()` | **OUT OF SCOPE** | Contains duplicate join but returns `pa.Table`; different responsibility | +| `execute_packet()` | **KEEP** | Signature unchanged; inherits entry_id keying via `_process_packet_internal()` | +| `_async_execute_one_packet()` | **KEEP** | Clean async helper used by `async_execute()` | +| `_filter_by_content_hash()` | **KEEP** | Used inside `_load_cached_entries()` | +| `_require_pipeline_database()` | **KEEP** | Guard; still needed | +| `add_pipeline_record()` | **KEEP** | Core provenance responsibility; unchanged | +| `compute_pipeline_entry_id()` | **KEEP** | Now also called inside `_process_packet_internal()` | + +**State removed from `__init__`, `clear_cache()`, and `from_descriptor()`:** `_cached_input_iterator`, `_needs_iterator` + +--- + +## Testing + +### New test file: `tests/test_core/nodes/test_function_node_iteration.py` + +1. `iter_packets()` on a fresh node with no DB and no `run()` call yields nothing +2. `iter_packets()` on a node with DB records (prior session) hot-loads and yields without calling `_process_packet_internal()` (mock assertion) +3. `iter_packets()` does not call `_process_packet_internal()` under any non-compute path (mock assertion covers FULL, READ_ONLY, CACHE_ONLY modes) +4. After `run()`, `iter_packets()` yields from `_cached_output_packets` without an additional DB query +5. `iter_packets()` called twice in the same session returns the same results and the same order; DB is only queried on the first call +6. `_cached_output_packets` is keyed by entry_id strings (not ints) after `run()` +7. `as_table()` on a fresh node with no `run()` and empty DB returns an empty table (0 rows, valid schema) — no computation triggered +8. `run()` on a CACHE_ONLY node is a no-op (returns without error, no computation, no exception) +9. `run()` on an UNAVAILABLE node raises `RuntimeError` +10. `execute()` sequential path: `on_packet_crash` fires per failing packet when `error_policy="continue"`; exception propagates on first failure when `error_policy="fail_fast"` — test is named `test_execute_error_policy_continue_skips_failures` and uses `LocalExecutor` (non-concurrent) to exercise the sequential path + +### Existing tests requiring full rewrite + +The following tests **use `iter_packets()` as the sole computation trigger** and will break after the refactor. They must be updated to call `run()` or `execute()` first, then call `iter_packets()` (or inspect `_cached_output_packets`) to verify results: + +- `tests/test_core/function_pod/test_function_pod_node_stream.py`: + - All of `TestFunctionNodeStreamBasic` (`test_iter_packets_yields_correct_count`, `test_iter_packets_correct_values`, `test_iter_is_repeatable`, `test_dunder_iter_delegates_to_iter_packets`) + - `TestIterPacketsMissingEntriesOnly::test_partial_fill_total_row_count_correct` and `test_partial_fill_all_values_correct` — these rely on Phase 2 computing missing entries during iteration; after the refactor only the 2 pre-existing DB entries are returned, not all 4 + - `TestFunctionNodeStaleness::test_clear_cache_resets_output_packets` and `test_iter_packets_auto_detects_stale_and_repopulates` + +- `tests/test_core/function_pod/test_function_pod_node.py`: + - `TestFunctionNodeStreamInterface::test_iter_packets_correct_values` and any other test calling `iter_packets()` without a prior `run()` + +- `tests/test_core/function_pod/test_function_node_attach_db.py`: + - `test_attach_databases_clears_caches` (calls `list(node.iter_packets())` and asserts `len > 0`) + +- `tests/test_data/test_polars_nullability/test_function_node_nullability.py`: + - Any test calling `node._iter_all_from_database()` directly — this method is removed; replace with `node._load_cached_entries()` or test via `iter_packets()` after DB population + +- `tests/test_core/test_regression_fixes.py`: + - `TestConcurrentFallbackInRunningLoop` — tests the removed `_iter_packets_concurrent` behavior; covered by `test_execute_error_policy_continue_skips_failures` in the new test file (sequential `execute()` path, no concurrent gather) + +### Existing tests requiring minor updates (key type change only) + +- `test_function_node_get_cached.py`: verify `_cached_output_packets` key type is `str`; `test_get_cached_results_populates_internal_cache` still passes with add-only semantics +- `test_node_execute.py`: size assertions hold; verify keys are entry_id strings +- `test_function_pod_node_stream.py` (tests not listed above): size assertions hold; key type changes +- Any test referencing `_cached_input_iterator`, `_needs_iterator`, `_ensure_iterator`, `_iter_packets_sequential`, `_iter_packets_concurrent`, or `_load_all_cached_records` — update or remove + +--- + +## Out of Scope + +- Concurrent execution in `execute()` (always sequential; async path is `async_execute()`) → any future concurrent `execute()` tracked in **ENG-381** +- Refactoring `get_all_records()` to use `_load_cached_entries()` (different return type / responsibility) +- Refactoring other node types (`OperatorNode`, `SourceNode`) +- Lazy evaluation / deferred computation patterns diff --git a/src/orcapod/core/nodes/function_node.py b/src/orcapod/core/nodes/function_node.py index a360ad0e..650b1a79 100644 --- a/src/orcapod/core/nodes/function_node.py +++ b/src/orcapod/core/nodes/function_node.py @@ -117,13 +117,9 @@ def __init__( self._input_stream = input_stream - # stream-level caching state (iterator acquired lazily on first use) - self._cached_input_iterator: ( - Iterator[tuple[TagProtocol, PacketProtocol]] | None - ) = None - self._needs_iterator = True + # stream-level caching state self._cached_output_packets: dict[ - int, tuple[TagProtocol, PacketProtocol | None] + str, tuple[TagProtocol, PacketProtocol | None] ] = {} self._cached_output_table: pa.Table | None = None self._cached_content_hash_column: pa.Array | None = None @@ -355,8 +351,6 @@ def from_descriptor( node._packet_function = None node._input_stream = None node.tracker_manager = DEFAULT_TRACKER_MANAGER - node._cached_input_iterator = None - node._needs_iterator = True node._cached_output_packets = {} node._cached_output_table = None node._cached_content_hash_column = None @@ -533,16 +527,7 @@ def node_uri(self) -> tuple[str, ...]: # Caching # ------------------------------------------------------------------ - def _ensure_iterator(self) -> None: - """Lazily acquire the upstream iterator on first use.""" - if self._needs_iterator: - self._cached_input_iterator = self._input_stream.iter_packets() - self._needs_iterator = False - self._update_modified_time() - def clear_cache(self) -> None: - self._cached_input_iterator = None - self._needs_iterator = True self._cached_output_packets.clear() self._cached_output_table = None self._cached_content_hash_column = None @@ -583,15 +568,19 @@ def execute( ) -> list[tuple[TagProtocol, PacketProtocol]]: """Execute all packets from a stream: compute, persist, and cache. + For each packet: fire ``on_packet_start``, check the in-memory cache + (populated from DB if needed), compute if missing, fire + ``on_packet_end`` or ``on_packet_crash``. + Args: input_stream: The input stream to process. observer: Optional execution observer for hooks. - error_policy: ``"continue"`` (default) skips failed packets; + error_policy: ``"continue"`` skips failed packets; ``"fail_fast"`` re-raises on the first failure. Returns: Materialized list of (tag, output_packet) pairs, excluding - None outputs and failed packets. + ``None`` outputs and failed packets. """ from orcapod.pipeline.observer import NoOpObserver @@ -604,35 +593,41 @@ def execute( tag_schema = input_stream.output_schema(columns={"system_tags": True})[0] ctx_obs.on_node_start(node_label, node_hash, tag_schema=tag_schema) - # Gather entry IDs and check cache - upstream_entries = [ + # Collect upstream entries and resolve entry_ids + upstream_entries: list[tuple[TagProtocol, PacketProtocol, str]] = [ (tag, packet, self.compute_pipeline_entry_id(tag, packet)) for tag, packet in input_stream.iter_packets() ] entry_ids = [eid for _, _, eid in upstream_entries] - cached = self.get_cached_results(entry_ids=entry_ids) + + # Hot-load any already-computed results from DB into _cached_output_packets. + # get_cached_results() is called for its side effect (populating the + # in-memory cache); the returned dict is intentionally discarded here so + # that the per-packet cache-hit check below uses _cached_output_packets + # directly — which includes None-output entries (function returned None) + # and prevents spurious recomputation of already-processed packets. + self.get_cached_results(entry_ids=entry_ids) output: list[tuple[TagProtocol, PacketProtocol]] = [] for tag, packet, entry_id in upstream_entries: ctx_obs.on_packet_start(node_label, tag, packet) - if entry_id in cached: - tag_out, result = cached[entry_id] - ctx_obs.on_packet_end( - node_label, tag, packet, result, cached=True - ) - output.append((tag_out, result)) + if entry_id in self._cached_output_packets: + tag_out, result = self._cached_output_packets[entry_id] + ctx_obs.on_packet_end(node_label, tag, packet, result, cached=True) + if result is not None: + output.append((tag_out, result)) else: - pkt_logger = ctx_obs.create_packet_logger( - tag, packet - ) + pkt_logger = ctx_obs.create_packet_logger(tag, packet) try: tag_out, result = self._process_packet_internal( tag, packet, logger=pkt_logger ) except Exception as exc: logger.warning( - "Packet execution failed in %s: %s", node_label, exc, + "Packet execution failed in %s: %s", + node_label, + exc, exc_info=True, ) ctx_obs.on_packet_crash(node_label, tag, packet, exc) @@ -647,31 +642,26 @@ def execute( output.append((tag_out, result)) ctx_obs.on_node_end(node_label, node_hash) + # Mark this node as freshly computed so subsequent iter_packets() calls + # skip the is_stale check and serve results directly from the in-memory cache. + self._update_modified_time() return output def _process_packet_internal( self, tag: TagProtocol, packet: PacketProtocol, - cache_index: int | None = None, *, logger: PacketExecutionLoggerProtocol | None = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Core compute + persist + cache. - Used by ``execute_packet``, ``execute``, and ``iter_packets``. - No input validation is performed — the caller guarantees correctness. - Exceptions propagate to the caller. + Used by ``execute_packet`` and ``execute``. + Stores result in ``_cached_output_packets`` keyed by entry_id. + Exceptions propagate to the caller — no error handling here. Returns: A ``(tag, output_packet)`` 2-tuple. - - Args: - tag: The input tag. - packet: The input packet. - cache_index: Optional explicit index for the internal cache. - When ``None``, auto-assigns at ``len(_cached_output_packets)``. - logger: Optional packet execution logger. """ if self._cached_function_pod is not None: tag_out, output_packet = self._cached_function_pod.process_packet( @@ -695,13 +685,9 @@ def _process_packet_internal( tag, packet, logger=logger ) - # Cache internally and invalidate derived caches - idx = ( - cache_index if cache_index is not None else len(self._cached_output_packets) - ) - self._cached_output_packets[idx] = (tag_out, output_packet) - self._cached_input_iterator = None - self._needs_iterator = False + # Store by entry_id and invalidate derived caches + entry_id = self.compute_pipeline_entry_id(tag, packet) + self._cached_output_packets[entry_id] = (tag_out, output_packet) self._cached_output_table = None self._cached_content_hash_column = None @@ -712,106 +698,50 @@ def get_cached_results( ) -> dict[str, tuple[TagProtocol, PacketProtocol]]: """Retrieve cached results for specific pipeline entry IDs. - Looks up the pipeline DB and result DB, joins them, and filters - to the requested entry IDs. Returns a mapping from entry ID to - (tag, output_packet). + Checks in-memory cache first. Loads only truly missing entries from DB. + Add-only semantics: existing in-memory entries are never cleared or + overwritten (overwrite is safe since in-memory and DB entries for the + same entry_id are always semantically equivalent). Args: entry_ids: Pipeline entry IDs to look up. Returns: - Mapping from entry_id to (tag, output_packet) for found entries. + Mapping from entry_id to ``(tag, output_packet)`` for found entries. Empty dict if no DB is attached or no matches found. """ if self._cached_function_pod is None or not entry_ids: return {} - self._require_pipeline_database() - - PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" - - taginfo = self._pipeline_database.get_all_records( - self.node_identity_path, - record_id_column=PIPELINE_ENTRY_ID_COL, - ) - results = self._cached_function_pod._result_database.get_all_records( - self._cached_function_pod.record_path, - record_id_column=constants.PACKET_RECORD_ID, - ) - - if taginfo is None or results is None: - return {} - - taginfo = self._filter_by_content_hash(taginfo) - taginfo_schema = taginfo.schema - results_schema = results.schema - filtered = ( - pl.DataFrame(taginfo) - .join( - pl.DataFrame(results), - on=constants.PACKET_RECORD_ID, - how="inner", - ) - .filter(pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids)) - .to_arrow() - ) - filtered = arrow_utils.restore_schema_nullability(filtered, taginfo_schema, results_schema) - - if filtered.num_rows == 0: - return {} - - tag_keys = self._input_stream.keys()[0] - drop_cols = [ - c - for c in filtered.column_names - if c.startswith(constants.META_PREFIX) - or c == PIPELINE_ENTRY_ID_COL - or c == constants.NODE_CONTENT_HASH_COL - ] - data_table = filtered.drop([c for c in drop_cols if c in filtered.column_names]) - - stream = ArrowTableStream(data_table, tag_columns=tag_keys) - filtered_entry_ids = filtered.column(PIPELINE_ENTRY_ID_COL).to_pylist() - - result_dict: dict[str, tuple[TagProtocol, PacketProtocol]] = {} - for entry_id, (tag, packet) in zip(filtered_entry_ids, stream.iter_packets()): - result_dict[entry_id] = (tag, packet) - - # Populate internal cache with retrieved results (clear first to - # avoid duplicates on repeated orchestrator runs) - self._cached_output_packets.clear() - self._cached_output_table = None - self._cached_content_hash_column = None - for entry_id, (tag, packet) in result_dict.items(): - next_idx = len(self._cached_output_packets) - self._cached_output_packets[next_idx] = (tag, packet) - self._cached_input_iterator = None - self._needs_iterator = False - - return result_dict + missing = [eid for eid in entry_ids if eid not in self._cached_output_packets] + if missing: + loaded = self._load_cached_entries(missing) + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + return { + eid: self._cached_output_packets[eid] + for eid in entry_ids + if eid in self._cached_output_packets + and self._cached_output_packets[eid][1] is not None + } async def _async_process_packet_internal( self, tag: TagProtocol, packet: PacketProtocol, - cache_index: int | None = None, *, logger: PacketExecutionLoggerProtocol | None = None, ) -> tuple[TagProtocol, PacketProtocol | None]: """Async counterpart of ``_process_packet_internal``. - Computes via async path, writes pipeline provenance, and caches - internally — no schema validation. Exceptions propagate. + Computes via async path, writes pipeline provenance, caches by entry_id. + Exceptions propagate. Returns: A ``(tag, output_packet)`` 2-tuple. - - Args: - tag: The input tag. - packet: The input packet. - cache_index: Optional explicit index for the internal cache. - When ``None``, auto-assigns at ``len(_cached_output_packets)``. - logger: Optional packet execution logger. """ if self._cached_function_pod is not None: tag_out, output_packet = ( @@ -839,13 +769,9 @@ async def _async_process_packet_internal( ) ) - # Cache internally and invalidate derived caches - idx = ( - cache_index if cache_index is not None else len(self._cached_output_packets) - ) - self._cached_output_packets[idx] = (tag_out, output_packet) - self._cached_input_iterator = None - self._needs_iterator = False + # Store by entry_id and invalidate derived caches + entry_id = self.compute_pipeline_entry_id(tag, packet) + self._cached_output_packets[entry_id] = (tag_out, output_packet) self._cached_output_table = None self._cached_content_hash_column = None @@ -1048,18 +974,25 @@ def as_source(self): # Cache-only helpers (PLT-1156) # ------------------------------------------------------------------ - def _load_all_cached_records( + def _load_cached_entries( self, - ) -> "tuple[tuple[str, ...], Any] | None": - """Join pipeline DB and result DB; return (tag_keys, data_table). + entry_ids: list[str] | None = None, + ) -> "dict[str, tuple[TagProtocol, PacketProtocol]]": + """Load (tag, packet) pairs from pipeline DB + result DB. - Returns ``None`` when either database is empty or unavailable. - Does not access ``_input_stream``. - """ - import polars as pl + Args: + entry_ids: If provided, load only these specific entry IDs. + If ``None``, load all records for this node. + Returns: + dict mapping entry_id → (tag, packet). Empty dict when either + database is None, records are empty, or no rows match. + + Does NOT mutate ``_cached_output_packets``. + Callers merge via ``self._cached_output_packets.update(loaded)``. + """ if self._cached_function_pod is None or self._pipeline_database is None: - return None + return {} PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" @@ -1073,39 +1006,46 @@ def _load_all_cached_records( ) if taginfo is None or results is None: - return None + return {} taginfo = self._filter_by_content_hash(taginfo) taginfo_schema = taginfo.schema results_schema = results.schema - joined = ( - pl.DataFrame(taginfo) - .join( - pl.DataFrame(results), - on=constants.PACKET_RECORD_ID, - how="inner", + + joined_df = pl.DataFrame(taginfo).join( + pl.DataFrame(results), + on=constants.PACKET_RECORD_ID, + how="inner", + ) + if entry_ids is not None: + joined_df = joined_df.filter( + pl.col(PIPELINE_ENTRY_ID_COL).is_in(entry_ids) ) - .to_arrow() + joined = joined_df.to_arrow() + joined = arrow_utils.restore_schema_nullability( + joined, taginfo_schema, results_schema ) - joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) if joined.num_rows == 0: - return None + return {} - # Tag keys are the user-facing tag columns from the pipeline DB table. - # Exclude: meta columns (__*), source columns (_source_*), - # system-tag columns (e.g. __tag_*), the entry-ID column, and the - # node content hash column. - tag_keys = tuple( - c - for c in taginfo.column_names - if not c.startswith(constants.META_PREFIX) - and not c.startswith(constants.SOURCE_PREFIX) - and not c.startswith(constants.SYSTEM_TAG_PREFIX) - and c != PIPELINE_ENTRY_ID_COL - and c != constants.NODE_CONTENT_HASH_COL - ) + # Derive tag keys: prefer input_stream when available; fall back to + # taginfo column exclusion for CACHE_ONLY / deserialized nodes. + if self._input_stream is not None: + tag_keys = self._input_stream.keys()[0] + else: + tag_keys = tuple( + c + for c in taginfo.column_names + if not c.startswith(constants.META_PREFIX) + and not c.startswith(constants.SOURCE_PREFIX) + and not c.startswith(constants.SYSTEM_TAG_PREFIX) + and c != PIPELINE_ENTRY_ID_COL + and c != constants.NODE_CONTENT_HASH_COL + ) + # Drop internal columns (SOURCE_PREFIX is kept — ArrowTableStream needs it) + entry_ids_col = joined.column(PIPELINE_ENTRY_ID_COL).to_pylist() drop_cols = [ c for c in joined.column_names @@ -1114,24 +1054,12 @@ def _load_all_cached_records( or c == constants.NODE_CONTENT_HASH_COL ] data_table = joined.drop([c for c in drop_cols if c in joined.column_names]) - return tag_keys, data_table - - def _iter_all_from_database( - self, - ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: - """Yield all cached (tag, packet) pairs from the DB. - - Used in ``CACHE_ONLY`` mode when the upstream is unavailable. - Does not access ``_input_stream``. - """ - result = self._load_all_cached_records() - if result is None: - return - tag_keys, data_table = result stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for i, (tag, packet) in enumerate(stream.iter_packets()): - self._cached_output_packets[i] = (tag, packet) - yield tag, packet + + loaded: dict[str, tuple[TagProtocol, PacketProtocol]] = {} + for eid, (tag, packet) in zip(entry_ids_col, stream.iter_packets()): + loaded[eid] = (tag, packet) + return loaded async def _async_execute_cache_only( self, @@ -1153,11 +1081,14 @@ async def _async_execute_cache_only( ctx_obs.on_node_start(node_label, node_hash, tag_schema=None) try: - result = self._load_all_cached_records() - if result is not None: - tag_keys, data_table = result - stream = ArrowTableStream(data_table, tag_columns=tag_keys) - for tag, packet in stream.iter_packets(): + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + for tag, packet in self._cached_output_packets.values(): + if packet is not None: ctx_obs.on_packet_start(node_label, tag, packet) ctx_obs.on_packet_end(node_label, tag, packet, packet, cached=True) await output.send((tag, packet)) @@ -1170,184 +1101,83 @@ async def _async_execute_cache_only( # ------------------------------------------------------------------ def iter_packets(self) -> Iterator[tuple[TagProtocol, PacketProtocol]]: + """Yield all computed (tag, packet) pairs for this node. + + Strictly read-only — never triggers computation. Callers must call + ``run()`` or ``execute()`` first if they want results computed. + + On the first call with an empty in-memory store and a DB attached, + hot-loads all existing records from the DB (one-shot, no recompute). + + Raises: + RuntimeError: If ``load_status`` is UNAVAILABLE. + """ from orcapod.pipeline.serialization import LoadStatus status = self.load_status - if status == LoadStatus.CACHE_ONLY: - yield from self._iter_all_from_database() - return - if status == LoadStatus.UNAVAILABLE: raise RuntimeError( f"FunctionNode {self.label!r} is unavailable: " "no function pod and no database attached." ) + if status == LoadStatus.CACHE_ONLY: + # Upstream unavailable; serve entirely from DB. + if not self._cached_output_packets: + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + yield from ( + (tag, pkt) + for tag, pkt in self._cached_output_packets.values() + if pkt is not None + ) + return + + # FULL / READ_ONLY — in-memory store may be populated from computation + # (via execute/run) or hot-loaded from DB. if self.is_stale: self.clear_cache() - self._ensure_iterator() - - if self._cached_function_pod is not None: - # Two-phase iteration with DB backing - if self._cached_input_iterator is not None: - input_iter = self._cached_input_iterator - # --- Phase 1: yield already-computed results from the databases --- - # Retrieve pipeline records with their entry_ids (record IDs) - # and join with result records to reconstruct (tag, output_packet). - PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" - existing_entry_ids: set[str] = set() - - taginfo = self._pipeline_database.get_all_records( - self.node_identity_path, - record_id_column=PIPELINE_ENTRY_ID_COL, - ) - results = self._cached_function_pod._result_database.get_all_records( - self._cached_function_pod.record_path, - record_id_column=constants.PACKET_RECORD_ID, - ) - if taginfo is not None and results is not None: - taginfo = self._filter_by_content_hash(taginfo) - taginfo_schema = taginfo.schema - results_schema = results.schema - joined = ( - pl.DataFrame(taginfo) - .join( - pl.DataFrame(results), - on=constants.PACKET_RECORD_ID, - how="inner", - ) - .to_arrow() - ) - joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) - if joined.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - # Collect pipeline entry_ids for Phase 2 skip check - existing_entry_ids = set( - cast( - list[str], - joined.column(PIPELINE_ENTRY_ID_COL).to_pylist(), - ) - ) - # Drop internal columns before yielding as stream - drop_cols = [ - c - for c in joined.column_names - if c.startswith(constants.META_PREFIX) - or c == PIPELINE_ENTRY_ID_COL - or c == constants.NODE_CONTENT_HASH_COL - ] - data_table = joined.drop( - [c for c in drop_cols if c in joined.column_names] - ) - existing_stream = ArrowTableStream( - data_table, tag_columns=tag_keys - ) - for i, (tag, packet) in enumerate( - existing_stream.iter_packets() - ): - self._cached_output_packets[i] = (tag, packet) - yield tag, packet - - # --- Phase 2: process only missing input packets --- - # Skip inputs whose pipeline entry_id (tag+system_tags+packet_hash) - # already exists in the pipeline database. - for tag, packet in input_iter: - entry_id = self.compute_pipeline_entry_id(tag, packet) - if entry_id in existing_entry_ids: - continue - tag, output_packet = self._process_packet_internal(tag, packet) - if output_packet is not None: - yield tag, output_packet - - self._cached_input_iterator = None - else: - # Yield from snapshot of complete cache - for i in range(len(self._cached_output_packets)): - tag, packet = self._cached_output_packets[i] - if packet is not None: - yield tag, packet - else: - # Simple iteration without DB - if self._cached_input_iterator is not None: - if _executor_supports_concurrent(self._packet_function): - yield from self._iter_packets_concurrent( - self._cached_input_iterator - ) - else: - yield from self._iter_packets_sequential( - self._cached_input_iterator - ) - else: - for i in range(len(self._cached_output_packets)): - tag, packet = self._cached_output_packets[i] - if packet is not None: - yield tag, packet - - def _iter_packets_sequential( - self, input_iter: Iterator[tuple[TagProtocol, PacketProtocol]] - ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: - for i, (tag, packet) in enumerate(input_iter): - if i in self._cached_output_packets: - tag, packet = self._cached_output_packets[i] - if packet is not None: - yield tag, packet - else: - tag, output_packet = self._process_packet_internal(tag, packet) - if output_packet is not None: - yield tag, output_packet - self._cached_input_iterator = None - - def _iter_packets_concurrent( - self, - input_iter: Iterator[tuple[TagProtocol, PacketProtocol]], - ) -> Iterator[tuple[TagProtocol, PacketProtocol]]: - """Collect remaining inputs, execute concurrently, and yield results in order.""" - - all_inputs: list[tuple[int, TagProtocol, PacketProtocol]] = [] - to_compute: list[tuple[int, TagProtocol, PacketProtocol]] = [] - for i, (tag, packet) in enumerate(input_iter): - all_inputs.append((i, tag, packet)) - if i not in self._cached_output_packets: - to_compute.append((i, tag, packet)) - self._cached_input_iterator = None - - if to_compute: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop is not None: - # Already in event loop — fall back to sequential sync - for i, tag, pkt in to_compute: - self._process_packet_internal(tag, pkt, cache_index=i) - else: - - async def _gather() -> list[tuple[TagProtocol, PacketProtocol | None]]: - return list( - await asyncio.gather( - *[ - self._async_process_packet_internal( - tag, pkt, cache_index=i - ) - for i, tag, pkt in to_compute - ] - ) - ) + if not self._cached_output_packets and self._cached_function_pod is not None: + # Hot-load from DB on the first call when store is empty. + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + + yield from ( + (tag, pkt) + for tag, pkt in self._cached_output_packets.values() + if pkt is not None + ) - asyncio.run(_gather()) + def run(self) -> None: + """Eagerly compute all input packets, filling pipeline and result databases. - # Yield all results in order from internal cache - for idx in sorted(self._cached_output_packets.keys()): - tag, packet = self._cached_output_packets[idx] - if packet is not None: - yield tag, packet + Raises: + RuntimeError: If ``load_status`` is UNAVAILABLE (no pod, no DB). + """ + from orcapod.pipeline.serialization import LoadStatus - def run(self) -> None: - """Eagerly process all input packets, filling the pipeline and result databases.""" - for _ in self.iter_packets(): - pass + if self._load_status == LoadStatus.UNAVAILABLE: + raise RuntimeError( + f"FunctionNode {self.label!r} is unavailable: " + "no function pod and no database attached." + ) + if self._load_status in (LoadStatus.CACHE_ONLY, LoadStatus.READ_ONLY): + # CACHE_ONLY: upstream unavailable; computation requires a live input stream. + # READ_ONLY: function pod is a proxy placeholder — cannot compute. + # Callers should use iter_packets() to serve existing DB results. + return + if self.is_stale: + # Discard any stale in-memory entries before a fresh computation run + # so that rerunning does not mix old cached entries with new results. + self.clear_cache() + self.execute(self._input_stream) # ------------------------------------------------------------------ # as_table @@ -1371,6 +1201,9 @@ def as_table( all_tags.append(tag.as_dict(all_info=True)) all_packets.append(packet.as_dict(all_info=True)) + if not all_tags: + self._cached_output_table = pa.table({}) + converter = self.data_context.type_converter struct_packets = converter.python_dicts_to_struct_dicts(all_packets) @@ -1386,9 +1219,8 @@ def as_table( self._cached_output_table = arrow_utils.hstack_tables( all_tags_as_tables, all_packets_as_tables ) - assert self._cached_output_table is not None, ( - "_cached_output_table should not be None here." - ) + if self._cached_output_table is None: + self._cached_output_table = pa.table({}) column_config = ColumnConfig.handle_config(columns, all_info=all_info) @@ -1517,58 +1349,13 @@ async def async_execute( ctx_obs.on_node_start(node_label, node_hash, tag_schema=tag_schema) if self._cached_function_pod is not None: - # DB-backed async execution: # Phase 1: build cache lookup from pipeline DB - PIPELINE_ENTRY_ID_COL = "__pipeline_entry_id" - cached_by_entry_id: dict[ - str, tuple[TagProtocol, PacketProtocol] - ] = {} - - taginfo = self._pipeline_database.get_all_records( - self.node_identity_path, - record_id_column=PIPELINE_ENTRY_ID_COL, - ) - results = self._cached_function_pod._result_database.get_all_records( - self._cached_function_pod.record_path, - record_id_column=constants.PACKET_RECORD_ID, - ) - - if taginfo is not None and results is not None: - taginfo = self._filter_by_content_hash(taginfo) - taginfo_schema = taginfo.schema - results_schema = results.schema - joined = ( - pl.DataFrame(taginfo) - .join( - pl.DataFrame(results), - on=constants.PACKET_RECORD_ID, - how="inner", - ) - .to_arrow() - ) - joined = arrow_utils.restore_schema_nullability(joined, taginfo_schema, results_schema) - if joined.num_rows > 0: - tag_keys = self._input_stream.keys()[0] - entry_ids_col = joined.column( - PIPELINE_ENTRY_ID_COL - ).to_pylist() - drop_cols = [ - c - for c in joined.column_names - if c.startswith(constants.META_PREFIX) - or c == PIPELINE_ENTRY_ID_COL - or c == constants.NODE_CONTENT_HASH_COL - ] - data_table = joined.drop( - [c for c in drop_cols if c in joined.column_names] - ) - existing_stream = ArrowTableStream( - data_table, tag_columns=tag_keys - ) - for eid, (tag_out, pkt_out) in zip( - entry_ids_col, existing_stream.iter_packets() - ): - cached_by_entry_id[eid] = (tag_out, pkt_out) + loaded = self._load_cached_entries() + self._cached_output_packets.update(loaded) + if loaded: + self._cached_output_table = None + self._cached_content_hash_column = None + cached_by_entry_id: dict[str, tuple[TagProtocol, PacketProtocol]] = dict(loaded) # Phase 2: drive output from input channel — cached or compute async def _process_one_db( diff --git a/test-objective/unit/test_nodes.py b/test-objective/unit/test_nodes.py index 69e3921f..e4e39518 100644 --- a/test-objective/unit/test_nodes.py +++ b/test-objective/unit/test_nodes.py @@ -139,7 +139,8 @@ def test_caches_computed_results(self): pipeline_database=pipeline_db, result_database=result_db, ) - # First iteration computes all + # run() computes all; iter_packets() is read-only and serves from cache + node.run() packets = list(node.iter_packets()) assert len(packets) == 3 diff --git a/tests/test_channels/test_node_async_execute.py b/tests/test_channels/test_node_async_execute.py index d926bb50..ef483655 100644 --- a/tests/test_channels/test_node_async_execute.py +++ b/tests/test_channels/test_node_async_execute.py @@ -678,12 +678,13 @@ def test_function_node_sequential_uses_execute_packet(self): # Monkey-patch to verify routing through internal path original = node._process_packet_internal - def patched(tag, packet): + def patched(tag, packet, *, logger=None): call_log.append("_process_packet_internal") - return original(tag, packet) + return original(tag, packet, logger=logger) node._process_packet_internal = patched + node.run() # computation now triggered via run(), not iter_packets() results = list(node.iter_packets()) assert len(results) == 3 assert len(call_log) == 3 diff --git a/tests/test_core/function_pod/test_function_node_attach_db.py b/tests/test_core/function_pod/test_function_node_attach_db.py index 07f05add..db41cbc8 100644 --- a/tests/test_core/function_pod/test_function_node_attach_db.py +++ b/tests/test_core/function_pod/test_function_node_attach_db.py @@ -45,6 +45,7 @@ def test_construction_without_database(self): def test_iter_packets_without_database(self): node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=3)) + node.run() results = list(node.iter_packets()) assert len(results) == 3 assert results[0][1]["result"] == 0 @@ -76,7 +77,7 @@ def test_attach_databases_creates_cached_function_pod(self): def test_attach_databases_clears_caches(self): node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream()) - list(node.iter_packets()) # populate cache + node.run() # populate cache assert len(node._cached_output_packets) > 0 db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) @@ -107,6 +108,7 @@ def test_iter_packets_after_attach_works(self): node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=2)) db = InMemoryArrowDatabase() node.attach_databases(pipeline_database=db, result_database=db) + node.run() results = list(node.iter_packets()) assert len(results) == 2 @@ -140,5 +142,6 @@ def test_iter_packets_with_database(self): pipeline_database=db, result_database=db, ) + node.run() results = list(node.iter_packets()) assert len(results) == 3 diff --git a/tests/test_core/function_pod/test_function_node_caching.py b/tests/test_core/function_pod/test_function_node_caching.py index 2e9ddab8..07aba677 100644 --- a/tests/test_core/function_pod/test_function_node_caching.py +++ b/tests/test_core/function_pod/test_function_node_caching.py @@ -249,6 +249,7 @@ def test_phase2_processes_novel_entry_ids_only(self): [{"id": 0, "x": 10}, {"id": 1, "x": 20}, {"id": 2, "x": 30}] ) node2, _ = _make_node(stream2, db=db) + node2.run() results = list(node2.iter_packets()) # Should yield 3 total: 2 from Phase 1 + 1 from Phase 2 @@ -279,6 +280,7 @@ def test_same_packet_new_tag_triggers_phase2(self): assert node1.node_identity_path == node2.node_identity_path assert node1.node_identity_path[-1].startswith("schema:") + node2.run() results = list(node2.iter_packets()) # Phase 1 finds no records for node2's content_hash → Phase 2 processes the row diff --git a/tests/test_core/function_pod/test_function_pod_node.py b/tests/test_core/function_pod/test_function_pod_node.py index 352425b1..d75ae66c 100644 --- a/tests/test_core/function_pod/test_function_pod_node.py +++ b/tests/test_core/function_pod/test_function_pod_node.py @@ -302,11 +302,13 @@ class TestFunctionNodeStreamInterface: @pytest.fixture def node(self, double_pf) -> FunctionNode: db = InMemoryArrowDatabase() - return FunctionNode( + node = FunctionNode( function_pod=FunctionPod(packet_function=double_pf), input_stream=make_int_stream(n=3), pipeline_database=db, ) + node.run() + return node def test_iter_packets_correct_values(self, node): assert [packet["result"] for _, packet in node.iter_packets()] == [0, 2, 4] diff --git a/tests/test_core/function_pod/test_function_pod_node_stream.py b/tests/test_core/function_pod/test_function_pod_node_stream.py index d60dc8d0..64c24625 100644 --- a/tests/test_core/function_pod/test_function_pod_node_stream.py +++ b/tests/test_core/function_pod/test_function_pod_node_stream.py @@ -61,11 +61,13 @@ class TestFunctionNodeStreamBasic: @pytest.fixture def node(self, double_pf) -> FunctionNode: db = InMemoryArrowDatabase() - return FunctionNode( + node = FunctionNode( function_pod=FunctionPod(packet_function=double_pf), input_stream=make_int_stream(n=3), pipeline_database=db, ) + node.run() + return node def test_iter_packets_yields_correct_count(self, node): assert len(list(node.iter_packets())) == 3 @@ -117,6 +119,7 @@ def test_output_schema_has_result_in_packet_schema(self, node): class TestFunctionNodeColumnConfig: def test_as_table_content_hash_column(self, double_pf): node = _make_node(double_pf, n=3) + node.run() table = node.as_table(columns={"content_hash": True}) assert "_content_hash" in table.column_names assert len(table.column("_content_hash")) == 3 @@ -138,6 +141,7 @@ def test_as_table_sort_by_tags(self, double_pf): input_stream=input_stream, pipeline_database=db, ) + node.run() result = node.as_table(columns={"sort_by_tags": True}) ids: list[int] = result.column("id").to_pylist() # type: ignore[assignment] assert ids == sorted(ids) @@ -166,6 +170,7 @@ def test_as_table_returns_cached_results_when_packet_function_inactive( n = 3 db = InMemoryArrowDatabase() node1 = _make_node(double_pf, n=n, db=db) + node1.run() table1 = node1.as_table() assert len(table1) == n @@ -242,8 +247,12 @@ def test_db_served_results_have_correct_values(self, double_pf): n = 4 db = InMemoryArrowDatabase() - table1 = _make_node(double_pf, n=n, db=db).as_table() - table2 = _make_node(double_pf, n=n, db=db).as_table() + node1 = _make_node(double_pf, n=n, db=db) + node1.run() + table1 = node1.as_table() + node2 = _make_node(double_pf, n=n, db=db) + node2.run() + table2 = node2.as_table() assert sorted(table1.column("result").to_pylist()) == sorted( table2.column("result").to_pylist() @@ -300,14 +309,18 @@ def test_partial_fill_total_row_count_correct(self, double_pf): n = 4 db = InMemoryArrowDatabase() _fill_node(_make_node(double_pf, n=2, db=db)) - packets = list(_make_node(double_pf, n=n, db=db).iter_packets()) + node = _make_node(double_pf, n=n, db=db) + node.run() + packets = list(node.iter_packets()) assert len(packets) == n def test_partial_fill_all_values_correct(self, double_pf): n = 4 db = InMemoryArrowDatabase() _fill_node(_make_node(double_pf, n=2, db=db)) - table = _make_node(double_pf, n=n, db=db).as_table() + node = _make_node(double_pf, n=n, db=db) + node.run() + table = node.as_table() assert sorted(table.column("result").to_pylist()) == [0, 2, 4, 6] def test_partial_fill_as_table_after_run_on_same_node(self, double_pf): @@ -401,7 +414,7 @@ def test_is_stale_false_after_clear_cache(self, double_pf): def test_clear_cache_resets_output_packets(self, double_pf): node = _make_node(double_pf, n=3) - list(node.iter_packets()) + node.run() assert len(node._cached_output_packets) == 3 node.clear_cache() @@ -410,6 +423,7 @@ def test_clear_cache_resets_output_packets(self, double_pf): def test_clear_cache_produces_same_results_on_re_iteration(self, double_pf): node = _make_node(double_pf, n=3) + node.run() table_before = node.as_table() node.clear_cache() @@ -451,6 +465,7 @@ def test_as_table_auto_detects_stale_and_repopulates(self, double_pf): input_stream=input_stream, pipeline_database=db, ) + node.run() table_before = node.as_table() assert len(table_before) == 3 diff --git a/tests/test_core/function_pod/test_pipeline_hash_integration.py b/tests/test_core/function_pod/test_pipeline_hash_integration.py index e907a15b..35e17b6a 100644 --- a/tests/test_core/function_pod/test_pipeline_hash_integration.py +++ b/tests/test_core/function_pod/test_pipeline_hash_integration.py @@ -499,6 +499,7 @@ def test_shared_db_results_are_correct_values(self, double_pf): input_stream=make_int_stream(n=5), pipeline_database=db, ) + node2.run() results = sorted(cast(int, p["result"]) for _, p in node2.iter_packets()) assert results == [0, 2, 4, 6, 8] diff --git a/tests/test_core/nodes/test_function_node_iteration.py b/tests/test_core/nodes/test_function_node_iteration.py new file mode 100644 index 00000000..9bef49eb --- /dev/null +++ b/tests/test_core/nodes/test_function_node_iteration.py @@ -0,0 +1,164 @@ +"""Tests for the refactored FunctionNode iteration semantics. + +After ENG-379: +- iter_packets() is strictly read-only — never triggers computation +- Computation only via run() / execute() / async_execute() +""" +from __future__ import annotations + +from unittest.mock import patch + +import pyarrow as pa +import pytest + +from orcapod.core.function_pod import FunctionPod +from orcapod.core.nodes import FunctionNode +from orcapod.core.packet_function import PythonPacketFunction +from orcapod.core.sources import ArrowTableSource +from orcapod.databases import InMemoryArrowDatabase +from orcapod.core.executors import LocalExecutor + + +def _make_source(n: int = 3) -> ArrowTableSource: + table = pa.table( + { + "id": pa.array(list(range(n)), type=pa.int64()), + "x": pa.array(list(range(n)), type=pa.int64()), + }, + schema=pa.schema( + [ + pa.field("id", pa.int64(), nullable=False), + pa.field("x", pa.int64(), nullable=False), + ] + ), + ) + return ArrowTableSource(table, tag_columns=["id"]) + + +def _make_node(n: int = 3, db: InMemoryArrowDatabase | None = None) -> FunctionNode: + def double(x: int) -> int: + return x * 2 + + pf = PythonPacketFunction(double, output_keys="result") + pod = FunctionPod(pf) + pipeline_db = db if db is not None else InMemoryArrowDatabase() + return FunctionNode(pod, _make_source(n=n), pipeline_database=pipeline_db) + + +class TestIterPacketsReadOnly: + def test_fresh_node_no_db_yields_nothing(self): + """iter_packets() on a fresh node with no run() and empty DB yields nothing.""" + node = _make_node() + assert list(node.iter_packets()) == [] + + def test_iter_does_not_call_process_packet_internal(self): + """iter_packets() never calls _process_packet_internal under any non-compute path.""" + node = _make_node() + with patch.object(node, "_process_packet_internal") as mock_proc: + list(node.iter_packets()) + mock_proc.assert_not_called() + + def test_iter_after_db_populated_hot_loads_without_compute(self): + """iter_packets() on a node with DB records hot-loads without _process_packet_internal.""" + db = InMemoryArrowDatabase() + node1 = _make_node(n=3, db=db) + node1.run() # populate DB + + node2 = _make_node(n=3, db=db) + with patch.object(node2, "_process_packet_internal") as mock_proc: + results = list(node2.iter_packets()) + mock_proc.assert_not_called() + assert len(results) == 3 + + def test_after_run_iter_yields_from_cache_no_db_query(self): + """After run(), iter_packets() yields from _cached_output_packets without DB query.""" + node = _make_node() + node.run() + initial_count = len(node._cached_output_packets) + assert initial_count == 3 + + with patch.object(node, "_load_cached_entries") as mock_load: + results = list(node.iter_packets()) + mock_load.assert_not_called() + assert len(results) == 3 + + def test_iter_twice_same_order_db_queried_once(self): + """Two successive iter_packets() calls return same order; DB queried at most once.""" + db = InMemoryArrowDatabase() + node1 = _make_node(n=3, db=db) + node1.run() + + node2 = _make_node(n=3, db=db) + with patch.object(node2, "_load_cached_entries", wraps=node2._load_cached_entries) as mock_load: + first = [(t["id"], p["result"]) for t, p in node2.iter_packets()] + second = [(t["id"], p["result"]) for t, p in node2.iter_packets()] + assert mock_load.call_count <= 1 # at most one DB query + assert first == second + + def test_cached_output_packets_keyed_by_entry_id_strings(self): + """After run(), _cached_output_packets keys are entry_id strings, not ints.""" + node = _make_node() + node.run() + assert len(node._cached_output_packets) == 3 + for key in node._cached_output_packets: + assert isinstance(key, str), f"Expected str key, got {type(key)}: {key!r}" + + def test_as_table_fresh_node_returns_empty_no_compute(self): + """as_table() on a fresh node with no run() and empty DB returns empty table.""" + node = _make_node() + with patch.object(node, "_process_packet_internal") as mock_proc: + table = node.as_table() + mock_proc.assert_not_called() + assert isinstance(table, pa.Table) + assert len(table) == 0 + + def test_run_cache_only_is_noop(self): + """run() on a CACHE_ONLY node returns without error and without computation.""" + from orcapod.pipeline.serialization import LoadStatus + + node = _make_node() + node._load_status = LoadStatus.CACHE_ONLY + node._input_stream = None # simulate no upstream + + with patch.object(node, "execute") as mock_exec: + node.run() + mock_exec.assert_not_called() + + def test_run_unavailable_raises(self): + """run() on an UNAVAILABLE node raises RuntimeError.""" + from orcapod.pipeline.serialization import LoadStatus + + node = _make_node() + node._load_status = LoadStatus.UNAVAILABLE + with pytest.raises(RuntimeError, match="unavailable"): + node.run() + + def test_execute_error_policy_continue_skips_failures(self): + """execute() fires on_packet_crash per failing packet and returns successes when error_policy='continue'. + + Uses LocalExecutor (non-concurrent) to test the sequential execute() path. + """ + errors = [] + + def sometimes_fail(x: int) -> int: + if x == 1: + raise ValueError("intentional failure") + return x * 2 + + pf = PythonPacketFunction(sometimes_fail, output_keys="result") + pf.executor = LocalExecutor() # sets executor (LocalExecutor.supports_concurrent_execution is False) + pod = FunctionPod(pf) + db = InMemoryArrowDatabase() + node = FunctionNode(pod, _make_source(n=3), pipeline_database=db) + + from orcapod.pipeline.observer import NoOpObserver + + class CapturingObserver(NoOpObserver): + def on_packet_crash(self, node_label, tag, packet, exc): + errors.append(exc) + + results = node.execute(node._input_stream, observer=CapturingObserver(), error_policy="continue") + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + # Two non-failing packets should succeed + assert len(results) == 2 diff --git a/tests/test_core/packet_function/test_executor.py b/tests/test_core/packet_function/test_executor.py index 2bae8b1b..3f79de40 100644 --- a/tests/test_core/packet_function/test_executor.py +++ b/tests/test_core/packet_function/test_executor.py @@ -488,6 +488,7 @@ def test_node_iter_uses_executor(self): node = FunctionNode(pod, _make_add_stream()) + node.run() results = list(node.iter_packets()) assert len(results) == 2 assert results[0][1].as_dict()["result"] == 3 @@ -632,7 +633,11 @@ def test_function_pod_stream_uses_async_path(self): assert len(spy.async_calls) == 2 assert len(spy.sync_calls) == 0 - def test_function_node_uses_async_path(self): + def test_function_node_uses_sync_path_via_run(self): + """FunctionNode.run() delegates to execute(), which is always sequential + (synchronous). The async path is only used through async_execute() in the + async pipeline orchestrator. Even with a ConcurrentSpyExecutor attached, + run() → execute() → sync executor path.""" from orcapod.core.function_pod import FunctionPod from orcapod.core.nodes import FunctionNode @@ -641,13 +646,15 @@ def test_function_node_uses_async_path(self): pod = FunctionPod(pf) node = FunctionNode(pod, _make_add_stream()) + node.run() results = list(node.iter_packets()) assert len(results) == 2 assert results[0][1].as_dict()["result"] == 3 assert results[1][1].as_dict()["result"] == 7 - assert len(spy.async_calls) == 2 - assert len(spy.sync_calls) == 0 + # execute() is always sequential — sync path used, not async + assert len(spy.sync_calls) == 2 + assert len(spy.async_calls) == 0 def test_non_concurrent_executor_uses_sync_path(self): """SpyExecutor has supports_concurrent_execution=False (default).""" @@ -659,6 +666,7 @@ def test_non_concurrent_executor_uses_sync_path(self): pod = FunctionPod(pf) node = FunctionNode(pod, _make_add_stream()) + node.run() results = list(node.iter_packets()) assert len(results) == 2 @@ -673,6 +681,7 @@ def test_no_executor_uses_sync_path(self): pod = FunctionPod(pf) node = FunctionNode(pod, _make_add_stream()) + node.run() results = list(node.iter_packets()) assert len(results) == 2 diff --git a/tests/test_core/sources/test_derived_source.py b/tests/test_core/sources/test_derived_source.py index 42e2486b..fc5a018e 100644 --- a/tests/test_core/sources/test_derived_source.py +++ b/tests/test_core/sources/test_derived_source.py @@ -173,7 +173,8 @@ class TestDerivedSourceRoundTrip: def test_derived_source_matches_node_output(self): """Data from DerivedSource must exactly match data from FunctionNode.""" node = _make_node(n=5) - # Collect from node directly + node.run() + # Collect from node directly (iter_packets is read-only; run() must be called first) node_results = sorted(cast(int, p["result"]) for _, p in node.iter_packets()) # Now get via DerivedSource @@ -232,6 +233,7 @@ def test_derived_source_can_feed_downstream_node(self): ) # node2 doubles the already-doubled values: 0*2*2=0, 1*2*2=4, 2*2*2=8 + node2.run() results = sorted(cast(int, p["result"]) for _, p in node2.iter_packets()) assert results == [0, 4, 8] diff --git a/tests/test_data/test_polars_nullability/test_function_node_nullability.py b/tests/test_data/test_polars_nullability/test_function_node_nullability.py index ec845c6a..e717ac91 100644 --- a/tests/test_data/test_polars_nullability/test_function_node_nullability.py +++ b/tests/test_data/test_polars_nullability/test_function_node_nullability.py @@ -105,7 +105,8 @@ def triple(x: int) -> int: class TestFunctionNodeIterPacketsNullability: """FunctionNode.iter_packets must yield packets whose underlying Arrow schema - preserves non-nullable column constraints.""" + preserves non-nullable column constraints. Uses _load_cached_entries to + simulate the CACHE_ONLY path used after save/load.""" def test_iter_packets_from_database_preserves_non_nullable_output(self): """Packets loaded from DB via iter_packets carry non-nullable output schema.""" @@ -128,9 +129,10 @@ def add_one(x: int) -> int: fn_nodes = _get_function_nodes(pipeline) fn_node = fn_nodes[0] - # Force a DB-backed iteration by going through _iter_all_from_database + # Force a DB-backed iteration by going through _load_cached_entries # (simulates the CACHE_ONLY path used after save/load) - packets_seen = list(fn_node._iter_all_from_database()) + loaded = fn_node._load_cached_entries() + packets_seen = list(loaded.values()) assert len(packets_seen) == 1, "Expected one packet from the database" _tag, packet = packets_seen[0] diff --git a/tests/test_pipeline/test_serialization.py b/tests/test_pipeline/test_serialization.py index 68c38eb2..ff3f6e73 100644 --- a/tests/test_pipeline/test_serialization.py +++ b/tests/test_pipeline/test_serialization.py @@ -2238,6 +2238,7 @@ def test_definition_save_load_run_roundtrip(self, tmp_path): assert fn_node.load_status == LoadStatus.FULL # Run the loaded pipeline and compare results + loaded.run() loaded_results = sorted( p.as_dict()["result"] for _, p in fn_node.iter_packets() )