Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a335ba0
docs: add FunctionNode refactor design spec (ENG-379)
kurodo3[bot] Apr 7, 2026
e213e5d
docs(ENG-379): add implementation plan for FunctionNode refactor
kurodo3[bot] Apr 7, 2026
99b05eb
test(ENG-379): add failing tests for read-only iter_packets semantics
kurodo3[bot] Apr 7, 2026
67b0520
test(ENG-379): fix unused asyncio import and misleading concurrent do…
kurodo3[bot] Apr 7, 2026
3bda290
test(ENG-379): remove stale asyncio.gather mention from module docstring
kurodo3[bot] Apr 7, 2026
3f6bb33
feat(ENG-379): add _load_cached_entries() single DB join helper
kurodo3[bot] Apr 7, 2026
3398301
refactor(ENG-379): extract entry_ids_col before join.drop() for clarity
kurodo3[bot] Apr 7, 2026
da9ce26
refactor(ENG-379): store _cached_output_packets by entry_id, remove c…
kurodo3[bot] Apr 7, 2026
508ed52
refactor(ENG-379): update _cached_output_packets type annotation to s…
kurodo3[bot] Apr 7, 2026
be767ed
refactor(ENG-379): rewrite run() with load_status guard delegating to…
kurodo3[bot] Apr 7, 2026
8472b98
refactor(ENG-379): rewrite iter_packets() as strictly read-only with …
kurodo3[bot] Apr 7, 2026
6dd21dd
refactor(ENG-379): remove _cached_input_iterator/_needs_iterator iter…
kurodo3[bot] Apr 7, 2026
f4c776b
refactor(ENG-379): revert str annotation to str|int interim, remove d…
kurodo3[bot] Apr 7, 2026
4a1095f
refactor(ENG-379): simplify get_cached_results() using _load_cached_e…
kurodo3[bot] Apr 7, 2026
6e57a3f
refactor(ENG-379): simplify _async_execute_cache_only() using _load_c…
kurodo3[bot] Apr 7, 2026
a773c31
refactor(ENG-379): refactor execute() with selective DB reload and co…
kurodo3[bot] Apr 7, 2026
1ff23f6
refactor(ENG-379): simplify async_execute() Phase 1 using _load_cache…
kurodo3[bot] Apr 7, 2026
cf39ddf
refactor(ENG-379): remove dead methods (iter_sequential/concurrent, i…
kurodo3[bot] Apr 7, 2026
bcf61e8
test(ENG-379): update test_function_pod_node_stream.py — add run() be…
kurodo3[bot] Apr 7, 2026
7c5e71f
test(ENG-379): update remaining test files for read-only iter_packets…
kurodo3[bot] Apr 7, 2026
c6dcf8b
fix(ENG-379): as_table() returns empty pa.Table when iter_packets() y…
kurodo3[bot] Apr 7, 2026
b101c49
test(ENG-379): fix test_function_node_sequential_uses_execute_packet …
kurodo3[bot] Apr 7, 2026
7f97abd
fix(ENG-379): fix execute() on_packet_start ordering and test regress…
kurodo3[bot] Apr 7, 2026
065d7fc
test(ENG-379): fix test_caches_computed_results in test-objective — a…
kurodo3[bot] Apr 7, 2026
d67aca1
fix(ENG-379): address Copilot review — None guard, run() READ_ONLY, g…
kurodo3[bot] Apr 7, 2026
289573f
fix(ENG-379): address second Copilot review — cache-hit check, doc co…
kurodo3[bot] Apr 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,630 changes: 1,630 additions & 0 deletions docs/superpowers/plans/2026-04-07-eng379-function-node-refactor.md

Large diffs are not rendered by default.

315 changes: 315 additions & 0 deletions docs/superpowers/specs/2026-04-07-function-node-refactor-design.md

Large diffs are not rendered by default.

595 changes: 191 additions & 404 deletions src/orcapod/core/nodes/function_node.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion test-objective/unit/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def test_caches_computed_results(self):
pipeline_database=pipeline_db,
result_database=result_db,
)
# First iteration computes all
# run() computes all; iter_packets() is read-only and serves from cache
node.run()
packets = list(node.iter_packets())
assert len(packets) == 3

Expand Down
5 changes: 3 additions & 2 deletions tests/test_channels/test_node_async_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,12 +678,13 @@ def test_function_node_sequential_uses_execute_packet(self):
# Monkey-patch to verify routing through internal path
original = node._process_packet_internal

def patched(tag, packet):
def patched(tag, packet, *, logger=None):
call_log.append("_process_packet_internal")
return original(tag, packet)
return original(tag, packet, logger=logger)

node._process_packet_internal = patched

node.run() # computation now triggered via run(), not iter_packets()
results = list(node.iter_packets())
assert len(results) == 3
assert len(call_log) == 3
Expand Down
5 changes: 4 additions & 1 deletion tests/test_core/function_pod/test_function_node_attach_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_construction_without_database(self):

def test_iter_packets_without_database(self):
node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=3))
node.run()
results = list(node.iter_packets())
assert len(results) == 3
assert results[0][1]["result"] == 0
Expand Down Expand Up @@ -76,7 +77,7 @@ def test_attach_databases_creates_cached_function_pod(self):

def test_attach_databases_clears_caches(self):
node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream())
list(node.iter_packets()) # populate cache
node.run() # populate cache
assert len(node._cached_output_packets) > 0
db = InMemoryArrowDatabase()
node.attach_databases(pipeline_database=db, result_database=db)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_iter_packets_after_attach_works(self):
node = FunctionNode(function_pod=_make_pod(), input_stream=_make_stream(n=2))
db = InMemoryArrowDatabase()
node.attach_databases(pipeline_database=db, result_database=db)
node.run()
results = list(node.iter_packets())
assert len(results) == 2

Expand Down Expand Up @@ -140,5 +142,6 @@ def test_iter_packets_with_database(self):
pipeline_database=db,
result_database=db,
)
node.run()
results = list(node.iter_packets())
assert len(results) == 3
2 changes: 2 additions & 0 deletions tests/test_core/function_pod/test_function_node_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_phase2_processes_novel_entry_ids_only(self):
[{"id": 0, "x": 10}, {"id": 1, "x": 20}, {"id": 2, "x": 30}]
)
node2, _ = _make_node(stream2, db=db)
node2.run()
results = list(node2.iter_packets())

# Should yield 3 total: 2 from Phase 1 + 1 from Phase 2
Expand Down Expand Up @@ -279,6 +280,7 @@ def test_same_packet_new_tag_triggers_phase2(self):
assert node1.node_identity_path == node2.node_identity_path
assert node1.node_identity_path[-1].startswith("schema:")

node2.run()
results = list(node2.iter_packets())

# Phase 1 finds no records for node2's content_hash → Phase 2 processes the row
Expand Down
4 changes: 3 additions & 1 deletion tests/test_core/function_pod/test_function_pod_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,11 +302,13 @@ class TestFunctionNodeStreamInterface:
@pytest.fixture
def node(self, double_pf) -> FunctionNode:
db = InMemoryArrowDatabase()
return FunctionNode(
node = FunctionNode(
function_pod=FunctionPod(packet_function=double_pf),
input_stream=make_int_stream(n=3),
pipeline_database=db,
)
node.run()
return node

def test_iter_packets_correct_values(self, node):
assert [packet["result"] for _, packet in node.iter_packets()] == [0, 2, 4]
Expand Down
27 changes: 21 additions & 6 deletions tests/test_core/function_pod/test_function_pod_node_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,13 @@ class TestFunctionNodeStreamBasic:
@pytest.fixture
def node(self, double_pf) -> FunctionNode:
db = InMemoryArrowDatabase()
return FunctionNode(
node = FunctionNode(
function_pod=FunctionPod(packet_function=double_pf),
input_stream=make_int_stream(n=3),
pipeline_database=db,
)
node.run()
return node

def test_iter_packets_yields_correct_count(self, node):
assert len(list(node.iter_packets())) == 3
Expand Down Expand Up @@ -117,6 +119,7 @@ def test_output_schema_has_result_in_packet_schema(self, node):
class TestFunctionNodeColumnConfig:
def test_as_table_content_hash_column(self, double_pf):
node = _make_node(double_pf, n=3)
node.run()
table = node.as_table(columns={"content_hash": True})
assert "_content_hash" in table.column_names
assert len(table.column("_content_hash")) == 3
Expand All @@ -138,6 +141,7 @@ def test_as_table_sort_by_tags(self, double_pf):
input_stream=input_stream,
pipeline_database=db,
)
node.run()
result = node.as_table(columns={"sort_by_tags": True})
ids: list[int] = result.column("id").to_pylist() # type: ignore[assignment]
assert ids == sorted(ids)
Expand Down Expand Up @@ -166,6 +170,7 @@ def test_as_table_returns_cached_results_when_packet_function_inactive(
n = 3
db = InMemoryArrowDatabase()
node1 = _make_node(double_pf, n=n, db=db)
node1.run()
table1 = node1.as_table()
assert len(table1) == n

Expand Down Expand Up @@ -242,8 +247,12 @@ def test_db_served_results_have_correct_values(self, double_pf):
n = 4
db = InMemoryArrowDatabase()

table1 = _make_node(double_pf, n=n, db=db).as_table()
table2 = _make_node(double_pf, n=n, db=db).as_table()
node1 = _make_node(double_pf, n=n, db=db)
node1.run()
table1 = node1.as_table()
node2 = _make_node(double_pf, n=n, db=db)
node2.run()
table2 = node2.as_table()

assert sorted(table1.column("result").to_pylist()) == sorted(
table2.column("result").to_pylist()
Expand Down Expand Up @@ -300,14 +309,18 @@ def test_partial_fill_total_row_count_correct(self, double_pf):
n = 4
db = InMemoryArrowDatabase()
_fill_node(_make_node(double_pf, n=2, db=db))
packets = list(_make_node(double_pf, n=n, db=db).iter_packets())
node = _make_node(double_pf, n=n, db=db)
node.run()
packets = list(node.iter_packets())
assert len(packets) == n

def test_partial_fill_all_values_correct(self, double_pf):
n = 4
db = InMemoryArrowDatabase()
_fill_node(_make_node(double_pf, n=2, db=db))
table = _make_node(double_pf, n=n, db=db).as_table()
node = _make_node(double_pf, n=n, db=db)
node.run()
table = node.as_table()
assert sorted(table.column("result").to_pylist()) == [0, 2, 4, 6]

def test_partial_fill_as_table_after_run_on_same_node(self, double_pf):
Expand Down Expand Up @@ -401,7 +414,7 @@ def test_is_stale_false_after_clear_cache(self, double_pf):

def test_clear_cache_resets_output_packets(self, double_pf):
node = _make_node(double_pf, n=3)
list(node.iter_packets())
node.run()
assert len(node._cached_output_packets) == 3

node.clear_cache()
Expand All @@ -410,6 +423,7 @@ def test_clear_cache_resets_output_packets(self, double_pf):

def test_clear_cache_produces_same_results_on_re_iteration(self, double_pf):
node = _make_node(double_pf, n=3)
node.run()
table_before = node.as_table()

node.clear_cache()
Expand Down Expand Up @@ -451,6 +465,7 @@ def test_as_table_auto_detects_stale_and_repopulates(self, double_pf):
input_stream=input_stream,
pipeline_database=db,
)
node.run()
table_before = node.as_table()
assert len(table_before) == 3

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def test_shared_db_results_are_correct_values(self, double_pf):
input_stream=make_int_stream(n=5),
pipeline_database=db,
)
node2.run()
results = sorted(cast(int, p["result"]) for _, p in node2.iter_packets())
assert results == [0, 2, 4, 6, 8]

Expand Down
164 changes: 164 additions & 0 deletions tests/test_core/nodes/test_function_node_iteration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Tests for the refactored FunctionNode iteration semantics.

After ENG-379:
- iter_packets() is strictly read-only — never triggers computation
- Computation only via run() / execute() / async_execute()
"""
from __future__ import annotations

from unittest.mock import patch

import pyarrow as pa
import pytest

from orcapod.core.function_pod import FunctionPod
from orcapod.core.nodes import FunctionNode
from orcapod.core.packet_function import PythonPacketFunction
from orcapod.core.sources import ArrowTableSource
from orcapod.databases import InMemoryArrowDatabase
from orcapod.core.executors import LocalExecutor


def _make_source(n: int = 3) -> ArrowTableSource:
table = pa.table(
{
"id": pa.array(list(range(n)), type=pa.int64()),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't you be able to specify nullability on the arrays and skip having to specify it on the schema?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyArrow doesn't expose nullability at the array level — nullable is a property of pa.field only, not pa.array. There's no pa.array(..., nullable=False) API. The explicit schema=pa.schema([...]) block is the standard PyArrow idiom for enforcing non-nullability on a raw table, so I've kept it as-is.

"x": pa.array(list(range(n)), type=pa.int64()),
},
schema=pa.schema(
[
pa.field("id", pa.int64(), nullable=False),
pa.field("x", pa.int64(), nullable=False),
]
),
)
return ArrowTableSource(table, tag_columns=["id"])


def _make_node(n: int = 3, db: InMemoryArrowDatabase | None = None) -> FunctionNode:
def double(x: int) -> int:
return x * 2

pf = PythonPacketFunction(double, output_keys="result")
pod = FunctionPod(pf)
pipeline_db = db if db is not None else InMemoryArrowDatabase()
return FunctionNode(pod, _make_source(n=n), pipeline_database=pipeline_db)


class TestIterPacketsReadOnly:
def test_fresh_node_no_db_yields_nothing(self):
"""iter_packets() on a fresh node with no run() and empty DB yields nothing."""
node = _make_node()
assert list(node.iter_packets()) == []

def test_iter_does_not_call_process_packet_internal(self):
"""iter_packets() never calls _process_packet_internal under any non-compute path."""
node = _make_node()
with patch.object(node, "_process_packet_internal") as mock_proc:
list(node.iter_packets())
mock_proc.assert_not_called()

def test_iter_after_db_populated_hot_loads_without_compute(self):
"""iter_packets() on a node with DB records hot-loads without _process_packet_internal."""
db = InMemoryArrowDatabase()
node1 = _make_node(n=3, db=db)
node1.run() # populate DB

node2 = _make_node(n=3, db=db)
with patch.object(node2, "_process_packet_internal") as mock_proc:
results = list(node2.iter_packets())
mock_proc.assert_not_called()
assert len(results) == 3

def test_after_run_iter_yields_from_cache_no_db_query(self):
"""After run(), iter_packets() yields from _cached_output_packets without DB query."""
node = _make_node()
node.run()
initial_count = len(node._cached_output_packets)
assert initial_count == 3

with patch.object(node, "_load_cached_entries") as mock_load:
results = list(node.iter_packets())
mock_load.assert_not_called()
assert len(results) == 3

def test_iter_twice_same_order_db_queried_once(self):
"""Two successive iter_packets() calls return same order; DB queried at most once."""
db = InMemoryArrowDatabase()
node1 = _make_node(n=3, db=db)
node1.run()

node2 = _make_node(n=3, db=db)
with patch.object(node2, "_load_cached_entries", wraps=node2._load_cached_entries) as mock_load:
first = [(t["id"], p["result"]) for t, p in node2.iter_packets()]
second = [(t["id"], p["result"]) for t, p in node2.iter_packets()]
assert mock_load.call_count <= 1 # at most one DB query
assert first == second

def test_cached_output_packets_keyed_by_entry_id_strings(self):
"""After run(), _cached_output_packets keys are entry_id strings, not ints."""
node = _make_node()
node.run()
assert len(node._cached_output_packets) == 3
for key in node._cached_output_packets:
assert isinstance(key, str), f"Expected str key, got {type(key)}: {key!r}"

def test_as_table_fresh_node_returns_empty_no_compute(self):
"""as_table() on a fresh node with no run() and empty DB returns empty table."""
node = _make_node()
with patch.object(node, "_process_packet_internal") as mock_proc:
table = node.as_table()
mock_proc.assert_not_called()
assert isinstance(table, pa.Table)
assert len(table) == 0

def test_run_cache_only_is_noop(self):
"""run() on a CACHE_ONLY node returns without error and without computation."""
from orcapod.pipeline.serialization import LoadStatus

node = _make_node()
node._load_status = LoadStatus.CACHE_ONLY
node._input_stream = None # simulate no upstream

with patch.object(node, "execute") as mock_exec:
node.run()
mock_exec.assert_not_called()

def test_run_unavailable_raises(self):
"""run() on an UNAVAILABLE node raises RuntimeError."""
from orcapod.pipeline.serialization import LoadStatus

node = _make_node()
node._load_status = LoadStatus.UNAVAILABLE
with pytest.raises(RuntimeError, match="unavailable"):
node.run()

def test_execute_error_policy_continue_skips_failures(self):
"""execute() fires on_packet_crash per failing packet and returns successes when error_policy='continue'.

Uses LocalExecutor (non-concurrent) to test the sequential execute() path.
"""
errors = []

def sometimes_fail(x: int) -> int:
if x == 1:
raise ValueError("intentional failure")
return x * 2

pf = PythonPacketFunction(sometimes_fail, output_keys="result")
pf.executor = LocalExecutor() # sets executor (LocalExecutor.supports_concurrent_execution is False)
pod = FunctionPod(pf)
db = InMemoryArrowDatabase()
node = FunctionNode(pod, _make_source(n=3), pipeline_database=db)

from orcapod.pipeline.observer import NoOpObserver

class CapturingObserver(NoOpObserver):
def on_packet_crash(self, node_label, tag, packet, exc):
errors.append(exc)

results = node.execute(node._input_stream, observer=CapturingObserver(), error_policy="continue")
assert len(errors) == 1
assert isinstance(errors[0], ValueError)
# Two non-failing packets should succeed
assert len(results) == 2
Loading
Loading