diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index bef75991b..2f404997b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -25,6 +25,7 @@ jobs: backend_worker: ${{ steps.filter.outputs.backend_worker }} backend_frame: ${{ steps.filter.outputs.backend_frame }} backend_flowfile: ${{ steps.filter.outputs.backend_flowfile }} + kernel: ${{ steps.filter.outputs.kernel }} frontend: ${{ steps.filter.outputs.frontend }} docs: ${{ steps.filter.outputs.docs }} shared: ${{ steps.filter.outputs.shared }} @@ -46,6 +47,11 @@ jobs: - 'flowfile_frame/**' backend_flowfile: - 'flowfile/**' + kernel: + - 'kernel_runtime/**' + - 'flowfile_core/flowfile_core/kernel/**' + - 'flowfile_core/tests/flowfile/test_kernel_integration.py' + - 'flowfile_core/tests/kernel_fixtures.py' frontend: - 'flowfile_frontend/**' docs: @@ -145,7 +151,7 @@ jobs: needs.detect-changes.outputs.shared == 'true' || needs.detect-changes.outputs.test_workflow == 'true' || github.event.inputs.run_all_tests == 'true' - run: poetry run pytest flowfile_core/tests --disable-warnings $COV_ARGS + run: poetry run pytest flowfile_core/tests -m "not kernel" --disable-warnings $COV_ARGS env: COV_ARGS: ${{ (matrix.os == 'ubuntu-latest' && matrix.python-version == '3.12') && '--cov --cov-append --cov-report=' || '' }} @@ -271,7 +277,7 @@ jobs: needs.detect-changes.outputs.test_workflow == 'true' || github.event.inputs.run_all_tests == 'true' shell: pwsh - run: poetry run pytest flowfile_core/tests --disable-warnings + run: poetry run pytest flowfile_core/tests -m "not kernel" --disable-warnings - name: Run pytest for flowfile_worker if: | @@ -299,6 +305,48 @@ jobs: shell: pwsh run: poetry run pytest flowfile/tests --disable-warnings + # Kernel integration tests - runs in parallel on a separate worker + kernel-tests: + needs: detect-changes + if: | + needs.detect-changes.outputs.kernel == 'true' || + needs.detect-changes.outputs.backend_core == 'true' || + needs.detect-changes.outputs.shared == 'true' || + needs.detect-changes.outputs.test_workflow == 'true' || + github.event.inputs.run_all_tests == 'true' + runs-on: ubuntu-latest + timeout-minutes: 15 + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + cache: 'pip' + + - name: Install Poetry + run: | + curl -sSL https://install.python-poetry.org | python - + echo "$HOME/.poetry/bin" >> $GITHUB_PATH + + - name: Install Dependencies + run: | + poetry install --no-interaction --no-ansi --with dev + + - name: Build kernel Docker image + run: | + docker build -t flowfile-kernel -f kernel_runtime/Dockerfile kernel_runtime/ + + - name: Run kernel_runtime unit tests + run: | + pip install -e "kernel_runtime/[test]" + python -m pytest kernel_runtime/tests -v --disable-warnings + + - name: Run kernel integration tests + run: | + poetry run pytest flowfile_core/tests -m kernel -v --disable-warnings + # Frontend web build test - runs when frontend changes or test workflow changes test-web: needs: detect-changes @@ -472,7 +520,7 @@ jobs: # Summary job - always runs to provide status test-summary: - needs: [detect-changes, backend-tests, backend-tests-windows, test-web, electron-tests-macos, electron-tests-windows, docs-test] + needs: [detect-changes, backend-tests, backend-tests-windows, kernel-tests, test-web, electron-tests-macos, electron-tests-windows, docs-test] if: always() runs-on: ubuntu-latest steps: @@ -485,6 +533,7 @@ jobs: echo " - Backend Worker: ${{ needs.detect-changes.outputs.backend_worker }}" echo " - Backend Frame: ${{ needs.detect-changes.outputs.backend_frame }}" echo " - Backend Flowfile: ${{ needs.detect-changes.outputs.backend_flowfile }}" + echo " - Kernel: ${{ needs.detect-changes.outputs.kernel }}" echo " - Frontend: ${{ needs.detect-changes.outputs.frontend }}" echo " - Docs: ${{ needs.detect-changes.outputs.docs }}" echo " - Shared/Dependencies: ${{ needs.detect-changes.outputs.shared }}" @@ -493,6 +542,7 @@ jobs: echo "Job results:" echo " - Backend Tests: ${{ needs.backend-tests.result }}" echo " - Backend Tests (Windows): ${{ needs.backend-tests-windows.result }}" + echo " - Kernel Tests: ${{ needs.kernel-tests.result }}" echo " - Web Tests: ${{ needs.test-web.result }}" echo " - Electron Tests (macOS): ${{ needs.electron-tests-macos.result }}" echo " - Electron Tests (Windows): ${{ needs.electron-tests-windows.result }}" @@ -501,6 +551,7 @@ jobs: # Fail if any non-skipped job failed if [[ "${{ needs.backend-tests.result }}" == "failure" ]] || \ [[ "${{ needs.backend-tests-windows.result }}" == "failure" ]] || \ + [[ "${{ needs.kernel-tests.result }}" == "failure" ]] || \ [[ "${{ needs.test-web.result }}" == "failure" ]] || \ [[ "${{ needs.electron-tests-macos.result }}" == "failure" ]] || \ [[ "${{ needs.electron-tests-windows.result }}" == "failure" ]] || \ diff --git a/.gitignore b/.gitignore index 6e586b8cd..fab22f65e 100644 --- a/.gitignore +++ b/.gitignore @@ -67,6 +67,9 @@ htmlcov/ # Docker flowfile_data/ +# Egg info +*.egg-info/ + # Secrets and keys - NEVER commit these master_key.txt *.key diff --git a/flowfile_core/flowfile_core/catalog/__init__.py b/flowfile_core/flowfile_core/catalog/__init__.py new file mode 100644 index 000000000..67c9bf8e4 --- /dev/null +++ b/flowfile_core/flowfile_core/catalog/__init__.py @@ -0,0 +1,44 @@ +"""Flow Catalog service layer. + +Public interface: + +* ``CatalogService`` — business-logic orchestrator +* ``CatalogRepository`` — data-access protocol (for type-hints / mocking) +* ``SQLAlchemyCatalogRepository`` — concrete SQLAlchemy implementation +* Domain exceptions (``CatalogError`` hierarchy) +""" + +from .exceptions import ( + CatalogError, + FavoriteNotFoundError, + FlowExistsError, + FlowNotFoundError, + FollowNotFoundError, + NamespaceExistsError, + NamespaceNotEmptyError, + NamespaceNotFoundError, + NestingLimitError, + NoSnapshotError, + NotAuthorizedError, + RunNotFoundError, +) +from .repository import CatalogRepository, SQLAlchemyCatalogRepository +from .service import CatalogService + +__all__ = [ + "CatalogService", + "CatalogRepository", + "SQLAlchemyCatalogRepository", + "CatalogError", + "NamespaceNotFoundError", + "NamespaceExistsError", + "NestingLimitError", + "NamespaceNotEmptyError", + "FlowNotFoundError", + "FlowExistsError", + "RunNotFoundError", + "NotAuthorizedError", + "FavoriteNotFoundError", + "FollowNotFoundError", + "NoSnapshotError", +] diff --git a/flowfile_core/flowfile_core/catalog/exceptions.py b/flowfile_core/flowfile_core/catalog/exceptions.py new file mode 100644 index 000000000..37d06bc53 --- /dev/null +++ b/flowfile_core/flowfile_core/catalog/exceptions.py @@ -0,0 +1,121 @@ +"""Domain-specific exceptions for the Flow Catalog system. + +These exceptions represent business-rule violations and are raised by the +service layer. Route handlers catch them and translate to appropriate +HTTP responses. +""" + + +class CatalogError(Exception): + """Base exception for all catalog domain errors.""" + + +class NamespaceNotFoundError(CatalogError): + """Raised when a namespace lookup fails.""" + + def __init__(self, namespace_id: int | None = None, name: str | None = None): + self.namespace_id = namespace_id + self.name = name + detail = "Namespace not found" + if namespace_id is not None: + detail = f"Namespace with id={namespace_id} not found" + elif name is not None: + detail = f"Namespace '{name}' not found" + super().__init__(detail) + + +class NamespaceExistsError(CatalogError): + """Raised when attempting to create a duplicate namespace.""" + + def __init__(self, name: str, parent_id: int | None = None): + self.name = name + self.parent_id = parent_id + super().__init__( + f"Namespace '{name}' already exists" + + (f" under parent_id={parent_id}" if parent_id is not None else " at root level") + ) + + +class NestingLimitError(CatalogError): + """Raised when attempting to nest namespaces deeper than catalog -> schema.""" + + def __init__(self, parent_id: int, parent_level: int): + self.parent_id = parent_id + self.parent_level = parent_level + super().__init__("Cannot nest deeper than catalog -> schema") + + +class NamespaceNotEmptyError(CatalogError): + """Raised when trying to delete a namespace that still has children or flows.""" + + def __init__(self, namespace_id: int, children: int = 0, flows: int = 0): + self.namespace_id = namespace_id + self.children = children + self.flows = flows + super().__init__("Cannot delete namespace with children or flows") + + +class FlowNotFoundError(CatalogError): + """Raised when a flow registration lookup fails.""" + + def __init__(self, registration_id: int | None = None, name: str | None = None): + self.registration_id = registration_id + self.name = name + detail = "Flow not found" + if registration_id is not None: + detail = f"Flow with id={registration_id} not found" + elif name is not None: + detail = f"Flow '{name}' not found" + super().__init__(detail) + + +class FlowExistsError(CatalogError): + """Raised when attempting to create a duplicate flow registration.""" + + def __init__(self, name: str, namespace_id: int | None = None): + self.name = name + self.namespace_id = namespace_id + super().__init__(f"Flow '{name}' already exists in namespace_id={namespace_id}") + + +class RunNotFoundError(CatalogError): + """Raised when a flow run lookup fails.""" + + def __init__(self, run_id: int): + self.run_id = run_id + super().__init__(f"Run with id={run_id} not found") + + +class NotAuthorizedError(CatalogError): + """Raised when a user attempts an action they are not permitted to perform.""" + + def __init__(self, user_id: int, action: str = "perform this action"): + self.user_id = user_id + self.action = action + super().__init__(f"User {user_id} is not authorized to {action}") + + +class FavoriteNotFoundError(CatalogError): + """Raised when a favorite record is not found.""" + + def __init__(self, user_id: int, registration_id: int): + self.user_id = user_id + self.registration_id = registration_id + super().__init__(f"Favorite not found for user={user_id}, flow={registration_id}") + + +class FollowNotFoundError(CatalogError): + """Raised when a follow record is not found.""" + + def __init__(self, user_id: int, registration_id: int): + self.user_id = user_id + self.registration_id = registration_id + super().__init__(f"Follow not found for user={user_id}, flow={registration_id}") + + +class NoSnapshotError(CatalogError): + """Raised when a run has no flow snapshot available.""" + + def __init__(self, run_id: int): + self.run_id = run_id + super().__init__(f"No flow snapshot available for run id={run_id}") diff --git a/flowfile_core/flowfile_core/catalog/repository.py b/flowfile_core/flowfile_core/catalog/repository.py new file mode 100644 index 000000000..e81f6f42e --- /dev/null +++ b/flowfile_core/flowfile_core/catalog/repository.py @@ -0,0 +1,504 @@ +"""Data-access abstraction for the Flow Catalog system. + +Defines a ``CatalogRepository`` :pep:`544` Protocol and provides a concrete +``SQLAlchemyCatalogRepository`` implementation backed by SQLAlchemy. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from sqlalchemy.orm import Session + +from flowfile_core.database.models import ( + CatalogNamespace, + FlowFavorite, + FlowFollow, + FlowRegistration, + FlowRun, +) + + +# --------------------------------------------------------------------------- +# Repository Protocol +# --------------------------------------------------------------------------- + + +@runtime_checkable +class CatalogRepository(Protocol): + """Abstract interface for catalog data access. + + Any class that satisfies this protocol can be used by ``CatalogService``, + enabling easy unit-testing with mock implementations. + """ + + # -- Namespace operations ------------------------------------------------ + + def get_namespace(self, namespace_id: int) -> CatalogNamespace | None: ... + + def get_namespace_by_name( + self, name: str, parent_id: int | None + ) -> CatalogNamespace | None: ... + + def list_namespaces(self, parent_id: int | None = None) -> list[CatalogNamespace]: ... + + def list_root_namespaces(self) -> list[CatalogNamespace]: ... + + def list_child_namespaces(self, parent_id: int) -> list[CatalogNamespace]: ... + + def create_namespace(self, ns: CatalogNamespace) -> CatalogNamespace: ... + + def update_namespace(self, ns: CatalogNamespace) -> CatalogNamespace: ... + + def delete_namespace(self, namespace_id: int) -> None: ... + + def count_children(self, namespace_id: int) -> int: ... + + # -- Flow registration operations ---------------------------------------- + + def get_flow(self, registration_id: int) -> FlowRegistration | None: ... + + def get_flow_by_name( + self, name: str, namespace_id: int + ) -> FlowRegistration | None: ... + + def get_flow_by_path(self, flow_path: str) -> FlowRegistration | None: ... + + def list_flows( + self, + namespace_id: int | None = None, + owner_id: int | None = None, + ) -> list[FlowRegistration]: ... + + def create_flow(self, reg: FlowRegistration) -> FlowRegistration: ... + + def update_flow(self, reg: FlowRegistration) -> FlowRegistration: ... + + def delete_flow(self, registration_id: int) -> None: ... + + def count_flows_in_namespace(self, namespace_id: int) -> int: ... + + # -- Run operations ------------------------------------------------------ + + def get_run(self, run_id: int) -> FlowRun | None: ... + + def list_runs( + self, + registration_id: int | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[FlowRun]: ... + + def create_run(self, run: FlowRun) -> FlowRun: ... + + def update_run(self, run: FlowRun) -> FlowRun: ... + + def count_runs(self) -> int: ... + + # -- Favorites ----------------------------------------------------------- + + def get_favorite( + self, user_id: int, registration_id: int + ) -> FlowFavorite | None: ... + + def add_favorite(self, fav: FlowFavorite) -> FlowFavorite: ... + + def remove_favorite(self, user_id: int, registration_id: int) -> None: ... + + def list_favorites(self, user_id: int) -> list[FlowFavorite]: ... + + def count_favorites(self, user_id: int) -> int: ... + + # -- Follows ------------------------------------------------------------- + + def get_follow( + self, user_id: int, registration_id: int + ) -> FlowFollow | None: ... + + def add_follow(self, follow: FlowFollow) -> FlowFollow: ... + + def remove_follow(self, user_id: int, registration_id: int) -> None: ... + + def list_follows(self, user_id: int) -> list[FlowFollow]: ... + + # -- Aggregate helpers --------------------------------------------------- + + def count_run_for_flow(self, registration_id: int) -> int: ... + + def last_run_for_flow(self, registration_id: int) -> FlowRun | None: ... + + def count_catalog_namespaces(self) -> int: ... + + def count_all_flows(self) -> int: ... + + # -- Bulk enrichment helpers (for N+1 elimination) ----------------------- + + def bulk_get_favorite_flow_ids( + self, user_id: int, flow_ids: list[int] + ) -> set[int]: ... + + def bulk_get_follow_flow_ids( + self, user_id: int, flow_ids: list[int] + ) -> set[int]: ... + + def bulk_get_run_stats( + self, flow_ids: list[int] + ) -> dict[int, tuple[int, FlowRun | None]]: ... + + +# --------------------------------------------------------------------------- +# SQLAlchemy implementation +# --------------------------------------------------------------------------- + + +class SQLAlchemyCatalogRepository: + """Concrete ``CatalogRepository`` backed by a SQLAlchemy ``Session``.""" + + def __init__(self, db: Session) -> None: + self._db = db + + # -- Namespace operations ------------------------------------------------ + + def get_namespace(self, namespace_id: int) -> CatalogNamespace | None: + return self._db.get(CatalogNamespace, namespace_id) + + def get_namespace_by_name( + self, name: str, parent_id: int | None + ) -> CatalogNamespace | None: + return ( + self._db.query(CatalogNamespace) + .filter_by(name=name, parent_id=parent_id) + .first() + ) + + def list_namespaces(self, parent_id: int | None = None) -> list[CatalogNamespace]: + q = self._db.query(CatalogNamespace) + if parent_id is not None: + q = q.filter(CatalogNamespace.parent_id == parent_id) + else: + q = q.filter(CatalogNamespace.parent_id.is_(None)) + return q.order_by(CatalogNamespace.name).all() + + def list_root_namespaces(self) -> list[CatalogNamespace]: + return ( + self._db.query(CatalogNamespace) + .filter(CatalogNamespace.parent_id.is_(None)) + .order_by(CatalogNamespace.name) + .all() + ) + + def list_child_namespaces(self, parent_id: int) -> list[CatalogNamespace]: + return ( + self._db.query(CatalogNamespace) + .filter_by(parent_id=parent_id) + .order_by(CatalogNamespace.name) + .all() + ) + + def create_namespace(self, ns: CatalogNamespace) -> CatalogNamespace: + self._db.add(ns) + self._db.commit() + self._db.refresh(ns) + return ns + + def update_namespace(self, ns: CatalogNamespace) -> CatalogNamespace: + self._db.commit() + self._db.refresh(ns) + return ns + + def delete_namespace(self, namespace_id: int) -> None: + ns = self._db.get(CatalogNamespace, namespace_id) + if ns is not None: + self._db.delete(ns) + self._db.commit() + + def count_children(self, namespace_id: int) -> int: + return ( + self._db.query(CatalogNamespace) + .filter_by(parent_id=namespace_id) + .count() + ) + + # -- Flow registration operations ---------------------------------------- + + def get_flow(self, registration_id: int) -> FlowRegistration | None: + return self._db.get(FlowRegistration, registration_id) + + def get_flow_by_name( + self, name: str, namespace_id: int + ) -> FlowRegistration | None: + return ( + self._db.query(FlowRegistration) + .filter_by(name=name, namespace_id=namespace_id) + .first() + ) + + def get_flow_by_path(self, flow_path: str) -> FlowRegistration | None: + return ( + self._db.query(FlowRegistration) + .filter_by(flow_path=flow_path) + .first() + ) + + def list_flows( + self, + namespace_id: int | None = None, + owner_id: int | None = None, + ) -> list[FlowRegistration]: + q = self._db.query(FlowRegistration) + if namespace_id is not None: + q = q.filter_by(namespace_id=namespace_id) + if owner_id is not None: + q = q.filter_by(owner_id=owner_id) + return q.order_by(FlowRegistration.name).all() + + def create_flow(self, reg: FlowRegistration) -> FlowRegistration: + self._db.add(reg) + self._db.commit() + self._db.refresh(reg) + return reg + + def update_flow(self, reg: FlowRegistration) -> FlowRegistration: + self._db.commit() + self._db.refresh(reg) + return reg + + def delete_flow(self, registration_id: int) -> None: + # Clean up related records first + self._db.query(FlowFavorite).filter_by(registration_id=registration_id).delete() + self._db.query(FlowFollow).filter_by(registration_id=registration_id).delete() + flow = self._db.get(FlowRegistration, registration_id) + if flow is not None: + self._db.delete(flow) + self._db.commit() + + def count_flows_in_namespace(self, namespace_id: int) -> int: + return ( + self._db.query(FlowRegistration) + .filter_by(namespace_id=namespace_id) + .count() + ) + + # -- Run operations ------------------------------------------------------ + + def get_run(self, run_id: int) -> FlowRun | None: + return self._db.get(FlowRun, run_id) + + def list_runs( + self, + registration_id: int | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[FlowRun]: + q = self._db.query(FlowRun) + if registration_id is not None: + q = q.filter_by(registration_id=registration_id) + return ( + q.order_by(FlowRun.started_at.desc()) + .offset(offset) + .limit(limit) + .all() + ) + + def create_run(self, run: FlowRun) -> FlowRun: + self._db.add(run) + self._db.commit() + self._db.refresh(run) + return run + + def update_run(self, run: FlowRun) -> FlowRun: + self._db.commit() + self._db.refresh(run) + return run + + def count_runs(self) -> int: + return self._db.query(FlowRun).count() + + # -- Favorites ----------------------------------------------------------- + + def get_favorite( + self, user_id: int, registration_id: int + ) -> FlowFavorite | None: + return ( + self._db.query(FlowFavorite) + .filter_by(user_id=user_id, registration_id=registration_id) + .first() + ) + + def add_favorite(self, fav: FlowFavorite) -> FlowFavorite: + self._db.add(fav) + self._db.commit() + self._db.refresh(fav) + return fav + + def remove_favorite(self, user_id: int, registration_id: int) -> None: + fav = ( + self._db.query(FlowFavorite) + .filter_by(user_id=user_id, registration_id=registration_id) + .first() + ) + if fav is not None: + self._db.delete(fav) + self._db.commit() + + def list_favorites(self, user_id: int) -> list[FlowFavorite]: + return ( + self._db.query(FlowFavorite) + .filter_by(user_id=user_id) + .order_by(FlowFavorite.created_at.desc()) + .all() + ) + + def count_favorites(self, user_id: int) -> int: + return ( + self._db.query(FlowFavorite) + .filter_by(user_id=user_id) + .count() + ) + + # -- Follows ------------------------------------------------------------- + + def get_follow( + self, user_id: int, registration_id: int + ) -> FlowFollow | None: + return ( + self._db.query(FlowFollow) + .filter_by(user_id=user_id, registration_id=registration_id) + .first() + ) + + def add_follow(self, follow: FlowFollow) -> FlowFollow: + self._db.add(follow) + self._db.commit() + self._db.refresh(follow) + return follow + + def remove_follow(self, user_id: int, registration_id: int) -> None: + follow = ( + self._db.query(FlowFollow) + .filter_by(user_id=user_id, registration_id=registration_id) + .first() + ) + if follow is not None: + self._db.delete(follow) + self._db.commit() + + def list_follows(self, user_id: int) -> list[FlowFollow]: + return ( + self._db.query(FlowFollow) + .filter_by(user_id=user_id) + .order_by(FlowFollow.created_at.desc()) + .all() + ) + + # -- Aggregate helpers --------------------------------------------------- + + def count_run_for_flow(self, registration_id: int) -> int: + return ( + self._db.query(FlowRun) + .filter_by(registration_id=registration_id) + .count() + ) + + def last_run_for_flow(self, registration_id: int) -> FlowRun | None: + return ( + self._db.query(FlowRun) + .filter_by(registration_id=registration_id) + .order_by(FlowRun.started_at.desc()) + .first() + ) + + def count_catalog_namespaces(self) -> int: + return ( + self._db.query(CatalogNamespace) + .filter_by(level=0) + .count() + ) + + def count_all_flows(self) -> int: + return self._db.query(FlowRegistration).count() + + # -- Bulk enrichment helpers (for N+1 elimination) ----------------------- + + def bulk_get_favorite_flow_ids( + self, user_id: int, flow_ids: list[int] + ) -> set[int]: + """Return the subset of flow_ids that the user has favourited.""" + if not flow_ids: + return set() + rows = ( + self._db.query(FlowFavorite.registration_id) + .filter( + FlowFavorite.user_id == user_id, + FlowFavorite.registration_id.in_(flow_ids), + ) + .all() + ) + return {r[0] for r in rows} + + def bulk_get_follow_flow_ids( + self, user_id: int, flow_ids: list[int] + ) -> set[int]: + """Return the subset of flow_ids that the user is following.""" + if not flow_ids: + return set() + rows = ( + self._db.query(FlowFollow.registration_id) + .filter( + FlowFollow.user_id == user_id, + FlowFollow.registration_id.in_(flow_ids), + ) + .all() + ) + return {r[0] for r in rows} + + def bulk_get_run_stats( + self, flow_ids: list[int] + ) -> dict[int, tuple[int, FlowRun | None]]: + """Return run_count and last_run for each flow_id in one query batch. + + Returns a dict: flow_id -> (run_count, last_run_or_none) + """ + if not flow_ids: + return {} + + from sqlalchemy import func + + # Query 1: counts per registration_id + count_rows = ( + self._db.query( + FlowRun.registration_id, + func.count(FlowRun.id).label("cnt"), + ) + .filter(FlowRun.registration_id.in_(flow_ids)) + .group_by(FlowRun.registration_id) + .all() + ) + counts = {r[0]: r[1] for r in count_rows} + + # Query 2: last run per registration_id using a subquery for max started_at + subq = ( + self._db.query( + FlowRun.registration_id, + func.max(FlowRun.started_at).label("max_started"), + ) + .filter(FlowRun.registration_id.in_(flow_ids)) + .group_by(FlowRun.registration_id) + .subquery() + ) + last_runs_rows = ( + self._db.query(FlowRun) + .join( + subq, + (FlowRun.registration_id == subq.c.registration_id) + & (FlowRun.started_at == subq.c.max_started), + ) + .all() + ) + last_runs = {r.registration_id: r for r in last_runs_rows} + + # Build result dict + result: dict[int, tuple[int, FlowRun | None]] = {} + for fid in flow_ids: + result[fid] = (counts.get(fid, 0), last_runs.get(fid)) + return result diff --git a/flowfile_core/flowfile_core/catalog/service.py b/flowfile_core/flowfile_core/catalog/service.py new file mode 100644 index 000000000..1cead5bed --- /dev/null +++ b/flowfile_core/flowfile_core/catalog/service.py @@ -0,0 +1,672 @@ +"""Business-logic layer for the Flow Catalog system. + +``CatalogService`` encapsulates all domain rules (validation, authorisation, +enrichment) and delegates persistence to a ``CatalogRepository``. It never +raises ``HTTPException`` — only domain-specific exceptions from +``catalog.exceptions``. +""" + +from __future__ import annotations + +import os +from datetime import datetime, timezone + +from flowfile_core.catalog.exceptions import ( + FavoriteNotFoundError, + FlowNotFoundError, + FollowNotFoundError, + NamespaceExistsError, + NamespaceNotEmptyError, + NamespaceNotFoundError, + NestingLimitError, + NoSnapshotError, + RunNotFoundError, +) +from flowfile_core.catalog.repository import CatalogRepository +from flowfile_core.database.models import ( + CatalogNamespace, + FlowFavorite, + FlowFollow, + FlowRegistration, + FlowRun, +) +from flowfile_core.schemas.catalog_schema import ( + CatalogStats, + FlowRegistrationOut, + FlowRunDetail, + FlowRunOut, + NamespaceTree, +) + + +class CatalogService: + """Coordinates all catalog business logic. + + Parameters + ---------- + repo: + Any object satisfying the ``CatalogRepository`` protocol. + """ + + def __init__(self, repo: CatalogRepository) -> None: + self.repo = repo + + # ------------------------------------------------------------------ # + # Private helpers + # ------------------------------------------------------------------ # + + def _enrich_flow_registration( + self, flow: FlowRegistration, user_id: int + ) -> FlowRegistrationOut: + """Attach favourite/follow flags and run stats to a single registration. + + Note: For bulk operations, prefer ``_bulk_enrich_flows`` to avoid N+1 queries. + """ + is_fav = self.repo.get_favorite(user_id, flow.id) is not None + is_follow = self.repo.get_follow(user_id, flow.id) is not None + run_count = self.repo.count_run_for_flow(flow.id) + last_run = self.repo.last_run_for_flow(flow.id) + return FlowRegistrationOut( + id=flow.id, + name=flow.name, + description=flow.description, + flow_path=flow.flow_path, + namespace_id=flow.namespace_id, + owner_id=flow.owner_id, + created_at=flow.created_at, + updated_at=flow.updated_at, + is_favorite=is_fav, + is_following=is_follow, + run_count=run_count, + last_run_at=last_run.started_at if last_run else None, + last_run_success=last_run.success if last_run else None, + file_exists=os.path.exists(flow.flow_path) if flow.flow_path else False, + ) + + def _bulk_enrich_flows( + self, flows: list[FlowRegistration], user_id: int + ) -> list[FlowRegistrationOut]: + """Enrich multiple flows with favourites, follows, and run stats in bulk. + + Uses 3 queries total instead of 4×N, dramatically improving performance + when listing many flows. + """ + if not flows: + return [] + + flow_ids = [f.id for f in flows] + + # Bulk fetch all enrichment data (3 queries total) + fav_ids = self.repo.bulk_get_favorite_flow_ids(user_id, flow_ids) + follow_ids = self.repo.bulk_get_follow_flow_ids(user_id, flow_ids) + run_stats = self.repo.bulk_get_run_stats(flow_ids) + + result: list[FlowRegistrationOut] = [] + for flow in flows: + run_count, last_run = run_stats.get(flow.id, (0, None)) + result.append( + FlowRegistrationOut( + id=flow.id, + name=flow.name, + description=flow.description, + flow_path=flow.flow_path, + namespace_id=flow.namespace_id, + owner_id=flow.owner_id, + created_at=flow.created_at, + updated_at=flow.updated_at, + is_favorite=flow.id in fav_ids, + is_following=flow.id in follow_ids, + run_count=run_count, + last_run_at=last_run.started_at if last_run else None, + last_run_success=last_run.success if last_run else None, + file_exists=os.path.exists(flow.flow_path) if flow.flow_path else False, + ) + ) + return result + + @staticmethod + def _run_to_out(run: FlowRun) -> FlowRunOut: + return FlowRunOut( + id=run.id, + registration_id=run.registration_id, + flow_name=run.flow_name, + flow_path=run.flow_path, + user_id=run.user_id, + started_at=run.started_at, + ended_at=run.ended_at, + success=run.success, + nodes_completed=run.nodes_completed, + number_of_nodes=run.number_of_nodes, + duration_seconds=run.duration_seconds, + run_type=run.run_type, + has_snapshot=run.flow_snapshot is not None, + ) + + # ------------------------------------------------------------------ # + # Namespace operations + # ------------------------------------------------------------------ # + + def create_namespace( + self, + name: str, + owner_id: int, + parent_id: int | None = None, + description: str | None = None, + ) -> CatalogNamespace: + """Create a catalog (level 0) or schema (level 1) namespace. + + Raises + ------ + NamespaceNotFoundError + If ``parent_id`` is given but doesn't exist. + NestingLimitError + If the parent is already at level 1 (schema). + NamespaceExistsError + If a namespace with the same name already exists under the parent. + """ + level = 0 + if parent_id is not None: + parent = self.repo.get_namespace(parent_id) + if parent is None: + raise NamespaceNotFoundError(namespace_id=parent_id) + if parent.level >= 1: + raise NestingLimitError(parent_id=parent_id, parent_level=parent.level) + level = parent.level + 1 + + existing = self.repo.get_namespace_by_name(name, parent_id) + if existing is not None: + raise NamespaceExistsError(name=name, parent_id=parent_id) + + ns = CatalogNamespace( + name=name, + parent_id=parent_id, + level=level, + description=description, + owner_id=owner_id, + ) + return self.repo.create_namespace(ns) + + def update_namespace( + self, + namespace_id: int, + name: str | None = None, + description: str | None = None, + ) -> CatalogNamespace: + """Update a namespace's name and/or description. + + Raises + ------ + NamespaceNotFoundError + If the namespace doesn't exist. + """ + ns = self.repo.get_namespace(namespace_id) + if ns is None: + raise NamespaceNotFoundError(namespace_id=namespace_id) + if name is not None: + ns.name = name + if description is not None: + ns.description = description + return self.repo.update_namespace(ns) + + def delete_namespace(self, namespace_id: int) -> None: + """Delete a namespace if it has no children or flows. + + Raises + ------ + NamespaceNotFoundError + If the namespace doesn't exist. + NamespaceNotEmptyError + If the namespace has child namespaces or flow registrations. + """ + ns = self.repo.get_namespace(namespace_id) + if ns is None: + raise NamespaceNotFoundError(namespace_id=namespace_id) + children = self.repo.count_children(namespace_id) + flows = self.repo.count_flows_in_namespace(namespace_id) + if children > 0 or flows > 0: + raise NamespaceNotEmptyError( + namespace_id=namespace_id, children=children, flows=flows + ) + self.repo.delete_namespace(namespace_id) + + def get_namespace(self, namespace_id: int) -> CatalogNamespace: + """Retrieve a single namespace by ID. + + Raises + ------ + NamespaceNotFoundError + If the namespace doesn't exist. + """ + ns = self.repo.get_namespace(namespace_id) + if ns is None: + raise NamespaceNotFoundError(namespace_id=namespace_id) + return ns + + def list_namespaces(self, parent_id: int | None = None) -> list[CatalogNamespace]: + """List namespaces, optionally filtered by parent.""" + return self.repo.list_namespaces(parent_id) + + def get_namespace_tree(self, user_id: int) -> list[NamespaceTree]: + """Build the full catalog tree with flows nested under schemas. + + Uses bulk enrichment to avoid N+1 queries when there are many flows. + """ + catalogs = self.repo.list_root_namespaces() + + # Collect all flows first, then bulk-enrich them + all_flows: list[FlowRegistration] = [] + namespace_flow_map: dict[int, list[FlowRegistration]] = {} + + for cat in catalogs: + cat_flows = self.repo.list_flows(namespace_id=cat.id) + namespace_flow_map[cat.id] = cat_flows + all_flows.extend(cat_flows) + + for schema in self.repo.list_child_namespaces(cat.id): + schema_flows = self.repo.list_flows(namespace_id=schema.id) + namespace_flow_map[schema.id] = schema_flows + all_flows.extend(schema_flows) + + # Bulk enrich all flows at once + enriched = self._bulk_enrich_flows(all_flows, user_id) + enriched_map = {e.id: e for e in enriched} + + # Build tree structure + result: list[NamespaceTree] = [] + for cat in catalogs: + schemas = self.repo.list_child_namespaces(cat.id) + children: list[NamespaceTree] = [] + for schema in schemas: + schema_flows = namespace_flow_map.get(schema.id, []) + flow_outs = [enriched_map[f.id] for f in schema_flows if f.id in enriched_map] + children.append( + NamespaceTree( + id=schema.id, + name=schema.name, + parent_id=schema.parent_id, + level=schema.level, + description=schema.description, + owner_id=schema.owner_id, + created_at=schema.created_at, + updated_at=schema.updated_at, + children=[], + flows=flow_outs, + ) + ) + cat_flows = namespace_flow_map.get(cat.id, []) + root_flow_outs = [enriched_map[f.id] for f in cat_flows if f.id in enriched_map] + result.append( + NamespaceTree( + id=cat.id, + name=cat.name, + parent_id=cat.parent_id, + level=cat.level, + description=cat.description, + owner_id=cat.owner_id, + created_at=cat.created_at, + updated_at=cat.updated_at, + children=children, + flows=root_flow_outs, + ) + ) + return result + + def get_default_namespace_id(self) -> int | None: + """Return the ID of the default 'user_flows' schema under 'General'.""" + general = self.repo.get_namespace_by_name("General", parent_id=None) + if general is None: + return None + user_flows = self.repo.get_namespace_by_name("user_flows", parent_id=general.id) + if user_flows is None: + return None + return user_flows.id + + # ------------------------------------------------------------------ # + # Flow registration operations + # ------------------------------------------------------------------ # + + def register_flow( + self, + name: str, + flow_path: str, + owner_id: int, + namespace_id: int | None = None, + description: str | None = None, + ) -> FlowRegistrationOut: + """Register a new flow in the catalog. + + Raises + ------ + NamespaceNotFoundError + If ``namespace_id`` is given but doesn't exist. + """ + if namespace_id is not None: + ns = self.repo.get_namespace(namespace_id) + if ns is None: + raise NamespaceNotFoundError(namespace_id=namespace_id) + flow = FlowRegistration( + name=name, + description=description, + flow_path=flow_path, + namespace_id=namespace_id, + owner_id=owner_id, + ) + flow = self.repo.create_flow(flow) + return self._enrich_flow_registration(flow, owner_id) + + def update_flow( + self, + registration_id: int, + requesting_user_id: int, + name: str | None = None, + description: str | None = None, + namespace_id: int | None = None, + ) -> FlowRegistrationOut: + """Update a flow registration. + + Raises + ------ + FlowNotFoundError + If the flow doesn't exist. + """ + flow = self.repo.get_flow(registration_id) + if flow is None: + raise FlowNotFoundError(registration_id=registration_id) + if name is not None: + flow.name = name + if description is not None: + flow.description = description + if namespace_id is not None: + flow.namespace_id = namespace_id + flow = self.repo.update_flow(flow) + return self._enrich_flow_registration(flow, requesting_user_id) + + def delete_flow(self, registration_id: int) -> None: + """Delete a flow and its related favourites/follows. + + Raises + ------ + FlowNotFoundError + If the flow doesn't exist. + """ + flow = self.repo.get_flow(registration_id) + if flow is None: + raise FlowNotFoundError(registration_id=registration_id) + self.repo.delete_flow(registration_id) + + def get_flow(self, registration_id: int, user_id: int) -> FlowRegistrationOut: + """Get an enriched flow registration. + + Raises + ------ + FlowNotFoundError + If the flow doesn't exist. + """ + flow = self.repo.get_flow(registration_id) + if flow is None: + raise FlowNotFoundError(registration_id=registration_id) + return self._enrich_flow_registration(flow, user_id) + + def list_flows( + self, user_id: int, namespace_id: int | None = None + ) -> list[FlowRegistrationOut]: + """List flows, optionally filtered by namespace, enriched with user context. + + Uses bulk enrichment to avoid N+1 queries. + """ + flows = self.repo.list_flows(namespace_id=namespace_id) + return self._bulk_enrich_flows(flows, user_id) + + # ------------------------------------------------------------------ # + # Run operations + # ------------------------------------------------------------------ # + + def list_runs( + self, + registration_id: int | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[FlowRunOut]: + """List run summaries (without snapshots).""" + runs = self.repo.list_runs( + registration_id=registration_id, limit=limit, offset=offset + ) + return [self._run_to_out(r) for r in runs] + + def get_run_detail(self, run_id: int) -> FlowRunDetail: + """Get a single run including the YAML snapshot. + + Raises + ------ + RunNotFoundError + If the run doesn't exist. + """ + run = self.repo.get_run(run_id) + if run is None: + raise RunNotFoundError(run_id=run_id) + return FlowRunDetail( + id=run.id, + registration_id=run.registration_id, + flow_name=run.flow_name, + flow_path=run.flow_path, + user_id=run.user_id, + started_at=run.started_at, + ended_at=run.ended_at, + success=run.success, + nodes_completed=run.nodes_completed, + number_of_nodes=run.number_of_nodes, + duration_seconds=run.duration_seconds, + run_type=run.run_type, + has_snapshot=run.flow_snapshot is not None, + flow_snapshot=run.flow_snapshot, + node_results_json=run.node_results_json, + ) + + def get_run(self, run_id: int) -> FlowRun: + """Get a raw FlowRun model. + + Raises + ------ + RunNotFoundError + If the run doesn't exist. + """ + run = self.repo.get_run(run_id) + if run is None: + raise RunNotFoundError(run_id=run_id) + return run + + def start_run( + self, + registration_id: int | None, + flow_name: str, + flow_path: str | None, + user_id: int, + number_of_nodes: int, + run_type: str = "full_run", + flow_snapshot: str | None = None, + ) -> FlowRun: + """Record a new flow run start.""" + run = FlowRun( + registration_id=registration_id, + flow_name=flow_name, + flow_path=flow_path, + user_id=user_id, + started_at=datetime.now(timezone.utc), + number_of_nodes=number_of_nodes, + run_type=run_type, + flow_snapshot=flow_snapshot, + ) + return self.repo.create_run(run) + + def complete_run( + self, + run_id: int, + success: bool, + nodes_completed: int, + node_results_json: str | None = None, + ) -> FlowRun: + """Mark a run as completed. + + Raises + ------ + RunNotFoundError + If the run doesn't exist. + """ + run = self.repo.get_run(run_id) + if run is None: + raise RunNotFoundError(run_id=run_id) + now = datetime.now(timezone.utc) + run.ended_at = now + run.success = success + run.nodes_completed = nodes_completed + if run.started_at: + run.duration_seconds = (now - run.started_at).total_seconds() + if node_results_json is not None: + run.node_results_json = node_results_json + return self.repo.update_run(run) + + def get_run_snapshot(self, run_id: int) -> str: + """Return the flow snapshot text for a run. + + Raises + ------ + RunNotFoundError + If the run doesn't exist. + NoSnapshotError + If the run has no snapshot. + """ + run = self.repo.get_run(run_id) + if run is None: + raise RunNotFoundError(run_id=run_id) + if not run.flow_snapshot: + raise NoSnapshotError(run_id=run_id) + return run.flow_snapshot + + # ------------------------------------------------------------------ # + # Favorites + # ------------------------------------------------------------------ # + + def add_favorite(self, user_id: int, registration_id: int) -> FlowFavorite: + """Add a flow to user's favourites (idempotent). + + Raises + ------ + FlowNotFoundError + If the flow doesn't exist. + """ + flow = self.repo.get_flow(registration_id) + if flow is None: + raise FlowNotFoundError(registration_id=registration_id) + existing = self.repo.get_favorite(user_id, registration_id) + if existing is not None: + return existing + fav = FlowFavorite(user_id=user_id, registration_id=registration_id) + return self.repo.add_favorite(fav) + + def remove_favorite(self, user_id: int, registration_id: int) -> None: + """Remove a flow from user's favourites. + + Raises + ------ + FavoriteNotFoundError + If the favourite doesn't exist. + """ + existing = self.repo.get_favorite(user_id, registration_id) + if existing is None: + raise FavoriteNotFoundError(user_id=user_id, registration_id=registration_id) + self.repo.remove_favorite(user_id, registration_id) + + def list_favorites(self, user_id: int) -> list[FlowRegistrationOut]: + """List all flows the user has favourited, enriched. + + Uses bulk enrichment to avoid N+1 queries. + """ + favs = self.repo.list_favorites(user_id) + flows: list[FlowRegistration] = [] + for fav in favs: + flow = self.repo.get_flow(fav.registration_id) + if flow is not None: + flows.append(flow) + return self._bulk_enrich_flows(flows, user_id) + + # ------------------------------------------------------------------ # + # Follows + # ------------------------------------------------------------------ # + + def add_follow(self, user_id: int, registration_id: int) -> FlowFollow: + """Follow a flow (idempotent). + + Raises + ------ + FlowNotFoundError + If the flow doesn't exist. + """ + flow = self.repo.get_flow(registration_id) + if flow is None: + raise FlowNotFoundError(registration_id=registration_id) + existing = self.repo.get_follow(user_id, registration_id) + if existing is not None: + return existing + follow = FlowFollow(user_id=user_id, registration_id=registration_id) + return self.repo.add_follow(follow) + + def remove_follow(self, user_id: int, registration_id: int) -> None: + """Unfollow a flow. + + Raises + ------ + FollowNotFoundError + If the follow record doesn't exist. + """ + existing = self.repo.get_follow(user_id, registration_id) + if existing is None: + raise FollowNotFoundError(user_id=user_id, registration_id=registration_id) + self.repo.remove_follow(user_id, registration_id) + + def list_following(self, user_id: int) -> list[FlowRegistrationOut]: + """List all flows the user is following, enriched. + + Uses bulk enrichment to avoid N+1 queries. + """ + follows = self.repo.list_follows(user_id) + flows: list[FlowRegistration] = [] + for follow in follows: + flow = self.repo.get_flow(follow.registration_id) + if flow is not None: + flows.append(flow) + return self._bulk_enrich_flows(flows, user_id) + + # ------------------------------------------------------------------ # + # Dashboard / Stats + # ------------------------------------------------------------------ # + + def get_catalog_stats(self, user_id: int) -> CatalogStats: + """Return an overview of the catalog for the dashboard. + + Uses bulk enrichment for favourite flows to avoid N+1 queries. + """ + total_ns = self.repo.count_catalog_namespaces() + total_flows = self.repo.count_all_flows() + total_runs = self.repo.count_runs() + total_favs = self.repo.count_favorites(user_id) + + recent_runs = self.repo.list_runs(limit=10, offset=0) + recent_out = [self._run_to_out(r) for r in recent_runs] + + # Bulk enrich favourite flows + favs = self.repo.list_favorites(user_id) + flows: list[FlowRegistration] = [] + for fav in favs: + flow = self.repo.get_flow(fav.registration_id) + if flow is not None: + flows.append(flow) + fav_flows = self._bulk_enrich_flows(flows, user_id) + + return CatalogStats( + total_namespaces=total_ns, + total_flows=total_flows, + total_runs=total_runs, + total_favorites=total_favs, + recent_runs=recent_out, + favorite_flows=fav_flows, + ) diff --git a/flowfile_core/flowfile_core/configs/node_store/nodes.py b/flowfile_core/flowfile_core/configs/node_store/nodes.py index ef6e7840d..6ccfe6ac5 100644 --- a/flowfile_core/flowfile_core/configs/node_store/nodes.py +++ b/flowfile_core/flowfile_core/configs/node_store/nodes.py @@ -286,6 +286,20 @@ def get_all_standard_nodes() -> tuple[list[NodeTemplate], dict[str, NodeTemplate drawer_title="Polars Code", drawer_intro="Write custom Polars DataFrame transformations", ), + NodeTemplate( + name="Python Script", + item="python_script", + input=10, + output=1, + transform_type="narrow", + image="python_code.svg", + node_group="transform", + multi=True, + can_be_start=True, + node_type="process", + drawer_title="Python Script", + drawer_intro="Execute Python code on an isolated kernel container", + ), NodeTemplate( name="Read from Database", item="database_reader", diff --git a/flowfile_core/flowfile_core/database/models.py b/flowfile_core/flowfile_core/database/models.py index e7cb7438a..ff43168ab 100644 --- a/flowfile_core/flowfile_core/database/models.py +++ b/flowfile_core/flowfile_core/database/models.py @@ -1,3 +1,4 @@ + from sqlalchemy import Boolean, Column, DateTime, Float, ForeignKey, Integer, String, Text, UniqueConstraint from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func @@ -90,6 +91,17 @@ class CloudStoragePermission(Base): can_list = Column(Boolean, default=True) +class Kernel(Base): + __tablename__ = "kernels" + + id = Column(String, primary_key=True, index=True) + name = Column(String, nullable=False) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) + packages = Column(Text, default="[]") # JSON-serialized list of package names + cpu_cores = Column(Float, default=2.0) + memory_gb = Column(Float, default=4.0) + gpu = Column(Boolean, default=False) + created_at = Column(DateTime, default=func.now(), nullable=False) # ==================== Flow Catalog Models ==================== diff --git a/flowfile_core/flowfile_core/flowfile/artifacts.py b/flowfile_core/flowfile_core/flowfile/artifacts.py new file mode 100644 index 000000000..2b0e81862 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/artifacts.py @@ -0,0 +1,485 @@ +"""Artifact context tracking for the FlowGraph. + +This module provides metadata tracking for Python artifacts that are +published and consumed by ``python_script`` nodes running on kernel +containers. The actual objects remain in kernel memory; this module +only tracks *references* (name, source node, type info, etc.) so the +FlowGraph can reason about artifact availability across the DAG. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ArtifactRef: + """Metadata reference to an artifact (not the object itself).""" + + name: str + source_node_id: int + kernel_id: str = "" + type_name: str = "" + module: str = "" + size_bytes: int = 0 + created_at: datetime = field(default_factory=datetime.now) + + def to_dict(self) -> dict[str, Any]: + return { + "name": self.name, + "source_node_id": self.source_node_id, + "kernel_id": self.kernel_id, + "type_name": self.type_name, + "module": self.module, + "size_bytes": self.size_bytes, + "created_at": self.created_at.isoformat(), + } + + +@dataclass +class NodeArtifactState: + """Artifact state for a single node.""" + + published: list[ArtifactRef] = field(default_factory=list) + available: dict[str, ArtifactRef] = field(default_factory=dict) + consumed: list[str] = field(default_factory=list) + deleted: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "published": [r.to_dict() for r in self.published], + "available": {k: v.to_dict() for k, v in self.available.items()}, + "consumed": list(self.consumed), + "deleted": list(self.deleted), + } + + +class ArtifactContext: + """Tracks artifact availability across the flow graph. + + This is a metadata-only tracker. Actual Python objects stay inside + the kernel container's ``ArtifactStore``. + """ + + def __init__(self) -> None: + self._node_states: dict[int, NodeArtifactState] = {} + self._kernel_artifacts: dict[str, dict[str, ArtifactRef]] = {} + # Reverse index: (kernel_id, artifact_name) → set of node_ids that + # published it. Avoids O(N) scan in record_deleted / clear_kernel. + self._publisher_index: dict[tuple[str, str], set[int]] = {} + # Tracks which nodes produced the artifacts that were deleted by each + # node. Used during re-execution to force producers to re-run when + # a consumer that deleted their artifacts needs to re-execute. + # Maps: deleter_node_id → [(kernel_id, artifact_name, publisher_node_id), …] + self._deletion_origins: dict[int, list[tuple[str, str, int]]] = {} + + # ------------------------------------------------------------------ + # Recording + # ------------------------------------------------------------------ + + def record_published( + self, + node_id: int, + kernel_id: str, + artifacts: list[dict[str, Any] | str], + ) -> list[ArtifactRef]: + """Record artifacts published by *node_id*. + + ``artifacts`` may be a list of dicts (with at least a ``"name"`` key) + or a plain list of artifact name strings. + + If an artifact with the same name was already published by this node, + it is replaced (node_id + artifact_name are unique). + + Returns the created :class:`ArtifactRef` objects. + """ + state = self._get_or_create_state(node_id) + refs: list[ArtifactRef] = [] + for item in artifacts: + if isinstance(item, str): + item = {"name": item} + artifact_name = item["name"] + + # Remove any existing artifact with the same name from this node + # to ensure (node_id, artifact_name) uniqueness + state.published = [ + r for r in state.published + if not (r.name == artifact_name and r.kernel_id == kernel_id) + ] + + ref = ArtifactRef( + name=artifact_name, + source_node_id=node_id, + kernel_id=kernel_id, + type_name=item.get("type_name", ""), + module=item.get("module", ""), + size_bytes=item.get("size_bytes", 0), + created_at=datetime.now(timezone.utc), + ) + refs.append(ref) + state.published.append(ref) + + # Update the per-kernel index + kernel_map = self._kernel_artifacts.setdefault(kernel_id, {}) + kernel_map[ref.name] = ref + + # Update the reverse index + key = (kernel_id, ref.name) + self._publisher_index.setdefault(key, set()).add(node_id) + + logger.debug( + "Node %s published %d artifact(s) on kernel '%s': %s", + node_id, + len(refs), + kernel_id, + [r.name for r in refs], + ) + return refs + + def record_consumed(self, node_id: int, artifact_names: list[str]) -> None: + """Record that *node_id* consumed (read) the given artifact names.""" + state = self._get_or_create_state(node_id) + state.consumed.extend(artifact_names) + + def record_deleted( + self, + node_id: int, + kernel_id: str, + artifact_names: list[str], + ) -> None: + """Record that *node_id* deleted the given artifacts from *kernel_id*. + + Removes the artifacts from the kernel index so they are no longer + available to downstream nodes. The original publisher's + ``state.published`` list is **not** modified — it serves as a + permanent record of what the node produced. + """ + state = self._get_or_create_state(node_id) + state.deleted.extend(artifact_names) + + kernel_map = self._kernel_artifacts.get(kernel_id, {}) + for name in artifact_names: + kernel_map.pop(name, None) + # Clean up the reverse index entry but leave published intact + key = (kernel_id, name) + publisher_ids = self._publisher_index.pop(key, set()) + + # Remember which nodes produced these artifacts so we can + # force them to re-run if this deleter node is re-executed. + for pid in publisher_ids: + self._deletion_origins.setdefault(node_id, []).append( + (kernel_id, name, pid) + ) + # NOTE: We do NOT remove from publisher's published list here. + # The published list serves as a permanent historical record + # for visualization (badges showing what the node produced). + + logger.debug( + "Node %s deleted %d artifact(s) on kernel '%s': %s", + node_id, + len(artifact_names), + kernel_id, + artifact_names, + ) + + # ------------------------------------------------------------------ + # Availability computation + # ------------------------------------------------------------------ + + def compute_available( + self, + node_id: int, + kernel_id: str, + upstream_node_ids: list[int], + ) -> dict[str, ArtifactRef]: + """Compute which artifacts are available to *node_id*. + + An artifact is available if it was published by an upstream node + (direct or transitive) that used the **same** ``kernel_id`` and + has **not** been deleted by a later upstream node. + + Upstream nodes are processed in topological order (sorted by node ID). + For each node, deletions are applied first, then publications — so + a later node can delete-then-republish an artifact and the new + version will be available downstream. + + The result is stored on the node's :class:`NodeArtifactState` and + also returned. + """ + available: dict[str, ArtifactRef] = {} + + # Sort by node ID to ensure topological processing order + # (FlowGraph._get_upstream_node_ids returns BFS order which is reversed) + for uid in sorted(upstream_node_ids): + upstream_state = self._node_states.get(uid) + if upstream_state is None: + continue + # First, remove artifacts deleted by this upstream node + for name in upstream_state.deleted: + available.pop(name, None) + # Then, add artifacts published by this upstream node + for ref in upstream_state.published: + if ref.kernel_id == kernel_id: + available[ref.name] = ref + + state = self._get_or_create_state(node_id) + state.available = available + + logger.debug( + "Node %s has %d available artifact(s): %s", + node_id, + len(available), + list(available.keys()), + ) + return available + + # ------------------------------------------------------------------ + # Queries + # ------------------------------------------------------------------ + + def get_published_by_node(self, node_id: int) -> list[ArtifactRef]: + """Return artifacts published by *node_id* (empty list if unknown).""" + state = self._node_states.get(node_id) + if state is None: + return [] + return list(state.published) + + def get_available_for_node(self, node_id: int) -> dict[str, ArtifactRef]: + """Return the availability map for *node_id* (empty dict if unknown).""" + state = self._node_states.get(node_id) + if state is None: + return {} + return dict(state.available) + + def get_kernel_artifacts(self, kernel_id: str) -> dict[str, ArtifactRef]: + """Return all known artifacts for a given kernel.""" + return dict(self._kernel_artifacts.get(kernel_id, {})) + + def get_all_artifacts(self) -> dict[str, ArtifactRef]: + """Return every tracked artifact across all kernels.""" + result: dict[str, ArtifactRef] = {} + for kernel_map in self._kernel_artifacts.values(): + result.update(kernel_map) + return result + + def get_producer_nodes_for_deletions( + self, deleter_node_ids: set[int], + ) -> set[int]: + """Return node IDs that produced artifacts deleted by *deleter_node_ids*. + + When a consumer node that previously deleted artifacts needs to + re-execute, the original producer nodes must also re-run so the + artifacts are available again in the kernel's in-memory store. + """ + producers: set[int] = set() + for nid in deleter_node_ids: + for _kernel_id, _name, pub_id in self._deletion_origins.get(nid, []): + producers.add(pub_id) + return producers + + # ------------------------------------------------------------------ + # Clearing + # ------------------------------------------------------------------ + + def clear_kernel(self, kernel_id: str) -> None: + """Remove tracking for a specific kernel. + + Clears the kernel index and availability maps. The ``published`` + lists on node states are preserved as historical records. + """ + # Clean reverse index entries for this kernel + keys_to_remove = [k for k in self._publisher_index if k[0] == kernel_id] + for k in keys_to_remove: + del self._publisher_index[k] + + # Clean deletion origin entries for this kernel + for nid in list(self._deletion_origins): + self._deletion_origins[nid] = [ + entry for entry in self._deletion_origins[nid] + if entry[0] != kernel_id + ] + if not self._deletion_origins[nid]: + del self._deletion_origins[nid] + + self._kernel_artifacts.pop(kernel_id, None) + for state in self._node_states.values(): + state.available = { + k: v for k, v in state.available.items() if v.kernel_id != kernel_id + } + + def clear_all(self) -> None: + """Remove all tracking data.""" + self._node_states.clear() + self._kernel_artifacts.clear() + self._publisher_index.clear() + self._deletion_origins.clear() + + def clear_nodes(self, node_ids: set[int]) -> None: + """Remove tracking data only for the specified *node_ids*. + + Artifacts published by these nodes are removed from kernel + indices and publisher indices. States for other nodes are + left untouched so their artifact metadata is preserved. + """ + for nid in node_ids: + self._deletion_origins.pop(nid, None) + state = self._node_states.pop(nid, None) + if state is None: + continue + for ref in state.published: + # Remove from the kernel artifact index + kernel_map = self._kernel_artifacts.get(ref.kernel_id) + if kernel_map is not None: + # Only remove if this ref is still the current entry + existing = kernel_map.get(ref.name) + if existing is not None and existing.source_node_id == nid: + del kernel_map[ref.name] + # Remove from the reverse publisher index + key = (ref.kernel_id, ref.name) + pub_set = self._publisher_index.get(key) + if pub_set is not None: + pub_set.discard(nid) + if not pub_set: + del self._publisher_index[key] + + logger.debug( + "Cleared artifact metadata for node(s): %s", sorted(node_ids) + ) + + def snapshot_node_states(self) -> dict[int, NodeArtifactState]: + """Return a shallow copy of the current per-node states. + + Useful for saving state before ``clear_all()`` so cached + (skipped) nodes can have their artifact state restored afterwards. + """ + return dict(self._node_states) + + def restore_node_state(self, node_id: int, state: NodeArtifactState) -> None: + """Re-insert a previously-snapshotted node state. + + Rebuilds the kernel index and reverse index entries for every + published artifact in *state*. + """ + self._node_states[node_id] = state + for ref in state.published: + kernel_map = self._kernel_artifacts.setdefault(ref.kernel_id, {}) + kernel_map[ref.name] = ref + key = (ref.kernel_id, ref.name) + self._publisher_index.setdefault(key, set()).add(node_id) + + # ------------------------------------------------------------------ + # Visualisation helpers + # ------------------------------------------------------------------ + + def get_artifact_edges(self) -> list[dict[str, Any]]: + """Build a list of artifact edges for canvas visualisation. + + Each edge connects a publisher node to every consumer node that + consumed one of its artifacts (on the same kernel). + + Returns a list of dicts with keys: + source, target, artifact_name, artifact_type, kernel_id + """ + edges: list[dict[str, Any]] = [] + seen: set[tuple[int, int, str]] = set() + + for nid, state in self._node_states.items(): + if not state.consumed: + continue + for art_name in state.consumed: + # Look up the publisher via the available map first + ref = state.available.get(art_name) + if ref is None: + # Fallback: scan kernel artifacts + for km in self._kernel_artifacts.values(): + if art_name in km: + ref = km[art_name] + break + if ref is None: + continue + key = (ref.source_node_id, nid, art_name) + if key in seen: + continue + seen.add(key) + edges.append({ + "source": ref.source_node_id, + "target": nid, + "artifact_name": art_name, + "artifact_type": ref.type_name, + "kernel_id": ref.kernel_id, + }) + + return edges + + def get_node_summaries(self) -> dict[str, dict[str, Any]]: + """Return per-node artifact summary for badge/tab display. + + Returns a dict keyed by str(node_id) with: + published_count, consumed_count, deleted_count, + published, consumed, deleted, kernel_id + """ + summaries: dict[str, dict[str, Any]] = {} + for nid, state in self._node_states.items(): + if not state.published and not state.consumed and not state.deleted: + continue + kernel_id = "" + if state.published: + kernel_id = state.published[0].kernel_id + summaries[str(nid)] = { + "published_count": len(state.published), + "consumed_count": len(state.consumed), + "deleted_count": len(state.deleted), + "published": [ + { + "name": r.name, + "type_name": r.type_name, + "module": r.module, + } + for r in state.published + ], + "consumed": [ + { + "name": name, + "source_node_id": state.available[name].source_node_id + if name in state.available + else None, + "type_name": state.available[name].type_name + if name in state.available + else "", + } + for name in state.consumed + ], + "deleted": list(state.deleted), + "kernel_id": kernel_id, + } + return summaries + + # ------------------------------------------------------------------ + # Serialisation + # ------------------------------------------------------------------ + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serialisable summary of the context.""" + return { + "nodes": { + str(nid): state.to_dict() for nid, state in self._node_states.items() + }, + "kernels": { + kid: {name: ref.to_dict() for name, ref in refs.items()} + for kid, refs in self._kernel_artifacts.items() + }, + } + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_or_create_state(self, node_id: int) -> NodeArtifactState: + if node_id not in self._node_states: + self._node_states[node_id] = NodeArtifactState() + return self._node_states[node_id] diff --git a/flowfile_core/flowfile_core/flowfile/flow_graph.py b/flowfile_core/flowfile_core/flowfile/flow_graph.py index 323d79e7a..24dc9656c 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_graph.py +++ b/flowfile_core/flowfile_core/flowfile/flow_graph.py @@ -22,11 +22,14 @@ from flowfile_core.configs import logger from flowfile_core.configs.flow_logger import FlowLogger from flowfile_core.configs.node_store import CUSTOM_NODE_STORE +from flowfile_core.configs.settings import SERVER_PORT from flowfile_core.flowfile.analytics.utils import create_graphic_walker_node_from_node_promise +from flowfile_core.flowfile.artifacts import ArtifactContext from flowfile_core.flowfile.database_connection_manager.db_connections import ( get_local_cloud_connection, get_local_database_connection, ) +from flowfile_core.flowfile.filter_expressions import build_filter_expression from flowfile_core.flowfile.flow_data_engine.cloud_storage_reader import CloudStorageReader from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine, execute_polars_code from flowfile_core.flowfile.flow_data_engine.flow_file_column.main import FlowfileColumn, cast_str_to_polars_type @@ -60,10 +63,10 @@ from flowfile_core.flowfile.sources.external_sources.sql_source import models as sql_models from flowfile_core.flowfile.sources.external_sources.sql_source import utils as sql_utils from flowfile_core.flowfile.sources.external_sources.sql_source.sql_source import BaseSqlSource, SqlSource -from flowfile_core.flowfile.filter_expressions import build_filter_expression from flowfile_core.flowfile.util.calculate_layout import calculate_layered_layout from flowfile_core.flowfile.util.execution_orderer import ExecutionPlan, ExecutionStage, compute_execution_plan from flowfile_core.flowfile.utils import snake_case_to_camel_case +from flowfile_core.kernel import ExecuteRequest, get_kernel_manager from flowfile_core.schemas import input_schema, schemas, transform_schema from flowfile_core.schemas.cloud_storage_schemas import ( AuthMethod, @@ -356,6 +359,7 @@ def __init__( self.cache_results = cache_results self.__name__ = name if name else "flow_" + str(id(self)) self.depends_on = {} + self.artifact_context = ArtifactContext() # Initialize history manager for undo/redo support from flowfile_core.flowfile.history_manager import HistoryManager @@ -1116,6 +1120,110 @@ def _func(*flowfile_tables: FlowDataEngine) -> FlowDataEngine: node = self.get_node(node_id=node_polars_code.node_id) node.results.errors = str(e) + @with_history_capture(HistoryActionType.UPDATE_SETTINGS) + def add_python_script(self, node_python_script: input_schema.NodePythonScript): + """Adds a node that executes Python code on a kernel container.""" + + def _func(*flowfile_tables: FlowDataEngine) -> FlowDataEngine: + + kernel_id = node_python_script.python_script_input.kernel_id + code = node_python_script.python_script_input.code + + if not kernel_id: + raise ValueError("No kernel selected for python_script node") + + manager = get_kernel_manager() + + node_id = node_python_script.node_id + flow_id = self.flow_id + node_logger = self.flow_logger.get_node_logger(node_id) + + # Compute available artifacts before execution + upstream_ids = self._get_upstream_node_ids(node_id) + self.artifact_context.compute_available( + node_id=node_id, + kernel_id=kernel_id, + upstream_node_ids=upstream_ids, + ) + + shared_base = manager.shared_volume_path + input_dir = os.path.join(shared_base, str(flow_id), str(node_id), "inputs") + output_dir = os.path.join(shared_base, str(flow_id), str(node_id), "outputs") + + os.makedirs(input_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + # Write inputs to parquet — supports N inputs under "main" + input_paths: dict[str, list[str]] = {} + main_paths: list[str] = [] + for idx, ft in enumerate(flowfile_tables): + filename = f"main_{idx}.parquet" + local_path = os.path.join(input_dir, filename) + ft.data_frame.collect().write_parquet(local_path) + # Ensure the file is fully flushed to disk before the kernel reads it + # This prevents "File must end with PAR1" errors from race conditions + with open(local_path, "rb") as f: + os.fsync(f.fileno()) + main_paths.append(f"/shared/{flow_id}/{node_id}/inputs/{filename}") + input_paths["main"] = main_paths + + # Build the callback URL so the kernel can stream logs in real time + log_callback_url = f"http://host.docker.internal:{SERVER_PORT}/raw_logs" + + # Execute on kernel (synchronous — no async boundary issues) + request = ExecuteRequest( + node_id=node_id, + code=code, + input_paths=input_paths, + output_dir=f"/shared/{flow_id}/{node_id}/outputs", + flow_id=flow_id, + log_callback_url=log_callback_url, + ) + result = manager.execute_sync(kernel_id, request, self.flow_logger) + + # Forward captured stdout/stderr to the flow logger + if result.stdout: + for line in result.stdout.strip().splitlines(): + node_logger.info(f"[stdout] {line}") + if result.stderr: + for line in result.stderr.strip().splitlines(): + node_logger.warning(f"[stderr] {line}") + + if not result.success: + raise RuntimeError(f"Kernel execution failed: {result.error}") + + # Record published artifacts after successful execution + if result.artifacts_published: + self.artifact_context.record_published( + node_id=node_id, + kernel_id=kernel_id, + artifacts=[{"name": n} for n in result.artifacts_published], + ) + + # Record deleted artifacts after successful execution + if result.artifacts_deleted: + self.artifact_context.record_deleted( + node_id=node_id, + kernel_id=kernel_id, + artifact_names=result.artifacts_deleted, + ) + + # Read output + output_path = os.path.join(output_dir, "main.parquet") + if os.path.exists(output_path): + return FlowDataEngine(pl.scan_parquet(output_path)) + + # No output published, pass through first input + return flowfile_tables[0] if flowfile_tables else FlowDataEngine(pl.LazyFrame()) + + self.add_node_step( + node_id=node_python_script.node_id, + function=_func, + node_type="python_script", + setting_input=node_python_script, + input_node_ids=node_python_script.depending_on_ids, + ) + def add_dependency_on_polars_lazy_frame(self, lazy_frame: pl.LazyFrame, node_id: int): """Adds a special node that directly injects a Polars LazyFrame into the graph. @@ -2292,6 +2400,90 @@ def trigger_fetch_node(self, node_id: int) -> RunInformation | None: finally: self.flow_settings.is_running = False + # ------------------------------------------------------------------ + # Artifact helpers + # ------------------------------------------------------------------ + + def _get_upstream_node_ids(self, node_id: int) -> list[int]: + """Get all upstream node IDs (direct and transitive) for *node_id*. + + Traverses the ``all_inputs`` links recursively and returns a + deduplicated list in breadth-first order. + """ + node = self.get_node(node_id) + if node is None: + return [] + + visited: set[int] = set() + result: list[int] = [] + queue = list(node.all_inputs) + while queue: + current = queue.pop(0) + cid = current.node_id + if cid in visited: + continue + visited.add(cid) + result.append(cid) + queue.extend(current.all_inputs) + return result + + def _get_required_kernel_ids(self) -> set[str]: + """Return the set of kernel IDs used by ``python_script`` nodes.""" + kernel_ids: set[str] = set() + for node in self.nodes: + if node.node_type == "python_script" and node.setting_input is not None: + kid = getattr( + getattr(node.setting_input, "python_script_input", None), + "kernel_id", + None, + ) + if kid: + kernel_ids.add(kid) + return kernel_ids + + def _compute_rerun_python_script_node_ids( + self, plan_skip_ids: set[str | int], + ) -> set[int]: + """Return node IDs for ``python_script`` nodes that will re-execute. + + A python_script node will re-execute (and thus needs its old + artifacts cleared) when: + + * It is NOT in the execution-plan skip set, **and** + * Its execution state indicates it has NOT already run with the + current setup (i.e. its cache is stale or it never ran). + """ + rerun: set[int] = set() + for node in self.nodes: + if node.node_type != "python_script": + continue + if node.node_id in plan_skip_ids: + continue + if not node._execution_state.has_run_with_current_setup: + rerun.add(node.node_id) + return rerun + + def _group_rerun_nodes_by_kernel( + self, rerun_node_ids: set[int], + ) -> dict[str, set[int]]: + """Group *rerun_node_ids* by their kernel ID. + + Returns a mapping ``kernel_id → {node_id, …}``. + """ + kernel_nodes: dict[str, set[int]] = {} + for node in self.nodes: + if node.node_id not in rerun_node_ids: + continue + if node.node_type == "python_script" and node.setting_input is not None: + kid = getattr( + getattr(node.setting_input, "python_script_input", None), + "kernel_id", + None, + ) + if kid: + kernel_nodes.setdefault(kid, set()).add(node.node_id) + return kernel_nodes + def _execute_single_node( self, node: FlowNode, @@ -2366,10 +2558,61 @@ def run_graph(self) -> RunInformation | None: self.flow_settings.is_canceled = False self.flow_logger.clear_log_file() self.flow_logger.info("Starting to run flowfile flow...") + execution_plan = compute_execution_plan( nodes=self.nodes, flow_starts=self._flow_starts + self.get_implicit_starter_nodes() ) + # Selectively clear artifacts only for nodes that will re-run. + # Nodes that are up-to-date keep their artifacts in both the + # metadata tracker AND the kernel's in-memory store so that + # downstream nodes can still read them. + plan_skip_ids: set[str | int] = {n.node_id for n in execution_plan.skip_nodes} + rerun_node_ids = self._compute_rerun_python_script_node_ids(plan_skip_ids) + + # Expand re-run set: if a re-running node previously deleted + # artifacts, the original producer nodes must also re-run so + # those artifacts are available again in the kernel store. + while True: + deleted_producers = self.artifact_context.get_producer_nodes_for_deletions( + rerun_node_ids, + ) + new_ids = deleted_producers - rerun_node_ids + if not new_ids: + break + rerun_node_ids |= new_ids + + # Force producer nodes (added due to artifact deletions) to + # actually re-execute by marking their execution state stale. + for nid in rerun_node_ids: + node = self.get_node(nid) + if node is not None and node._execution_state.has_run_with_current_setup: + node._execution_state.has_run_with_current_setup = False + + # Also purge stale metadata for nodes not in this graph + # (e.g. injected externally or left over from removed nodes). + graph_node_ids = set(self._node_db.keys()) + stale_node_ids = { + nid for nid in self.artifact_context._node_states + if nid not in graph_node_ids + } + nodes_to_clear = rerun_node_ids | stale_node_ids + if nodes_to_clear: + self.artifact_context.clear_nodes(nodes_to_clear) + + if rerun_node_ids: + # Clear the actual kernel-side artifacts for re-running nodes + kernel_node_map = self._group_rerun_nodes_by_kernel(rerun_node_ids) + for kid, node_ids_for_kernel in kernel_node_map.items(): + try: + manager = get_kernel_manager() + manager.clear_node_artifacts_sync(kid, list(node_ids_for_kernel), flow_id=self.flow_id, flow_logger=self.flow_logger) + except Exception: + logger.debug( + "Could not clear node artifacts for kernel '%s', nodes %s", + kid, sorted(node_ids_for_kernel), + ) + self.latest_run_info = self.create_initial_run_information( execution_plan.node_count, "full_run" ) @@ -2379,7 +2622,7 @@ def run_graph(self) -> RunInformation | None: performance_mode = self.flow_settings.execution_mode == "Performance" run_info_lock = threading.Lock() - skip_node_ids: set[str | int] = {n.node_id for n in execution_plan.skip_nodes} + skip_node_ids: set[str | int] = plan_skip_ids for stage in execution_plan.stages: if self.flow_settings.is_canceled: diff --git a/flowfile_core/flowfile_core/kernel/__init__.py b/flowfile_core/flowfile_core/kernel/__init__.py new file mode 100644 index 000000000..6600f4e3a --- /dev/null +++ b/flowfile_core/flowfile_core/kernel/__init__.py @@ -0,0 +1,52 @@ +from flowfile_core.kernel.manager import KernelManager +from flowfile_core.kernel.models import ( + ArtifactIdentifier, + ArtifactPersistenceInfo, + CleanupRequest, + CleanupResult, + ClearNodeArtifactsRequest, + ClearNodeArtifactsResult, + DisplayOutput, + DockerStatus, + ExecuteRequest, + ExecuteResult, + KernelConfig, + KernelInfo, + KernelState, + RecoveryMode, + RecoveryStatus, +) +from flowfile_core.kernel.routes import router + +__all__ = [ + "KernelManager", + "ArtifactIdentifier", + "ArtifactPersistenceInfo", + "CleanupRequest", + "CleanupResult", + "ClearNodeArtifactsRequest", + "ClearNodeArtifactsResult", + "DisplayOutput", + "DockerStatus", + "KernelConfig", + "KernelInfo", + "KernelState", + "ExecuteRequest", + "ExecuteResult", + "RecoveryMode", + "RecoveryStatus", + "router", + "get_kernel_manager", +] + +_manager: KernelManager | None = None + + +def get_kernel_manager() -> KernelManager: + global _manager + if _manager is None: + from shared.storage_config import storage + + shared_path = str(storage.temp_directory / "kernel_shared") + _manager = KernelManager(shared_volume_path=shared_path) + return _manager diff --git a/flowfile_core/flowfile_core/kernel/manager.py b/flowfile_core/flowfile_core/kernel/manager.py new file mode 100644 index 000000000..60838a965 --- /dev/null +++ b/flowfile_core/flowfile_core/kernel/manager.py @@ -0,0 +1,612 @@ +import asyncio +import logging +import socket +import time + +import docker +import httpx + +from flowfile_core.configs.flow_logger import FlowLogger +from flowfile_core.kernel.models import ( + ArtifactPersistenceInfo, + CleanupRequest, + CleanupResult, + ClearNodeArtifactsResult, + ExecuteRequest, + ExecuteResult, + KernelConfig, + KernelInfo, + KernelState, + RecoveryStatus, +) +from shared.storage_config import storage + +logger = logging.getLogger(__name__) + +_KERNEL_IMAGE = "flowfile-kernel" +_BASE_PORT = 19000 +_PORT_RANGE = 1000 # 19000-19999 +_HEALTH_TIMEOUT = 120 +_HEALTH_POLL_INTERVAL = 2 + + +def _is_port_available(port: int) -> bool: + """Check whether a TCP port is free on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("127.0.0.1", port)) + return True + except OSError: + return False + + +class KernelManager: + def __init__(self, shared_volume_path: str | None = None): + self._docker = docker.from_env() + self._kernels: dict[str, KernelInfo] = {} + self._kernel_owners: dict[str, int] = {} # kernel_id -> user_id + self._shared_volume = shared_volume_path or str(storage.cache_directory) + self._restore_kernels_from_db() + self._reclaim_running_containers() + + @property + def shared_volume_path(self) -> str: + return self._shared_volume + + # ------------------------------------------------------------------ + # Database persistence helpers + # ------------------------------------------------------------------ + + def _restore_kernels_from_db(self) -> None: + """Load persisted kernel configs from the database on startup.""" + try: + from flowfile_core.database.connection import get_db_context + from flowfile_core.kernel.persistence import get_all_kernels + + with get_db_context() as db: + for config, user_id in get_all_kernels(db): + if config.id in self._kernels: + continue + kernel = KernelInfo( + id=config.id, + name=config.name, + state=KernelState.STOPPED, + packages=config.packages, + memory_gb=config.memory_gb, + cpu_cores=config.cpu_cores, + gpu=config.gpu, + ) + self._kernels[config.id] = kernel + self._kernel_owners[config.id] = user_id + logger.info("Restored kernel '%s' for user %d from database", config.id, user_id) + except Exception as exc: + logger.warning("Could not restore kernels from database: %s", exc) + + def _persist_kernel(self, kernel: KernelInfo, user_id: int) -> None: + """Save a kernel record to the database.""" + try: + from flowfile_core.database.connection import get_db_context + from flowfile_core.kernel.persistence import save_kernel + + with get_db_context() as db: + save_kernel(db, kernel, user_id) + except Exception as exc: + logger.warning("Could not persist kernel '%s': %s", kernel.id, exc) + + def _remove_kernel_from_db(self, kernel_id: str) -> None: + """Remove a kernel record from the database.""" + try: + from flowfile_core.database.connection import get_db_context + from flowfile_core.kernel.persistence import delete_kernel + + with get_db_context() as db: + delete_kernel(db, kernel_id) + except Exception as exc: + logger.warning("Could not remove kernel '%s' from database: %s", kernel_id, exc) + + # ------------------------------------------------------------------ + # Port allocation + # ------------------------------------------------------------------ + + def _reclaim_running_containers(self) -> None: + """Discover running flowfile-kernel containers and reclaim their ports.""" + try: + containers = self._docker.containers.list( + filters={"name": "flowfile-kernel-", "status": "running"} + ) + except (docker.errors.APIError, docker.errors.DockerException) as exc: + logger.warning("Could not list running containers: %s", exc) + return + + for container in containers: + name = container.name + if not name.startswith("flowfile-kernel-"): + continue + kernel_id = name[len("flowfile-kernel-"):] + + # Determine which host port is mapped + port = None + try: + bindings = container.attrs["NetworkSettings"]["Ports"].get("9999/tcp") + if bindings: + port = int(bindings[0]["HostPort"]) + except (KeyError, IndexError, TypeError, ValueError): + pass + + if port is not None and kernel_id in self._kernels: + # Kernel was restored from DB — update with runtime info + self._kernels[kernel_id].container_id = container.id + self._kernels[kernel_id].port = port + self._kernels[kernel_id].state = KernelState.IDLE + logger.info( + "Reclaimed running kernel '%s' on port %d (container %s)", + kernel_id, port, container.short_id, + ) + elif port is not None and kernel_id not in self._kernels: + # Orphan container with no DB record — stop it + logger.warning( + "Found orphan kernel container '%s' with no database record, stopping it", + kernel_id, + ) + try: + container.stop(timeout=10) + container.remove(force=True) + except Exception as exc: + logger.warning("Error stopping orphan container '%s': %s", kernel_id, exc) + + def _allocate_port(self) -> int: + """Find the next available port in the kernel port range.""" + used_ports = {k.port for k in self._kernels.values() if k.port is not None} + for port in range(_BASE_PORT, _BASE_PORT + _PORT_RANGE): + if port not in used_ports and _is_port_available(port): + return port + raise RuntimeError( + f"No available ports in range {_BASE_PORT}-{_BASE_PORT + _PORT_RANGE - 1}" + ) + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + async def create_kernel(self, config: KernelConfig, user_id: int) -> KernelInfo: + if config.id in self._kernels: + raise ValueError(f"Kernel '{config.id}' already exists") + + port = self._allocate_port() + kernel = KernelInfo( + id=config.id, + name=config.name, + state=KernelState.STOPPED, + port=port, + packages=config.packages, + memory_gb=config.memory_gb, + cpu_cores=config.cpu_cores, + gpu=config.gpu, + health_timeout=config.health_timeout, + persistence_enabled=config.persistence_enabled, + recovery_mode=config.recovery_mode, + ) + self._kernels[config.id] = kernel + self._kernel_owners[config.id] = user_id + self._persist_kernel(kernel, user_id) + logger.info("Created kernel '%s' on port %d for user %d", config.id, port, user_id) + return kernel + + async def start_kernel(self, kernel_id: str) -> KernelInfo: + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state == KernelState.IDLE: + return kernel + + # Verify the kernel image exists before attempting to start + try: + self._docker.images.get(_KERNEL_IMAGE) + except docker.errors.ImageNotFound: + kernel.state = KernelState.ERROR + kernel.error_message = ( + f"Docker image '{_KERNEL_IMAGE}' not found. " + "Please build or pull the kernel image before starting a kernel." + ) + raise RuntimeError(kernel.error_message) + + # Allocate a port if the kernel doesn't have one yet (e.g. restored from DB) + if kernel.port is None: + kernel.port = self._allocate_port() + + kernel.state = KernelState.STARTING + kernel.error_message = None + + try: + packages_str = " ".join(kernel.packages) + run_kwargs: dict = { + "detach": True, + "name": f"flowfile-kernel-{kernel_id}", + "ports": {"9999/tcp": kernel.port}, + "volumes": {self._shared_volume: {"bind": "/shared", "mode": "rw"}}, + "environment": { + "KERNEL_PACKAGES": packages_str, + "KERNEL_ID": kernel_id, + "PERSISTENCE_ENABLED": "true" if kernel.persistence_enabled else "false", + "PERSISTENCE_PATH": "/shared/artifacts", + "RECOVERY_MODE": kernel.recovery_mode.value, + }, + "mem_limit": f"{kernel.memory_gb}g", + "nano_cpus": int(kernel.cpu_cores * 1e9), + "extra_hosts": {"host.docker.internal": "host-gateway"}, + } + container = self._docker.containers.run(_KERNEL_IMAGE, **run_kwargs) + kernel.container_id = container.id + await self._wait_for_healthy(kernel_id, timeout=kernel.health_timeout) + kernel.state = KernelState.IDLE + logger.info("Kernel '%s' is idle (container %s)", kernel_id, container.short_id) + except (docker.errors.DockerException, httpx.HTTPError, TimeoutError, OSError) as exc: + kernel.state = KernelState.ERROR + kernel.error_message = str(exc) + logger.error("Failed to start kernel '%s': %s", kernel_id, exc) + self._cleanup_container(kernel_id) + raise + + return kernel + + def start_kernel_sync(self, kernel_id: str, flow_logger: FlowLogger | None = None) -> KernelInfo: + """Synchronous version of start_kernel() for use from non-async code.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state == KernelState.IDLE: + return kernel + + try: + self._docker.images.get(_KERNEL_IMAGE) + except docker.errors.ImageNotFound: + kernel.state = KernelState.ERROR + kernel.error_message = ( + f"Docker image '{_KERNEL_IMAGE}' not found. " + "Please build or pull the kernel image before starting a kernel." + ) + flow_logger.error(f"Docker image '{_KERNEL_IMAGE}' not found. " + "Please build or pull the kernel image before starting a kernel.") if flow_logger else None + raise RuntimeError(kernel.error_message) + + if kernel.port is None: + kernel.port = self._allocate_port() + + kernel.state = KernelState.STARTING + kernel.error_message = None + + try: + packages_str = " ".join(kernel.packages) + run_kwargs: dict = { + "detach": True, + "name": f"flowfile-kernel-{kernel_id}", + "ports": {"9999/tcp": kernel.port}, + "volumes": {self._shared_volume: {"bind": "/shared", "mode": "rw"}}, + "environment": { + "KERNEL_PACKAGES": packages_str, + "KERNEL_ID": kernel_id, + "PERSISTENCE_ENABLED": "true" if kernel.persistence_enabled else "false", + "PERSISTENCE_PATH": "/shared/artifacts", + "RECOVERY_MODE": kernel.recovery_mode.value, + }, + "mem_limit": f"{kernel.memory_gb}g", + "nano_cpus": int(kernel.cpu_cores * 1e9), + "extra_hosts": {"host.docker.internal": "host-gateway"}, + } + container = self._docker.containers.run(_KERNEL_IMAGE, **run_kwargs) + kernel.container_id = container.id + self._wait_for_healthy_sync(kernel_id, timeout=kernel.health_timeout) + kernel.state = KernelState.IDLE + flow_logger.info(f"Kernel {kernel_id} is idle (container {container.short_id})") if flow_logger else None + except (docker.errors.DockerException, httpx.HTTPError, TimeoutError, OSError) as exc: + kernel.state = KernelState.ERROR + kernel.error_message = str(exc) + flow_logger.error(f"Failed to start kernel {kernel_id}: {exc}") if flow_logger else None + self._cleanup_container(kernel_id) + raise + flow_logger.info(f"Kernel {kernel_id} started (container {container.short_id})") if flow_logger else None + return kernel + + async def stop_kernel(self, kernel_id: str) -> None: + kernel = self._get_kernel_or_raise(kernel_id) + self._cleanup_container(kernel_id) + kernel.state = KernelState.STOPPED + kernel.container_id = None + logger.info("Stopped kernel '%s'", kernel_id) + + async def delete_kernel(self, kernel_id: str) -> None: + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state in (KernelState.IDLE, KernelState.EXECUTING): + await self.stop_kernel(kernel_id) + del self._kernels[kernel_id] + self._kernel_owners.pop(kernel_id, None) + self._remove_kernel_from_db(kernel_id) + logger.info("Deleted kernel '%s'", kernel_id) + + def shutdown_all(self) -> None: + """Stop and remove all running kernel containers. Called on core shutdown.""" + kernel_ids = list(self._kernels.keys()) + for kernel_id in kernel_ids: + kernel = self._kernels.get(kernel_id) + if kernel and kernel.state in (KernelState.IDLE, KernelState.EXECUTING, KernelState.STARTING): + logger.info("Shutting down kernel '%s'", kernel_id) + self._cleanup_container(kernel_id) + kernel.state = KernelState.STOPPED + kernel.container_id = None + logger.info("All kernels have been shut down") + + # ------------------------------------------------------------------ + # Execution + # ------------------------------------------------------------------ + + async def execute(self, kernel_id: str, request: ExecuteRequest) -> ExecuteResult: + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + await self._ensure_running(kernel_id) + + kernel.state = KernelState.EXECUTING + try: + url = f"http://localhost:{kernel.port}/execute" + async with httpx.AsyncClient(timeout=httpx.Timeout(300.0)) as client: + response = await client.post(url, json=request.model_dump()) + response.raise_for_status() + return ExecuteResult(**response.json()) + finally: + # Only return to IDLE if we haven't been stopped/errored in the meantime + if kernel.state == KernelState.EXECUTING: + kernel.state = KernelState.IDLE + + def execute_sync(self, kernel_id: str, request: ExecuteRequest, + flow_logger: FlowLogger | None = None) -> ExecuteResult: + """Synchronous wrapper around execute() for use from non-async code.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + self._ensure_running_sync(kernel_id, flow_logger=flow_logger) + + kernel.state = KernelState.EXECUTING + try: + url = f"http://localhost:{kernel.port}/execute" + with httpx.Client(timeout=httpx.Timeout(300.0)) as client: + response = client.post(url, json=request.model_dump()) + response.raise_for_status() + return ExecuteResult(**response.json()) + finally: + if kernel.state == KernelState.EXECUTING: + kernel.state = KernelState.IDLE + + async def clear_artifacts(self, kernel_id: str) -> None: + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + await self._ensure_running(kernel_id) + + url = f"http://localhost:{kernel.port}/clear" + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.post(url) + response.raise_for_status() + + def clear_artifacts_sync(self, kernel_id: str) -> None: + """Synchronous wrapper around clear_artifacts() for use from non-async code.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + self._ensure_running_sync(kernel_id) + + url = f"http://localhost:{kernel.port}/clear" + with httpx.Client(timeout=httpx.Timeout(30.0)) as client: + response = client.post(url) + response.raise_for_status() + + async def clear_node_artifacts( + self, kernel_id: str, node_ids: list[int], flow_id: int | None = None, + ) -> ClearNodeArtifactsResult: + """Clear only artifacts published by the given node IDs.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + await self._ensure_running(kernel_id) + + url = f"http://localhost:{kernel.port}/clear_node_artifacts" + payload: dict = {"node_ids": node_ids} + if flow_id is not None: + payload["flow_id"] = flow_id + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.post(url, json=payload) + response.raise_for_status() + return ClearNodeArtifactsResult(**response.json()) + + def clear_node_artifacts_sync( + self, kernel_id: str, node_ids: list[int], flow_id: int | None = None, + flow_logger: FlowLogger | None = None, + ) -> ClearNodeArtifactsResult: + """Synchronous wrapper for clearing artifacts by node IDs.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + self._ensure_running_sync(kernel_id, flow_logger=flow_logger) + + url = f"http://localhost:{kernel.port}/clear_node_artifacts" + payload: dict = {"node_ids": node_ids} + if flow_id is not None: + payload["flow_id"] = flow_id + with httpx.Client(timeout=httpx.Timeout(30.0)) as client: + response = client.post(url, json=payload) + response.raise_for_status() + return ClearNodeArtifactsResult(**response.json()) + + async def get_node_artifacts(self, kernel_id: str, node_id: int) -> dict: + """Get artifacts published by a specific node.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + await self._ensure_running(kernel_id) + + url = f"http://localhost:{kernel.port}/artifacts/node/{node_id}" + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + + # ------------------------------------------------------------------ + # Artifact Persistence & Recovery + # ------------------------------------------------------------------ + + async def recover_artifacts(self, kernel_id: str) -> RecoveryStatus: + """Trigger manual artifact recovery on a running kernel.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + raise RuntimeError(f"Kernel '{kernel_id}' is not running (state: {kernel.state})") + + url = f"http://localhost:{kernel.port}/recover" + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0)) as client: + response = await client.post(url) + response.raise_for_status() + return RecoveryStatus(**response.json()) + + async def get_recovery_status(self, kernel_id: str) -> RecoveryStatus: + """Get the current recovery status of a kernel.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + raise RuntimeError(f"Kernel '{kernel_id}' is not running (state: {kernel.state})") + + url = f"http://localhost:{kernel.port}/recovery-status" + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.get(url) + response.raise_for_status() + return RecoveryStatus(**response.json()) + + async def cleanup_artifacts(self, kernel_id: str, request: CleanupRequest) -> CleanupResult: + """Clean up old persisted artifacts on a kernel.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + raise RuntimeError(f"Kernel '{kernel_id}' is not running (state: {kernel.state})") + + url = f"http://localhost:{kernel.port}/cleanup" + async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client: + response = await client.post(url, json=request.model_dump()) + response.raise_for_status() + return CleanupResult(**response.json()) + + async def get_persistence_info(self, kernel_id: str) -> ArtifactPersistenceInfo: + """Get persistence configuration and stats for a kernel.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state not in (KernelState.IDLE, KernelState.EXECUTING): + raise RuntimeError(f"Kernel '{kernel_id}' is not running (state: {kernel.state})") + + url = f"http://localhost:{kernel.port}/persistence" + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.get(url) + response.raise_for_status() + return ArtifactPersistenceInfo(**response.json()) + + # ------------------------------------------------------------------ + # Queries + # ------------------------------------------------------------------ + + async def list_kernels(self, user_id: int | None = None) -> list[KernelInfo]: + if user_id is not None: + return [ + k for kid, k in self._kernels.items() + if self._kernel_owners.get(kid) == user_id + ] + return list(self._kernels.values()) + + async def get_kernel(self, kernel_id: str) -> KernelInfo | None: + return self._kernels.get(kernel_id) + + def get_kernel_owner(self, kernel_id: str) -> int | None: + return self._kernel_owners.get(kernel_id) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _get_kernel_or_raise(self, kernel_id: str) -> KernelInfo: + kernel = self._kernels.get(kernel_id) + if kernel is None: + raise KeyError(f"Kernel '{kernel_id}' not found") + return kernel + + async def _ensure_running(self, kernel_id: str) -> None: + """Restart the kernel if it is STOPPED or ERROR, then wait until IDLE.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state in (KernelState.IDLE, KernelState.EXECUTING): + return + if kernel.state in (KernelState.STOPPED, KernelState.ERROR): + logger.info( + "Kernel '%s' is %s, attempting automatic restart...", + kernel_id, kernel.state.value, + ) + self._cleanup_container(kernel_id) + kernel.container_id = None + await self.start_kernel(kernel_id) + return + # STARTING — wait for it to finish + if kernel.state == KernelState.STARTING: + logger.info("Kernel '%s' is starting, waiting for it to become ready...", kernel_id) + await self._wait_for_healthy(kernel_id) + kernel.state = KernelState.IDLE + + def _ensure_running_sync(self, kernel_id: str, flow_logger: FlowLogger | None = None) -> None: + """Synchronous version of _ensure_running.""" + kernel = self._get_kernel_or_raise(kernel_id) + if kernel.state in (KernelState.IDLE, KernelState.EXECUTING): + return + if kernel.state in (KernelState.STOPPED, KernelState.ERROR): + msg = f"Kernel '{kernel_id}' is {kernel.state.value}, attempting automatic restart..." + logger.info(msg) + if flow_logger: + flow_logger.info(msg) + self._cleanup_container(kernel_id) + kernel.container_id = None + self.start_kernel_sync(kernel_id, flow_logger=flow_logger) + return + # STARTING — wait for it to finish + if kernel.state == KernelState.STARTING: + logger.info("Kernel '%s' is starting, waiting for it to become ready...", kernel_id) + self._wait_for_healthy_sync(kernel_id) + kernel.state = KernelState.IDLE + + def _cleanup_container(self, kernel_id: str) -> None: + kernel = self._kernels.get(kernel_id) + if kernel is None or kernel.container_id is None: + return + try: + container = self._docker.containers.get(kernel.container_id) + container.stop(timeout=10) + container.remove(force=True) + except docker.errors.NotFound: + pass + except (docker.errors.APIError, docker.errors.DockerException) as exc: + logger.warning("Error cleaning up container for kernel '%s': %s", kernel_id, exc) + + async def _wait_for_healthy(self, kernel_id: str, timeout: int = _HEALTH_TIMEOUT) -> None: + kernel = self._get_kernel_or_raise(kernel_id) + url = f"http://localhost:{kernel.port}/health" + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + + while loop.time() < deadline: + try: + async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: + response = await client.get(url) + if response.status_code == 200: + data = response.json() + kernel.kernel_version = data.get("version") + return + except (httpx.HTTPError, OSError) as exc: + logger.debug("Health poll for kernel '%s' failed: %s", kernel_id, exc) + await asyncio.sleep(_HEALTH_POLL_INTERVAL) + + raise TimeoutError(f"Kernel '{kernel_id}' did not become healthy within {timeout}s") + + def _wait_for_healthy_sync(self, kernel_id: str, timeout: int = _HEALTH_TIMEOUT) -> None: + """Synchronous version of _wait_for_healthy.""" + kernel = self._get_kernel_or_raise(kernel_id) + url = f"http://localhost:{kernel.port}/health" + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + try: + with httpx.Client(timeout=httpx.Timeout(5.0)) as client: + response = client.get(url) + if response.status_code == 200: + data = response.json() + kernel.kernel_version = data.get("version") + return + except (httpx.HTTPError, OSError) as exc: + logger.debug("Health poll for kernel '%s' failed: %s", kernel_id, exc) + time.sleep(_HEALTH_POLL_INTERVAL) + + raise TimeoutError(f"Kernel '{kernel_id}' did not become healthy within {timeout}s") diff --git a/flowfile_core/flowfile_core/kernel/models.py b/flowfile_core/flowfile_core/kernel/models.py new file mode 100644 index 000000000..847e6fadf --- /dev/null +++ b/flowfile_core/flowfile_core/kernel/models.py @@ -0,0 +1,142 @@ +from datetime import datetime, timezone +from enum import Enum + +from pydantic import BaseModel, Field + + +class KernelState(str, Enum): + STOPPED = "stopped" + STARTING = "starting" + IDLE = "idle" + EXECUTING = "executing" + ERROR = "error" + + +class RecoveryMode(str, Enum): + LAZY = "lazy" + EAGER = "eager" + CLEAR = "clear" # Clears all persisted artifacts on startup (destructive) + + +class KernelConfig(BaseModel): + id: str + name: str + packages: list[str] = Field(default_factory=list) + cpu_cores: float = 2.0 + memory_gb: float = 4.0 + gpu: bool = False + health_timeout: int = 120 + # Persistence configuration + persistence_enabled: bool = True + recovery_mode: RecoveryMode = RecoveryMode.LAZY + + +class KernelInfo(BaseModel): + id: str + name: str + state: KernelState = KernelState.STOPPED + container_id: str | None = None + port: int | None = None + packages: list[str] = Field(default_factory=list) + memory_gb: float = 4.0 + cpu_cores: float = 2.0 + gpu: bool = False + health_timeout: int = 120 + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + error_message: str | None = None + kernel_version: str | None = None + # Persistence configuration + persistence_enabled: bool = True + recovery_mode: RecoveryMode = RecoveryMode.LAZY + + +class DockerStatus(BaseModel): + available: bool + image_available: bool + error: str | None = None + + +class ExecuteRequest(BaseModel): + node_id: int + code: str + input_paths: dict[str, list[str]] = Field(default_factory=dict) + output_dir: str = "" + flow_id: int = 0 + log_callback_url: str = "" + interactive: bool = False # When True, auto-display last expression + + +class ClearNodeArtifactsRequest(BaseModel): + """Request to selectively clear artifacts owned by specific node IDs.""" + node_ids: list[int] + flow_id: int | None = None + + +class ClearNodeArtifactsResult(BaseModel): + """Result of a selective artifact clear operation.""" + status: str = "cleared" + removed: list[str] = Field(default_factory=list) + + +class DisplayOutput(BaseModel): + """A single display output from code execution.""" + mime_type: str # "image/png", "text/html", "text/plain" + data: str # base64 for images, raw HTML for text/html, plain text otherwise + title: str = "" + + +class ExecuteResult(BaseModel): + success: bool + output_paths: list[str] = Field(default_factory=list) + artifacts_published: list[str] = Field(default_factory=list) + artifacts_deleted: list[str] = Field(default_factory=list) + display_outputs: list[DisplayOutput] = Field(default_factory=list) + stdout: str = "" + stderr: str = "" + error: str | None = None + execution_time_ms: float = 0.0 + + +# --------------------------------------------------------------------------- +# Artifact Persistence & Recovery models +# --------------------------------------------------------------------------- + + +class RecoveryStatus(BaseModel): + status: str # "pending", "recovering", "completed", "error", "disabled" + mode: str | None = None + recovered: list[str] = Field(default_factory=list) + indexed: int | None = None + errors: list[str] = Field(default_factory=list) + + +class ArtifactIdentifier(BaseModel): + """Identifies a specific artifact by flow_id and name.""" + flow_id: int + name: str + + +class CleanupRequest(BaseModel): + """Request to clean up old persisted artifacts.""" + max_age_hours: float | None = None + artifact_names: list[ArtifactIdentifier] | None = Field( + default=None, + description="List of specific artifacts to delete", + ) + + +class CleanupResult(BaseModel): + status: str + removed_count: int = 0 + + +class ArtifactPersistenceInfo(BaseModel): + """Persistence configuration and stats for a kernel.""" + enabled: bool + recovery_mode: str = "lazy" + kernel_id: str | None = None + persistence_path: str | None = None + persisted_count: int = 0 + in_memory_count: int = 0 + disk_usage_bytes: int = 0 + artifacts: dict = Field(default_factory=dict) diff --git a/flowfile_core/flowfile_core/kernel/persistence.py b/flowfile_core/flowfile_core/kernel/persistence.py new file mode 100644 index 000000000..412f33f4a --- /dev/null +++ b/flowfile_core/flowfile_core/kernel/persistence.py @@ -0,0 +1,71 @@ +"""Database persistence for kernel configurations. + +Kernels are persisted so they survive core process restarts. Only the +configuration is stored (id, name, packages, resource limits, user ownership). +Runtime state (container_id, port, state) is ephemeral and reconstructed at +startup by reclaiming running Docker containers. +""" + +import json +import logging + +from sqlalchemy.orm import Session + +from flowfile_core.database import models as db_models +from flowfile_core.kernel.models import KernelConfig, KernelInfo + +logger = logging.getLogger(__name__) + + +def save_kernel(db: Session, kernel: KernelInfo, user_id: int) -> None: + """Insert or update a kernel record in the database.""" + existing = db.query(db_models.Kernel).filter(db_models.Kernel.id == kernel.id).first() + if existing: + existing.name = kernel.name + existing.packages = json.dumps(kernel.packages) + existing.cpu_cores = kernel.cpu_cores + existing.memory_gb = kernel.memory_gb + existing.gpu = kernel.gpu + existing.user_id = user_id + else: + record = db_models.Kernel( + id=kernel.id, + name=kernel.name, + user_id=user_id, + packages=json.dumps(kernel.packages), + cpu_cores=kernel.cpu_cores, + memory_gb=kernel.memory_gb, + gpu=kernel.gpu, + ) + db.add(record) + db.commit() + + +def delete_kernel(db: Session, kernel_id: str) -> None: + """Remove a kernel record from the database.""" + db.query(db_models.Kernel).filter(db_models.Kernel.id == kernel_id).delete() + db.commit() + + +def get_kernels_for_user(db: Session, user_id: int) -> list[KernelConfig]: + """Return all persisted kernel configs belonging to a user.""" + rows = db.query(db_models.Kernel).filter(db_models.Kernel.user_id == user_id).all() + return [_row_to_config(row) for row in rows] + + +def get_all_kernels(db: Session) -> list[tuple[KernelConfig, int]]: + """Return all persisted kernels as (config, user_id) tuples.""" + rows = db.query(db_models.Kernel).all() + return [(_row_to_config(row), row.user_id) for row in rows] + + +def _row_to_config(row: db_models.Kernel) -> KernelConfig: + packages = json.loads(row.packages) if row.packages else [] + return KernelConfig( + id=row.id, + name=row.name, + packages=packages, + cpu_cores=row.cpu_cores, + memory_gb=row.memory_gb, + gpu=row.gpu, + ) diff --git a/flowfile_core/flowfile_core/kernel/routes.py b/flowfile_core/flowfile_core/kernel/routes.py new file mode 100644 index 000000000..55624cd58 --- /dev/null +++ b/flowfile_core/flowfile_core/kernel/routes.py @@ -0,0 +1,308 @@ +import logging + +from fastapi import APIRouter, Depends, HTTPException + +from flowfile_core.auth.jwt import get_current_active_user +from flowfile_core.kernel.models import ( + ArtifactPersistenceInfo, + CleanupRequest, + CleanupResult, + ClearNodeArtifactsRequest, + ClearNodeArtifactsResult, + DockerStatus, + ExecuteRequest, + ExecuteResult, + KernelConfig, + KernelInfo, + RecoveryStatus, +) + +logger = logging.getLogger(__name__) + + +def _get_manager(): + from flowfile_core.kernel import get_kernel_manager + + try: + return get_kernel_manager() + except Exception as exc: + logger.error("Kernel manager unavailable: %s", exc) + raise HTTPException( + status_code=503, + detail="Docker is not available. Please ensure Docker is installed and running.", + ) + + +router = APIRouter(prefix="/kernels", dependencies=[Depends(get_current_active_user)]) + + +@router.get("/", response_model=list[KernelInfo]) +async def list_kernels(current_user=Depends(get_current_active_user)): + return await _get_manager().list_kernels(user_id=current_user.id) + + +@router.post("/", response_model=KernelInfo) +async def create_kernel(config: KernelConfig, current_user=Depends(get_current_active_user)): + try: + return await _get_manager().create_kernel(config, user_id=current_user.id) + except ValueError as exc: + raise HTTPException(status_code=409, detail=str(exc)) + + +@router.get("/docker-status", response_model=DockerStatus) +async def docker_status(): + """Check if Docker is reachable and the kernel image is available.""" + import docker as _docker + + try: + client = _docker.from_env() + client.ping() + except Exception as exc: + return DockerStatus(available=False, image_available=False, error=str(exc)) + + from flowfile_core.kernel.manager import _KERNEL_IMAGE + + try: + client.images.get(_KERNEL_IMAGE) + image_available = True + except _docker.errors.ImageNotFound: + image_available = False + except Exception: + image_available = False + + return DockerStatus(available=True, image_available=image_available) + + +@router.get("/{kernel_id}", response_model=KernelInfo) +async def get_kernel(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + return kernel + + +@router.delete("/{kernel_id}") +async def delete_kernel(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + await manager.delete_kernel(kernel_id) + return {"status": "deleted", "kernel_id": kernel_id} + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) + + +@router.post("/{kernel_id}/start", response_model=KernelInfo) +async def start_kernel(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.start_kernel(kernel_id) + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + +@router.post("/{kernel_id}/stop") +async def stop_kernel(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + await manager.stop_kernel(kernel_id) + return {"status": "stopped", "kernel_id": kernel_id} + except KeyError as exc: + raise HTTPException(status_code=404, detail=str(exc)) + + +@router.post("/{kernel_id}/execute", response_model=ExecuteResult) +async def execute_code(kernel_id: str, request: ExecuteRequest, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.execute(kernel_id, request) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.post("/{kernel_id}/execute_cell", response_model=ExecuteResult) +async def execute_cell(kernel_id: str, request: ExecuteRequest, current_user=Depends(get_current_active_user)): + """Execute a single notebook cell interactively. + + Same as /execute but sets interactive=True to enable auto-display of the last expression. + """ + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + # Force interactive mode for cell execution + request.interactive = True + return await manager.execute(kernel_id, request) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/{kernel_id}/artifacts") +async def get_artifacts(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + if kernel.state.value not in ("idle", "executing"): + raise HTTPException(status_code=400, detail=f"Kernel '{kernel_id}' is not running") + + try: + import httpx + + url = f"http://localhost:{kernel.port}/artifacts" + async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + +@router.post("/{kernel_id}/clear") +async def clear_artifacts(kernel_id: str, current_user=Depends(get_current_active_user)): + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + await manager.clear_artifacts(kernel_id) + return {"status": "cleared", "kernel_id": kernel_id} + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.post("/{kernel_id}/clear_node_artifacts", response_model=ClearNodeArtifactsResult) +async def clear_node_artifacts( + kernel_id: str, + request: ClearNodeArtifactsRequest, + current_user=Depends(get_current_active_user), +): + """Clear only artifacts published by specific node IDs.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.clear_node_artifacts(kernel_id, request.node_ids, flow_id=request.flow_id) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/{kernel_id}/artifacts/node/{node_id}") +async def get_node_artifacts( + kernel_id: str, + node_id: int, + current_user=Depends(get_current_active_user), +): + """Get artifacts published by a specific node.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + if kernel.state.value not in ("idle", "executing"): + raise HTTPException(status_code=400, detail=f"Kernel '{kernel_id}' is not running") + try: + return await manager.get_node_artifacts(kernel_id, node_id) + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + +# --------------------------------------------------------------------------- +# Artifact Persistence & Recovery endpoints +# --------------------------------------------------------------------------- + +@router.post("/{kernel_id}/recover", response_model=RecoveryStatus) +async def recover_artifacts(kernel_id: str, current_user=Depends(get_current_active_user)): + """Trigger manual artifact recovery from persisted storage.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.recover_artifacts(kernel_id) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/{kernel_id}/recovery-status", response_model=RecoveryStatus) +async def get_recovery_status(kernel_id: str, current_user=Depends(get_current_active_user)): + """Get the current artifact recovery status.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.get_recovery_status(kernel_id) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.post("/{kernel_id}/cleanup", response_model=CleanupResult) +async def cleanup_artifacts( + kernel_id: str, + request: CleanupRequest, + current_user=Depends(get_current_active_user), +): + """Clean up old persisted artifacts.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.cleanup_artifacts(kernel_id, request) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + + +@router.get("/{kernel_id}/persistence", response_model=ArtifactPersistenceInfo) +async def get_persistence_info(kernel_id: str, current_user=Depends(get_current_active_user)): + """Get persistence configuration and stats for a kernel.""" + manager = _get_manager() + kernel = await manager.get_kernel(kernel_id) + if kernel is None: + raise HTTPException(status_code=404, detail=f"Kernel '{kernel_id}' not found") + if manager.get_kernel_owner(kernel_id) != current_user.id: + raise HTTPException(status_code=403, detail="Not authorized to access this kernel") + try: + return await manager.get_persistence_info(kernel_id) + except RuntimeError as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/flowfile_core/flowfile_core/main.py b/flowfile_core/flowfile_core/main.py index 3f921a83b..1ae220e3f 100644 --- a/flowfile_core/flowfile_core/main.py +++ b/flowfile_core/flowfile_core/main.py @@ -15,6 +15,7 @@ WORKER_PORT, WORKER_URL, ) +from flowfile_core.kernel import router as kernel_router from flowfile_core.routes.auth import router as auth_router from flowfile_core.routes.catalog import router as catalog_router from flowfile_core.routes.cloud_connections import router as cloud_connections_router @@ -39,8 +40,8 @@ async def shutdown_handler(app: FastAPI): """Handles the graceful startup and shutdown of the FastAPI application. - This context manager ensures that resources, such as log files, are cleaned - up properly when the application is terminated. + This context manager ensures that resources, such as log files and kernel + containers, are cleaned up properly when the application is terminated. """ print("Starting core application...") try: @@ -48,10 +49,22 @@ async def shutdown_handler(app: FastAPI): finally: print("Shutting down core application...") print("Cleaning up core service resources...") + _shutdown_kernels() clear_all_flow_logs() await asyncio.sleep(0.1) # Give a moment for cleanup +def _shutdown_kernels(): + """Stop all running kernel containers during shutdown.""" + try: + from flowfile_core.kernel import get_kernel_manager + + manager = get_kernel_manager() + manager.shutdown_all() + except Exception as exc: + print(f"Error shutting down kernels: {exc}") + + # Initialize FastAPI with metadata app = FastAPI( title="Flowfile Backend", @@ -89,6 +102,7 @@ async def shutdown_handler(app: FastAPI): app.include_router(secrets_router, prefix="/secrets", tags=["secrets"]) app.include_router(cloud_connections_router, prefix="/cloud_connections", tags=["cloud_connections"]) app.include_router(user_defined_components_router, prefix="/user_defined_components", tags=["user_defined_components"]) +app.include_router(kernel_router, tags=["kernels"]) @app.post("/shutdown") diff --git a/flowfile_core/flowfile_core/routes/catalog.py b/flowfile_core/flowfile_core/routes/catalog.py index 8ceb87455..c6c5cd623 100644 --- a/flowfile_core/flowfile_core/routes/catalog.py +++ b/flowfile_core/flowfile_core/routes/catalog.py @@ -5,10 +5,12 @@ - Flow registration (persistent flow metadata) - Run history with versioned snapshots - Favorites and follows + +This module is a thin HTTP adapter: it delegates all business logic to +``CatalogService`` and translates domain exceptions into HTTP responses. """ import json -import os from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Query @@ -17,14 +19,20 @@ from flowfile_core import flow_file_handler from flowfile_core.auth.jwt import get_current_active_user -from flowfile_core.database.connection import get_db -from flowfile_core.database.models import ( - CatalogNamespace, - FlowFavorite, - FlowFollow, - FlowRegistration, - FlowRun, +from flowfile_core.catalog import ( + CatalogService, + FavoriteNotFoundError, + FlowNotFoundError, + FollowNotFoundError, + NamespaceExistsError, + NamespaceNotEmptyError, + NamespaceNotFoundError, + NestingLimitError, + NoSnapshotError, + RunNotFoundError, + SQLAlchemyCatalogRepository, ) +from flowfile_core.database.connection import get_db from flowfile_core.schemas.catalog_schema import ( CatalogStats, FavoriteOut, @@ -48,44 +56,14 @@ # --------------------------------------------------------------------------- -# Helpers +# Dependency injection # --------------------------------------------------------------------------- -def _enrich_flow( - flow: FlowRegistration, - db: Session, - user_id: int, -) -> FlowRegistrationOut: - """Attach favourite/follow flags and run stats to a FlowRegistration row.""" - is_fav = db.query(FlowFavorite).filter_by( - user_id=user_id, registration_id=flow.id - ).first() is not None - is_follow = db.query(FlowFollow).filter_by( - user_id=user_id, registration_id=flow.id - ).first() is not None - run_count = db.query(FlowRun).filter_by(registration_id=flow.id).count() - last_run = ( - db.query(FlowRun) - .filter_by(registration_id=flow.id) - .order_by(FlowRun.started_at.desc()) - .first() - ) - return FlowRegistrationOut( - id=flow.id, - name=flow.name, - description=flow.description, - flow_path=flow.flow_path, - namespace_id=flow.namespace_id, - owner_id=flow.owner_id, - created_at=flow.created_at, - updated_at=flow.updated_at, - is_favorite=is_fav, - is_following=is_follow, - run_count=run_count, - last_run_at=last_run.started_at if last_run else None, - last_run_success=last_run.success if last_run else None, - file_exists=os.path.exists(flow.flow_path) if flow.flow_path else False, - ) + +def get_catalog_service(db: Session = Depends(get_db)) -> CatalogService: + """FastAPI dependency that provides a configured ``CatalogService``.""" + repo = SQLAlchemyCatalogRepository(db) + return CatalogService(repo) # --------------------------------------------------------------------------- @@ -96,155 +74,70 @@ def _enrich_flow( @router.get("/namespaces", response_model=list[NamespaceOut]) def list_namespaces( parent_id: int | None = None, - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """List namespaces, optionally filtered by parent.""" - q = db.query(CatalogNamespace) - if parent_id is not None: - q = q.filter(CatalogNamespace.parent_id == parent_id) - else: - q = q.filter(CatalogNamespace.parent_id.is_(None)) - return q.order_by(CatalogNamespace.name).all() + return service.list_namespaces(parent_id) @router.post("/namespaces", response_model=NamespaceOut, status_code=201) def create_namespace( body: NamespaceCreate, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """Create a catalog (level 0) or schema (level 1) namespace.""" - level = 0 - if body.parent_id is not None: - parent = db.get(CatalogNamespace, body.parent_id) - if parent is None: - raise HTTPException(404, "Parent namespace not found") - if parent.level >= 1: - raise HTTPException(422, "Cannot nest deeper than catalog -> schema") - level = parent.level + 1 - - existing = ( - db.query(CatalogNamespace) - .filter_by(name=body.name, parent_id=body.parent_id) - .first() - ) - if existing: + try: + return service.create_namespace( + name=body.name, + owner_id=current_user.id, + parent_id=body.parent_id, + description=body.description, + ) + except NamespaceNotFoundError: + raise HTTPException(404, "Parent namespace not found") + except NamespaceExistsError: raise HTTPException(409, "Namespace with this name already exists at this level") - - ns = CatalogNamespace( - name=body.name, - parent_id=body.parent_id, - level=level, - description=body.description, - owner_id=current_user.id, - ) - db.add(ns) - db.commit() - db.refresh(ns) - return ns + except NestingLimitError: + raise HTTPException(422, "Cannot nest deeper than catalog -> schema") @router.put("/namespaces/{namespace_id}", response_model=NamespaceOut) def update_namespace( namespace_id: int, body: NamespaceUpdate, - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - ns = db.get(CatalogNamespace, namespace_id) - if ns is None: + try: + return service.update_namespace( + namespace_id=namespace_id, + name=body.name, + description=body.description, + ) + except NamespaceNotFoundError: raise HTTPException(404, "Namespace not found") - if body.name is not None: - ns.name = body.name - if body.description is not None: - ns.description = body.description - db.commit() - db.refresh(ns) - return ns @router.delete("/namespaces/{namespace_id}", status_code=204) def delete_namespace( namespace_id: int, - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - ns = db.get(CatalogNamespace, namespace_id) - if ns is None: + try: + service.delete_namespace(namespace_id) + except NamespaceNotFoundError: raise HTTPException(404, "Namespace not found") - # Prevent deletion if children or flows exist - children = db.query(CatalogNamespace).filter_by(parent_id=namespace_id).count() - flows = db.query(FlowRegistration).filter_by(namespace_id=namespace_id).count() - if children > 0 or flows > 0: + except NamespaceNotEmptyError: raise HTTPException(422, "Cannot delete namespace with children or flows") - db.delete(ns) - db.commit() @router.get("/namespaces/tree", response_model=list[NamespaceTree]) def get_namespace_tree( current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """Return the full catalog tree with flows nested under schemas.""" - catalogs = ( - db.query(CatalogNamespace) - .filter(CatalogNamespace.parent_id.is_(None)) - .order_by(CatalogNamespace.name) - .all() - ) - result = [] - for cat in catalogs: - schemas_db = ( - db.query(CatalogNamespace) - .filter_by(parent_id=cat.id) - .order_by(CatalogNamespace.name) - .all() - ) - children = [] - for schema in schemas_db: - flows_db = ( - db.query(FlowRegistration) - .filter_by(namespace_id=schema.id) - .order_by(FlowRegistration.name) - .all() - ) - flow_outs = [_enrich_flow(f, db, current_user.id) for f in flows_db] - children.append( - NamespaceTree( - id=schema.id, - name=schema.name, - parent_id=schema.parent_id, - level=schema.level, - description=schema.description, - owner_id=schema.owner_id, - created_at=schema.created_at, - updated_at=schema.updated_at, - children=[], - flows=flow_outs, - ) - ) - # Also include flows directly under catalog (unschema'd) - root_flows_db = ( - db.query(FlowRegistration) - .filter_by(namespace_id=cat.id) - .order_by(FlowRegistration.name) - .all() - ) - root_flows = [_enrich_flow(f, db, current_user.id) for f in root_flows_db] - result.append( - NamespaceTree( - id=cat.id, - name=cat.name, - parent_id=cat.parent_id, - level=cat.level, - description=cat.description, - owner_id=cat.owner_id, - created_at=cat.created_at, - updated_at=cat.updated_at, - children=children, - flows=root_flows, - ) - ) - return result + return service.get_namespace_tree(user_id=current_user.id) # --------------------------------------------------------------------------- @@ -254,18 +147,10 @@ def get_namespace_tree( @router.get("/default-namespace-id") def get_default_namespace_id( - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """Return the ID of the default 'user_flows' schema under 'General'.""" - general = db.query(CatalogNamespace).filter_by(name="General", parent_id=None).first() - if general is None: - return None - user_flows = db.query(CatalogNamespace).filter_by( - name="user_flows", parent_id=general.id - ).first() - if user_flows is None: - return None - return user_flows.id + return service.get_default_namespace_id() # --------------------------------------------------------------------------- @@ -277,48 +162,39 @@ def get_default_namespace_id( def list_flows( namespace_id: int | None = None, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - q = db.query(FlowRegistration) - if namespace_id is not None: - q = q.filter_by(namespace_id=namespace_id) - flows = q.order_by(FlowRegistration.name).all() - return [_enrich_flow(f, db, current_user.id) for f in flows] + return service.list_flows(user_id=current_user.id, namespace_id=namespace_id) @router.post("/flows", response_model=FlowRegistrationOut, status_code=201) def register_flow( body: FlowRegistrationCreate, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - if body.namespace_id is not None: - ns = db.get(CatalogNamespace, body.namespace_id) - if ns is None: - raise HTTPException(404, "Namespace not found") - flow = FlowRegistration( - name=body.name, - description=body.description, - flow_path=body.flow_path, - namespace_id=body.namespace_id, - owner_id=current_user.id, - ) - db.add(flow) - db.commit() - db.refresh(flow) - return _enrich_flow(flow, db, current_user.id) + try: + return service.register_flow( + name=body.name, + flow_path=body.flow_path, + owner_id=current_user.id, + namespace_id=body.namespace_id, + description=body.description, + ) + except NamespaceNotFoundError: + raise HTTPException(404, "Namespace not found") @router.get("/flows/{flow_id}", response_model=FlowRegistrationOut) def get_flow( flow_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - flow = db.get(FlowRegistration, flow_id) - if flow is None: + try: + return service.get_flow(registration_id=flow_id, user_id=current_user.id) + except FlowNotFoundError: raise HTTPException(404, "Flow not found") - return _enrich_flow(flow, db, current_user.id) @router.put("/flows/{flow_id}", response_model=FlowRegistrationOut) @@ -326,35 +202,29 @@ def update_flow( flow_id: int, body: FlowRegistrationUpdate, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - flow = db.get(FlowRegistration, flow_id) - if flow is None: + try: + return service.update_flow( + registration_id=flow_id, + requesting_user_id=current_user.id, + name=body.name, + description=body.description, + namespace_id=body.namespace_id, + ) + except FlowNotFoundError: raise HTTPException(404, "Flow not found") - if body.name is not None: - flow.name = body.name - if body.description is not None: - flow.description = body.description - if body.namespace_id is not None: - flow.namespace_id = body.namespace_id - db.commit() - db.refresh(flow) - return _enrich_flow(flow, db, current_user.id) @router.delete("/flows/{flow_id}", status_code=204) def delete_flow( flow_id: int, - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - flow = db.get(FlowRegistration, flow_id) - if flow is None: + try: + service.delete_flow(registration_id=flow_id) + except FlowNotFoundError: raise HTTPException(404, "Flow not found") - # Clean up related records - db.query(FlowFavorite).filter_by(registration_id=flow_id).delete() - db.query(FlowFollow).filter_by(registration_id=flow_id).delete() - db.delete(flow) - db.commit() # --------------------------------------------------------------------------- @@ -367,63 +237,23 @@ def list_runs( registration_id: int | None = None, limit: int = Query(50, ge=1, le=500), offset: int = Query(0, ge=0), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - q = db.query(FlowRun) - if registration_id is not None: - q = q.filter_by(registration_id=registration_id) - runs = ( - q.order_by(FlowRun.started_at.desc()) - .offset(offset) - .limit(limit) - .all() + return service.list_runs( + registration_id=registration_id, limit=limit, offset=offset ) - return [ - FlowRunOut( - id=r.id, - registration_id=r.registration_id, - flow_name=r.flow_name, - flow_path=r.flow_path, - user_id=r.user_id, - started_at=r.started_at, - ended_at=r.ended_at, - success=r.success, - nodes_completed=r.nodes_completed, - number_of_nodes=r.number_of_nodes, - duration_seconds=r.duration_seconds, - run_type=r.run_type, - has_snapshot=r.flow_snapshot is not None, - ) - for r in runs - ] @router.get("/runs/{run_id}", response_model=FlowRunDetail) def get_run_detail( run_id: int, - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """Get a single run including the YAML snapshot of the flow version that ran.""" - run = db.get(FlowRun, run_id) - if run is None: + try: + return service.get_run_detail(run_id) + except RunNotFoundError: raise HTTPException(404, "Run not found") - return FlowRunDetail( - id=run.id, - registration_id=run.registration_id, - flow_name=run.flow_name, - flow_path=run.flow_path, - user_id=run.user_id, - started_at=run.started_at, - ended_at=run.ended_at, - success=run.success, - nodes_completed=run.nodes_completed, - number_of_nodes=run.number_of_nodes, - duration_seconds=run.duration_seconds, - run_type=run.run_type, - has_snapshot=run.flow_snapshot is not None, - flow_snapshot=run.flow_snapshot, - node_results_json=run.node_results_json, - ) # --------------------------------------------------------------------------- @@ -435,17 +265,17 @@ def get_run_detail( def open_run_snapshot( run_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): """Write the run's flow snapshot to a temp file and import it into the designer.""" - run = db.get(FlowRun, run_id) - if run is None: + try: + snapshot_data = service.get_run_snapshot(run_id) + except RunNotFoundError: raise HTTPException(404, "Run not found") - if not run.flow_snapshot: + except NoSnapshotError: raise HTTPException(422, "No flow snapshot available for this run") # Determine file extension based on content - snapshot_data = run.flow_snapshot try: json.loads(snapshot_data) suffix = ".json" @@ -473,56 +303,35 @@ def open_run_snapshot( @router.get("/favorites", response_model=list[FlowRegistrationOut]) def list_favorites( current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - favs = ( - db.query(FlowFavorite) - .filter_by(user_id=current_user.id) - .order_by(FlowFavorite.created_at.desc()) - .all() - ) - result = [] - for fav in favs: - flow = db.get(FlowRegistration, fav.registration_id) - if flow: - result.append(_enrich_flow(flow, db, current_user.id)) - return result + return service.list_favorites(user_id=current_user.id) @router.post("/flows/{flow_id}/favorite", response_model=FavoriteOut, status_code=201) def add_favorite( flow_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - flow = db.get(FlowRegistration, flow_id) - if flow is None: + try: + return service.add_favorite( + user_id=current_user.id, registration_id=flow_id + ) + except FlowNotFoundError: raise HTTPException(404, "Flow not found") - existing = db.query(FlowFavorite).filter_by( - user_id=current_user.id, registration_id=flow_id - ).first() - if existing: - return existing - fav = FlowFavorite(user_id=current_user.id, registration_id=flow_id) - db.add(fav) - db.commit() - db.refresh(fav) - return fav @router.delete("/flows/{flow_id}/favorite", status_code=204) def remove_favorite( flow_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - fav = db.query(FlowFavorite).filter_by( - user_id=current_user.id, registration_id=flow_id - ).first() - if fav is None: + try: + service.remove_favorite(user_id=current_user.id, registration_id=flow_id) + except FavoriteNotFoundError: raise HTTPException(404, "Favorite not found") - db.delete(fav) - db.commit() # --------------------------------------------------------------------------- @@ -533,56 +342,35 @@ def remove_favorite( @router.get("/following", response_model=list[FlowRegistrationOut]) def list_following( current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - follows = ( - db.query(FlowFollow) - .filter_by(user_id=current_user.id) - .order_by(FlowFollow.created_at.desc()) - .all() - ) - result = [] - for follow in follows: - flow = db.get(FlowRegistration, follow.registration_id) - if flow: - result.append(_enrich_flow(flow, db, current_user.id)) - return result + return service.list_following(user_id=current_user.id) @router.post("/flows/{flow_id}/follow", response_model=FollowOut, status_code=201) def add_follow( flow_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - flow = db.get(FlowRegistration, flow_id) - if flow is None: + try: + return service.add_follow( + user_id=current_user.id, registration_id=flow_id + ) + except FlowNotFoundError: raise HTTPException(404, "Flow not found") - existing = db.query(FlowFollow).filter_by( - user_id=current_user.id, registration_id=flow_id - ).first() - if existing: - return existing - follow = FlowFollow(user_id=current_user.id, registration_id=flow_id) - db.add(follow) - db.commit() - db.refresh(follow) - return follow @router.delete("/flows/{flow_id}/follow", status_code=204) def remove_follow( flow_id: int, current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - follow = db.query(FlowFollow).filter_by( - user_id=current_user.id, registration_id=flow_id - ).first() - if follow is None: + try: + service.remove_follow(user_id=current_user.id, registration_id=flow_id) + except FollowNotFoundError: raise HTTPException(404, "Follow not found") - db.delete(follow) - db.commit() # --------------------------------------------------------------------------- @@ -593,50 +381,6 @@ def remove_follow( @router.get("/stats", response_model=CatalogStats) def get_catalog_stats( current_user=Depends(get_current_active_user), - db: Session = Depends(get_db), + service: CatalogService = Depends(get_catalog_service), ): - total_ns = db.query(CatalogNamespace).filter_by(level=0).count() - total_flows = db.query(FlowRegistration).count() - total_runs = db.query(FlowRun).count() - total_favs = db.query(FlowFavorite).filter_by(user_id=current_user.id).count() - recent = ( - db.query(FlowRun) - .order_by(FlowRun.started_at.desc()) - .limit(10) - .all() - ) - recent_out = [ - FlowRunOut( - id=r.id, - registration_id=r.registration_id, - flow_name=r.flow_name, - flow_path=r.flow_path, - user_id=r.user_id, - started_at=r.started_at, - ended_at=r.ended_at, - success=r.success, - nodes_completed=r.nodes_completed, - number_of_nodes=r.number_of_nodes, - duration_seconds=r.duration_seconds, - run_type=r.run_type, - has_snapshot=r.flow_snapshot is not None, - ) - for r in recent - ] - fav_ids = [ - f.registration_id - for f in db.query(FlowFavorite).filter_by(user_id=current_user.id).all() - ] - fav_flows = [] - for fid in fav_ids: - flow = db.get(FlowRegistration, fid) - if flow: - fav_flows.append(_enrich_flow(flow, db, current_user.id)) - return CatalogStats( - total_namespaces=total_ns, - total_flows=total_flows, - total_runs=total_runs, - total_favorites=total_favs, - recent_runs=recent_out, - favorite_flows=fav_flows, - ) + return service.get_catalog_stats(user_id=current_user.id) diff --git a/flowfile_core/flowfile_core/routes/logs.py b/flowfile_core/flowfile_core/routes/logs.py index 0d8a5de6a..2b3d78fed 100644 --- a/flowfile_core/flowfile_core/routes/logs.py +++ b/flowfile_core/flowfile_core/routes/logs.py @@ -45,17 +45,17 @@ async def add_log(flow_id: int, log_message: str): @router.post("/raw_logs", tags=["flow_logging"]) async def add_raw_log(raw_log_input: schemas.RawLogInput): """Adds a log message to the log file for a given flow_id.""" - logger.info("Adding raw logs") flow = flow_file_handler.get_flow(raw_log_input.flowfile_flow_id) if not flow: raise HTTPException(status_code=404, detail="Flow not found") - flow.flow_logger.get_log_filepath() flow_logger = flow.flow_logger - flow_logger.get_log_filepath() + node_id = raw_log_input.node_id if raw_log_input.node_id is not None else -1 if raw_log_input.log_type == "INFO": - flow_logger.info(raw_log_input.log_message, extra=raw_log_input.extra) + flow_logger.info(raw_log_input.log_message, extra=raw_log_input.extra, node_id=node_id) + elif raw_log_input.log_type == "WARNING": + flow_logger.warning(raw_log_input.log_message, extra=raw_log_input.extra, node_id=node_id) elif raw_log_input.log_type == "ERROR": - flow_logger.error(raw_log_input.log_message, extra=raw_log_input.extra) + flow_logger.error(raw_log_input.log_message, extra=raw_log_input.extra, node_id=node_id) return {"message": "Log added successfully"} diff --git a/flowfile_core/flowfile_core/routes/routes.py b/flowfile_core/flowfile_core/routes/routes.py index 0b9ace110..7c6c8145a 100644 --- a/flowfile_core/flowfile_core/routes/routes.py +++ b/flowfile_core/flowfile_core/routes/routes.py @@ -69,22 +69,28 @@ def get_node_model(setting_name_ref: str): def _auto_register_flow(flow_path: str, name: str, user_id: int | None) -> None: - """Register a flow in the default catalog namespace (General > user_flows) if it exists.""" + """Register a flow in the default catalog namespace (General > user_flows) if it exists. + + Failures are logged at info level since users may wonder why some flows + don't appear in the catalog. + """ if user_id is None or flow_path is None: return try: with get_db_context() as db: general = db.query(CatalogNamespace).filter_by(name="General", parent_id=None).first() if general is None: + logger.info("Auto-registration skipped: 'General' catalog namespace not found") return user_flows = db.query(CatalogNamespace).filter_by( name="user_flows", parent_id=general.id ).first() if user_flows is None: + logger.info("Auto-registration skipped: 'user_flows' schema not found under 'General'") return existing = db.query(FlowRegistration).filter_by(flow_path=flow_path).first() if existing: - return + return # Already registered, silent success reg = FlowRegistration( name=name or Path(flow_path).stem, flow_path=flow_path, @@ -93,8 +99,9 @@ def _auto_register_flow(flow_path: str, name: str, user_id: int | None) -> None: ) db.add(reg) db.commit() + logger.info(f"Auto-registered flow '{reg.name}' in default namespace") except Exception: - logger.debug("Auto-registration in default namespace failed (non-critical)", exc_info=True) + logger.info(f"Auto-registration failed for '{flow_path}' (non-critical)", exc_info=True) @router.post("/upload/") @@ -238,27 +245,38 @@ async def trigger_fetch_node_data(flow_id: int, node_id: int, background_tasks: def _run_and_track(flow, user_id: int | None): - """Wrapper that runs a flow and persists the run record to the database.""" + """Wrapper that runs a flow and persists the run record to the database. + + This runs in a BackgroundTask. If DB persistence fails, the run still + completed but won't appear in the run history. Failures are logged at + ERROR level so they're visible in logs. + """ + flow_name = getattr(flow.flow_settings, "name", None) or getattr(flow, "__name__", "unknown") + run_info = flow.run_graph() if run_info is None: + logger.error(f"Flow '{flow_name}' returned no run_info - run tracking skipped") return # Persist run record + tracking_succeeded = False try: - # Build snapshot + # Build snapshot (non-critical if fails) + snapshot_yaml = None try: snapshot_data = flow.get_flowfile_data() snapshot_yaml = snapshot_data.model_dump_json() - except Exception: - snapshot_yaml = None + except Exception as snap_err: + logger.warning(f"Flow '{flow_name}': snapshot serialization failed: {snap_err}") - # Serialise node results + # Serialise node results (non-critical if fails) + node_results = None try: node_results = json.dumps( [nr.model_dump(mode="json") for nr in (run_info.node_step_result or [])], ) - except Exception: - node_results = None + except Exception as node_err: + logger.warning(f"Flow '{flow_name}': node results serialization failed: {node_err}") duration = None if run_info.start_time and run_info.end_time: @@ -275,7 +293,7 @@ def _run_and_track(flow, user_id: int | None): db_run = FlowRun( registration_id=reg_id, - flow_name=flow.flow_settings.name or flow.__name__, + flow_name=flow_name, flow_path=flow_path, user_id=user_id if user_id is not None else 0, started_at=run_info.start_time, @@ -290,8 +308,25 @@ def _run_and_track(flow, user_id: int | None): ) db.add(db_run) db.commit() + tracking_succeeded = True + logger.info( + f"Flow '{flow_name}' run tracked: success={run_info.success}, " + f"nodes={run_info.nodes_completed}/{run_info.number_of_nodes}, " + f"duration={duration:.2f}s" if duration else f"duration=N/A" + ) except Exception as exc: - logger.warning(f"Failed to persist flow run record: {exc}") + logger.error( + f"Failed to persist run record for flow '{flow_name}'. " + f"The flow {'succeeded' if run_info.success else 'failed'} but won't appear in run history. " + f"Error: {exc}", + exc_info=True, + ) + + if not tracking_succeeded: + logger.error( + f"Run tracking failed for flow '{flow_name}'. " + "Check database connectivity and FlowRun table schema." + ) @router.post('/flow/run/', tags=['editor']) @@ -1007,6 +1042,24 @@ def get_vue_flow_data(flow_id: int) -> schemas.VueFlowInput: return data +@router.get('/flow/artifacts', tags=['editor']) +def get_flow_artifacts(flow_id: int): + """Returns artifact visualization data for the canvas. + + Includes per-node artifact summaries (for badges/tooltips) and + artifact edges (for dashed-line connections between publisher and + consumer nodes). + """ + flow = flow_file_handler.get_flow(flow_id) + if flow is None: + raise HTTPException(404, 'Could not find the flow') + ctx = flow.artifact_context + return { + "nodes": ctx.get_node_summaries(), + "edges": ctx.get_artifact_edges(), + } + + @router.get('/analysis_data/graphic_walker_input', tags=['analysis'], response_model=input_schema.NodeExploreData) def get_graphic_walker_input(flow_id: int, node_id: int): """Gets the data and configuration for the Graphic Walker data exploration tool.""" diff --git a/flowfile_core/flowfile_core/schemas/input_schema.py b/flowfile_core/flowfile_core/schemas/input_schema.py index b1fecc288..46fbae118 100644 --- a/flowfile_core/flowfile_core/schemas/input_schema.py +++ b/flowfile_core/flowfile_core/schemas/input_schema.py @@ -887,6 +887,19 @@ class NodePolarsCode(NodeMultiInput): polars_code_input: transform_schema.PolarsCodeInput +class PythonScriptInput(BaseModel): + """Settings for Python code execution on a kernel.""" + + code: str = "" + kernel_id: str | None = None + + +class NodePythonScript(NodeMultiInput): + """Node that executes Python code on a kernel container.""" + + python_script_input: PythonScriptInput = PythonScriptInput() + + class UserDefinedNode(NodeMultiInput): """Settings for a node that contains the user defined node information""" diff --git a/flowfile_core/flowfile_core/schemas/schemas.py b/flowfile_core/flowfile_core/schemas/schemas.py index 9458e14a7..ed3dcb256 100644 --- a/flowfile_core/flowfile_core/schemas/schemas.py +++ b/flowfile_core/flowfile_core/schemas/schemas.py @@ -28,6 +28,7 @@ "unpivot": input_schema.NodeUnpivot, "text_to_rows": input_schema.NodeTextToRows, "graph_solver": input_schema.NodeGraphSolver, + "python_script": input_schema.NodePythonScript, "polars_code": input_schema.NodePolarsCode, "join": input_schema.NodeJoin, "cross_join": input_schema.NodeCrossJoin, @@ -174,13 +175,15 @@ class RawLogInput(BaseModel): Attributes: flowfile_flow_id (int): The ID of the flow that generated the log. log_message (str): The content of the log message. - log_type (Literal["INFO", "ERROR"]): The type of log. + log_type (Literal["INFO", "WARNING", "ERROR"]): The type of log. + node_id (int | None): Optional node ID to attribute the log to. extra (Optional[dict]): Extra context data for the log. """ flowfile_flow_id: int log_message: str - log_type: Literal["INFO", "ERROR"] + log_type: Literal["INFO", "WARNING", "ERROR"] + node_id: int | None = None extra: dict | None = None diff --git a/flowfile_core/tests/conftest.py b/flowfile_core/tests/conftest.py index 99b1754e0..d4f993f03 100644 --- a/flowfile_core/tests/conftest.py +++ b/flowfile_core/tests/conftest.py @@ -28,6 +28,7 @@ def _patched_hashpw(password, salt): from test_utils.postgres import fixtures as pg_fixtures from tests.flowfile_core_test_utils import is_docker_available +from tests.kernel_fixtures import managed_kernel def is_port_in_use(port, host='localhost'): @@ -263,3 +264,21 @@ def postgres_db(): if not db_info: pytest.fail("PostgreSQL container could not be started") yield db_info + + +@pytest.fixture(scope="session") +def kernel_manager(): + """ + Pytest fixture that builds the flowfile-kernel Docker image, creates a + KernelManager, starts a test kernel, and tears everything down afterwards. + + Yields a (KernelManager, kernel_id) tuple. + """ + if not is_docker_available(): + pytest.skip("Docker is not available, skipping kernel tests") + + try: + with managed_kernel() as ctx: + yield ctx + except Exception as exc: + pytest.skip(f"Kernel container could not be started: {exc}") diff --git a/flowfile_core/tests/flowfile/test_artifact_context.py b/flowfile_core/tests/flowfile/test_artifact_context.py new file mode 100644 index 000000000..f5193637f --- /dev/null +++ b/flowfile_core/tests/flowfile/test_artifact_context.py @@ -0,0 +1,538 @@ +"""Unit tests for flowfile_core.flowfile.artifacts.""" + +from datetime import datetime + +import pytest + +from flowfile_core.flowfile.artifacts import ArtifactContext, ArtifactRef, NodeArtifactState + + +# --------------------------------------------------------------------------- +# ArtifactRef +# --------------------------------------------------------------------------- + + +class TestArtifactRef: + def test_create_ref(self): + ref = ArtifactRef(name="model", source_node_id=1, kernel_id="k1") + assert ref.name == "model" + assert ref.source_node_id == 1 + assert ref.kernel_id == "k1" + assert isinstance(ref.created_at, datetime) + + def test_refs_are_hashable(self): + """Frozen dataclass instances can be used in sets / as dict keys.""" + ref = ArtifactRef(name="model", source_node_id=1) + assert hash(ref) is not None + s = {ref} + assert ref in s + + def test_refs_equality(self): + ts = datetime(2025, 1, 1) + a = ArtifactRef(name="x", source_node_id=1, created_at=ts) + b = ArtifactRef(name="x", source_node_id=1, created_at=ts) + assert a == b + + def test_to_dict(self): + ref = ArtifactRef( + name="model", + source_node_id=1, + kernel_id="k1", + type_name="RandomForest", + module="sklearn.ensemble", + size_bytes=1024, + ) + d = ref.to_dict() + assert d["name"] == "model" + assert d["source_node_id"] == 1 + assert d["kernel_id"] == "k1" + assert d["type_name"] == "RandomForest" + assert d["module"] == "sklearn.ensemble" + assert d["size_bytes"] == 1024 + assert "created_at" in d + + +# --------------------------------------------------------------------------- +# NodeArtifactState +# --------------------------------------------------------------------------- + + +class TestNodeArtifactState: + def test_defaults(self): + state = NodeArtifactState() + assert state.published == [] + assert state.available == {} + assert state.consumed == [] + + def test_to_dict(self): + ref = ArtifactRef(name="m", source_node_id=1, kernel_id="k") + state = NodeArtifactState(published=[ref], available={"m": ref}, consumed=["m"]) + d = state.to_dict() + assert len(d["published"]) == 1 + assert "m" in d["available"] + assert d["consumed"] == ["m"] + + +# --------------------------------------------------------------------------- +# ArtifactContext — Recording +# --------------------------------------------------------------------------- + + +class TestArtifactContextRecording: + def test_record_published_with_dict(self): + ctx = ArtifactContext() + refs = ctx.record_published( + node_id=1, + kernel_id="k1", + artifacts=[{"name": "model", "type_name": "RF"}], + ) + assert len(refs) == 1 + assert refs[0].name == "model" + assert refs[0].type_name == "RF" + assert refs[0].source_node_id == 1 + assert refs[0].kernel_id == "k1" + + def test_record_published_with_string_list(self): + ctx = ArtifactContext() + refs = ctx.record_published(node_id=2, kernel_id="k1", artifacts=["a", "b"]) + assert len(refs) == 2 + assert refs[0].name == "a" + assert refs[1].name == "b" + + def test_record_published_multiple_nodes(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k1", ["encoder"]) + assert len(ctx.get_published_by_node(1)) == 1 + assert len(ctx.get_published_by_node(2)) == 1 + + def test_record_published_updates_kernel_artifacts(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ka = ctx.get_kernel_artifacts("k1") + assert "model" in ka + assert ka["model"].source_node_id == 1 + + def test_record_published_overwrites_same_name_same_node(self): + """Publishing the same artifact name from the same node should overwrite, + not create duplicates. This ensures (node_id, artifact_name) uniqueness.""" + ctx = ArtifactContext() + # First publish + refs1 = ctx.record_published( + node_id=1, + kernel_id="k1", + artifacts=[{"name": "model", "type_name": "RF"}], + ) + assert len(ctx.get_published_by_node(1)) == 1 + assert ctx.get_published_by_node(1)[0].type_name == "RF" + + # Second publish of same artifact name from same node - should overwrite + refs2 = ctx.record_published( + node_id=1, + kernel_id="k1", + artifacts=[{"name": "model", "type_name": "XGBoost"}], + ) + # Should still only have 1 artifact, not 2 + assert len(ctx.get_published_by_node(1)) == 1 + assert ctx.get_published_by_node(1)[0].type_name == "XGBoost" + # Kernel artifacts should also be updated + ka = ctx.get_kernel_artifacts("k1") + assert ka["model"].type_name == "XGBoost" + + def test_record_published_allows_same_name_different_nodes(self): + """Different nodes can publish artifacts with the same name.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", [{"name": "model", "type_name": "RF"}]) + ctx.record_published(2, "k1", [{"name": "model", "type_name": "XGBoost"}]) + # Both nodes should have their own published list + assert len(ctx.get_published_by_node(1)) == 1 + assert len(ctx.get_published_by_node(2)) == 1 + # Kernel artifacts should have the latest (from node 2) + ka = ctx.get_kernel_artifacts("k1") + assert ka["model"].source_node_id == 2 + + def test_record_published_allows_same_name_different_kernels(self): + """Same node can publish same artifact name to different kernels.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", [{"name": "model", "type_name": "RF"}]) + ctx.record_published(1, "k2", [{"name": "model", "type_name": "XGBoost"}]) + # Node 1 should have 2 published artifacts (one per kernel) + assert len(ctx.get_published_by_node(1)) == 2 + # Each kernel should have its own version + assert ctx.get_kernel_artifacts("k1")["model"].type_name == "RF" + assert ctx.get_kernel_artifacts("k2")["model"].type_name == "XGBoost" + + def test_record_consumed(self): + ctx = ArtifactContext() + ctx.record_consumed(5, ["model", "scaler"]) + state = ctx._node_states[5] + assert state.consumed == ["model", "scaler"] + + +# --------------------------------------------------------------------------- +# ArtifactContext — Availability +# --------------------------------------------------------------------------- + + +class TestArtifactContextAvailability: + def test_compute_available_from_direct_upstream(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + avail = ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + assert "model" in avail + assert avail["model"].source_node_id == 1 + + def test_compute_available_transitive(self): + """Node 3 should see artifacts from node 1 via node 2.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + # Node 2 doesn't publish anything + # Node 3 lists both 1 and 2 as upstream + avail = ctx.compute_available(node_id=3, kernel_id="k1", upstream_node_ids=[1, 2]) + assert "model" in avail + + def test_compute_available_different_kernels_isolated(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + avail = ctx.compute_available(node_id=2, kernel_id="k2", upstream_node_ids=[1]) + assert avail == {} + + def test_compute_available_same_kernel_visible(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + avail = ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + assert "model" in avail + + def test_compute_available_stores_on_node_state(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + assert "model" in ctx.get_available_for_node(2) + + def test_compute_available_no_upstream_returns_empty(self): + ctx = ArtifactContext() + avail = ctx.compute_available(node_id=1, kernel_id="k1", upstream_node_ids=[]) + assert avail == {} + + def test_compute_available_multiple_artifacts(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model", "scaler"]) + ctx.record_published(2, "k1", ["encoder"]) + avail = ctx.compute_available(node_id=3, kernel_id="k1", upstream_node_ids=[1, 2]) + assert set(avail.keys()) == {"model", "scaler", "encoder"} + + def test_compute_available_overwrites_previous(self): + """Re-computing availability replaces old data.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + # Re-compute with no upstream + ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[]) + assert ctx.get_available_for_node(2) == {} + + +# --------------------------------------------------------------------------- +# ArtifactContext — Deletion tracking +# --------------------------------------------------------------------------- + + +class TestArtifactContextDeletion: + def test_record_deleted_removes_from_kernel_index(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + assert ctx.get_kernel_artifacts("k1") == {} + + def test_record_deleted_preserves_publisher_published_list(self): + """Deletion does NOT remove from publisher's published list (historical record).""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model", "scaler"]) + ctx.record_deleted(2, "k1", ["model"]) + # Publisher's published list is preserved as historical record + published = ctx.get_published_by_node(1) + names = [r.name for r in published] + assert "model" in names # Still there as historical record + assert "scaler" in names + # The deleting node has it tracked in its deleted list + state = ctx._node_states[2] + assert "model" in state.deleted + + def test_record_deleted_tracks_on_node_state(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + state = ctx._node_states[2] + assert "model" in state.deleted + + def test_deleted_artifact_not_available_downstream(self): + """If node 2 deletes an artifact published by node 1, + node 3 should not see it as available.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + avail = ctx.compute_available(node_id=3, kernel_id="k1", upstream_node_ids=[1, 2]) + assert "model" not in avail + + def test_delete_and_republish_flow(self): + """Node 1 publishes, node 2 deletes, node 3 re-publishes, + node 4 should see the new version.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + ctx.record_published(3, "k1", ["model"]) + avail = ctx.compute_available(node_id=4, kernel_id="k1", upstream_node_ids=[1, 2, 3]) + assert "model" in avail + assert avail["model"].source_node_id == 3 + + +# --------------------------------------------------------------------------- +# ArtifactContext — Clearing +# --------------------------------------------------------------------------- + + +class TestArtifactContextClearing: + def test_clear_kernel_removes_only_that_kernel(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k2", ["encoder"]) + ctx.clear_kernel("k1") + assert ctx.get_kernel_artifacts("k1") == {} + assert "encoder" in ctx.get_kernel_artifacts("k2") + + def test_clear_kernel_preserves_published_lists(self): + """clear_kernel removes from kernel index but preserves published (historical record).""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(1, "k2", ["encoder"]) + ctx.clear_kernel("k1") + # Published list is preserved as historical record + published = ctx.get_published_by_node(1) + names = [r.name for r in published] + assert "model" in names # Still there as historical record + assert "encoder" in names + # But the kernel index is cleared + assert ctx.get_kernel_artifacts("k1") == {} + + def test_clear_kernel_removes_from_available(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + ctx.clear_kernel("k1") + assert ctx.get_available_for_node(2) == {} + + def test_clear_all_removes_everything(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k2", ["encoder"]) + ctx.compute_available(node_id=3, kernel_id="k1", upstream_node_ids=[1]) + ctx.clear_all() + assert ctx.get_published_by_node(1) == [] + assert ctx.get_published_by_node(2) == [] + assert ctx.get_available_for_node(3) == {} + assert ctx.get_kernel_artifacts("k1") == {} + assert ctx.get_kernel_artifacts("k2") == {} + assert ctx.get_all_artifacts() == {} + + +# --------------------------------------------------------------------------- +# ArtifactContext — Selective node clearing +# --------------------------------------------------------------------------- + + +class TestArtifactContextClearNodes: + def test_clear_nodes_removes_only_target(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k1", ["encoder"]) + ctx.clear_nodes({1}) + assert ctx.get_published_by_node(1) == [] + assert len(ctx.get_published_by_node(2)) == 1 + assert ctx.get_kernel_artifacts("k1") == {"encoder": ctx.get_published_by_node(2)[0]} + + def test_clear_nodes_preserves_other_node_metadata(self): + """Clearing node 2 should leave node 1's artifacts intact.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k1", ["scaler"]) + ctx.clear_nodes({2}) + published_1 = ctx.get_published_by_node(1) + assert len(published_1) == 1 + assert published_1[0].name == "model" + ka = ctx.get_kernel_artifacts("k1") + assert "model" in ka + assert "scaler" not in ka + + def test_clear_nodes_empty_set(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.clear_nodes(set()) + assert len(ctx.get_published_by_node(1)) == 1 + + def test_clear_nodes_nonexistent(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.clear_nodes({99}) # Should not raise + assert len(ctx.get_published_by_node(1)) == 1 + + def test_clear_nodes_allows_re_record(self): + """After clearing, the node can re-record new artifacts.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.clear_nodes({1}) + ctx.record_published(1, "k1", ["model_v2"]) + published = ctx.get_published_by_node(1) + assert len(published) == 1 + assert published[0].name == "model_v2" + + def test_clear_nodes_updates_publisher_index(self): + """Publisher index should be cleaned up when a node is cleared.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.clear_nodes({1}) + # After clearing, the artifact should not show up as available + avail = ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + assert avail == {} + + def test_clear_nodes_preserves_upstream_for_downstream(self): + """Simulates debug mode: node 1 is skipped (not cleared), + node 2 is re-running (cleared). Node 3 should still see node 1's artifact.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k1", ["predictions"]) + # Clear only node 2 (it will re-run) + ctx.clear_nodes({2}) + # Node 3 should still see "model" from node 1 + avail = ctx.compute_available(node_id=3, kernel_id="k1", upstream_node_ids=[1, 2]) + assert "model" in avail + assert "predictions" not in avail + + +# --------------------------------------------------------------------------- +# ArtifactContext — Queries +# --------------------------------------------------------------------------- + + +class TestArtifactContextQueries: + def test_get_published_by_node_returns_empty_for_unknown(self): + ctx = ArtifactContext() + assert ctx.get_published_by_node(999) == [] + + def test_get_available_for_node_returns_empty_for_unknown(self): + ctx = ArtifactContext() + assert ctx.get_available_for_node(999) == {} + + def test_get_kernel_artifacts(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["a", "b"]) + ka = ctx.get_kernel_artifacts("k1") + assert set(ka.keys()) == {"a", "b"} + + def test_get_kernel_artifacts_empty(self): + ctx = ArtifactContext() + assert ctx.get_kernel_artifacts("nonexistent") == {} + + def test_get_all_artifacts(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k2", ["encoder"]) + all_arts = ctx.get_all_artifacts() + assert set(all_arts.keys()) == {"model", "encoder"} + + def test_get_all_artifacts_empty(self): + ctx = ArtifactContext() + assert ctx.get_all_artifacts() == {} + + +# --------------------------------------------------------------------------- +# ArtifactContext — Serialisation +# --------------------------------------------------------------------------- + + +class TestArtifactContextSerialization: + def test_to_dict_structure(self): + ctx = ArtifactContext() + ctx.record_published(1, "k1", [{"name": "model", "type_name": "RF"}]) + ctx.compute_available(node_id=2, kernel_id="k1", upstream_node_ids=[1]) + d = ctx.to_dict() + assert "nodes" in d + assert "kernels" in d + assert "1" in d["nodes"] + assert "2" in d["nodes"] + assert "k1" in d["kernels"] + assert "model" in d["kernels"]["k1"] + + def test_to_dict_empty_context(self): + ctx = ArtifactContext() + d = ctx.to_dict() + assert d == {"nodes": {}, "kernels": {}} + + def test_to_dict_is_json_serialisable(self): + import json + + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + d = ctx.to_dict() + # Should not raise + serialised = json.dumps(d) + assert isinstance(serialised, str) + + +# --------------------------------------------------------------------------- +# ArtifactContext — Deletion origin tracking +# --------------------------------------------------------------------------- + + +class TestArtifactContextDeletionOrigins: + def test_get_producer_nodes_for_deletions_basic(self): + """Deleting an artifact tracks the original publisher.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + producers = ctx.get_producer_nodes_for_deletions({2}) + assert producers == {1} + + def test_get_producer_nodes_for_deletions_no_deletions(self): + """Nodes without deletions return an empty set.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + producers = ctx.get_producer_nodes_for_deletions({1}) + assert producers == set() + + def test_get_producer_nodes_for_deletions_multiple_artifacts(self): + """Deleting multiple artifacts from different producers.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k1", ["scaler"]) + ctx.record_deleted(3, "k1", ["model", "scaler"]) + producers = ctx.get_producer_nodes_for_deletions({3}) + assert producers == {1, 2} + + def test_clear_nodes_removes_deletion_origins(self): + """Clearing a deleter node also clears its deletion origins.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + ctx.clear_nodes({2}) + producers = ctx.get_producer_nodes_for_deletions({2}) + assert producers == set() + + def test_clear_all_removes_deletion_origins(self): + """clear_all removes all deletion origin tracking.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_deleted(2, "k1", ["model"]) + ctx.clear_all() + producers = ctx.get_producer_nodes_for_deletions({2}) + assert producers == set() + + def test_clear_kernel_removes_deletion_origins(self): + """clear_kernel removes deletion origins for that kernel only.""" + ctx = ArtifactContext() + ctx.record_published(1, "k1", ["model"]) + ctx.record_published(2, "k2", ["encoder"]) + ctx.record_deleted(3, "k1", ["model"]) + ctx.record_deleted(3, "k2", ["encoder"]) + ctx.clear_kernel("k1") + producers = ctx.get_producer_nodes_for_deletions({3}) + # Only the k2 producer should remain + assert producers == {2} diff --git a/flowfile_core/tests/flowfile/test_artifact_persistence_integration.py b/flowfile_core/tests/flowfile/test_artifact_persistence_integration.py new file mode 100644 index 000000000..f4b6a9400 --- /dev/null +++ b/flowfile_core/tests/flowfile/test_artifact_persistence_integration.py @@ -0,0 +1,421 @@ +""" +Unit-level integration tests for artifact persistence models and +KernelManager proxy methods. + +These tests do NOT require Docker — they use mocked HTTP responses. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from flowfile_core.kernel.models import ( + ArtifactPersistenceInfo, + CleanupRequest, + CleanupResult, + RecoveryMode, + RecoveryStatus, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run(coro): + """Run an async coroutine from sync test code.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +def _make_manager(kernel_id: str = "test-kernel", port: int = 19000): + """Create a KernelManager with a single IDLE kernel, patching Docker.""" + from flowfile_core.kernel.models import KernelInfo, KernelState + + kernel = KernelInfo( + id=kernel_id, + name="Test Kernel", + state=KernelState.IDLE, + port=port, + container_id="fake-container-id", + ) + + with patch("flowfile_core.kernel.manager.docker"): + from flowfile_core.kernel.manager import KernelManager + + with patch.object(KernelManager, "_restore_kernels_from_db"): + with patch.object(KernelManager, "_reclaim_running_containers"): + manager = KernelManager(shared_volume_path="/tmp/test_shared") + + manager._kernels[kernel_id] = kernel + manager._kernel_owners[kernel_id] = 1 + return manager + + +# --------------------------------------------------------------------------- +# Model tests +# --------------------------------------------------------------------------- + + +class TestPersistenceModels: + """Tests for the new persistence-related Pydantic models.""" + + def test_recovery_mode_enum_values(self): + assert RecoveryMode.LAZY == "lazy" + assert RecoveryMode.EAGER == "eager" + assert RecoveryMode.CLEAR == "clear" + + def test_recovery_mode_from_string(self): + assert RecoveryMode("lazy") == RecoveryMode.LAZY + assert RecoveryMode("eager") == RecoveryMode.EAGER + assert RecoveryMode("clear") == RecoveryMode.CLEAR + + def test_recovery_status_defaults(self): + status = RecoveryStatus(status="pending") + assert status.status == "pending" + assert status.mode is None + assert status.recovered == [] + assert status.errors == [] + assert status.indexed is None + + def test_recovery_status_full(self): + status = RecoveryStatus( + status="completed", + mode="eager", + recovered=["model", "encoder"], + errors=[], + ) + assert len(status.recovered) == 2 + + def test_cleanup_request_empty(self): + req = CleanupRequest() + assert req.max_age_hours is None + assert req.artifact_names is None + + def test_cleanup_request_with_age(self): + req = CleanupRequest(max_age_hours=24.0) + assert req.max_age_hours == 24.0 + + def test_cleanup_request_with_names(self): + req = CleanupRequest(artifact_names=[{"flow_id": 0, "name": "model"}]) + assert len(req.artifact_names) == 1 + + def test_cleanup_result(self): + result = CleanupResult(status="cleaned", removed_count=5) + assert result.removed_count == 5 + + def test_persistence_info_disabled(self): + info = ArtifactPersistenceInfo(enabled=False) + assert info.enabled is False + assert info.persisted_count == 0 + assert info.disk_usage_bytes == 0 + + def test_persistence_info_enabled(self): + info = ArtifactPersistenceInfo( + enabled=True, + recovery_mode="lazy", + kernel_id="my-kernel", + persistence_path="/shared/artifacts/my-kernel", + persisted_count=3, + in_memory_count=2, + disk_usage_bytes=1024000, + artifacts={"model": {"persisted": True, "in_memory": True}}, + ) + assert info.persisted_count == 3 + assert info.artifacts["model"]["persisted"] is True + + def test_persistence_info_serialization(self): + info = ArtifactPersistenceInfo( + enabled=True, + kernel_id="k1", + persisted_count=1, + ) + d = info.model_dump() + assert d["enabled"] is True + assert d["kernel_id"] == "k1" + # Should round-trip through JSON + info2 = ArtifactPersistenceInfo(**d) + assert info2 == info + + +# --------------------------------------------------------------------------- +# KernelManager proxy method tests (mocked HTTP) +# --------------------------------------------------------------------------- + + +class TestKernelManagerRecoverArtifacts: + """Tests for KernelManager.recover_artifacts() proxy method.""" + + def test_recover_artifacts_success(self): + manager = _make_manager() + response_data = { + "status": "completed", + "mode": "manual", + "recovered": ["model", "encoder"], + "errors": [], + } + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = _run(manager.recover_artifacts("test-kernel")) + + assert isinstance(result, RecoveryStatus) + assert result.status == "completed" + assert result.recovered == ["model", "encoder"] + + def test_recover_artifacts_kernel_not_running(self): + manager = _make_manager() + manager._kernels["test-kernel"].state = MagicMock(value="stopped") + # Set state to STOPPED + from flowfile_core.kernel.models import KernelState + manager._kernels["test-kernel"].state = KernelState.STOPPED + + with pytest.raises(RuntimeError, match="not running"): + _run(manager.recover_artifacts("test-kernel")) + + def test_recover_artifacts_kernel_not_found(self): + manager = _make_manager() + with pytest.raises(KeyError, match="not found"): + _run(manager.recover_artifacts("nonexistent")) + + +class TestKernelManagerRecoveryStatus: + """Tests for KernelManager.get_recovery_status() proxy method.""" + + def test_get_recovery_status(self): + manager = _make_manager() + response_data = { + "status": "completed", + "mode": "lazy", + "indexed": 5, + "recovered": [], + "errors": [], + } + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = _run(manager.get_recovery_status("test-kernel")) + + assert isinstance(result, RecoveryStatus) + assert result.status == "completed" + assert result.indexed == 5 + + +class TestKernelManagerCleanupArtifacts: + """Tests for KernelManager.cleanup_artifacts() proxy method.""" + + def test_cleanup_by_age(self): + manager = _make_manager() + response_data = {"status": "cleaned", "removed_count": 3} + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + request = CleanupRequest(max_age_hours=24) + result = _run(manager.cleanup_artifacts("test-kernel", request)) + + assert isinstance(result, CleanupResult) + assert result.removed_count == 3 + + def test_cleanup_by_name(self): + manager = _make_manager() + response_data = {"status": "cleaned", "removed_count": 1} + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + request = CleanupRequest( + artifact_names=[{"flow_id": 0, "name": "old_model"}], + ) + result = _run(manager.cleanup_artifacts("test-kernel", request)) + + assert result.removed_count == 1 + + +class TestKernelManagerPersistenceInfo: + """Tests for KernelManager.get_persistence_info() proxy method.""" + + def test_get_persistence_info_enabled(self): + manager = _make_manager() + response_data = { + "enabled": True, + "recovery_mode": "lazy", + "kernel_id": "test-kernel", + "persistence_path": "/shared/artifacts/test-kernel", + "persisted_count": 2, + "in_memory_count": 2, + "disk_usage_bytes": 51200, + "artifacts": { + "model": {"flow_id": 0, "persisted": True, "in_memory": True}, + "encoder": {"flow_id": 0, "persisted": True, "in_memory": False}, + }, + } + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = _run(manager.get_persistence_info("test-kernel")) + + assert isinstance(result, ArtifactPersistenceInfo) + assert result.enabled is True + assert result.persisted_count == 2 + assert result.disk_usage_bytes == 51200 + assert "model" in result.artifacts + assert "encoder" in result.artifacts + + def test_get_persistence_info_disabled(self): + manager = _make_manager() + response_data = { + "enabled": False, + "recovery_mode": "lazy", + "persisted_count": 0, + "disk_usage_bytes": 0, + } + + mock_response = MagicMock() + mock_response.json.return_value = response_data + mock_response.raise_for_status = MagicMock() + + with patch("httpx.AsyncClient") as mock_client_cls: + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client_cls.return_value = mock_client + + result = _run(manager.get_persistence_info("test-kernel")) + + assert result.enabled is False + assert result.persisted_count == 0 + + +# --------------------------------------------------------------------------- +# Docker environment variable injection tests +# --------------------------------------------------------------------------- + + +class TestKernelStartupEnvironment: + """Verify that persistence env vars are injected when starting a kernel.""" + + def test_start_kernel_passes_persistence_env_vars(self): + """start_kernel should pass KERNEL_ID, PERSISTENCE_ENABLED, etc.""" + from flowfile_core.kernel.models import KernelConfig, KernelState + + with patch("flowfile_core.kernel.manager.docker") as mock_docker: + from flowfile_core.kernel.manager import KernelManager + + with patch.object(KernelManager, "_restore_kernels_from_db"): + with patch.object(KernelManager, "_reclaim_running_containers"): + manager = KernelManager(shared_volume_path="/tmp/test") + + # Create a kernel + config = KernelConfig(id="env-test", name="Env Test") + _run(manager.create_kernel(config, user_id=1)) + + # Mock the Docker image check and container run + mock_docker.from_env.return_value.images.get.return_value = MagicMock() + mock_container = MagicMock() + mock_container.id = "fake-id" + mock_docker.from_env.return_value.containers.run.return_value = mock_container + + # Mock health check + with patch.object(manager, "_wait_for_healthy", new_callable=AsyncMock): + _run(manager.start_kernel("env-test")) + + # Verify containers.run was called with persistence env vars + call_args = mock_docker.from_env.return_value.containers.run.call_args + environment = call_args[1]["environment"] + + assert environment["KERNEL_ID"] == "env-test" + assert environment["PERSISTENCE_ENABLED"] == "true" + assert environment["PERSISTENCE_PATH"] == "/shared/artifacts" + assert environment["RECOVERY_MODE"] == "lazy" + + def test_start_kernel_uses_per_kernel_persistence_config(self): + """Persistence env vars should be taken from kernel config, not hardcoded.""" + from flowfile_core.kernel.models import KernelConfig, KernelState + + with patch("flowfile_core.kernel.manager.docker") as mock_docker: + from flowfile_core.kernel.manager import KernelManager + + with patch.object(KernelManager, "_restore_kernels_from_db"): + with patch.object(KernelManager, "_reclaim_running_containers"): + manager = KernelManager(shared_volume_path="/tmp/test") + + # Create a kernel with custom persistence settings + config = KernelConfig( + id="custom-persist", + name="Custom Persistence", + persistence_enabled=False, + recovery_mode=RecoveryMode.EAGER, + ) + _run(manager.create_kernel(config, user_id=1)) + + # Verify the kernel info has the persistence settings + kernel = manager._kernels["custom-persist"] + assert kernel.persistence_enabled is False + assert kernel.recovery_mode == RecoveryMode.EAGER + + # Mock Docker and start the kernel + mock_docker.from_env.return_value.images.get.return_value = MagicMock() + mock_container = MagicMock() + mock_container.id = "fake-id" + mock_docker.from_env.return_value.containers.run.return_value = mock_container + + with patch.object(manager, "_wait_for_healthy", new_callable=AsyncMock): + _run(manager.start_kernel("custom-persist")) + + # Verify containers.run received custom persistence settings + call_args = mock_docker.from_env.return_value.containers.run.call_args + environment = call_args[1]["environment"] + + assert environment["PERSISTENCE_ENABLED"] == "false" + assert environment["RECOVERY_MODE"] == "eager" diff --git a/flowfile_core/tests/flowfile/test_flowfile.py b/flowfile_core/tests/flowfile/test_flowfile.py index f1a386b14..f489cf6fd 100644 --- a/flowfile_core/tests/flowfile/test_flowfile.py +++ b/flowfile_core/tests/flowfile/test_flowfile.py @@ -1750,3 +1750,107 @@ def test_fetch_before_run_debug(): assert len(example_data_after_run) > 0, "There should be data after fetch operation" + +# --------------------------------------------------------------------------- +# FlowGraph — ArtifactContext integration +# --------------------------------------------------------------------------- + + +class TestFlowGraphArtifactContext: + """Tests for ArtifactContext integration on FlowGraph.""" + + def test_flowgraph_has_artifact_context(self): + """FlowGraph initializes with an ArtifactContext.""" + from flowfile_core.flowfile.artifacts import ArtifactContext + + graph = create_graph() + assert hasattr(graph, "artifact_context") + assert isinstance(graph.artifact_context, ArtifactContext) + + def test_get_upstream_node_ids_direct(self): + """Returns direct upstream dependencies.""" + data = [{"a": 1}] + graph = create_graph() + add_manual_input(graph, data, node_id=1) + # Add node 2 depending on node 1 + node_promise = input_schema.NodePromise(flow_id=1, node_id=2, node_type="sample") + graph.add_node_promise(node_promise) + graph.add_sample(input_schema.NodeSample(flow_id=1, node_id=2, depending_on_id=1)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + upstream = graph._get_upstream_node_ids(2) + assert 1 in upstream + + def test_get_upstream_node_ids_transitive(self): + """Returns transitive upstream dependencies (1 -> 2 -> 3).""" + data = [{"a": 1}] + graph = create_graph() + add_manual_input(graph, data, node_id=1) + + # Node 2 depends on 1 + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="sample") + graph.add_node_promise(node_promise_2) + graph.add_sample(input_schema.NodeSample(flow_id=1, node_id=2, depending_on_id=1)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Node 3 depends on 2 + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="sample") + graph.add_node_promise(node_promise_3) + graph.add_sample(input_schema.NodeSample(flow_id=1, node_id=3, depending_on_id=2)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + upstream = graph._get_upstream_node_ids(3) + assert 1 in upstream + assert 2 in upstream + + def test_get_upstream_node_ids_unknown_returns_empty(self): + """Unknown node returns empty list.""" + graph = create_graph() + assert graph._get_upstream_node_ids(999) == [] + + def test_get_required_kernel_ids_no_python_nodes(self): + """Returns empty set when no python_script nodes exist.""" + data = [{"a": 1}] + graph = create_graph() + add_manual_input(graph, data, node_id=1) + assert graph._get_required_kernel_ids() == set() + + def test_get_required_kernel_ids_with_python_nodes(self): + """Returns kernel IDs from python_script nodes.""" + data = [{"a": 1}] + graph = create_graph() + add_manual_input(graph, data, node_id=1) + + node_promise = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise) + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, + node_id=2, + depending_on_id=1, + python_script_input=input_schema.PythonScriptInput( + code='print("hi")', + kernel_id="ml_kernel", + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + assert "ml_kernel" in graph._get_required_kernel_ids() + + def test_run_graph_clears_artifact_context(self): + """Artifact context is cleared at flow start.""" + data = [{"a": 1}] + graph = create_graph() + add_manual_input(graph, data, node_id=1) + + # Pre-populate artifact_context + graph.artifact_context.record_published(99, "test", [{"name": "old"}]) + assert len(graph.artifact_context.get_published_by_node(99)) == 1 + + # Run graph + graph.run_graph() + + # Context should be cleared + assert graph.artifact_context.get_published_by_node(99) == [] + diff --git a/flowfile_core/tests/flowfile/test_kernel_integration.py b/flowfile_core/tests/flowfile/test_kernel_integration.py new file mode 100644 index 000000000..c551beb98 --- /dev/null +++ b/flowfile_core/tests/flowfile/test_kernel_integration.py @@ -0,0 +1,1829 @@ +""" +Integration tests for the Docker-based kernel system. + +These tests require Docker to be available. The ``kernel_manager`` fixture +(session-scoped, defined in conftest.py) builds the flowfile-kernel image, +starts a container, and tears it down after all tests in this module finish. +""" + +import asyncio +import os +from pathlib import Path +from typing import Literal + +import polars as pl +import pytest + +from flowfile_core.flowfile.flow_data_engine.flow_data_engine import FlowDataEngine +from flowfile_core.flowfile.flow_graph import FlowGraph, RunInformation, add_connection +from flowfile_core.flowfile.handler import FlowfileHandler +from flowfile_core.kernel.manager import KernelManager +from flowfile_core.kernel.models import ExecuteRequest, ExecuteResult +from flowfile_core.schemas import input_schema, schemas + +pytestmark = pytest.mark.kernel + + +# --------------------------------------------------------------------------- +# Helpers (same pattern as test_flowfile.py) +# --------------------------------------------------------------------------- + + +def _run(coro): + """Run an async coroutine from sync test code.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +def _create_graph( + flow_id: int = 1, + execution_mode: Literal["Development", "Performance"] = "Development", + execution_location: Literal["local", "remote"] | None = "remote", +) -> FlowGraph: + handler = FlowfileHandler() + handler.register_flow( + schemas.FlowSettings( + flow_id=flow_id, + name="kernel_test_flow", + path=".", + execution_mode=execution_mode, + execution_location=execution_location, + ) + ) + return handler.get_flow(flow_id) + + +def _handle_run_info(run_info: RunInformation): + if not run_info.success: + errors = "errors:" + for step in run_info.node_step_result: + if not step.success: + errors += f"\n node_id:{step.node_id}, error: {step.error}" + raise ValueError(f"Graph should run successfully:\n{errors}") + + +# --------------------------------------------------------------------------- +# Tests — kernel runtime (direct manager interaction) +# --------------------------------------------------------------------------- + + +class TestKernelRuntime: + """Tests that exercise the kernel container directly via KernelManager.""" + + def test_health_check(self, kernel_manager: tuple[KernelManager, str]): + """Kernel container responds to health checks.""" + manager, kernel_id = kernel_manager + info = _run(manager.get_kernel(kernel_id)) + assert info is not None + assert info.state.value == "idle" + + def test_execute_print(self, kernel_manager: tuple[KernelManager, str]): + """Simple print() produces stdout.""" + manager, kernel_id = kernel_manager + result: ExecuteResult = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=1, + code='print("hello from kernel")', + input_paths={}, + output_dir="/shared/test_print", + ), + ) + ) + assert result.success + assert "hello from kernel" in result.stdout + assert result.error is None + + def test_execute_syntax_error(self, kernel_manager: tuple[KernelManager, str]): + """Syntax errors are captured, not raised.""" + manager, kernel_id = kernel_manager + result: ExecuteResult = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=2, + code="def broken(", + input_paths={}, + output_dir="/shared/test_syntax_err", + ), + ) + ) + assert not result.success + assert result.error is not None + + def test_publish_and_list_artifacts(self, kernel_manager: tuple[KernelManager, str]): + """publish_artifact stores an object; list_artifacts returns metadata.""" + manager, kernel_id = kernel_manager + + # Clear any leftover artifacts from previous tests + _run(manager.clear_artifacts(kernel_id)) + + result: ExecuteResult = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=3, + code='flowfile.publish_artifact("my_dict", {"a": 1, "b": 2})', + input_paths={}, + output_dir="/shared/test_artifact", + ), + ) + ) + assert result.success + assert "my_dict" in result.artifacts_published + + def test_read_and_write_parquet(self, kernel_manager: tuple[KernelManager, str]): + """Kernel can read input parquet and write output parquet.""" + manager, kernel_id = kernel_manager + shared = manager.shared_volume_path + + # Prepare input parquet + input_dir = os.path.join(shared, "test_rw", "inputs") + output_dir = os.path.join(shared, "test_rw", "outputs") + os.makedirs(input_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + df_in = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + df_in.write_parquet(os.path.join(input_dir, "main.parquet")) + + code = """ +import polars as pl +df = flowfile.read_input() +df = df.with_columns((pl.col("x") * pl.col("y")).alias("product")) +flowfile.publish_output(df) +""" + + result: ExecuteResult = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=4, + code=code, + input_paths={"main": ["/shared/test_rw/inputs/main.parquet"]}, + output_dir="/shared/test_rw/outputs", + ), + ) + ) + assert result.success, f"Kernel execution failed: {result.error}" + assert len(result.output_paths) > 0 + + # Verify output + out_path = os.path.join(output_dir, "main.parquet") + assert os.path.exists(out_path), f"Expected output parquet at {out_path}" + df_out = pl.read_parquet(out_path) + assert "product" in df_out.columns + assert df_out["product"].to_list() == [10, 40, 90] + + def test_multiple_inputs(self, kernel_manager: tuple[KernelManager, str]): + """Kernel can read multiple named inputs.""" + manager, kernel_id = kernel_manager + shared = manager.shared_volume_path + + input_dir = os.path.join(shared, "test_multi", "inputs") + output_dir = os.path.join(shared, "test_multi", "outputs") + os.makedirs(input_dir, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) + + pl.DataFrame({"id": [1, 2], "name": ["a", "b"]}).write_parquet( + os.path.join(input_dir, "left.parquet") + ) + pl.DataFrame({"id": [1, 2], "score": [90, 80]}).write_parquet( + os.path.join(input_dir, "right.parquet") + ) + + code = """ +inputs = flowfile.read_inputs() +left = inputs["left"].collect() +right = inputs["right"].collect() +merged = left.join(right, on="id") +flowfile.publish_output(merged) +""" + result = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=5, + code=code, + input_paths={ + "left": ["/shared/test_multi/inputs/left.parquet"], + "right": ["/shared/test_multi/inputs/right.parquet"], + }, + output_dir="/shared/test_multi/outputs", + ), + ) + ) + assert result.success, f"Kernel execution failed: {result.error}" + + df_out = pl.read_parquet(os.path.join(output_dir, "main.parquet")) + assert set(df_out.columns) == {"id", "name", "score"} + assert len(df_out) == 2 + + def test_stderr_captured(self, kernel_manager: tuple[KernelManager, str]): + """Writes to stderr are captured.""" + manager, kernel_id = kernel_manager + result = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=6, + code='import sys; sys.stderr.write("warn\\n")', + input_paths={}, + output_dir="/shared/test_stderr", + ), + ) + ) + assert result.success + assert "warn" in result.stderr + + def test_execution_time_tracked(self, kernel_manager: tuple[KernelManager, str]): + """execution_time_ms is populated.""" + manager, kernel_id = kernel_manager + result = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=7, + code="x = sum(range(100000))", + input_paths={}, + output_dir="/shared/test_timing", + ), + ) + ) + assert result.success + assert result.execution_time_ms > 0 + + +# --------------------------------------------------------------------------- +# Tests — python_script node in FlowGraph +# --------------------------------------------------------------------------- + + +class TestPythonScriptNode: + """ + Tests that wire up the python_script node type inside a FlowGraph and + run the graph end-to-end against a real kernel container. + """ + + def test_python_script_passthrough(self, kernel_manager: tuple[KernelManager, str]): + """ + python_script node reads input, passes it through, and writes output. + """ + manager, kernel_id = kernel_manager + # Patch the singleton so flow_graph picks up *this* manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: manual input + data = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: python_script + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, + node_id=2, + depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, + kernel_id=kernel_id, + ), + ) + ) + + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + result = graph.get_node(2).get_resulting_data() + assert result is not None + df = result.data_frame + if hasattr(df, "collect"): + df = df.collect() + assert len(df) == 2 + assert set(df.columns) >= {"name", "age"} + + finally: + _kernel_mod._manager = _prev + + def test_python_script_transform(self, kernel_manager: tuple[KernelManager, str]): + """ + python_script node transforms data (adds a column). + """ + manager, kernel_id = kernel_manager + + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}, {"val": 2}, {"val": 3}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +import polars as pl +df = flowfile.read_input().collect() +df = df.with_columns((pl.col("val") * 10).alias("val_x10")) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, + node_id=2, + depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, + kernel_id=kernel_id, + ), + ) + ) + + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + result = graph.get_node(2).get_resulting_data() + assert result is not None + df = result.data_frame + if hasattr(df, "collect"): + df = df.collect() + assert "val_x10" in df.columns + assert df["val_x10"].to_list() == [10, 20, 30] + + finally: + _kernel_mod._manager = _prev + + def test_python_script_no_kernel_raises(self): + """ + If no kernel_id is set, the node should raise at execution time. + """ + graph = _create_graph() + + data = [{"a": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, + node_id=2, + depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code='print("hi")', + kernel_id=None, # intentionally no kernel + ), + ) + ) + + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + # Should fail because no kernel is selected + assert not run_info.success + + +# --------------------------------------------------------------------------- +# Tests — ArtifactContext integration (requires real kernel container) +# --------------------------------------------------------------------------- + + +class TestArtifactContextIntegration: + """Integration tests verifying ArtifactContext works with real kernel execution.""" + + def test_published_artifacts_recorded_in_context(self, kernel_manager: tuple[KernelManager, str]): + """After execution, published artifacts appear in artifact_context.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_artifact("my_model", {"accuracy": 0.95}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + published = graph.artifact_context.get_published_by_node(2) + assert len(published) >= 1 + names = [r.name for r in published] + assert "my_model" in names + finally: + _kernel_mod._manager = _prev + + def test_available_artifacts_computed_before_execution(self, kernel_manager: tuple[KernelManager, str]): + """Downstream nodes have correct available artifacts.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: publishes artifact + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + code_publish = """ +df = flowfile.read_input() +flowfile.publish_artifact("trained_model", {"type": "RF"}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code_publish, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Node 3: reads artifact (downstream of node 2) + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + code_consume = """ +df = flowfile.read_input() +model = flowfile.read_artifact("trained_model") +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=code_consume, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + # Node 3 should have "trained_model" available + available = graph.artifact_context.get_available_for_node(3) + assert "trained_model" in available + + finally: + _kernel_mod._manager = _prev + + def test_artifacts_cleared_between_runs(self, kernel_manager: tuple[KernelManager, str]): + """Running flow twice doesn't leak artifacts from first run.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_artifact("run_artifact", [1, 2, 3]) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # First run + run_info = graph.run_graph() + _handle_run_info(run_info) + assert len(graph.artifact_context.get_published_by_node(2)) >= 1 + + # Second run — context should be cleared at start then repopulated + run_info2 = graph.run_graph() + _handle_run_info(run_info2) + + # Should still have the artifact from this run, but no leftover state + published = graph.artifact_context.get_published_by_node(2) + names = [r.name for r in published] + assert "run_artifact" in names + # Verify it's exactly one entry (not duplicated from first run) + assert names.count("run_artifact") == 1 + + finally: + _kernel_mod._manager = _prev + + def test_multiple_artifacts_from_single_node(self, kernel_manager: tuple[KernelManager, str]): + """Node publishing multiple artifacts records all of them.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", {"type": "classifier"}) +flowfile.publish_artifact("encoder", {"type": "label_encoder"}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + published = graph.artifact_context.get_published_by_node(2) + names = {r.name for r in published} + assert "model" in names + assert "encoder" in names + + finally: + _kernel_mod._manager = _prev + + def test_artifact_context_to_dict_after_run(self, kernel_manager: tuple[KernelManager, str]): + """to_dict() returns valid structure after flow execution.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_artifact("ctx_model", {"version": 1}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + d = graph.artifact_context.to_dict() + assert "nodes" in d + assert "kernels" in d + # Should have at least node 2 in nodes + assert "2" in d["nodes"] + # Kernel should be tracked + assert kernel_id in d["kernels"] + + finally: + _kernel_mod._manager = _prev + + def test_train_model_and_apply(self, kernel_manager: tuple[KernelManager, str]): + """Train a numpy linear-regression model in node 2, apply it in node 3.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input data with features and target + data = [ + {"x1": 1.0, "x2": 2.0, "y": 5.0}, + {"x1": 2.0, "x2": 3.0, "y": 8.0}, + {"x1": 3.0, "x2": 4.0, "y": 11.0}, + {"x1": 4.0, "x2": 5.0, "y": 14.0}, + ] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: train model (least-squares fit) and publish as artifact + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + train_code = """ +import numpy as np +import polars as pl + +df = flowfile.read_input().collect() +X = np.column_stack([df["x1"].to_numpy(), df["x2"].to_numpy(), np.ones(len(df))]) +y_vals = df["y"].to_numpy() +coeffs = np.linalg.lstsq(X, y_vals, rcond=None)[0] +flowfile.publish_artifact("linear_model", {"coefficients": coeffs.tolist()}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=train_code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Node 3: load model and apply predictions + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + apply_code = """ +import numpy as np +import polars as pl + +df = flowfile.read_input().collect() +model = flowfile.read_artifact("linear_model") +coeffs = np.array(model["coefficients"]) +X = np.column_stack([df["x1"].to_numpy(), df["x2"].to_numpy(), np.ones(len(df))]) +predictions = X @ coeffs +result = df.with_columns(pl.Series("predicted_y", predictions)) +flowfile.publish_output(result) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=apply_code, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + # Verify model was published and tracked + published = graph.artifact_context.get_published_by_node(2) + assert any(r.name == "linear_model" for r in published) + + # Verify node 3 had the model available + available = graph.artifact_context.get_available_for_node(3) + assert "linear_model" in available + + # Verify predictions were produced + node_3 = graph.get_node(3) + result_df = node_3.get_resulting_data().data_frame.collect() + assert "predicted_y" in result_df.columns + # The predictions should be close to the actual y values + preds = result_df["predicted_y"].to_list() + actuals = result_df["y"].to_list() + for pred, actual in zip(preds, actuals): + assert abs(pred - actual) < 0.01, f"Prediction {pred} too far from {actual}" + + finally: + _kernel_mod._manager = _prev + + def test_publish_delete_republish_access(self, kernel_manager: tuple[KernelManager, str]): + """ + Flow: node_a publishes model -> node_b uses & deletes model -> + node_c publishes new model -> node_d accesses new model. + """ + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input data + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2 (node_a): publish artifact_model v1 + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + code_a = """ +df = flowfile.read_input() +flowfile.publish_artifact("artifact_model", {"version": 1, "weights": [0.5]}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code_a, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Node 3 (node_b): read artifact_model, use it, then delete it + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + code_b = """ +df = flowfile.read_input() +model = flowfile.read_artifact("artifact_model") +assert model["version"] == 1, f"Expected v1, got {model}" +flowfile.delete_artifact("artifact_model") +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=code_b, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + # Node 4 (node_c): publish new artifact_model v2 + node_promise_4 = input_schema.NodePromise(flow_id=1, node_id=4, node_type="python_script") + graph.add_node_promise(node_promise_4) + code_c = """ +df = flowfile.read_input() +flowfile.publish_artifact("artifact_model", {"version": 2, "weights": [0.9]}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=4, depending_on_ids=[3], + python_script_input=input_schema.PythonScriptInput( + code=code_c, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(3, 4)) + + # Node 5 (node_d): read artifact_model — should get v2 + node_promise_5 = input_schema.NodePromise(flow_id=1, node_id=5, node_type="python_script") + graph.add_node_promise(node_promise_5) + code_d = """ +df = flowfile.read_input() +model = flowfile.read_artifact("artifact_model") +assert model["version"] == 2, f"Expected v2, got {model}" +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=5, depending_on_ids=[4], + python_script_input=input_schema.PythonScriptInput( + code=code_d, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(4, 5)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + # Verify artifact context tracks the flow correctly + # Node 4 re-published artifact_model + published_4 = graph.artifact_context.get_published_by_node(4) + assert any(r.name == "artifact_model" for r in published_4) + + # Node 5 should see artifact_model as available (from node 4) + available_5 = graph.artifact_context.get_available_for_node(5) + assert "artifact_model" in available_5 + assert available_5["artifact_model"].source_node_id == 4 + + finally: + _kernel_mod._manager = _prev + + def test_duplicate_publish_fails(self, kernel_manager: tuple[KernelManager, str]): + """Publishing an artifact with the same name without deleting first should fail.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + data = [{"val": 1}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: publishes artifact + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + code_publish = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", "v1") +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code_publish, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Node 3: tries to publish same name without deleting — should fail + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + code_dup = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", "v2") +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=code_dup, kernel_id=kernel_id, + ), + ) + ) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + run_info = graph.run_graph() + + # Node 3 should have failed + node_3_result = next( + r for r in run_info.node_step_result if r.node_id == 3 + ) + assert node_3_result.success is False + assert "already exists" in node_3_result.error + + finally: + _kernel_mod._manager = _prev + + def test_multi_input_python_script(self, kernel_manager: tuple[KernelManager, str]): + """python_script node receives data from multiple input nodes and unions them.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: first input dataset + data_a = [{"id": 1, "value": "alpha"}, {"id": 2, "value": "beta"}] + node_promise_1 = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise_1) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data_a), + ) + ) + + # Node 2: second input dataset (same schema, different rows) + data_b = [{"id": 3, "value": "gamma"}, {"id": 4, "value": "delta"}] + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="manual_input") + graph.add_node_promise(node_promise_2) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=2, + raw_data_format=input_schema.RawData.from_pylist(data_b), + ) + ) + + # Node 3: python_script that reads all inputs (union) and outputs the result + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + + code = """ +import polars as pl +df = flowfile.read_input().collect() +# Should contain all 4 rows from both inputs +assert len(df) == 4, f"Expected 4 rows, got {len(df)}" +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[1, 2], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + + # Connect both inputs to node 3 + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 3)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + # Verify the output contains all rows from both inputs + result = graph.get_node(3).get_resulting_data() + assert result is not None + df = result.data_frame + if hasattr(df, "collect"): + df = df.collect() + assert len(df) == 4 + assert set(df.columns) >= {"id", "value"} + ids = sorted(df["id"].to_list()) + assert ids == [1, 2, 3, 4] + + finally: + _kernel_mod._manager = _prev + + def test_multi_input_read_inputs_named(self, kernel_manager: tuple[KernelManager, str]): + """python_script node uses read_inputs() to access multiple named inputs individually.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: users dataset + users = [{"user_id": 1, "name": "Alice"}, {"user_id": 2, "name": "Bob"}] + node_promise_1 = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise_1) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(users), + ) + ) + + # Node 2: scores dataset + scores = [{"user_id": 1, "score": 95}, {"user_id": 2, "score": 87}] + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="manual_input") + graph.add_node_promise(node_promise_2) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=2, + raw_data_format=input_schema.RawData.from_pylist(scores), + ) + ) + + # Node 3: python_script that reads first input and passes it through + # Since all inputs go under "main", read_first gets just the first + node_promise_3 = input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script") + graph.add_node_promise(node_promise_3) + + code = """ +import polars as pl +df = flowfile.read_first().collect() +# read_first should return only the first input (2 rows, not 4) +assert len(df) == 2, f"Expected 2 rows from read_first, got {len(df)}" +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[1, 2], + python_script_input=input_schema.PythonScriptInput( + code=code, kernel_id=kernel_id, + ), + ) + ) + + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 3)) + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(2, 3)) + + run_info = graph.run_graph() + _handle_run_info(run_info) + + result = graph.get_node(3).get_resulting_data() + assert result is not None + df = result.data_frame + if hasattr(df, "collect"): + df = df.collect() + # read_first returns only the first input's data + assert len(df) == 2 + + finally: + _kernel_mod._manager = _prev + + +# --------------------------------------------------------------------------- +# Tests — debug mode artifact persistence +# --------------------------------------------------------------------------- + + +class TestDebugModeArtifactPersistence: + """Integration tests verifying that artifacts survive re-runs in debug + (Development) mode when the producing node is skipped (up-to-date) but + a downstream consumer node needs to re-execute. + + This reproduces the exact scenario from the bug report: + 1. First run: Node 2 publishes 'linear_model', Node 3 reads it — OK. + 2. User changes Node 3's code. + 3. Second run: Node 2 is up-to-date → skipped, Node 3 re-runs → + must still be able to read 'linear_model' from kernel memory. + """ + + def test_artifact_survives_when_producer_skipped( + self, kernel_manager: tuple[KernelManager, str], + ): + """Core scenario: producer skipped, consumer re-runs, artifact accessible.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input data + data = [ + {"x1": 1.0, "x2": 2.0, "y": 5.0}, + {"x1": 2.0, "x2": 3.0, "y": 8.0}, + {"x1": 3.0, "x2": 4.0, "y": 11.0}, + {"x1": 4.0, "x2": 5.0, "y": 14.0}, + ] + node_promise_1 = input_schema.NodePromise( + flow_id=1, node_id=1, node_type="manual_input", + ) + graph.add_node_promise(node_promise_1) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: train model and publish as artifact + node_promise_2 = input_schema.NodePromise( + flow_id=1, node_id=2, node_type="python_script", + ) + graph.add_node_promise(node_promise_2) + train_code = """ +import numpy as np +import polars as pl + +df = flowfile.read_input().collect() +X = np.column_stack([df["x1"].to_numpy(), df["x2"].to_numpy(), np.ones(len(df))]) +y_vals = df["y"].to_numpy() +coeffs = np.linalg.lstsq(X, y_vals, rcond=None)[0] +flowfile.publish_artifact("linear_model", {"coefficients": coeffs.tolist()}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=train_code, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(1, 2), + ) + + # Node 3: read model artifact and produce predictions + node_promise_3 = input_schema.NodePromise( + flow_id=1, node_id=3, node_type="python_script", + ) + graph.add_node_promise(node_promise_3) + apply_code_v1 = """ +import numpy as np +import polars as pl + +df = flowfile.read_input().collect() +model = flowfile.read_artifact("linear_model") +coeffs = np.array(model["coefficients"]) +X = np.column_stack([df["x1"].to_numpy(), df["x2"].to_numpy(), np.ones(len(df))]) +predictions = X @ coeffs +result = df.with_columns(pl.Series("predicted_y", predictions)) +flowfile.publish_output(result) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=apply_code_v1, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(2, 3), + ) + + # ---- First run: everything executes ---- + run_info_1 = graph.run_graph() + _handle_run_info(run_info_1) + + # Verify artifact was published and predictions were produced + published = graph.artifact_context.get_published_by_node(2) + assert any(r.name == "linear_model" for r in published) + node_3_df = graph.get_node(3).get_resulting_data().data_frame.collect() + assert "predicted_y" in node_3_df.columns + + # ---- Change Node 3's code (simulates user editing the consumer) ---- + # The new code still reads the same artifact but adds an extra column. + apply_code_v2 = """ +import numpy as np +import polars as pl + +df = flowfile.read_input().collect() +model = flowfile.read_artifact("linear_model") +coeffs = np.array(model["coefficients"]) +X = np.column_stack([df["x1"].to_numpy(), df["x2"].to_numpy(), np.ones(len(df))]) +predictions = X @ coeffs +residuals = df["y"].to_numpy() - predictions +result = df.with_columns( + pl.Series("predicted_y", predictions), + pl.Series("residual", residuals), +) +flowfile.publish_output(result) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=apply_code_v2, kernel_id=kernel_id, + ), + ) + ) + + # Verify the execution state before second run: + # Node 2 (producer) should still be up-to-date + node_2 = graph.get_node(2) + assert node_2._execution_state.has_run_with_current_setup, ( + "Producer node should be up-to-date (will be skipped)" + ) + # Node 3 (consumer) should need re-execution + node_3 = graph.get_node(3) + assert not node_3._execution_state.has_run_with_current_setup, ( + "Consumer node should be invalidated (will re-run)" + ) + + # ---- Second run: Node 2 is skipped, Node 3 re-runs ---- + # This is the critical test: Node 3 must still be able to + # read "linear_model" from kernel memory even though Node 2 + # did not re-execute. + run_info_2 = graph.run_graph() + _handle_run_info(run_info_2) + + # Verify the producer's artifact metadata is still tracked + published_after = graph.artifact_context.get_published_by_node(2) + assert any(r.name == "linear_model" for r in published_after), ( + "Producer's artifact metadata should be preserved when skipped" + ) + + # Verify the consumer ran with the new code (has residual column) + node_3_df_v2 = graph.get_node(3).get_resulting_data().data_frame.collect() + assert "predicted_y" in node_3_df_v2.columns + assert "residual" in node_3_df_v2.columns, ( + "Consumer should have run with updated code" + ) + # Residuals should be near-zero for this perfect linear fit + for r in node_3_df_v2["residual"].to_list(): + assert abs(r) < 0.01 + + finally: + _kernel_mod._manager = _prev + + def test_multiple_artifacts_survive_selective_clear( + self, kernel_manager: tuple[KernelManager, str], + ): + """Multiple artifacts from a skipped producer survive when only + the consumer is re-run.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input data + data = [{"val": 10}, {"val": 20}, {"val": 30}] + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input"), + ) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: publish two artifacts (model + scaler) + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script"), + ) + producer_code = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", {"type": "linear", "coeff": 2.0}) +flowfile.publish_artifact("scaler", {"mean": 20.0, "std": 10.0}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=producer_code, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(1, 2), + ) + + # Node 3: read both artifacts + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script"), + ) + consumer_code_v1 = """ +import polars as pl +df = flowfile.read_input().collect() +model = flowfile.read_artifact("model") +scaler = flowfile.read_artifact("scaler") +result = df.with_columns( + (pl.col("val") * model["coeff"]).alias("scaled"), +) +flowfile.publish_output(result) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=consumer_code_v1, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(2, 3), + ) + + # First run + _handle_run_info(graph.run_graph()) + + # Change the consumer's code — also use the scaler now + consumer_code_v2 = """ +import polars as pl +df = flowfile.read_input().collect() +model = flowfile.read_artifact("model") +scaler = flowfile.read_artifact("scaler") +normalized = (pl.col("val") - scaler["mean"]) / scaler["std"] +result = df.with_columns( + (pl.col("val") * model["coeff"]).alias("scaled"), + normalized.alias("normalized"), +) +flowfile.publish_output(result) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=consumer_code_v2, kernel_id=kernel_id, + ), + ) + ) + + # Second run — producer skipped, consumer re-runs + _handle_run_info(graph.run_graph()) + + # Both artifacts should still be accessible + published = graph.artifact_context.get_published_by_node(2) + names = {r.name for r in published} + assert "model" in names + assert "scaler" in names + + # Consumer should have the new column + df_out = graph.get_node(3).get_resulting_data().data_frame.collect() + assert "scaled" in df_out.columns + assert "normalized" in df_out.columns + + finally: + _kernel_mod._manager = _prev + + def test_rerun_producer_clears_old_artifacts( + self, kernel_manager: tuple[KernelManager, str], + ): + """When the producer itself is changed and re-runs, its old + artifacts are properly cleared before re-execution.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input + data = [{"val": 1}] + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input"), + ) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: publish artifact v1 + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script"), + ) + code_v1 = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", {"version": 1}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code_v1, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(1, 2), + ) + + # Node 3: read artifact + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script"), + ) + consumer_code = """ +df = flowfile.read_input() +model = flowfile.read_artifact("model") +print(f"model version: {model['version']}") +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=consumer_code, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(2, 3), + ) + + # First run + _handle_run_info(graph.run_graph()) + + published = graph.artifact_context.get_published_by_node(2) + assert any(r.name == "model" for r in published) + + # Change the PRODUCER (Node 2) — publish v2 of the artifact + code_v2 = """ +df = flowfile.read_input() +flowfile.publish_artifact("model", {"version": 2}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code_v2, kernel_id=kernel_id, + ), + ) + ) + + # Both Node 2 and Node 3 should need re-execution + # (Node 3 because its upstream changed via evaluate_nodes) + assert not graph.get_node(2)._execution_state.has_run_with_current_setup + assert not graph.get_node(3)._execution_state.has_run_with_current_setup + + # Second run — both re-execute; old "model" must be cleared + # before Node 2 re-publishes, otherwise publish would fail + # with "already exists". + _handle_run_info(graph.run_graph()) + + # Artifact should be the new version + published_v2 = graph.artifact_context.get_published_by_node(2) + assert any(r.name == "model" for r in published_v2) + + finally: + _kernel_mod._manager = _prev + + def test_deleted_artifact_producer_reruns_on_consumer_change( + self, kernel_manager: tuple[KernelManager, str], + ): + """When a consumer that deleted an artifact is re-run, the + producer must also re-run so the artifact is available again.""" + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + graph = _create_graph() + + # Node 1: input data + data = [{"x1": 1, "x2": 2, "y": 5}, {"x1": 3, "x2": 4, "y": 11}] + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input"), + ) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + # Node 2: publish artifact + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script"), + ) + producer_code = """ +df = flowfile.read_input() +flowfile.publish_artifact("linear_model", {"coefficients": [1.0, 2.0, 3.0]}) +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=2, depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=producer_code, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(1, 2), + ) + + # Node 3: read artifact, use it, then delete it + graph.add_node_promise( + input_schema.NodePromise(flow_id=1, node_id=3, node_type="python_script"), + ) + consumer_code_v1 = """ +import polars as pl +df = flowfile.read_input().collect() +model = flowfile.read_artifact("linear_model") +coeffs = model["coefficients"] +result = df.with_columns(pl.lit(coeffs[0]).alias("c0")) +flowfile.publish_output(result) +flowfile.delete_artifact("linear_model") +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=consumer_code_v1, kernel_id=kernel_id, + ), + ) + ) + add_connection( + graph, + input_schema.NodeConnection.create_from_simple_input(2, 3), + ) + + # First run — everything works + _handle_run_info(graph.run_graph()) + + # Verify node 3 produced output + df_out = graph.get_node(3).get_resulting_data().data_frame.collect() + assert "c0" in df_out.columns + + # Change the consumer's code (node 3) — still deletes the artifact + consumer_code_v2 = """ +import polars as pl +df = flowfile.read_input().collect() +model = flowfile.read_artifact("linear_model") +coeffs = model["coefficients"] +result = df.with_columns( + pl.lit(coeffs[0]).alias("c0"), + pl.lit(coeffs[1]).alias("c1"), +) +flowfile.publish_output(result) +flowfile.delete_artifact("linear_model") +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, node_id=3, depending_on_ids=[2], + python_script_input=input_schema.PythonScriptInput( + code=consumer_code_v2, kernel_id=kernel_id, + ), + ) + ) + + # Second run — consumer re-runs; producer must also re-run + # because the artifact was deleted on the first run. + _handle_run_info(graph.run_graph()) + + # Consumer should have the new columns + df_out2 = graph.get_node(3).get_resulting_data().data_frame.collect() + assert "c0" in df_out2.columns + assert "c1" in df_out2.columns + + finally: + _kernel_mod._manager = _prev + + +# --------------------------------------------------------------------------- +# Tests — auto-restart stopped/errored kernels +# --------------------------------------------------------------------------- + + +class TestKernelAutoRestart: + """Tests verifying that stopped/errored kernels auto-restart on execution.""" + + def test_execute_sync_restarts_stopped_kernel(self, kernel_manager: tuple[KernelManager, str]): + """execute_sync auto-restarts a STOPPED kernel instead of raising.""" + from flowfile_core.kernel.models import KernelState + + manager, kernel_id = kernel_manager + + # Stop the kernel + _run(manager.stop_kernel(kernel_id)) + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.STOPPED + + # execute_sync should auto-restart and succeed + result = manager.execute_sync( + kernel_id, + ExecuteRequest( + node_id=100, + code='print("restarted!")', + input_paths={}, + output_dir="/shared/test_restart", + ), + ) + assert result.success + assert "restarted!" in result.stdout + + # Kernel should be IDLE again + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.IDLE + + def test_execute_async_restarts_stopped_kernel(self, kernel_manager: tuple[KernelManager, str]): + """async execute() auto-restarts a STOPPED kernel instead of raising.""" + from flowfile_core.kernel.models import KernelState + + manager, kernel_id = kernel_manager + + # Stop the kernel + _run(manager.stop_kernel(kernel_id)) + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.STOPPED + + # execute should auto-restart and succeed + result = _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=101, + code='print("async restarted!")', + input_paths={}, + output_dir="/shared/test_restart_async", + ), + ) + ) + assert result.success + assert "async restarted!" in result.stdout + + # Kernel should be IDLE again + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.IDLE + + def test_clear_node_artifacts_restarts_stopped_kernel(self, kernel_manager: tuple[KernelManager, str]): + """clear_node_artifacts_sync auto-restarts a STOPPED kernel.""" + from flowfile_core.kernel.models import KernelState + + manager, kernel_id = kernel_manager + + # Stop the kernel + _run(manager.stop_kernel(kernel_id)) + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.STOPPED + + # clear_node_artifacts_sync should auto-restart and succeed + result = manager.clear_node_artifacts_sync(kernel_id, node_ids=[1, 2, 3]) + assert result is not None + + # Kernel should be IDLE again + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.IDLE + + def test_python_script_node_with_stopped_kernel(self, kernel_manager: tuple[KernelManager, str]): + """python_script node execution auto-restarts a stopped kernel.""" + from flowfile_core.kernel.models import KernelState + + manager, kernel_id = kernel_manager + import flowfile_core.kernel as _kernel_mod + + _prev = _kernel_mod._manager + _kernel_mod._manager = manager + + try: + # Stop the kernel first + _run(manager.stop_kernel(kernel_id)) + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.STOPPED + + # Create a flow with a python_script node + graph = _create_graph() + + data = [{"val": 42}] + node_promise = input_schema.NodePromise(flow_id=1, node_id=1, node_type="manual_input") + graph.add_node_promise(node_promise) + graph.add_manual_input( + input_schema.NodeManualInput( + flow_id=1, + node_id=1, + raw_data_format=input_schema.RawData.from_pylist(data), + ) + ) + + node_promise_2 = input_schema.NodePromise(flow_id=1, node_id=2, node_type="python_script") + graph.add_node_promise(node_promise_2) + + code = """ +df = flowfile.read_input() +flowfile.publish_output(df) +""" + graph.add_python_script( + input_schema.NodePythonScript( + flow_id=1, + node_id=2, + depending_on_ids=[1], + python_script_input=input_schema.PythonScriptInput( + code=code, + kernel_id=kernel_id, + ), + ) + ) + + add_connection(graph, input_schema.NodeConnection.create_from_simple_input(1, 2)) + + # Run the graph — kernel should auto-restart + run_info = graph.run_graph() + _handle_run_info(run_info) + + # Verify execution succeeded + result = graph.get_node(2).get_resulting_data() + assert result is not None + df = result.data_frame + if hasattr(df, "collect"): + df = df.collect() + assert len(df) == 1 + assert df["val"].to_list() == [42] + + # Kernel should be IDLE + kernel = _run(manager.get_kernel(kernel_id)) + assert kernel.state == KernelState.IDLE + + finally: + _kernel_mod._manager = _prev diff --git a/flowfile_core/tests/flowfile/test_kernel_persistence_integration.py b/flowfile_core/tests/flowfile/test_kernel_persistence_integration.py new file mode 100644 index 000000000..93699fa76 --- /dev/null +++ b/flowfile_core/tests/flowfile/test_kernel_persistence_integration.py @@ -0,0 +1,442 @@ +""" +Docker-based integration tests for artifact persistence and recovery. + +These tests require Docker to be available and are marked with +``@pytest.mark.kernel``. The ``kernel_manager`` fixture (session-scoped, +defined in conftest.py) builds the flowfile-kernel image, starts a +container, and tears it down after all tests finish. + +The tests exercise the full persistence lifecycle: + - Artifacts automatically persisted on publish + - Persistence status visible via API + - Recovery after clearing in-memory state + - Cleanup of old artifacts + - Lazy loading from disk +""" + +import asyncio +import time + +import httpx +import pytest + +from flowfile_core.kernel.manager import KernelManager +from flowfile_core.kernel.models import CleanupRequest, ExecuteRequest, ExecuteResult + +pytestmark = pytest.mark.kernel + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run(coro): + """Run an async coroutine from sync test code.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +def _execute(manager: KernelManager, kernel_id: str, code: str, node_id: int = 1) -> ExecuteResult: + """Execute code on the kernel and return the result.""" + return _run( + manager.execute( + kernel_id, + ExecuteRequest( + node_id=node_id, + code=code, + input_paths={}, + output_dir=f"/shared/test_persist/{node_id}", + ), + ) + ) + + +def _get_json(port: int, path: str) -> dict: + """GET a JSON endpoint on the kernel runtime.""" + with httpx.Client(timeout=httpx.Timeout(30.0)) as client: + response = client.get(f"http://localhost:{port}{path}") + response.raise_for_status() + return response.json() + + +def _post_json(port: int, path: str, json: dict | None = None) -> dict: + """POST to a JSON endpoint on the kernel runtime.""" + with httpx.Client(timeout=httpx.Timeout(30.0)) as client: + response = client.post(f"http://localhost:{port}{path}", json=json or {}) + response.raise_for_status() + return response.json() + + +# --------------------------------------------------------------------------- +# Tests — persistence basics +# --------------------------------------------------------------------------- + + +class TestArtifactPersistenceBasics: + """Verify that artifacts are automatically persisted when published.""" + + def test_published_artifact_is_persisted(self, kernel_manager: tuple[KernelManager, str]): + """Publishing an artifact should automatically persist it to disk.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + # Clear any leftover state + _run(manager.clear_artifacts(kernel_id)) + + # Publish an artifact + result = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("persist_test", {"weights": [1, 2, 3]})', + node_id=100, + ) + assert result.success + assert "persist_test" in result.artifacts_published + + # Check persistence info + persistence = _get_json(kernel.port, "/persistence") + assert persistence["enabled"] is True + assert persistence["persisted_count"] >= 1 + assert "persist_test" in persistence["artifacts"] + assert persistence["artifacts"]["persist_test"]["persisted"] is True + + def test_persistence_metadata_in_artifact_list(self, kernel_manager: tuple[KernelManager, str]): + """The /artifacts endpoint should include persistence status.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("meta_test", [1, 2, 3])', + node_id=101, + ) + + artifacts = _get_json(kernel.port, "/artifacts") + assert "meta_test" in artifacts + assert artifacts["meta_test"]["persisted"] is True + + def test_disk_usage_reported(self, kernel_manager: tuple[KernelManager, str]): + """Persistence info should report non-zero disk usage after publishing.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("big_item", list(range(10000)))', + node_id=102, + ) + + persistence = _get_json(kernel.port, "/persistence") + assert persistence["disk_usage_bytes"] > 0 + + +class TestHealthAndRecoveryStatus: + """Verify health and recovery status endpoints include persistence info.""" + + def test_health_includes_persistence(self, kernel_manager: tuple[KernelManager, str]): + """The /health endpoint should indicate persistence status.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + health = _get_json(kernel.port, "/health") + assert "persistence" in health + assert health["persistence"] == "enabled" + assert "recovery_mode" in health + + def test_recovery_status_available(self, kernel_manager: tuple[KernelManager, str]): + """The /recovery-status endpoint should return valid status.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + status = _get_json(kernel.port, "/recovery-status") + assert "status" in status + assert status["status"] in ("completed", "pending", "disabled") + + +# --------------------------------------------------------------------------- +# Tests — manual recovery +# --------------------------------------------------------------------------- + + +class TestManualRecovery: + """Test manual artifact recovery via /recover endpoint.""" + + def test_recover_loads_persisted_artifacts(self, kernel_manager: tuple[KernelManager, str]): + """After clearing in-memory state, /recover restores from disk.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + # Start fresh + _run(manager.clear_artifacts(kernel_id)) + + # Publish two artifacts + result1 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("model_a", {"type": "linear"})', + node_id=200, + ) + assert result1.success + + result2 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("model_b", {"type": "tree"})', + node_id=201, + ) + assert result2.success + + # Verify both are persisted + persistence = _get_json(kernel.port, "/persistence") + assert persistence["persisted_count"] >= 2 + + # Clear in-memory state only (use the /clear endpoint which also clears disk) + # Instead, we'll verify recovery by checking the recover endpoint reports them + # Since the artifacts are already in memory and on disk, recovery should + # report them as already loaded (0 newly recovered). + recovery = _post_json(kernel.port, "/recover") + assert recovery["status"] == "completed" + # They're already in memory, so recovered list may be empty + # (recover_all skips artifacts already in memory) + + def test_recovery_status_after_manual_trigger(self, kernel_manager: tuple[KernelManager, str]): + """Recovery status should reflect manual recovery completion.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _post_json(kernel.port, "/recover") + + status = _get_json(kernel.port, "/recovery-status") + assert status["status"] == "completed" + assert status["mode"] == "manual" + + def test_artifact_accessible_after_publish_and_recover( + self, kernel_manager: tuple[KernelManager, str], + ): + """Artifact published by node A should be readable by node B after recovery.""" + manager, kernel_id = kernel_manager + + _run(manager.clear_artifacts(kernel_id)) + + # Node 300 publishes + r1 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("shared_model", {"accuracy": 0.95})', + node_id=300, + ) + assert r1.success + + # Node 301 reads it + r2 = _execute( + manager, kernel_id, + """ +model = flowfile.read_artifact("shared_model") +assert model["accuracy"] == 0.95, f"Expected 0.95, got {model}" +print(f"model accuracy: {model['accuracy']}") +""", + node_id=301, + ) + assert r2.success, f"Read artifact failed: {r2.error}" + assert "0.95" in r2.stdout + + +# --------------------------------------------------------------------------- +# Tests — cleanup +# --------------------------------------------------------------------------- + + +class TestArtifactCleanup: + """Test artifact cleanup via /cleanup endpoint.""" + + def test_cleanup_specific_artifacts(self, kernel_manager: tuple[KernelManager, str]): + """Cleanup by name should remove specific artifacts from disk.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + # Publish two artifacts + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("keep_me", 42)', + node_id=400, + ) + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("delete_me", 99)', + node_id=401, + ) + + # Cleanup only "delete_me" + cleanup_result = _post_json(kernel.port, "/cleanup", { + "artifact_names": [{"flow_id": 0, "name": "delete_me"}], + }) + assert cleanup_result["status"] == "cleaned" + assert cleanup_result["removed_count"] == 1 + + def test_cleanup_by_age_keeps_recent(self, kernel_manager: tuple[KernelManager, str]): + """Cleanup with max_age_hours should not remove recently published artifacts.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("recent_item", "fresh")', + node_id=410, + ) + + # Cleanup with 24h threshold — recent artifacts should survive + cleanup_result = _post_json(kernel.port, "/cleanup", { + "max_age_hours": 24, + }) + assert cleanup_result["removed_count"] == 0 + + def test_clear_all_removes_from_disk(self, kernel_manager: tuple[KernelManager, str]): + """POST /clear should remove artifacts from both memory and disk.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("doomed", 123)', + node_id=420, + ) + + # Verify it's persisted + persistence_before = _get_json(kernel.port, "/persistence") + assert persistence_before["persisted_count"] >= 1 + + # Clear all + _post_json(kernel.port, "/clear") + + # Verify disk is clean too + persistence_after = _get_json(kernel.port, "/persistence") + assert persistence_after["persisted_count"] == 0 + + +# --------------------------------------------------------------------------- +# Tests — persistence through KernelManager proxy +# --------------------------------------------------------------------------- + + +class TestKernelManagerPersistenceProxy: + """Test the persistence proxy methods on KernelManager.""" + + def test_manager_recover_artifacts(self, kernel_manager: tuple[KernelManager, str]): + """KernelManager.recover_artifacts() returns RecoveryStatus.""" + manager, kernel_id = kernel_manager + result = _run(manager.recover_artifacts(kernel_id)) + assert result.status in ("completed", "disabled") + + def test_manager_get_recovery_status(self, kernel_manager: tuple[KernelManager, str]): + """KernelManager.get_recovery_status() returns RecoveryStatus.""" + manager, kernel_id = kernel_manager + result = _run(manager.get_recovery_status(kernel_id)) + assert result.status in ("completed", "pending", "disabled") + + def test_manager_cleanup_artifacts(self, kernel_manager: tuple[KernelManager, str]): + """KernelManager.cleanup_artifacts() returns CleanupResult.""" + manager, kernel_id = kernel_manager + request = CleanupRequest(max_age_hours=24) + result = _run(manager.cleanup_artifacts(kernel_id, request)) + assert result.status in ("cleaned", "disabled") + + def test_manager_get_persistence_info(self, kernel_manager: tuple[KernelManager, str]): + """KernelManager.get_persistence_info() returns ArtifactPersistenceInfo.""" + manager, kernel_id = kernel_manager + result = _run(manager.get_persistence_info(kernel_id)) + assert result.enabled is True + assert result.recovery_mode in ("lazy", "eager", "none") + + +# --------------------------------------------------------------------------- +# Tests — persistence survives node re-execution +# --------------------------------------------------------------------------- + + +class TestPersistenceThroughReexecution: + """Verify that persisted artifacts survive node re-execution cycles.""" + + def test_reexecution_preserves_other_nodes_artifacts( + self, kernel_manager: tuple[KernelManager, str], + ): + """Re-executing node B should not affect node A's persisted artifacts.""" + manager, kernel_id = kernel_manager + kernel = _run(manager.get_kernel(kernel_id)) + + _run(manager.clear_artifacts(kernel_id)) + + # Node 500 publishes "stable_model" + r1 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("stable_model", {"v": 1})', + node_id=500, + ) + assert r1.success + + # Node 501 publishes "temp_model" + r2 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("temp_model", {"v": 1})', + node_id=501, + ) + assert r2.success + + # Re-execute node 501 (clears its own artifacts, publishes new) + r3 = _execute( + manager, kernel_id, + 'flowfile.publish_artifact("temp_model", {"v": 2})', + node_id=501, + ) + assert r3.success + + # "stable_model" from node 500 should still be on disk + persistence = _get_json(kernel.port, "/persistence") + assert "stable_model" in persistence["artifacts"] + assert persistence["artifacts"]["stable_model"]["persisted"] is True + + def test_persisted_artifact_readable_after_reexecution( + self, kernel_manager: tuple[KernelManager, str], + ): + """After re-executing a node, previously persisted artifacts from other nodes + should still be readable.""" + manager, kernel_id = kernel_manager + + _run(manager.clear_artifacts(kernel_id)) + + # Publish model + _execute( + manager, kernel_id, + 'flowfile.publish_artifact("durable_model", {"accuracy": 0.99})', + node_id=510, + ) + + # Different node re-executes multiple times + for i in range(3): + _execute( + manager, kernel_id, + f'flowfile.publish_artifact("ephemeral_{i}", {i})', + node_id=511 + i, + ) + + # Verify durable_model is still readable + r = _execute( + manager, kernel_id, + """ +model = flowfile.read_artifact("durable_model") +assert model["accuracy"] == 0.99 +print("durable model OK") +""", + node_id=520, + ) + assert r.success, f"Failed to read durable_model: {r.error}" + assert "durable model OK" in r.stdout diff --git a/flowfile_core/tests/kernel_fixtures.py b/flowfile_core/tests/kernel_fixtures.py new file mode 100644 index 000000000..a686891b6 --- /dev/null +++ b/flowfile_core/tests/kernel_fixtures.py @@ -0,0 +1,127 @@ +""" +Kernel test fixtures. + +Provides utilities to build the flowfile-kernel Docker image, +create a KernelManager, start/stop kernels, and clean up. +""" + +import asyncio +import logging +import os +import subprocess +import tempfile +from collections.abc import Generator +from contextlib import contextmanager +from pathlib import Path + +logger = logging.getLogger("kernel_fixture") + +KERNEL_IMAGE = "flowfile-kernel" +KERNEL_TEST_ID = "integration-test" +KERNEL_CONTAINER_NAME = f"flowfile-kernel-{KERNEL_TEST_ID}" + +_REPO_ROOT = Path(__file__).resolve().parent.parent.parent + + +def _build_kernel_image() -> bool: + """Build the flowfile-kernel Docker image from kernel_runtime/.""" + dockerfile = _REPO_ROOT / "kernel_runtime" / "Dockerfile" + context = _REPO_ROOT / "kernel_runtime" + + if not dockerfile.exists(): + logger.error("Dockerfile not found at %s", dockerfile) + return False + + logger.info("Building Docker image '%s' ...", KERNEL_IMAGE) + try: + subprocess.run( + ["docker", "build", "-t", KERNEL_IMAGE, "-f", str(dockerfile), str(context)], + check=True, + capture_output=True, + text=True, + timeout=300, + ) + logger.info("Docker image '%s' built successfully", KERNEL_IMAGE) + return True + except subprocess.CalledProcessError as exc: + logger.error("Failed to build Docker image: %s\nstdout: %s\nstderr: %s", exc, exc.stdout, exc.stderr) + return False + except subprocess.TimeoutExpired: + logger.error("Docker build timed out") + return False + + +def _remove_container(name: str) -> None: + """Force-remove a container by name (ignore errors if it doesn't exist).""" + subprocess.run( + ["docker", "rm", "-f", name], + capture_output=True, + check=False, + ) + + +@contextmanager +def managed_kernel( + packages: list[str] | None = None, +) -> Generator[tuple, None, None]: + """ + Context manager that: + 1. Builds the flowfile-kernel Docker image + 2. Creates a KernelManager with a temp shared volume + 3. Creates and starts a kernel + 4. Yields (manager, kernel_id) + 5. Stops + deletes the kernel and cleans up + + Usage:: + + with managed_kernel(packages=["scikit-learn"]) as (manager, kernel_id): + result = await manager.execute(kernel_id, request) + """ + from flowfile_core.kernel.manager import KernelManager + from flowfile_core.kernel.models import KernelConfig + + # 1 — Build image + if not _build_kernel_image(): + raise RuntimeError("Could not build flowfile-kernel Docker image") + + # 2 — Ensure stale container is removed + _remove_container(KERNEL_CONTAINER_NAME) + + # 3 — Temp shared volume + shared_dir = tempfile.mkdtemp(prefix="kernel_test_shared_") + + manager = KernelManager(shared_volume_path=shared_dir) + kernel_id = KERNEL_TEST_ID + + try: + # 4 — Create + start + loop = asyncio.new_event_loop() + config = KernelConfig( + id=kernel_id, + name="Integration Test Kernel", + packages=packages or [], + ) + loop.run_until_complete(manager.create_kernel(config, user_id=1)) + loop.run_until_complete(manager.start_kernel(kernel_id)) + + yield manager, kernel_id + + finally: + # 5 — Tear down + try: + loop.run_until_complete(manager.stop_kernel(kernel_id)) + except Exception as exc: + logger.warning("Error stopping kernel during teardown: %s", exc) + try: + loop.run_until_complete(manager.delete_kernel(kernel_id)) + except Exception as exc: + logger.warning("Error deleting kernel during teardown: %s", exc) + loop.close() + + # Belt-and-suspenders: force-remove the container + _remove_container(KERNEL_CONTAINER_NAME) + + # Clean up shared dir + import shutil + + shutil.rmtree(shared_dir, ignore_errors=True) diff --git a/flowfile_frontend/package-lock.json b/flowfile_frontend/package-lock.json index 74331c405..9218aac16 100644 --- a/flowfile_frontend/package-lock.json +++ b/flowfile_frontend/package-lock.json @@ -1,12 +1,12 @@ { "name": "Flowfile", - "version": "0.5.6", + "version": "0.6.2", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "Flowfile", - "version": "0.5.6", + "version": "0.6.2", "dependencies": { "@ag-grid-community/client-side-row-model": "^31.1.1", "@ag-grid-community/core": "^31.1.1", diff --git a/flowfile_frontend/src/renderer/app/api/flow.api.ts b/flowfile_frontend/src/renderer/app/api/flow.api.ts index c6d4b8b47..53accc957 100644 --- a/flowfile_frontend/src/renderer/app/api/flow.api.ts +++ b/flowfile_frontend/src/renderer/app/api/flow.api.ts @@ -11,6 +11,7 @@ import type { HistoryState, UndoRedoResult, OperationResponse, + FlowArtifactData, } from "../types"; export class FlowApi { @@ -318,4 +319,19 @@ export class FlowApi { }); return response.data; } + + // ============================================================================ + // Artifact Operations + // ============================================================================ + + /** + * Get artifact visualization data for a flow (badges, edges) + */ + static async getArtifacts(flowId: number): Promise { + const response = await axios.get("/flow/artifacts", { + params: { flow_id: flowId }, + headers: { accept: "application/json" }, + }); + return response.data; + } } diff --git a/flowfile_frontend/src/renderer/app/api/kernel.api.ts b/flowfile_frontend/src/renderer/app/api/kernel.api.ts new file mode 100644 index 000000000..914682f83 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/api/kernel.api.ts @@ -0,0 +1,113 @@ +import axios from "../services/axios.config"; +import type { + DockerStatus, + ExecuteCellRequest, + ExecuteResult, + KernelConfig, + KernelInfo, +} from "../types"; + +const API_BASE_URL = "/kernels"; + +export class KernelApi { + static async getAll(): Promise { + try { + const response = await axios.get(`${API_BASE_URL}/`); + return response.data; + } catch (error) { + console.error("API Error: Failed to load kernels:", error); + const errorMsg = (error as any).response?.data?.detail || "Failed to load kernels"; + throw new Error(errorMsg); + } + } + + static async get(kernelId: string): Promise { + try { + const response = await axios.get( + `${API_BASE_URL}/${encodeURIComponent(kernelId)}`, + ); + return response.data; + } catch (error) { + console.error("API Error: Failed to get kernel:", error); + throw error; + } + } + + static async create(config: KernelConfig): Promise { + try { + const response = await axios.post(`${API_BASE_URL}/`, config); + return response.data; + } catch (error) { + console.error("API Error: Failed to create kernel:", error); + const errorMsg = (error as any).response?.data?.detail || "Failed to create kernel"; + throw new Error(errorMsg); + } + } + + static async delete(kernelId: string): Promise { + try { + await axios.delete(`${API_BASE_URL}/${encodeURIComponent(kernelId)}`); + } catch (error) { + console.error("API Error: Failed to delete kernel:", error); + throw error; + } + } + + static async start(kernelId: string): Promise { + try { + const response = await axios.post( + `${API_BASE_URL}/${encodeURIComponent(kernelId)}/start`, + ); + return response.data; + } catch (error) { + console.error("API Error: Failed to start kernel:", error); + const errorMsg = (error as any).response?.data?.detail || "Failed to start kernel"; + throw new Error(errorMsg); + } + } + + static async stop(kernelId: string): Promise { + try { + await axios.post(`${API_BASE_URL}/${encodeURIComponent(kernelId)}/stop`); + } catch (error) { + console.error("API Error: Failed to stop kernel:", error); + throw error; + } + } + + static async getArtifacts(kernelId: string): Promise> { + try { + const response = await axios.get>( + `${API_BASE_URL}/${encodeURIComponent(kernelId)}/artifacts`, + ); + return response.data; + } catch (error) { + console.error("API Error: Failed to get artifacts:", error); + return {}; + } + } + + static async getDockerStatus(): Promise { + try { + const response = await axios.get(`${API_BASE_URL}/docker-status`); + return response.data; + } catch (error) { + console.error("API Error: Failed to check Docker status:", error); + return { available: false, image_available: false, error: "Failed to reach server" }; + } + } + + static async executeCell(kernelId: string, request: ExecuteCellRequest): Promise { + try { + const response = await axios.post( + `${API_BASE_URL}/${encodeURIComponent(kernelId)}/execute_cell`, + request, + ); + return response.data; + } catch (error) { + console.error("API Error: Failed to execute cell:", error); + const errorMsg = (error as any).response?.data?.detail || "Failed to execute cell"; + throw new Error(errorMsg); + } + } +} diff --git a/flowfile_frontend/src/renderer/app/components/layout/Sidebar/NavigationRoutes.ts b/flowfile_frontend/src/renderer/app/components/layout/Sidebar/NavigationRoutes.ts index 7d050569f..c0d2bfee1 100644 --- a/flowfile_frontend/src/renderer/app/components/layout/Sidebar/NavigationRoutes.ts +++ b/flowfile_frontend/src/renderer/app/components/layout/Sidebar/NavigationRoutes.ts @@ -56,6 +56,13 @@ export default { icon: "fa-solid fa-key", }, }, + { + name: "kernelManager", + displayName: "menu.kernelManager", + meta: { + icon: "fa-solid fa-server", + }, + }, { name: "nodeDesigner", displayName: "menu.nodeDesigner", diff --git a/flowfile_frontend/src/renderer/app/components/nodes/ArtifactBadge.vue b/flowfile_frontend/src/renderer/app/components/nodes/ArtifactBadge.vue new file mode 100644 index 000000000..99403417b --- /dev/null +++ b/flowfile_frontend/src/renderer/app/components/nodes/ArtifactBadge.vue @@ -0,0 +1,247 @@ + + + + + diff --git a/flowfile_frontend/src/renderer/app/components/nodes/NodeWrapper.vue b/flowfile_frontend/src/renderer/app/components/nodes/NodeWrapper.vue index 663143d49..8f6931ccb 100644 --- a/flowfile_frontend/src/renderer/app/components/nodes/NodeWrapper.vue +++ b/flowfile_frontend/src/renderer/app/components/nodes/NodeWrapper.vue @@ -57,6 +57,9 @@ /> + + +
+ + + + + + diff --git a/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/PythonScript.vue b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/PythonScript.vue new file mode 100644 index 000000000..a0b6daa10 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/PythonScript.vue @@ -0,0 +1,578 @@ + + + + + diff --git a/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/flowfileCompletions.ts b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/flowfileCompletions.ts new file mode 100644 index 000000000..d6b12ef01 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/flowfileCompletions.ts @@ -0,0 +1,90 @@ +export const flowfileCompletionVals = [ + // flowfile module + { + label: "flowfile", + type: "variable", + info: "FlowFile API module for data I/O and artifacts", + }, + + // Data I/O functions + { + label: "read_input", + type: "function", + info: "Read input DataFrame. Optional name parameter for named inputs.", + detail: "flowfile.read_input(name?)", + apply: "read_input()", + }, + { + label: "read_inputs", + type: "function", + info: "Read all inputs as a dict of DataFrames.", + detail: "flowfile.read_inputs()", + apply: "read_inputs()", + }, + { + label: "publish_output", + type: "function", + info: "Write output DataFrame. Optional name parameter for named outputs.", + detail: "flowfile.publish_output(df, name?)", + apply: "publish_output(df)", + }, + + // Display function + { + label: "display", + type: "function", + info: "Display a rich object (matplotlib figure, plotly figure, PIL image, HTML string) in the output panel.", + detail: "flowfile.display(obj, title?)", + apply: "display(obj)", + }, + + // Artifact functions + { + label: "publish_artifact", + type: "function", + info: "Store a Python object as a named artifact in kernel memory.", + detail: 'flowfile.publish_artifact("name", obj)', + apply: 'publish_artifact("name", obj)', + }, + { + label: "read_artifact", + type: "function", + info: "Retrieve a Python object from a named artifact.", + detail: 'flowfile.read_artifact("name")', + apply: 'read_artifact("name")', + }, + { + label: "delete_artifact", + type: "function", + info: "Remove a named artifact from kernel memory.", + detail: 'flowfile.delete_artifact("name")', + apply: 'delete_artifact("name")', + }, + { + label: "list_artifacts", + type: "function", + info: "List all artifacts available in the kernel.", + detail: "flowfile.list_artifacts()", + apply: "list_artifacts()", + }, + + // Polars basics (also useful in python_script context) + { label: "pl", type: "variable", info: "Polars main module" }, + { label: "col", type: "function", info: "Polars column selector" }, + { label: "lit", type: "function", info: "Polars literal value" }, + + // Common Polars operations + { label: "select", type: "method", info: "Select columns" }, + { label: "filter", type: "method", info: "Filter rows" }, + { label: "group_by", type: "method", info: "Group by columns" }, + { label: "with_columns", type: "method", info: "Add/modify columns" }, + { label: "join", type: "method", info: "Join operations" }, + { label: "sort", type: "method", info: "Sort DataFrame" }, + { label: "collect", type: "method", info: "Collect LazyFrame to DataFrame" }, + + // Basic Python + { label: "print", type: "function" }, + { label: "len", type: "function" }, + { label: "range", type: "function" }, + { label: "import", type: "keyword" }, +]; diff --git a/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/utils.ts b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/utils.ts new file mode 100644 index 000000000..046bd685c --- /dev/null +++ b/flowfile_frontend/src/renderer/app/components/nodes/node-types/elements/pythonScript/utils.ts @@ -0,0 +1,33 @@ +import type { NodePythonScript, PythonScriptInput } from "../../../../../types/node.types"; + +export const DEFAULT_PYTHON_SCRIPT_CODE = `import polars as pl + +# Read input data +df = flowfile.read_input() + +# Your transformation here +# df = df.filter(pl.col("column") > 0) + +# Publish output +flowfile.publish_output(df) +`; + +export const createPythonScriptNode = ( + flowId: number, + nodeId: number, +): NodePythonScript => { + const pythonScriptInput: PythonScriptInput = { + code: DEFAULT_PYTHON_SCRIPT_CODE, + kernel_id: null, + }; + + return { + flow_id: flowId, + node_id: nodeId, + pos_x: 0, + pos_y: 0, + depending_on_ids: null, + python_script_input: pythonScriptInput, + cache_results: false, + }; +}; diff --git a/flowfile_frontend/src/renderer/app/features/designer/assets/icons/python_code.svg b/flowfile_frontend/src/renderer/app/features/designer/assets/icons/python_code.svg new file mode 100644 index 000000000..3a4e65470 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/features/designer/assets/icons/python_code.svg @@ -0,0 +1,15 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/flowfile_frontend/src/renderer/app/features/designer/dataPreview.vue b/flowfile_frontend/src/renderer/app/features/designer/dataPreview.vue index 0c4f461bc..993db09a8 100644 --- a/flowfile_frontend/src/renderer/app/features/designer/dataPreview.vue +++ b/flowfile_frontend/src/renderer/app/features/designer/dataPreview.vue @@ -6,46 +6,141 @@
- -
-

- Displayed data might be outdated. - -

+ +
+
- - - -
-

Step has not stored any data yet. Click here to trigger a run for this node

- + +
+ +
+

+ Displayed data might be outdated. + +

+ +
+ + + + +
+

Step has not stored any data yet. Click here to trigger a run for this node

+ +
+
+ + +
+ + +
+
Published
+ + + + + + + + + + + + + + + +
NameTypeModule
{{ art.name }}{{ art.type_name || "-" }}{{ art.module || "-" }}
+
+ +
+
Consumed
+ + + + + + + + + + + + + + + +
NameTypeSource Node
{{ art.name }}{{ art.type_name || "-" }}{{ art.source_node_id != null ? `Node ${art.source_node_id}` : "-" }}
+
+ +
+
Deleted
+ + + + + + + + + + + +
Name
{{ name }}
+
+ +
+ No artifacts recorded for this node. +
+ + diff --git a/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelCard.vue b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelCard.vue new file mode 100644 index 000000000..86bafae99 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelCard.vue @@ -0,0 +1,247 @@ + + + + + diff --git a/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelManagerView.vue b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelManagerView.vue new file mode 100644 index 000000000..258ccd312 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelManagerView.vue @@ -0,0 +1,227 @@ + + + + + diff --git a/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelStatusBadge.vue b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelStatusBadge.vue new file mode 100644 index 000000000..f238c1cab --- /dev/null +++ b/flowfile_frontend/src/renderer/app/views/KernelManagerView/KernelStatusBadge.vue @@ -0,0 +1,68 @@ + + + + + diff --git a/flowfile_frontend/src/renderer/app/views/KernelManagerView/useKernelManager.ts b/flowfile_frontend/src/renderer/app/views/KernelManagerView/useKernelManager.ts new file mode 100644 index 000000000..86efb5697 --- /dev/null +++ b/flowfile_frontend/src/renderer/app/views/KernelManagerView/useKernelManager.ts @@ -0,0 +1,118 @@ +import { ref, onMounted, onUnmounted } from "vue"; +import type { Ref } from "vue"; +import { KernelApi } from "../../api/kernel.api"; +import type { DockerStatus, KernelInfo, KernelConfig } from "../../types"; + +const POLL_INTERVAL_MS = 5000; + +export function useKernelManager() { + const kernels: Ref = ref([]); + const isLoading = ref(true); + const errorMessage: Ref = ref(null); + const dockerStatus: Ref = ref(null); + const actionInProgress: Ref> = ref({}); + let pollTimer: ReturnType | null = null; + + const checkDockerStatus = async () => { + dockerStatus.value = await KernelApi.getDockerStatus(); + }; + + const loadKernels = async () => { + try { + kernels.value = await KernelApi.getAll(); + errorMessage.value = null; + } catch (error: any) { + console.error("Failed to load kernels:", error); + errorMessage.value = error.message || "Failed to load kernels"; + throw error; + } finally { + isLoading.value = false; + } + }; + + const createKernel = async (config: KernelConfig): Promise => { + const kernel = await KernelApi.create(config); + await loadKernels(); + return kernel; + }; + + const startKernel = async (kernelId: string) => { + actionInProgress.value[kernelId] = true; + try { + await KernelApi.start(kernelId); + await loadKernels(); + } finally { + actionInProgress.value[kernelId] = false; + } + }; + + const stopKernel = async (kernelId: string) => { + actionInProgress.value[kernelId] = true; + try { + await KernelApi.stop(kernelId); + await loadKernels(); + } finally { + actionInProgress.value[kernelId] = false; + } + }; + + const deleteKernel = async (kernelId: string) => { + actionInProgress.value[kernelId] = true; + try { + await KernelApi.delete(kernelId); + await loadKernels(); + } finally { + delete actionInProgress.value[kernelId]; + } + }; + + const isActionInProgress = (kernelId: string): boolean => { + return !!actionInProgress.value[kernelId]; + }; + + const startPolling = () => { + stopPolling(); + pollTimer = setInterval(async () => { + try { + kernels.value = await KernelApi.getAll(); + } catch { + // Silently ignore poll errors to avoid spamming the user + } + }, POLL_INTERVAL_MS); + }; + + const stopPolling = () => { + if (pollTimer !== null) { + clearInterval(pollTimer); + pollTimer = null; + } + }; + + onMounted(async () => { + await checkDockerStatus(); + try { + await loadKernels(); + } catch { + // Error already captured in errorMessage + } + startPolling(); + }); + + onUnmounted(() => { + stopPolling(); + }); + + return { + kernels, + isLoading, + errorMessage, + dockerStatus, + actionInProgress, + loadKernels, + createKernel, + startKernel, + stopKernel, + deleteKernel, + isActionInProgress, + }; +} diff --git a/flowfile_worker/flowfile_worker/models.py b/flowfile_worker/flowfile_worker/models.py index 658b6e936..d72b2d3a3 100644 --- a/flowfile_worker/flowfile_worker/models.py +++ b/flowfile_worker/flowfile_worker/models.py @@ -140,5 +140,6 @@ def __hash__(self): class RawLogInput(BaseModel): flowfile_flow_id: int log_message: str - log_type: Literal["INFO", "ERROR"] + log_type: Literal["INFO", "WARNING", "ERROR"] + node_id: int | None = None extra: dict | None = None diff --git a/kernel_runtime/Dockerfile b/kernel_runtime/Dockerfile new file mode 100644 index 000000000..1c6bffa9a --- /dev/null +++ b/kernel_runtime/Dockerfile @@ -0,0 +1,24 @@ +FROM python:3.12-slim + +WORKDIR /app + +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl build-essential && rm -rf /var/lib/apt/lists/* + +RUN pip install --no-cache-dir \ + polars>=1.0.0 pyarrow>=14.0.0 numpy>=1.24.0 \ + fastapi>=0.115.0 uvicorn>=0.32.0 httpx>=0.24.0 \ + cloudpickle>=3.0.0 + +COPY kernel_runtime /app/kernel_runtime +COPY entrypoint.sh /app/entrypoint.sh +RUN chmod +x /app/entrypoint.sh + +ENV KERNEL_PACKAGES="" +VOLUME ["/shared"] +EXPOSE 9999 + +HEALTHCHECK --interval=10s --timeout=5s --start-period=30s \ + CMD curl -f http://localhost:9999/health || exit 1 + +ENTRYPOINT ["/app/entrypoint.sh"] diff --git a/kernel_runtime/README.md b/kernel_runtime/README.md new file mode 100644 index 000000000..617620829 --- /dev/null +++ b/kernel_runtime/README.md @@ -0,0 +1,295 @@ +# Kernel Runtime + +A FastAPI-based Python code execution kernel that runs in isolated Docker containers. It executes arbitrary Python code with built-in support for Polars DataFrames, artifact storage, and multi-flow isolation. + +## Overview + +The kernel runtime provides: +- Isolated Python code execution via REST API +- Built-in `flowfile` module for data I/O and artifact management +- Parquet-based data exchange using Polars LazyFrames +- Thread-safe in-memory artifact storage +- Multi-flow support with artifact isolation +- Automatic stdout/stderr capture + +## Building the Docker Image + +### Standard Build + +```bash +cd kernel_runtime +docker build -t kernel_runtime:latest . +``` + +### Build with Custom Tag + +```bash +docker build -t flowfile/kernel_runtime:v0.2.0 . +``` + +## Running the Container + +### Basic Run + +```bash +docker run -p 9999:9999 kernel_runtime:latest +``` + +### With Shared Volume for Data Exchange + +```bash +docker run -p 9999:9999 -v /path/to/data:/shared kernel_runtime:latest +``` + +### With Additional Python Packages + +The `KERNEL_PACKAGES` environment variable allows installing additional packages at container startup: + +```bash +docker run -p 9999:9999 \ + -e KERNEL_PACKAGES="scikit-learn pandas matplotlib" \ + kernel_runtime:latest +``` + +### Full Example with All Options + +```bash +docker run -d \ + --name flowfile-kernel \ + -p 9999:9999 \ + -v /path/to/data:/shared \ + -e KERNEL_PACKAGES="scikit-learn xgboost" \ + kernel_runtime:latest +``` + +## API Endpoints + +### Health Check + +```bash +curl http://localhost:9999/health +``` + +Response: +```json +{ + "status": "healthy", + "version": "0.2.0", + "artifact_count": 0 +} +``` + +### Execute Code + +```bash +curl -X POST http://localhost:9999/execute \ + -H "Content-Type: application/json" \ + -d '{ + "node_id": "node_1", + "code": "import polars as pl\ndf = flowfile.read_input()\nresult = df.collect()\nflowfile.publish_output(result)", + "input_paths": {"main": ["/shared/input.parquet"]}, + "output_dir": "/shared/output", + "flow_id": 1 + }' +``` + +Response: +```json +{ + "success": true, + "output_paths": ["/shared/output/output_0.parquet"], + "published_artifacts": [], + "deleted_artifacts": [], + "stdout": "", + "stderr": "", + "execution_time_ms": 150 +} +``` + +### List Artifacts + +```bash +# All artifacts +curl http://localhost:9999/artifacts + +# Artifacts for a specific flow +curl http://localhost:9999/artifacts?flow_id=1 + +# Artifacts for a specific node +curl http://localhost:9999/artifacts/node/node_1?flow_id=1 +``` + +### Clear Artifacts + +```bash +# Clear all artifacts +curl -X POST http://localhost:9999/clear + +# Clear artifacts for a specific flow +curl -X POST http://localhost:9999/clear?flow_id=1 + +# Clear artifacts by node IDs +curl -X POST http://localhost:9999/clear_node_artifacts \ + -H "Content-Type: application/json" \ + -d '{"node_ids": ["node_1", "node_2"], "flow_id": 1}' +``` + +## Using the `flowfile` Module + +When code is executed, the `flowfile` module is automatically injected into the namespace. Here's how to use it: + +### Reading Input Data + +```python +# Read the main input as a LazyFrame +df = flowfile.read_input() + +# Read a named input +df = flowfile.read_input(name="customers") + +# Read only the first file of an input +df = flowfile.read_first(name="main") + +# Read all inputs as a dictionary +inputs = flowfile.read_inputs() +# Returns: {"main": LazyFrame, "customers": LazyFrame, ...} +``` + +### Writing Output Data + +```python +# Publish a DataFrame or LazyFrame +result = df.collect() +flowfile.publish_output(result) + +# Publish with a custom name +flowfile.publish_output(result, name="cleaned_data") +``` + +### Artifact Management + +Artifacts allow you to store Python objects in memory for use across executions: + +```python +# Store an artifact +model = train_model(data) +flowfile.publish_artifact("trained_model", model) + +# Retrieve an artifact +model = flowfile.read_artifact("trained_model") + +# List all artifacts in current flow +artifacts = flowfile.list_artifacts() + +# Delete an artifact +flowfile.delete_artifact("trained_model") +``` + +### Logging + +```python +# General logging +flowfile.log("Processing started", level="INFO") + +# Convenience methods +flowfile.log_info("Step 1 complete") +flowfile.log_warning("Missing values detected") +flowfile.log_error("Failed to process record") +``` + +## Complete Example + +```python +import polars as pl + +# Read input data +df = flowfile.read_input() + +# Transform the data +result = ( + df + .filter(pl.col("status") == "active") + .group_by("category") + .agg(pl.col("amount").sum().alias("total")) + .collect() +) + +flowfile.log_info(f"Processed {result.height} categories") + +# Store intermediate result as artifact +flowfile.publish_artifact("category_totals", result) + +# Write output +flowfile.publish_output(result) +``` + +## Pre-installed Packages + +The Docker image comes with these packages pre-installed: + +- `polars>=1.0.0` - Fast DataFrame library +- `pyarrow>=14.0.0` - Columnar data format support +- `numpy>=1.24.0` - Numerical computing +- `fastapi>=0.115.0` - API framework +- `uvicorn>=0.32.0` - ASGI server +- `httpx>=0.24.0` - HTTP client + +## Development + +### Local Setup + +```bash +cd kernel_runtime +pip install -e ".[test]" +``` + +### Running Tests + +```bash +pytest tests/ -v +``` + +### Running Locally (without Docker) + +```bash +uvicorn kernel_runtime.main:app --host 0.0.0.0 --port 9999 +``` + +## Architecture + +``` +kernel_runtime/ +├── Dockerfile # Container definition +├── entrypoint.sh # Container startup script +├── pyproject.toml # Project configuration +├── kernel_runtime/ +│ ├── main.py # FastAPI application and endpoints +│ ├── flowfile_client.py # The flowfile module for code execution +│ └── artifact_store.py # Thread-safe artifact storage +└── tests/ # Test suite +``` + +### Key Design Decisions + +1. **Flow Isolation**: Multiple flows can share a container without conflicts. Artifacts are keyed by `(flow_id, name)`. + +2. **Automatic Cleanup**: When a node re-executes, its previous artifacts are automatically cleared. + +3. **Lazy Evaluation**: Input data is read as Polars LazyFrames for efficient processing. + +4. **Context Isolation**: Each execution request has its own isolated context using Python's `contextvars`. + +## Configuration + +| Environment Variable | Description | Default | +|---------------------|-------------|---------| +| `KERNEL_PACKAGES` | Additional pip packages to install at startup | None | + +## Health Check + +The container includes a health check that verifies the `/health` endpoint responds: + +```dockerfile +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:9999/health || exit 1 +``` diff --git a/kernel_runtime/entrypoint.sh b/kernel_runtime/entrypoint.sh new file mode 100755 index 000000000..70da434fe --- /dev/null +++ b/kernel_runtime/entrypoint.sh @@ -0,0 +1,9 @@ +#!/bin/bash +set -e + +if [ -n "$KERNEL_PACKAGES" ]; then + echo "Installing packages: $KERNEL_PACKAGES" + pip install --no-cache-dir $KERNEL_PACKAGES +fi + +exec uvicorn kernel_runtime.main:app --host 0.0.0.0 --port 9999 diff --git a/kernel_runtime/kernel_runtime/__init__.py b/kernel_runtime/kernel_runtime/__init__.py new file mode 100644 index 000000000..49f34f498 --- /dev/null +++ b/kernel_runtime/kernel_runtime/__init__.py @@ -0,0 +1 @@ +__version__ = "0.2.0" \ No newline at end of file diff --git a/kernel_runtime/kernel_runtime/artifact_persistence.py b/kernel_runtime/kernel_runtime/artifact_persistence.py new file mode 100644 index 000000000..0570d931d --- /dev/null +++ b/kernel_runtime/kernel_runtime/artifact_persistence.py @@ -0,0 +1,260 @@ +"""Disk-backed persistence layer for kernel artifacts. + +Uses ``cloudpickle`` for serialisation — it handles lambdas, closures, +sklearn models, torch modules, and most arbitrary Python objects out of +the box. Each artifact is stored as a pair of files: + + {base_path}/{flow_id}/{artifact_name}/data.artifact # cloudpickle bytes + {base_path}/{flow_id}/{artifact_name}/meta.json # JSON metadata + +A SHA-256 checksum is written into the metadata so corruption can be +detected on load. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import re +import shutil +import time +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any + +import cloudpickle + +logger = logging.getLogger(__name__) + + +class RecoveryMode(str, Enum): + LAZY = "lazy" + EAGER = "eager" + CLEAR = "clear" # Clears all persisted artifacts on startup + + @classmethod + def _missing_(cls, value: object) -> "RecoveryMode | None": + """Handle 'none' as an alias for 'clear' for backwards compatibility.""" + if isinstance(value, str) and value.lower() == "none": + logger.warning( + "RECOVERY_MODE='none' is deprecated, use 'clear' instead. " + "This will delete ALL persisted artifacts on startup." + ) + return cls.CLEAR + return None + + +def _safe_dirname(name: str) -> str: + """Convert an artifact name to a filesystem-safe directory name. + + Strips leading dots to prevent hidden directories. + """ + # First replace unsafe characters + safe = re.sub(r"[^\w\-.]", "_", name) + # Strip leading dots to prevent hidden directories + return safe.lstrip(".") + + +def _sha256(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + +class ArtifactPersistence: + """Saves and loads artifacts to/from local disk using cloudpickle. + + Parameters + ---------- + base_path: + Root directory for persisted artifacts (e.g. ``/shared/artifacts/{kernel_id}``). + """ + + def __init__(self, base_path: str | Path) -> None: + self._base = Path(base_path) + self._base.mkdir(parents=True, exist_ok=True) + + # ------------------------------------------------------------------ + # Paths + # ------------------------------------------------------------------ + + def _artifact_dir(self, flow_id: int, name: str) -> Path: + return self._base / str(flow_id) / _safe_dirname(name) + + def _data_path(self, flow_id: int, name: str) -> Path: + return self._artifact_dir(flow_id, name) / "data.artifact" + + def _meta_path(self, flow_id: int, name: str) -> Path: + return self._artifact_dir(flow_id, name) / "meta.json" + + # ------------------------------------------------------------------ + # Save / Load / Delete + # ------------------------------------------------------------------ + + # Fields that should be persisted to meta.json (whitelist approach) + _PERSISTABLE_FIELDS = frozenset([ + "name", "type_name", "module", "node_id", "flow_id", + "created_at", "size_bytes", + ]) + + def save(self, name: str, obj: Any, metadata: dict[str, Any], flow_id: int = 0) -> None: + """Persist *obj* to disk alongside its *metadata*. + + Only JSON-serializable fields from ``_PERSISTABLE_FIELDS`` are written + to meta.json. This whitelist approach prevents accidentally persisting + non-serializable objects. + """ + artifact_dir = self._artifact_dir(flow_id, name) + artifact_dir.mkdir(parents=True, exist_ok=True) + + data = cloudpickle.dumps(obj) + checksum = _sha256(data) + + data_path = self._data_path(flow_id, name) + data_path.write_bytes(data) + + # Explicitly select only the fields we want to persist (whitelist) + meta = { + k: v for k, v in metadata.items() + if k in self._PERSISTABLE_FIELDS + } + meta["checksum"] = checksum + meta["persisted_at"] = datetime.now(timezone.utc).isoformat() + meta["data_size_bytes"] = len(data) + + self._meta_path(flow_id, name).write_text(json.dumps(meta, indent=2)) + logger.debug("Persisted artifact '%s' (flow_id=%d, %d bytes)", name, flow_id, len(data)) + + def load(self, name: str, flow_id: int = 0) -> Any: + """Load an artifact from disk. Raises ``FileNotFoundError`` if + the artifact has not been persisted or ``ValueError`` on + checksum mismatch. + """ + data_path = self._data_path(flow_id, name) + meta_path = self._meta_path(flow_id, name) + + if not data_path.exists(): + raise FileNotFoundError(f"No persisted artifact '{name}' for flow_id={flow_id}") + + data = data_path.read_bytes() + + if meta_path.exists(): + meta = json.loads(meta_path.read_text()) + expected = meta.get("checksum") + if expected and _sha256(data) != expected: + raise ValueError( + f"Checksum mismatch for artifact '{name}' — the persisted file may be corrupt" + ) + + return cloudpickle.loads(data) + + def load_metadata(self, name: str, flow_id: int = 0) -> dict[str, Any] | None: + """Load only the JSON metadata for a persisted artifact.""" + meta_path = self._meta_path(flow_id, name) + if not meta_path.exists(): + return None + return json.loads(meta_path.read_text()) + + def delete(self, name: str, flow_id: int = 0) -> None: + """Remove a persisted artifact from disk.""" + artifact_dir = self._artifact_dir(flow_id, name) + if artifact_dir.exists(): + shutil.rmtree(artifact_dir) + logger.debug("Deleted persisted artifact '%s' (flow_id=%d)", name, flow_id) + + def clear(self, flow_id: int | None = None) -> None: + """Remove all persisted artifacts, optionally scoped to *flow_id*.""" + if flow_id is not None: + flow_dir = self._base / str(flow_id) + if flow_dir.exists(): + shutil.rmtree(flow_dir) + logger.debug("Cleared persisted artifacts for flow_id=%d", flow_id) + else: + for child in self._base.iterdir(): + if child.is_dir(): + shutil.rmtree(child) + logger.debug("Cleared all persisted artifacts") + + # ------------------------------------------------------------------ + # Index / Discovery + # ------------------------------------------------------------------ + + def list_persisted(self, flow_id: int | None = None) -> dict[tuple[int, str], dict[str, Any]]: + """Scan disk and return ``{(flow_id, name): metadata}`` for all + persisted artifacts. + """ + result: dict[tuple[int, str], dict[str, Any]] = {} + flow_dirs = ( + [self._base / str(flow_id)] if flow_id is not None + else [d for d in self._base.iterdir() if d.is_dir()] + ) + for flow_dir in flow_dirs: + if not flow_dir.exists(): + continue + try: + fid = int(flow_dir.name) + except ValueError: + continue + for artifact_dir in flow_dir.iterdir(): + if not artifact_dir.is_dir(): + continue + meta_path = artifact_dir / "meta.json" + if not meta_path.exists(): + continue + try: + meta = json.loads(meta_path.read_text()) + name = meta.get("name", artifact_dir.name) + result[(fid, name)] = meta + except (json.JSONDecodeError, OSError) as exc: + logger.warning("Skipping corrupt metadata in %s: %s", meta_path, exc) + return result + + def disk_usage_bytes(self) -> int: + """Return total bytes used by all persisted artifacts.""" + total = 0 + for path in self._base.rglob("*"): + if path.is_file(): + total += path.stat().st_size + return total + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def cleanup( + self, + max_age_hours: float | None = None, + names: list[tuple[int, str]] | None = None, + ) -> int: + """Remove old or specific persisted artifacts. + + Parameters + ---------- + max_age_hours: + If set, delete artifacts persisted more than this many hours ago. + names: + If set, delete these specific ``(flow_id, name)`` pairs. + + Returns the number of artifacts removed. + """ + removed = 0 + + if names: + for flow_id, name in names: + self.delete(name, flow_id=flow_id) + removed += 1 + + if max_age_hours is not None: + cutoff = time.time() - (max_age_hours * 3600) + for (fid, name), meta in self.list_persisted().items(): + persisted_at = meta.get("persisted_at") + if persisted_at: + try: + ts = datetime.fromisoformat(persisted_at).timestamp() + if ts < cutoff: + self.delete(name, flow_id=fid) + removed += 1 + except (ValueError, OSError): + pass + + return removed diff --git a/kernel_runtime/kernel_runtime/artifact_store.py b/kernel_runtime/kernel_runtime/artifact_store.py new file mode 100644 index 000000000..3d2fbbdbb --- /dev/null +++ b/kernel_runtime/kernel_runtime/artifact_store.py @@ -0,0 +1,403 @@ +from __future__ import annotations + +import logging +import sys +import threading +from datetime import datetime, timezone +from typing import Any + +logger = logging.getLogger(__name__) + + +class ArtifactStore: + """Thread-safe in-memory store for Python artifacts produced during kernel execution. + + Artifacts are scoped by ``flow_id`` so that multiple flows sharing the + same kernel container cannot collide on artifact names. + + When an :class:`~kernel_runtime.artifact_persistence.ArtifactPersistence` + backend is attached, artifacts are automatically saved to disk on + ``publish()`` and removed on ``delete()`` / ``clear()``. In *lazy* + recovery mode, ``get()`` transparently loads from disk when the + artifact is not yet in memory. + + .. note:: **Tech Debt / Future Improvement** + + Currently stores the entire object in memory via ``self._artifacts``. + For very large artifacts (e.g., ML models >1GB), this causes memory + pressure and potential OOM. A future improvement would be to: + + 1. Implement a spill-to-disk mechanism (e.g., pickle to temp file when + size exceeds threshold, keep only metadata in memory). + 2. Or integrate with an external object store (S3, MinIO) for truly + large artifacts, storing only a reference here. + 3. For blob uploads, consider a streaming/chunked approach rather than + reading the entire file into memory before storage. + + See: https://github.com/Edwardvaneechoud/Flowfile/issues/XXX (placeholder) + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + # Keyed by (flow_id, name) so each flow has its own namespace. + self._artifacts: dict[tuple[int, str], dict[str, Any]] = {} + + # Optional persistence backend — set via ``enable_persistence()``. + self._persistence: Any | None = None # ArtifactPersistence + # Index of artifacts known to be on disk but not yet loaded. + # Only used in lazy-recovery mode. + self._lazy_index: dict[tuple[int, str], dict[str, Any]] = {} + + # Per-key locks for lazy loading to avoid blocking the global lock + # during potentially slow I/O operations. + self._loading_locks: dict[tuple[int, str], threading.Lock] = {} + self._loading_locks_lock = threading.Lock() # protects _loading_locks dict + + # Track keys currently being persisted to handle race conditions + self._persist_pending: set[tuple[int, str]] = set() + + # ------------------------------------------------------------------ + # Persistence integration + # ------------------------------------------------------------------ + + def _get_loading_lock(self, key: tuple[int, str]) -> threading.Lock: + """Get or create a per-key lock for lazy loading.""" + with self._loading_locks_lock: + if key not in self._loading_locks: + self._loading_locks[key] = threading.Lock() + return self._loading_locks[key] + + def _cleanup_loading_lock(self, key: tuple[int, str]) -> None: + """Remove a per-key lock after loading is complete.""" + with self._loading_locks_lock: + self._loading_locks.pop(key, None) + + def enable_persistence(self, persistence: Any) -> None: + """Attach a persistence backend to this store. + + Parameters + ---------- + persistence: + An :class:`~kernel_runtime.artifact_persistence.ArtifactPersistence` + instance. + """ + self._persistence = persistence + + def recover_all(self) -> list[str]: + """Eagerly load **all** persisted artifacts into memory. + + Returns the names of recovered artifacts. + """ + if self._persistence is None: + return [] + + recovered: list[str] = [] + for (flow_id, name), meta in self._persistence.list_persisted().items(): + key = (flow_id, name) + if key in self._artifacts: + continue # already in memory + try: + obj = self._persistence.load(name, flow_id=flow_id) + with self._lock: + self._artifacts[key] = { + "object": obj, + "name": name, + "type_name": meta.get("type_name", type(obj).__name__), + "module": meta.get("module", type(obj).__module__), + "node_id": meta.get("node_id", -1), + "flow_id": flow_id, + "created_at": meta.get("created_at", datetime.now(timezone.utc).isoformat()), + "size_bytes": meta.get("size_bytes", sys.getsizeof(obj)), + "persisted": True, + "recovered": True, + } + recovered.append(name) + logger.info("Recovered artifact '%s' (flow_id=%d)", name, flow_id) + except Exception as exc: + logger.warning("Failed to recover artifact '%s' (flow_id=%d): %s", name, flow_id, exc) + return recovered + + def build_lazy_index(self) -> int: + """Scan persisted artifacts and build the lazy-load index. + + Returns the number of artifacts indexed. + """ + if self._persistence is None: + return 0 + persisted = self._persistence.list_persisted() + with self._lock: + for key, meta in persisted.items(): + if key not in self._artifacts: + self._lazy_index[key] = meta + return len(self._lazy_index) + + def _try_lazy_load(self, key: tuple[int, str]) -> bool: + """Attempt to load an artifact from disk into memory (lazy mode). + + Uses a two-phase approach to avoid holding the global lock during + potentially slow I/O operations: + 1. Under global lock: check if in lazy_index, grab metadata, release + 2. Under per-key lock: do the actual disk I/O + 3. Under global lock: store result in _artifacts + + Returns True if the artifact was loaded. + """ + if self._persistence is None: + return False + + # Phase 1: Check lazy index under global lock + with self._lock: + if key in self._artifacts: + return True # Already loaded (maybe by another thread) + if key not in self._lazy_index: + return False + meta = self._lazy_index.get(key) + if meta is None: + return False + + # Phase 2: Do I/O under per-key lock (not global lock) + loading_lock = self._get_loading_lock(key) + with loading_lock: + # Double-check after acquiring per-key lock + with self._lock: + if key in self._artifacts: + self._cleanup_loading_lock(key) + return True + if key not in self._lazy_index: + self._cleanup_loading_lock(key) + return False + meta = self._lazy_index.pop(key) + + # Do the actual I/O outside any lock + flow_id, name = key + try: + obj = self._persistence.load(name, flow_id=flow_id) + except Exception as exc: + logger.warning("Failed to lazy-load artifact '%s' (flow_id=%d): %s", name, flow_id, exc) + # Put metadata back in lazy_index so we can retry + with self._lock: + if key not in self._artifacts: + self._lazy_index[key] = meta + self._cleanup_loading_lock(key) + return False + + # Phase 3: Store result under global lock + with self._lock: + self._artifacts[key] = { + "object": obj, + "name": name, + "type_name": meta.get("type_name", type(obj).__name__), + "module": meta.get("module", type(obj).__module__), + "node_id": meta.get("node_id", -1), + "flow_id": flow_id, + "created_at": meta.get("created_at", datetime.now(timezone.utc).isoformat()), + "size_bytes": meta.get("size_bytes", sys.getsizeof(obj)), + "persisted": True, + "recovered": True, + } + logger.info("Lazy-loaded artifact '%s' (flow_id=%d)", name, flow_id) + self._cleanup_loading_lock(key) + return True + + # ------------------------------------------------------------------ + # Core operations + # ------------------------------------------------------------------ + + def publish(self, name: str, obj: Any, node_id: int, flow_id: int = 0) -> None: + key = (flow_id, name) + with self._lock: + if key in self._artifacts: + raise ValueError( + f"Artifact '{name}' already exists (published by node " + f"{self._artifacts[key]['node_id']}). " + f"Delete it first with flowfile.delete_artifact('{name}') " + f"before publishing a new one with the same name." + ) + metadata = { + "object": obj, + "name": name, + "type_name": type(obj).__name__, + "module": type(obj).__module__, + "node_id": node_id, + "flow_id": flow_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "size_bytes": sys.getsizeof(obj), + "persisted": False, # Will be set True after successful persist + "persist_pending": self._persistence is not None, + } + self._artifacts[key] = metadata + + # Remove from lazy index if present (we now have it in memory) + self._lazy_index.pop(key, None) + + # Track that persistence is in progress + if self._persistence is not None: + self._persist_pending.add(key) + + # Persist to disk outside the lock (I/O can be slow) + if self._persistence is not None: + try: + self._persistence.save(name, obj, metadata, flow_id=flow_id) + # Mark as successfully persisted + with self._lock: + if key in self._artifacts: + self._artifacts[key]["persisted"] = True + self._artifacts[key]["persist_pending"] = False + self._persist_pending.discard(key) + except Exception as exc: + logger.warning("Failed to persist artifact '%s': %s", name, exc) + with self._lock: + if key in self._artifacts: + self._artifacts[key]["persisted"] = False + self._artifacts[key]["persist_pending"] = False + self._persist_pending.discard(key) + + def delete(self, name: str, flow_id: int = 0) -> None: + key = (flow_id, name) + with self._lock: + if key not in self._artifacts and key not in self._lazy_index: + raise KeyError(f"Artifact '{name}' not found") + self._artifacts.pop(key, None) + self._lazy_index.pop(key, None) + + if self._persistence is not None: + try: + self._persistence.delete(name, flow_id=flow_id) + except Exception as exc: + logger.warning("Failed to delete persisted artifact '%s': %s", name, exc) + + def get(self, name: str, flow_id: int = 0) -> Any: + key = (flow_id, name) + # First check in-memory (fast path) + with self._lock: + if key in self._artifacts: + return self._artifacts[key]["object"] + # Check if it's in lazy index before attempting load + in_lazy_index = key in self._lazy_index + if not in_lazy_index: + raise KeyError(f"Artifact '{name}' not found") + + # Attempt lazy load from disk (releases global lock during I/O) + if self._try_lazy_load(key): + with self._lock: + if key in self._artifacts: + return self._artifacts[key]["object"] + + # If we get here, the artifact was in lazy_index but failed to load + raise KeyError( + f"Artifact '{name}' exists on disk but failed to load. " + "Check logs for details." + ) + + def list_all(self, flow_id: int | None = None) -> dict[str, dict[str, Any]]: + """Return metadata for all artifacts, optionally filtered by *flow_id*. + + Includes both in-memory artifacts and artifacts known to be + persisted on disk (lazy index). + """ + with self._lock: + result: dict[str, dict[str, Any]] = {} + # In-memory artifacts + for (_fid, _name), meta in self._artifacts.items(): + if flow_id is None or _fid == flow_id: + result[meta["name"]] = {k: v for k, v in meta.items() if k != "object"} + # Lazy-indexed (on disk, not yet loaded) + for (_fid, _name), meta in self._lazy_index.items(): + if flow_id is None or _fid == flow_id: + name = meta.get("name", _name) + if name not in result: + entry = dict(meta) + entry["persisted"] = True + entry["in_memory"] = False + result[name] = entry + return result + + def clear(self, flow_id: int | None = None) -> None: + """Clear all artifacts, or only those belonging to *flow_id*.""" + with self._lock: + if flow_id is None: + self._artifacts.clear() + self._lazy_index.clear() + else: + to_remove = [ + key for key in self._artifacts if key[0] == flow_id + ] + for key in to_remove: + del self._artifacts[key] + lazy_remove = [ + key for key in self._lazy_index if key[0] == flow_id + ] + for key in lazy_remove: + del self._lazy_index[key] + + if self._persistence is not None: + try: + self._persistence.clear(flow_id=flow_id) + except Exception as exc: + logger.warning("Failed to clear persisted artifacts: %s", exc) + + def clear_by_node_ids( + self, node_ids: set[int], flow_id: int | None = None, + ) -> list[str]: + """Remove all artifacts published by the given *node_ids*. + + When *flow_id* is provided, only artifacts in that flow are + considered. Returns the names of deleted artifacts. + """ + # Initialize before lock to ensure they're defined even if lock raises + to_remove: list[tuple[int, str]] = [] + lazy_remove: list[tuple[int, str]] = [] + removed_names: list[str] = [] + + with self._lock: + to_remove = [ + key + for key, meta in self._artifacts.items() + if meta["node_id"] in node_ids + and (flow_id is None or key[0] == flow_id) + ] + removed_names = [self._artifacts[key]["name"] for key in to_remove] + for key in to_remove: + del self._artifacts[key] + # Also clear from lazy index + lazy_remove = [ + key + for key, meta in self._lazy_index.items() + if meta.get("node_id") in node_ids + and (flow_id is None or key[0] == flow_id) + ] + for key in lazy_remove: + name = self._lazy_index[key].get("name", key[1]) + if name not in removed_names: + removed_names.append(name) + del self._lazy_index[key] + + # Also remove from disk + if self._persistence is not None: + for key in to_remove + lazy_remove: + fid, name = key + try: + self._persistence.delete(name, flow_id=fid) + except Exception as exc: + logger.warning("Failed to delete persisted artifact '%s': %s", name, exc) + + return removed_names + + def list_by_node_id( + self, node_id: int, flow_id: int | None = None, + ) -> dict[str, dict[str, Any]]: + """Return metadata for artifacts published by *node_id*.""" + with self._lock: + result: dict[str, dict[str, Any]] = {} + for (_fid, _name), meta in self._artifacts.items(): + if meta["node_id"] == node_id and (flow_id is None or _fid == flow_id): + result[meta["name"]] = {k: v for k, v in meta.items() if k != "object"} + for (_fid, _name), meta in self._lazy_index.items(): + if meta.get("node_id") == node_id and (flow_id is None or _fid == flow_id): + name = meta.get("name", _name) + if name not in result: + entry = dict(meta) + entry["persisted"] = True + entry["in_memory"] = False + result[name] = entry + return result diff --git a/kernel_runtime/kernel_runtime/flowfile_client.py b/kernel_runtime/kernel_runtime/flowfile_client.py new file mode 100644 index 000000000..f7e4717fd --- /dev/null +++ b/kernel_runtime/kernel_runtime/flowfile_client.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import base64 +import contextvars +import io +import os +import re +from pathlib import Path +from typing import Any, Literal + +import httpx +import polars as pl + +from kernel_runtime.artifact_store import ArtifactStore + +_context: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("flowfile_context") + +# Reusable HTTP client for log callbacks (created per execution context) +_log_client: contextvars.ContextVar[httpx.Client | None] = contextvars.ContextVar( + "flowfile_log_client", default=None +) + +# Display outputs collector (reset at start of each execution) +_displays: contextvars.ContextVar[list[dict[str, str]]] = contextvars.ContextVar( + "flowfile_displays", default=[] +) + + +def _set_context( + node_id: int, + input_paths: dict[str, list[str]], + output_dir: str, + artifact_store: ArtifactStore, + flow_id: int = 0, + log_callback_url: str = "", +) -> None: + _context.set({ + "node_id": node_id, + "input_paths": input_paths, + "output_dir": output_dir, + "artifact_store": artifact_store, + "flow_id": flow_id, + "log_callback_url": log_callback_url, + }) + # Create a reusable HTTP client for log callbacks + if log_callback_url: + _log_client.set(httpx.Client(timeout=httpx.Timeout(5.0))) + else: + _log_client.set(None) + + +def _clear_context() -> None: + client = _log_client.get(None) + if client is not None: + try: + client.close() + except Exception: + pass + _log_client.set(None) + _context.set({}) + _displays.set([]) + + +def _get_context_value(key: str) -> Any: + ctx = _context.get({}) + if key not in ctx: + raise RuntimeError(f"flowfile context not initialized (missing '{key}'). This API is only available during /execute.") + return ctx[key] + + +def read_input(name: str = "main") -> pl.LazyFrame: + """Read all input files for *name* and return them as a single LazyFrame. + + When multiple paths are registered under the same name (e.g. a union + of several upstream nodes), all files are scanned and concatenated + automatically by Polars. + """ + input_paths: dict[str, list[str]] = _get_context_value("input_paths") + if name not in input_paths: + available = list(input_paths.keys()) + raise KeyError(f"Input '{name}' not found. Available inputs: {available}") + paths = input_paths[name] + if len(paths) == 1: + return pl.scan_parquet(paths[0]) + return pl.scan_parquet(paths) + + +def read_first(name: str = "main") -> pl.LazyFrame: + """Read only the first input file for *name*. + + This is a convenience shortcut equivalent to scanning + ``input_paths[name][0]``. + """ + input_paths: dict[str, list[str]] = _get_context_value("input_paths") + if name not in input_paths: + available = list(input_paths.keys()) + raise KeyError(f"Input '{name}' not found. Available inputs: {available}") + return pl.scan_parquet(input_paths[name][0]) + + +def read_inputs() -> dict[str, pl.LazyFrame]: + """Read all named inputs, returning a dict of LazyFrames. + + Each entry concatenates all paths registered under that name. + """ + input_paths: dict[str, list[str]] = _get_context_value("input_paths") + result: dict[str, pl.LazyFrame] = {} + for name, paths in input_paths.items(): + if len(paths) == 1: + result[name] = pl.scan_parquet(paths[0]) + else: + result[name] = pl.scan_parquet(paths) + return result + + +def publish_output(df: pl.LazyFrame | pl.DataFrame, name: str = "main") -> None: + output_dir = _get_context_value("output_dir") + os.makedirs(output_dir, exist_ok=True) + output_path = Path(output_dir) / f"{name}.parquet" + if isinstance(df, pl.LazyFrame): + df = df.collect() + df.write_parquet(str(output_path)) + # Ensure the file is fully flushed to disk before the host reads it + # This prevents "File must end with PAR1" errors from race conditions + with open(output_path, "rb") as f: + os.fsync(f.fileno()) + + +def publish_artifact(name: str, obj: Any) -> None: + store: ArtifactStore = _get_context_value("artifact_store") + node_id: int = _get_context_value("node_id") + flow_id: int = _get_context_value("flow_id") + store.publish(name, obj, node_id, flow_id=flow_id) + + +def read_artifact(name: str) -> Any: + store: ArtifactStore = _get_context_value("artifact_store") + flow_id: int = _get_context_value("flow_id") + return store.get(name, flow_id=flow_id) + + +def delete_artifact(name: str) -> None: + store: ArtifactStore = _get_context_value("artifact_store") + flow_id: int = _get_context_value("flow_id") + store.delete(name, flow_id=flow_id) + + +def list_artifacts() -> dict: + store: ArtifactStore = _get_context_value("artifact_store") + flow_id: int = _get_context_value("flow_id") + return store.list_all(flow_id=flow_id) + + +# ===== Logging APIs ===== + +def log(message: str, level: Literal["INFO", "WARNING", "ERROR"] = "INFO") -> None: + """Send a log message to the FlowFile log viewer. + + The message appears in the frontend log stream in real time. + + Args: + message: The log message text. + level: Log severity — ``"INFO"`` (default), ``"WARNING"``, or ``"ERROR"``. + """ + flow_id: int = _get_context_value("flow_id") + node_id: int = _get_context_value("node_id") + callback_url: str = _get_context_value("log_callback_url") + if not callback_url: + # No callback configured — fall back to printing so the message + # still shows up in captured stdout. + print(f"[{level}] {message}") # noqa: T201 + return + + client = _log_client.get(None) + if client is None: + print(f"[{level}] {message}") # noqa: T201 + return + + payload = { + "flowfile_flow_id": flow_id, + "node_id": node_id, + "log_message": message, + "log_type": level, + } + try: + client.post(callback_url, json=payload) + except Exception: + # Best-effort — don't let logging failures break user code. + pass + + +def log_info(message: str) -> None: + """Convenience wrapper: ``flowfile.log(message, level="INFO")``.""" + log(message, level="INFO") + + +def log_warning(message: str) -> None: + """Convenience wrapper: ``flowfile.log(message, level="WARNING")``.""" + log(message, level="WARNING") + + +def log_error(message: str) -> None: + """Convenience wrapper: ``flowfile.log(message, level="ERROR")``.""" + log(message, level="ERROR") + + +# ===== Display APIs ===== + +def _is_matplotlib_figure(obj: Any) -> bool: + """Check if obj is a matplotlib Figure (without requiring matplotlib).""" + try: + import matplotlib.figure + return isinstance(obj, matplotlib.figure.Figure) + except ImportError: + return False + + +def _is_plotly_figure(obj: Any) -> bool: + """Check if obj is a plotly Figure (without requiring plotly).""" + try: + import plotly.graph_objects as go + return isinstance(obj, go.Figure) + except ImportError: + return False + + +def _is_pil_image(obj: Any) -> bool: + """Check if obj is a PIL Image (without requiring PIL).""" + try: + from PIL import Image + return isinstance(obj, Image.Image) + except ImportError: + return False + + +# Regex to detect HTML tags: , , ,
, etc. +_HTML_TAG_RE = re.compile(r"<[a-zA-Z/][^>]*>") + + +def _is_html_string(obj: Any) -> bool: + """Check if obj is a string that looks like HTML. + + Uses a regex to detect actual HTML tags like ,
,
, etc. + This avoids false positives from strings like "x < 10 and y > 5". + """ + if not isinstance(obj, str): + return False + return bool(_HTML_TAG_RE.search(obj)) + + +def _reset_displays() -> None: + """Clear the display outputs list. Called at start of each execution.""" + _displays.set([]) + + +def _get_displays() -> list[dict[str, str]]: + """Return the current list of display outputs.""" + return _displays.get([]) + + +def display(obj: Any, title: str = "") -> None: + """Display a rich object in the output panel. + + Supported object types: + - matplotlib.figure.Figure: Rendered as PNG image + - plotly.graph_objects.Figure: Rendered as interactive HTML + - PIL.Image.Image: Rendered as PNG image + - str containing HTML tags: Rendered as HTML + - Anything else: Converted to string and displayed as plain text + + Args: + obj: The object to display. + title: Optional title for the display output. + """ + displays = _displays.get([]) + + if _is_matplotlib_figure(obj): + # Render matplotlib figure to PNG + buf = io.BytesIO() + obj.savefig(buf, format="png", dpi=150, bbox_inches="tight") + buf.seek(0) + data = base64.b64encode(buf.read()).decode("ascii") + displays.append({ + "mime_type": "image/png", + "data": data, + "title": title, + }) + elif _is_plotly_figure(obj): + # Render plotly figure to HTML + html = obj.to_html(include_plotlyjs="cdn", full_html=False) + displays.append({ + "mime_type": "text/html", + "data": html, + "title": title, + }) + elif _is_pil_image(obj): + # Render PIL image to PNG + buf = io.BytesIO() + obj.save(buf, format="PNG") + buf.seek(0) + data = base64.b64encode(buf.read()).decode("ascii") + displays.append({ + "mime_type": "image/png", + "data": data, + "title": title, + }) + elif _is_html_string(obj): + # Store HTML string directly + displays.append({ + "mime_type": "text/html", + "data": obj, + "title": title, + }) + else: + # Fall back to plain text + displays.append({ + "mime_type": "text/plain", + "data": str(obj), + "title": title, + }) + + _displays.set(displays) diff --git a/kernel_runtime/kernel_runtime/main.py b/kernel_runtime/kernel_runtime/main.py new file mode 100644 index 000000000..dd49e1e73 --- /dev/null +++ b/kernel_runtime/kernel_runtime/main.py @@ -0,0 +1,469 @@ +import ast +import contextlib +import io +import logging +import os +import time +from collections.abc import AsyncIterator +from pathlib import Path + +from fastapi import FastAPI, Query +from pydantic import BaseModel, Field + +from kernel_runtime import __version__, flowfile_client +from kernel_runtime.artifact_persistence import ArtifactPersistence, RecoveryMode +from kernel_runtime.artifact_store import ArtifactStore + +logger = logging.getLogger(__name__) + +artifact_store = ArtifactStore() + +# --------------------------------------------------------------------------- +# Persistence setup (driven by environment variables) +# --------------------------------------------------------------------------- +_persistence: ArtifactPersistence | None = None +_recovery_mode = RecoveryMode.LAZY +_recovery_status: dict = {"status": "pending", "recovered": [], "errors": []} +_kernel_id: str = "default" +_persistence_path: str = "/shared/artifacts" + + +def _setup_persistence() -> None: + """Initialize persistence from environment variables. + + Environment variables are read at call time (not import time) so tests + can set them before creating the TestClient. + """ + global _persistence, _recovery_mode, _recovery_status, _kernel_id, _persistence_path + + persistence_enabled = os.environ.get("PERSISTENCE_ENABLED", "true").lower() in ("1", "true", "yes") + _persistence_path = os.environ.get("PERSISTENCE_PATH", "/shared/artifacts") + _kernel_id = os.environ.get("KERNEL_ID", "default") + recovery_mode_env = os.environ.get("RECOVERY_MODE", "lazy").lower() + # Cleanup artifacts older than this many hours on startup (0 = disabled) + cleanup_age_hours = float(os.environ.get("PERSISTENCE_CLEANUP_HOURS", "24")) + + if not persistence_enabled: + _recovery_status = {"status": "disabled", "recovered": [], "errors": []} + logger.info("Artifact persistence is disabled") + return + + base_path = Path(_persistence_path) / _kernel_id + _persistence = ArtifactPersistence(base_path) + artifact_store.enable_persistence(_persistence) + + # Cleanup stale artifacts before recovery + if cleanup_age_hours > 0: + try: + removed = _persistence.cleanup(max_age_hours=cleanup_age_hours) + if removed > 0: + logger.info( + "Startup cleanup: removed %d artifacts older than %.1f hours", + removed, cleanup_age_hours + ) + except Exception as exc: + logger.warning("Startup cleanup failed (continuing anyway): %s", exc) + + try: + _recovery_mode = RecoveryMode(recovery_mode_env) + except ValueError: + _recovery_mode = RecoveryMode.LAZY + + if _recovery_mode == RecoveryMode.EAGER: + _recovery_status = {"status": "recovering", "recovered": [], "errors": []} + try: + recovered = artifact_store.recover_all() + _recovery_status = { + "status": "completed", + "mode": "eager", + "recovered": recovered, + "errors": [], + } + logger.info("Eager recovery complete: %d artifacts restored", len(recovered)) + except Exception as exc: + _recovery_status = { + "status": "error", + "mode": "eager", + "recovered": [], + "errors": [str(exc)], + } + logger.error("Eager recovery failed: %s", exc) + + elif _recovery_mode == RecoveryMode.LAZY: + count = artifact_store.build_lazy_index() + _recovery_status = { + "status": "completed", + "mode": "lazy", + "indexed": count, + "recovered": [], + "errors": [], + } + logger.info("Lazy recovery index built: %d artifacts available on disk", count) + + elif _recovery_mode == RecoveryMode.CLEAR: + logger.warning( + "RECOVERY_MODE=clear: Deleting ALL persisted artifacts. " + "This is destructive and cannot be undone." + ) + _persistence.clear() + _recovery_status = { + "status": "completed", + "mode": "clear", + "recovered": [], + "errors": [], + } + logger.info("Recovery mode=clear: cleared all persisted artifacts") + + +@contextlib.asynccontextmanager +async def _lifespan(app: FastAPI) -> AsyncIterator[None]: + _setup_persistence() + yield + + +app = FastAPI(title="FlowFile Kernel Runtime", version=__version__, lifespan=_lifespan) + + +# --------------------------------------------------------------------------- +# Request / Response models +# --------------------------------------------------------------------------- + +# Matplotlib setup code to auto-capture plt.show() calls +_MATPLOTLIB_SETUP = """\ +try: + import matplotlib as _mpl + _mpl.use('Agg') + import matplotlib.pyplot as _plt + _original_show = _plt.show + def _flowfile_show(*args, **kwargs): + import matplotlib.pyplot as __plt + for _fig_num in __plt.get_fignums(): + flowfile.display(__plt.figure(_fig_num)) + __plt.close('all') + _plt.show = _flowfile_show +except ImportError: + pass +""" + + +def _maybe_wrap_last_expression(code: str) -> str: + """If the last statement is a bare expression, wrap it in flowfile.display(). + + This provides Jupyter-like behavior where the result of the last expression + is automatically displayed. + """ + try: + tree = ast.parse(code) + except SyntaxError: + return code + if not tree.body: + return code + last = tree.body[-1] + if not isinstance(last, ast.Expr): + return code + + # Don't wrap if the expression is None, a string literal, or already a call to display/print + if isinstance(last.value, ast.Constant) and last.value.value is None: + return code + if isinstance(last.value, ast.Call): + # Check if it's already a print or display call + func = last.value.func + if isinstance(func, ast.Name) and func.id in ("print", "display"): + return code + if isinstance(func, ast.Attribute) and func.attr in ("print", "display"): + return code + + # Use ast.get_source_segment for robust source extraction (Python 3.8+) + last_expr_text = ast.get_source_segment(code, last) + if last_expr_text is None: + # Fallback if get_source_segment fails + return code + + # Build the new code with the last expression wrapped + lines = code.split('\n') + prefix = '\n'.join(lines[:last.lineno - 1]) + if prefix: + prefix += '\n' + return prefix + f'flowfile.display({last_expr_text})\n' + + +class ExecuteRequest(BaseModel): + node_id: int + code: str + input_paths: dict[str, list[str]] = {} + output_dir: str = "" + flow_id: int = 0 + log_callback_url: str = "" + interactive: bool = False # When True, auto-display last expression + + +class ClearNodeArtifactsRequest(BaseModel): + node_ids: list[int] + flow_id: int | None = None + + +class DisplayOutput(BaseModel): + """A single display output from code execution.""" + mime_type: str # "image/png", "text/html", "text/plain" + data: str # base64 for images, raw HTML for text/html, plain text otherwise + title: str = "" + + +class ExecuteResponse(BaseModel): + success: bool + output_paths: list[str] = [] + artifacts_published: list[str] = [] + artifacts_deleted: list[str] = [] + display_outputs: list[DisplayOutput] = [] + stdout: str = "" + stderr: str = "" + error: str | None = None + execution_time_ms: float = 0.0 + + +class ArtifactIdentifier(BaseModel): + """Identifies a specific artifact by flow_id and name.""" + flow_id: int + name: str + + +class CleanupRequest(BaseModel): + max_age_hours: float | None = None + artifact_names: list[ArtifactIdentifier] | None = Field( + default=None, + description="List of specific artifacts to delete", + ) + + +# --------------------------------------------------------------------------- +# Existing endpoints +# --------------------------------------------------------------------------- + +@app.post("/execute", response_model=ExecuteResponse) +async def execute(request: ExecuteRequest): + start = time.perf_counter() + stdout_buf = io.StringIO() + stderr_buf = io.StringIO() + + output_dir = request.output_dir + if output_dir: + os.makedirs(output_dir, exist_ok=True) + + # Clear any artifacts this node previously published so re-execution + # doesn't fail with "already exists". + artifact_store.clear_by_node_ids({request.node_id}, flow_id=request.flow_id) + + artifacts_before = set(artifact_store.list_all(flow_id=request.flow_id).keys()) + + try: + flowfile_client._set_context( + node_id=request.node_id, + input_paths=request.input_paths, + output_dir=output_dir, + artifact_store=artifact_store, + flow_id=request.flow_id, + log_callback_url=request.log_callback_url, + ) + + # Reset display outputs for this execution + flowfile_client._reset_displays() + + # Prepare execution namespace with flowfile module + exec_globals = {"flowfile": flowfile_client} + + with contextlib.redirect_stdout(stdout_buf), contextlib.redirect_stderr(stderr_buf): + # Execute matplotlib setup to patch plt.show() + exec(_MATPLOTLIB_SETUP, exec_globals) # noqa: S102 + + # Prepare user code - optionally wrap last expression for interactive mode + user_code = request.code + if request.interactive: + user_code = _maybe_wrap_last_expression(user_code) + + # Execute user code + exec(user_code, exec_globals) # noqa: S102 + + # Collect display outputs + display_outputs = [ + DisplayOutput(**d) for d in flowfile_client._get_displays() + ] + + # Collect output parquet files + output_paths: list[str] = [] + if output_dir and Path(output_dir).exists(): + output_paths = [ + str(p) for p in sorted(Path(output_dir).glob("*.parquet")) + ] + + artifacts_after = set(artifact_store.list_all(flow_id=request.flow_id).keys()) + new_artifacts = sorted(artifacts_after - artifacts_before) + deleted_artifacts = sorted(artifacts_before - artifacts_after) + + elapsed = (time.perf_counter() - start) * 1000 + return ExecuteResponse( + success=True, + output_paths=output_paths, + artifacts_published=new_artifacts, + artifacts_deleted=deleted_artifacts, + display_outputs=display_outputs, + stdout=stdout_buf.getvalue(), + stderr=stderr_buf.getvalue(), + execution_time_ms=elapsed, + ) + except Exception as exc: + # Still collect any display outputs that were generated before the error + display_outputs = [ + DisplayOutput(**d) for d in flowfile_client._get_displays() + ] + elapsed = (time.perf_counter() - start) * 1000 + return ExecuteResponse( + success=False, + display_outputs=display_outputs, + stdout=stdout_buf.getvalue(), + stderr=stderr_buf.getvalue(), + error=f"{type(exc).__name__}: {exc}", + execution_time_ms=elapsed, + ) + finally: + flowfile_client._clear_context() + + +@app.post("/clear") +async def clear_artifacts(flow_id: int | None = Query(default=None)): + """Clear all artifacts, or only those belonging to a specific flow.""" + artifact_store.clear(flow_id=flow_id) + return {"status": "cleared"} + + +@app.post("/clear_node_artifacts") +async def clear_node_artifacts(request: ClearNodeArtifactsRequest): + """Clear only artifacts published by the specified node IDs.""" + removed = artifact_store.clear_by_node_ids( + set(request.node_ids), flow_id=request.flow_id, + ) + return {"status": "cleared", "removed": removed} + + +@app.get("/artifacts") +async def list_artifacts(flow_id: int | None = Query(default=None)): + """List all artifacts, optionally filtered by flow_id.""" + return artifact_store.list_all(flow_id=flow_id) + + +@app.get("/artifacts/node/{node_id}") +async def list_node_artifacts( + node_id: int, flow_id: int | None = Query(default=None), +): + """List artifacts published by a specific node.""" + return artifact_store.list_by_node_id(node_id, flow_id=flow_id) + + +# --------------------------------------------------------------------------- +# Persistence & Recovery endpoints +# --------------------------------------------------------------------------- + +@app.post("/recover") +async def recover_artifacts(): + """Trigger manual artifact recovery from disk.""" + global _recovery_status + + if _persistence is None: + return {"status": "disabled", "message": "Persistence is not enabled"} + + _recovery_status = {"status": "recovering", "recovered": [], "errors": []} + try: + recovered = artifact_store.recover_all() + _recovery_status = { + "status": "completed", + "mode": "manual", + "recovered": recovered, + "errors": [], + } + return _recovery_status + except Exception as exc: + _recovery_status = { + "status": "error", + "mode": "manual", + "recovered": [], + "errors": [str(exc)], + } + return _recovery_status + + +@app.get("/recovery-status") +async def recovery_status(): + """Return the current recovery status.""" + return _recovery_status + + +@app.post("/cleanup") +async def cleanup_artifacts(request: CleanupRequest): + """Clean up old or specific persisted artifacts.""" + if _persistence is None: + return {"status": "disabled", "removed_count": 0} + + names = None + if request.artifact_names: + names = [(item.flow_id, item.name) for item in request.artifact_names] + + removed_count = _persistence.cleanup( + max_age_hours=request.max_age_hours, + names=names, + ) + # Rebuild lazy index after cleanup + artifact_store.build_lazy_index() + return {"status": "cleaned", "removed_count": removed_count} + + +@app.get("/persistence") +async def persistence_info(): + """Return persistence configuration and stats.""" + if _persistence is None: + return { + "enabled": False, + "recovery_mode": _recovery_mode.value, + "persisted_count": 0, + "disk_usage_bytes": 0, + } + + persisted = _persistence.list_persisted() + in_memory = artifact_store.list_all() + + # Build per-artifact status + artifact_status = {} + for (fid, name), meta in persisted.items(): + artifact_status[name] = { + "flow_id": fid, + "persisted": True, + "in_memory": name in in_memory and in_memory[name].get("in_memory", True) is not False, + } + for name, meta in in_memory.items(): + if name not in artifact_status: + artifact_status[name] = { + "flow_id": meta.get("flow_id", 0), + "persisted": meta.get("persisted", False), + "in_memory": True, + } + + return { + "enabled": True, + "recovery_mode": _recovery_mode.value, + "kernel_id": _kernel_id, + "persistence_path": str(Path(_persistence_path) / _kernel_id), + "persisted_count": len(persisted), + "in_memory_count": len([a for a in in_memory.values() if a.get("in_memory", True) is not False]), + "disk_usage_bytes": _persistence.disk_usage_bytes(), + "artifacts": artifact_status, + } + + +@app.get("/health") +async def health(): + persistence_status = "enabled" if _persistence is not None else "disabled" + return { + "status": "healthy", + "version": __version__, + "artifact_count": len(artifact_store.list_all()), + "persistence": persistence_status, + "recovery_mode": _recovery_mode.value, + } diff --git a/kernel_runtime/pyproject.toml b/kernel_runtime/pyproject.toml new file mode 100644 index 000000000..fa73f3e93 --- /dev/null +++ b/kernel_runtime/pyproject.toml @@ -0,0 +1,23 @@ +[project] +name = "kernel_runtime" +version = "0.1.0" +description = "FlowFile kernel runtime - executes Python code in isolated Docker containers" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn>=0.32.0", + "polars>=1.0.0", + "pyarrow>=14.0.0", + "httpx>=0.24.0", + "cloudpickle>=3.0.0", +] + +[project.optional-dependencies] +test = [ + "pytest>=7.0.0", + "httpx>=0.24.0", +] + +[build-system] +requires = ["setuptools>=68.0"] +build-backend = "setuptools.build_meta" diff --git a/kernel_runtime/tests/__init__.py b/kernel_runtime/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kernel_runtime/tests/conftest.py b/kernel_runtime/tests/conftest.py new file mode 100644 index 000000000..a8c8bf09e --- /dev/null +++ b/kernel_runtime/tests/conftest.py @@ -0,0 +1,94 @@ +"""Shared fixtures for kernel_runtime tests.""" + +import os +import tempfile +from collections.abc import Generator +from pathlib import Path + +import pytest +from fastapi.testclient import TestClient + +from kernel_runtime.artifact_persistence import ArtifactPersistence +from kernel_runtime.artifact_store import ArtifactStore +from kernel_runtime.main import app, artifact_store + + +@pytest.fixture() +def store() -> ArtifactStore: + """Fresh ArtifactStore for each test.""" + return ArtifactStore() + + +@pytest.fixture(autouse=True) +def _clear_global_state(): + """Reset the global artifact_store and persistence state between tests.""" + from kernel_runtime import main + from kernel_runtime.artifact_persistence import RecoveryMode + + artifact_store.clear() + # Reset persistence state + main._persistence = None + main._recovery_mode = RecoveryMode.LAZY + main._recovery_status = {"status": "pending", "recovered": [], "errors": []} + main._kernel_id = "default" + main._persistence_path = "/shared/artifacts" + # Detach persistence from artifact store + artifact_store._persistence = None + artifact_store._lazy_index.clear() + artifact_store._loading_locks.clear() + artifact_store._persist_pending.clear() + + yield + + artifact_store.clear() + main._persistence = None + main._recovery_mode = RecoveryMode.LAZY + main._recovery_status = {"status": "pending", "recovered": [], "errors": []} + main._kernel_id = "default" + main._persistence_path = "/shared/artifacts" + artifact_store._persistence = None + artifact_store._lazy_index.clear() + artifact_store._loading_locks.clear() + artifact_store._persist_pending.clear() + + +@pytest.fixture() +def client(tmp_path: Path) -> Generator[TestClient, None, None]: + """FastAPI TestClient for the kernel runtime app. + + Sets PERSISTENCE_PATH to a temp directory so persistence tests work + in CI environments without /shared. + """ + # Set env vars before TestClient triggers lifespan + old_path = os.environ.get("PERSISTENCE_PATH") + os.environ["PERSISTENCE_PATH"] = str(tmp_path / "artifacts") + + with TestClient(app) as c: + yield c + + # Restore original env var + if old_path is None: + os.environ.pop("PERSISTENCE_PATH", None) + else: + os.environ["PERSISTENCE_PATH"] = old_path + + +@pytest.fixture() +def tmp_dir() -> Generator[Path, None, None]: + """Temporary directory cleaned up after each test.""" + with tempfile.TemporaryDirectory(prefix="kernel_test_") as d: + yield Path(d) + + +@pytest.fixture() +def persistence(tmp_dir: Path) -> ArtifactPersistence: + """Fresh ArtifactPersistence backed by a temporary directory.""" + return ArtifactPersistence(tmp_dir / "artifacts") + + +@pytest.fixture() +def store_with_persistence(persistence: ArtifactPersistence) -> ArtifactStore: + """ArtifactStore with persistence enabled.""" + s = ArtifactStore() + s.enable_persistence(persistence) + return s diff --git a/kernel_runtime/tests/test_artifact_persistence.py b/kernel_runtime/tests/test_artifact_persistence.py new file mode 100644 index 000000000..ab1bad8c7 --- /dev/null +++ b/kernel_runtime/tests/test_artifact_persistence.py @@ -0,0 +1,248 @@ +"""Tests for kernel_runtime.artifact_persistence.""" + +import json +import time + +import pytest + +from kernel_runtime.artifact_persistence import ArtifactPersistence, RecoveryMode, _safe_dirname + + +class TestSafeDirname: + def test_simple_name(self): + assert _safe_dirname("model") == "model" + + def test_with_spaces(self): + assert _safe_dirname("my model") == "my_model" + + def test_with_special_chars(self): + assert _safe_dirname("model/v1:latest") == "model_v1_latest" + + def test_with_dots_and_dashes(self): + assert _safe_dirname("model-v1.0") == "model-v1.0" + + +class TestSaveAndLoad: + def test_save_and_load_dict(self, persistence: ArtifactPersistence): + obj = {"weights": [1.0, 2.0, 3.0], "bias": 0.5} + metadata = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + + persistence.save("model", obj, metadata, flow_id=0) + loaded = persistence.load("model", flow_id=0) + + assert loaded == obj + + def test_save_and_load_list(self, persistence: ArtifactPersistence): + obj = [1, 2, 3, "hello"] + metadata = {"name": "data", "node_id": 2, "type_name": "list", "module": "builtins"} + + persistence.save("data", obj, metadata, flow_id=0) + loaded = persistence.load("data", flow_id=0) + + assert loaded == obj + + def test_save_and_load_none(self, persistence: ArtifactPersistence): + metadata = {"name": "nothing", "node_id": 1, "type_name": "NoneType", "module": "builtins"} + + persistence.save("nothing", None, metadata, flow_id=0) + loaded = persistence.load("nothing", flow_id=0) + + assert loaded is None + + def test_save_and_load_lambda(self, persistence: ArtifactPersistence): + """cloudpickle handles lambdas that standard pickle cannot.""" + fn = lambda x: x * 2 # noqa: E731 + metadata = {"name": "fn", "node_id": 1, "type_name": "function", "module": "__main__"} + + persistence.save("fn", fn, metadata, flow_id=0) + loaded = persistence.load("fn", flow_id=0) + + assert loaded(5) == 10 + + def test_save_and_load_custom_class(self, persistence: ArtifactPersistence): + class MyModel: + def __init__(self, w): + self.w = w + + def predict(self, x): + return x * self.w + + obj = MyModel(3.0) + metadata = {"name": "custom", "node_id": 1, "type_name": "MyModel", "module": "__main__"} + + persistence.save("custom", obj, metadata, flow_id=0) + loaded = persistence.load("custom", flow_id=0) + + assert loaded.predict(4) == 12.0 + + def test_load_nonexistent_raises(self, persistence: ArtifactPersistence): + with pytest.raises(FileNotFoundError, match="No persisted artifact"): + persistence.load("nonexistent", flow_id=0) + + def test_metadata_written(self, persistence: ArtifactPersistence): + metadata = {"name": "item", "node_id": 5, "type_name": "int", "module": "builtins"} + persistence.save("item", 42, metadata, flow_id=0) + + loaded_meta = persistence.load_metadata("item", flow_id=0) + assert loaded_meta is not None + assert loaded_meta["name"] == "item" + assert loaded_meta["node_id"] == 5 + assert "checksum" in loaded_meta + assert "persisted_at" in loaded_meta + assert "data_size_bytes" in loaded_meta + + def test_checksum_validation(self, persistence: ArtifactPersistence): + metadata = {"name": "item", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("item", 42, metadata, flow_id=0) + + # Corrupt the data file + data_path = persistence._data_path(0, "item") + data_path.write_bytes(b"corrupted data") + + with pytest.raises(ValueError, match="Checksum mismatch"): + persistence.load("item", flow_id=0) + + +class TestFlowIsolation: + def test_same_name_different_flows(self, persistence: ArtifactPersistence): + meta1 = {"name": "model", "node_id": 1, "type_name": "str", "module": "builtins"} + meta2 = {"name": "model", "node_id": 2, "type_name": "str", "module": "builtins"} + + persistence.save("model", "flow1_model", meta1, flow_id=1) + persistence.save("model", "flow2_model", meta2, flow_id=2) + + assert persistence.load("model", flow_id=1) == "flow1_model" + assert persistence.load("model", flow_id=2) == "flow2_model" + + def test_delete_scoped_to_flow(self, persistence: ArtifactPersistence): + meta = {"name": "model", "node_id": 1, "type_name": "str", "module": "builtins"} + persistence.save("model", "v1", meta, flow_id=1) + persistence.save("model", "v2", meta, flow_id=2) + + persistence.delete("model", flow_id=1) + + with pytest.raises(FileNotFoundError): + persistence.load("model", flow_id=1) + assert persistence.load("model", flow_id=2) == "v2" + + +class TestDelete: + def test_delete_removes_files(self, persistence: ArtifactPersistence): + meta = {"name": "temp", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("temp", 42, meta, flow_id=0) + persistence.delete("temp", flow_id=0) + + with pytest.raises(FileNotFoundError): + persistence.load("temp", flow_id=0) + + def test_delete_nonexistent_is_safe(self, persistence: ArtifactPersistence): + # Should not raise + persistence.delete("nonexistent", flow_id=0) + + +class TestClear: + def test_clear_all(self, persistence: ArtifactPersistence): + meta = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("a", 1, meta, flow_id=1) + persistence.save("b", 2, {**meta, "name": "b"}, flow_id=2) + + persistence.clear() + + assert persistence.list_persisted() == {} + + def test_clear_by_flow_id(self, persistence: ArtifactPersistence): + meta = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("a", 1, meta, flow_id=1) + persistence.save("b", 2, {**meta, "name": "b"}, flow_id=2) + + persistence.clear(flow_id=1) + + persisted = persistence.list_persisted() + assert len(persisted) == 1 + assert (2, "b") in persisted + + +class TestListPersisted: + def test_empty(self, persistence: ArtifactPersistence): + assert persistence.list_persisted() == {} + + def test_lists_all(self, persistence: ArtifactPersistence): + meta = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("a", 1, meta, flow_id=1) + persistence.save("b", 2, {**meta, "name": "b"}, flow_id=2) + + persisted = persistence.list_persisted() + assert len(persisted) == 2 + assert (1, "a") in persisted + assert (2, "b") in persisted + + def test_filter_by_flow_id(self, persistence: ArtifactPersistence): + meta = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("a", 1, meta, flow_id=1) + persistence.save("b", 2, {**meta, "name": "b"}, flow_id=2) + + persisted = persistence.list_persisted(flow_id=1) + assert len(persisted) == 1 + assert (1, "a") in persisted + + +class TestDiskUsage: + def test_disk_usage_increases(self, persistence: ArtifactPersistence): + assert persistence.disk_usage_bytes() == 0 + + meta = {"name": "big", "node_id": 1, "type_name": "bytes", "module": "builtins"} + persistence.save("big", b"x" * 10000, meta, flow_id=0) + + assert persistence.disk_usage_bytes() > 10000 + + +class TestCleanup: + def test_cleanup_by_age(self, persistence: ArtifactPersistence): + meta = {"name": "old", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("old", 1, meta, flow_id=0) + + # Manually backdate the persisted_at in metadata + meta_path = persistence._meta_path(0, "old") + meta_data = json.loads(meta_path.read_text()) + meta_data["persisted_at"] = "2020-01-01T00:00:00+00:00" + meta_path.write_text(json.dumps(meta_data)) + + removed = persistence.cleanup(max_age_hours=1) + assert removed == 1 + assert persistence.list_persisted() == {} + + def test_cleanup_by_name(self, persistence: ArtifactPersistence): + meta = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("a", 1, meta, flow_id=0) + persistence.save("b", 2, {**meta, "name": "b"}, flow_id=0) + + removed = persistence.cleanup(names=[(0, "a")]) + assert removed == 1 + + persisted = persistence.list_persisted() + assert len(persisted) == 1 + assert (0, "b") in persisted + + def test_cleanup_keeps_recent(self, persistence: ArtifactPersistence): + meta = {"name": "recent", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("recent", 1, meta, flow_id=0) + + removed = persistence.cleanup(max_age_hours=24) + assert removed == 0 + assert len(persistence.list_persisted()) == 1 + + +class TestRecoveryMode: + def test_enum_values(self): + assert RecoveryMode.LAZY == "lazy" + assert RecoveryMode.EAGER == "eager" + assert RecoveryMode.CLEAR == "clear" + + def test_from_string(self): + assert RecoveryMode("lazy") == RecoveryMode.LAZY + assert RecoveryMode("eager") == RecoveryMode.EAGER + assert RecoveryMode("clear") == RecoveryMode.CLEAR + + def test_none_backwards_compatibility(self): + """'none' is accepted for backwards compatibility but maps to CLEAR.""" + assert RecoveryMode("none") == RecoveryMode.CLEAR diff --git a/kernel_runtime/tests/test_artifact_store.py b/kernel_runtime/tests/test_artifact_store.py new file mode 100644 index 000000000..c138aa4b3 --- /dev/null +++ b/kernel_runtime/tests/test_artifact_store.py @@ -0,0 +1,256 @@ +"""Tests for kernel_runtime.artifact_store.""" + +import threading + +import pytest + +from kernel_runtime.artifact_store import ArtifactStore + + +class TestPublishAndGet: + def test_publish_and_retrieve(self, store: ArtifactStore): + store.publish("my_obj", {"a": 1}, node_id=1) + assert store.get("my_obj") == {"a": 1} + + def test_publish_duplicate_raises(self, store: ArtifactStore): + store.publish("key", "first", node_id=1) + with pytest.raises(ValueError, match="already exists"): + store.publish("key", "second", node_id=2) + + def test_publish_after_delete_succeeds(self, store: ArtifactStore): + store.publish("key", "first", node_id=1) + store.delete("key") + store.publish("key", "second", node_id=2) + assert store.get("key") == "second" + + def test_get_missing_raises(self, store: ArtifactStore): + with pytest.raises(KeyError, match="not found"): + store.get("nonexistent") + + def test_publish_various_types(self, store: ArtifactStore): + store.publish("int_val", 42, node_id=1) + store.publish("list_val", [1, 2, 3], node_id=1) + store.publish("none_val", None, node_id=1) + assert store.get("int_val") == 42 + assert store.get("list_val") == [1, 2, 3] + assert store.get("none_val") is None + + +class TestListAll: + def test_empty_store(self, store: ArtifactStore): + assert store.list_all() == {} + + def test_list_excludes_object(self, store: ArtifactStore): + store.publish("item", {"secret": "data"}, node_id=5) + listing = store.list_all() + assert "item" in listing + assert "object" not in listing["item"] + + def test_list_metadata_fields(self, store: ArtifactStore): + store.publish("item", [1, 2], node_id=3) + meta = store.list_all()["item"] + assert meta["name"] == "item" + assert meta["type_name"] == "list" + assert meta["module"] == "builtins" + assert meta["node_id"] == 3 + assert "created_at" in meta + assert "size_bytes" in meta + + def test_list_multiple_items(self, store: ArtifactStore): + store.publish("a", 1, node_id=1) + store.publish("b", 2, node_id=2) + listing = store.list_all() + assert set(listing.keys()) == {"a", "b"} + + +class TestClear: + def test_clear_empties_store(self, store: ArtifactStore): + store.publish("x", 1, node_id=1) + store.publish("y", 2, node_id=1) + store.clear() + assert store.list_all() == {} + + def test_clear_then_get_raises(self, store: ArtifactStore): + store.publish("x", 1, node_id=1) + store.clear() + with pytest.raises(KeyError): + store.get("x") + + def test_clear_idempotent(self, store: ArtifactStore): + store.clear() + store.clear() + assert store.list_all() == {} + + +class TestDelete: + def test_delete_removes_artifact(self, store: ArtifactStore): + store.publish("model", {"w": [1, 2]}, node_id=1) + store.delete("model") + assert "model" not in store.list_all() + + def test_delete_missing_raises(self, store: ArtifactStore): + with pytest.raises(KeyError, match="not found"): + store.delete("nonexistent") + + def test_delete_then_get_raises(self, store: ArtifactStore): + store.publish("tmp", 42, node_id=1) + store.delete("tmp") + with pytest.raises(KeyError, match="not found"): + store.get("tmp") + + def test_delete_only_target(self, store: ArtifactStore): + store.publish("keep", 1, node_id=1) + store.publish("remove", 2, node_id=1) + store.delete("remove") + assert store.get("keep") == 1 + assert set(store.list_all().keys()) == {"keep"} + + +class TestClearByNodeIds: + def test_clear_by_node_ids_removes_only_target(self, store: ArtifactStore): + store.publish("a", 1, node_id=1) + store.publish("b", 2, node_id=2) + store.publish("c", 3, node_id=1) + removed = store.clear_by_node_ids({1}) + assert sorted(removed) == ["a", "c"] + assert "b" in store.list_all() + assert "a" not in store.list_all() + assert "c" not in store.list_all() + + def test_clear_by_node_ids_empty_set(self, store: ArtifactStore): + store.publish("x", 1, node_id=1) + removed = store.clear_by_node_ids(set()) + assert removed == [] + assert "x" in store.list_all() + + def test_clear_by_node_ids_nonexistent(self, store: ArtifactStore): + store.publish("x", 1, node_id=1) + removed = store.clear_by_node_ids({99}) + assert removed == [] + assert "x" in store.list_all() + + def test_clear_by_node_ids_multiple(self, store: ArtifactStore): + store.publish("a", 1, node_id=1) + store.publish("b", 2, node_id=2) + store.publish("c", 3, node_id=3) + removed = store.clear_by_node_ids({1, 3}) + assert sorted(removed) == ["a", "c"] + assert set(store.list_all().keys()) == {"b"} + + def test_clear_allows_republish(self, store: ArtifactStore): + """After clearing a node's artifacts, re-publishing with the same name works.""" + store.publish("model", {"v": 1}, node_id=5) + store.clear_by_node_ids({5}) + store.publish("model", {"v": 2}, node_id=5) + assert store.get("model") == {"v": 2} + + +class TestListByNodeId: + def test_list_by_node_id(self, store: ArtifactStore): + store.publish("a", 1, node_id=1) + store.publish("b", 2, node_id=2) + store.publish("c", 3, node_id=1) + listing = store.list_by_node_id(1) + assert set(listing.keys()) == {"a", "c"} + + def test_list_by_node_id_empty(self, store: ArtifactStore): + assert store.list_by_node_id(99) == {} + + def test_list_by_node_id_excludes_object(self, store: ArtifactStore): + store.publish("x", {"secret": "data"}, node_id=1) + listing = store.list_by_node_id(1) + assert "object" not in listing["x"] + + +class TestFlowIsolation: + """Artifacts with the same name in different flows are independent.""" + + def test_same_name_different_flows(self, store: ArtifactStore): + store.publish("model", "flow1_model", node_id=1, flow_id=1) + store.publish("model", "flow2_model", node_id=2, flow_id=2) + assert store.get("model", flow_id=1) == "flow1_model" + assert store.get("model", flow_id=2) == "flow2_model" + + def test_delete_scoped_to_flow(self, store: ArtifactStore): + store.publish("model", "v1", node_id=1, flow_id=1) + store.publish("model", "v2", node_id=2, flow_id=2) + store.delete("model", flow_id=1) + # flow 2's artifact is untouched + assert store.get("model", flow_id=2) == "v2" + with pytest.raises(KeyError): + store.get("model", flow_id=1) + + def test_list_all_filtered_by_flow(self, store: ArtifactStore): + store.publish("a", 1, node_id=1, flow_id=1) + store.publish("b", 2, node_id=2, flow_id=2) + store.publish("c", 3, node_id=1, flow_id=1) + assert set(store.list_all(flow_id=1).keys()) == {"a", "c"} + assert set(store.list_all(flow_id=2).keys()) == {"b"} + + def test_list_all_unfiltered_returns_everything(self, store: ArtifactStore): + store.publish("a", 1, node_id=1, flow_id=1) + store.publish("b", 2, node_id=2, flow_id=2) + assert set(store.list_all().keys()) == {"a", "b"} + + def test_clear_scoped_to_flow(self, store: ArtifactStore): + store.publish("a", 1, node_id=1, flow_id=1) + store.publish("b", 2, node_id=2, flow_id=2) + store.clear(flow_id=1) + with pytest.raises(KeyError): + store.get("a", flow_id=1) + assert store.get("b", flow_id=2) == 2 + + def test_clear_all_clears_every_flow(self, store: ArtifactStore): + store.publish("a", 1, node_id=1, flow_id=1) + store.publish("b", 2, node_id=2, flow_id=2) + store.clear() + assert store.list_all() == {} + + def test_clear_by_node_ids_scoped_to_flow(self, store: ArtifactStore): + """Same node_id in different flows — only the targeted flow is cleared.""" + store.publish("model", "f1", node_id=5, flow_id=1) + store.publish("model", "f2", node_id=5, flow_id=2) + removed = store.clear_by_node_ids({5}, flow_id=1) + assert removed == ["model"] + # flow 2's artifact survives + assert store.get("model", flow_id=2) == "f2" + with pytest.raises(KeyError): + store.get("model", flow_id=1) + + def test_list_by_node_id_scoped_to_flow(self, store: ArtifactStore): + store.publish("a", 1, node_id=5, flow_id=1) + store.publish("b", 2, node_id=5, flow_id=2) + assert set(store.list_by_node_id(5, flow_id=1).keys()) == {"a"} + assert set(store.list_by_node_id(5, flow_id=2).keys()) == {"b"} + # Unfiltered returns both + assert set(store.list_by_node_id(5).keys()) == {"a", "b"} + + def test_metadata_includes_flow_id(self, store: ArtifactStore): + store.publish("item", 42, node_id=1, flow_id=7) + meta = store.list_all(flow_id=7)["item"] + assert meta["flow_id"] == 7 + + +class TestThreadSafety: + def test_concurrent_publishes(self, store: ArtifactStore): + errors = [] + + def publish_range(start: int, count: int): + try: + for i in range(start, start + count): + store.publish(f"item_{i}", i, node_id=i) + except Exception as exc: + errors.append(exc) + + threads = [ + threading.Thread(target=publish_range, args=(i * 100, 100)) + for i in range(4) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + listing = store.list_all() + assert len(listing) == 400 diff --git a/kernel_runtime/tests/test_artifact_store_persistence.py b/kernel_runtime/tests/test_artifact_store_persistence.py new file mode 100644 index 000000000..f1c163cdb --- /dev/null +++ b/kernel_runtime/tests/test_artifact_store_persistence.py @@ -0,0 +1,203 @@ +"""Tests for ArtifactStore persistence integration.""" + +import pytest + +from kernel_runtime.artifact_persistence import ArtifactPersistence +from kernel_runtime.artifact_store import ArtifactStore + + +class TestPersistenceOnPublish: + """Publishing an artifact should automatically persist to disk.""" + + def test_publish_persists_to_disk(self, store_with_persistence, persistence): + store_with_persistence.publish("model", {"w": [1, 2]}, node_id=1, flow_id=0) + + # Verify it's on disk + persisted = persistence.list_persisted() + assert (0, "model") in persisted + + # Verify the data is correct + loaded = persistence.load("model", flow_id=0) + assert loaded == {"w": [1, 2]} + + def test_publish_sets_persisted_flag(self, store_with_persistence): + store_with_persistence.publish("item", 42, node_id=1, flow_id=0) + + meta = store_with_persistence.list_all() + assert meta["item"]["persisted"] is True + + def test_delete_removes_from_disk(self, store_with_persistence, persistence): + store_with_persistence.publish("temp", 42, node_id=1, flow_id=0) + store_with_persistence.delete("temp", flow_id=0) + + assert persistence.list_persisted() == {} + + def test_clear_removes_from_disk(self, store_with_persistence, persistence): + store_with_persistence.publish("a", 1, node_id=1, flow_id=1) + store_with_persistence.publish("b", 2, node_id=2, flow_id=2) + store_with_persistence.clear() + + assert persistence.list_persisted() == {} + + def test_clear_by_flow_removes_from_disk(self, store_with_persistence, persistence): + store_with_persistence.publish("a", 1, node_id=1, flow_id=1) + store_with_persistence.publish("b", 2, node_id=2, flow_id=2) + store_with_persistence.clear(flow_id=1) + + persisted = persistence.list_persisted() + assert (1, "a") not in persisted + assert (2, "b") in persisted + + def test_clear_by_node_ids_removes_from_disk(self, store_with_persistence, persistence): + store_with_persistence.publish("a", 1, node_id=1, flow_id=0) + store_with_persistence.publish("b", 2, node_id=2, flow_id=0) + store_with_persistence.clear_by_node_ids({1}, flow_id=0) + + persisted = persistence.list_persisted() + assert (0, "a") not in persisted + assert (0, "b") in persisted + + +class TestLazyRecovery: + """Lazy loading: artifacts on disk are loaded into memory on first access.""" + + def test_lazy_index_built(self, persistence): + # Pre-populate disk + meta = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + persistence.save("model", {"w": 1}, meta, flow_id=0) + + # Create a fresh store with persistence + store = ArtifactStore() + store.enable_persistence(persistence) + count = store.build_lazy_index() + + assert count == 1 + + def test_lazy_load_on_get(self, persistence): + # Pre-populate disk + meta = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + persistence.save("model", {"w": 42}, meta, flow_id=0) + + # Create a fresh store with persistence + lazy index + store = ArtifactStore() + store.enable_persistence(persistence) + store.build_lazy_index() + + # The artifact should not be in memory yet + listing = store.list_all() + assert "model" in listing + assert listing["model"].get("in_memory") is False + + # Accessing it should trigger lazy load + obj = store.get("model", flow_id=0) + assert obj == {"w": 42} + + # Now it should be in memory + listing = store.list_all() + assert "model" in listing + # No more in_memory=False flag + + def test_lazy_load_preserves_metadata(self, persistence): + meta = {"name": "model", "node_id": 5, "type_name": "dict", "module": "builtins", + "created_at": "2024-01-01T00:00:00+00:00", "size_bytes": 100} + persistence.save("model", {"w": 1}, meta, flow_id=3) + + store = ArtifactStore() + store.enable_persistence(persistence) + store.build_lazy_index() + + # Trigger lazy load + store.get("model", flow_id=3) + + listing = store.list_all(flow_id=3) + assert listing["model"]["node_id"] == 5 + assert listing["model"]["flow_id"] == 3 + assert listing["model"]["recovered"] is True + + def test_lazy_list_includes_disk_artifacts(self, persistence): + meta = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + persistence.save("model", {"w": 1}, meta, flow_id=0) + + store = ArtifactStore() + store.enable_persistence(persistence) + store.build_lazy_index() + + # Publish an in-memory artifact + store.publish("other", 42, node_id=2, flow_id=0) + + listing = store.list_all(flow_id=0) + assert "model" in listing # from disk + assert "other" in listing # from memory + + def test_publish_removes_from_lazy_index(self, persistence): + meta = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + persistence.save("model", {"w": 1}, meta, flow_id=0) + + store = ArtifactStore() + store.enable_persistence(persistence) + store.build_lazy_index() + + # Delete (which should clear from lazy index) then republish + store.delete("model", flow_id=0) + store.publish("model", {"w": 2}, node_id=3, flow_id=0) + + assert store.get("model", flow_id=0) == {"w": 2} + + +class TestEagerRecovery: + """Eager recovery: all persisted artifacts loaded into memory at once.""" + + def test_recover_all(self, persistence): + meta1 = {"name": "a", "node_id": 1, "type_name": "int", "module": "builtins"} + meta2 = {"name": "b", "node_id": 2, "type_name": "str", "module": "builtins"} + persistence.save("a", 42, meta1, flow_id=0) + persistence.save("b", "hello", meta2, flow_id=1) + + store = ArtifactStore() + store.enable_persistence(persistence) + recovered = store.recover_all() + + assert sorted(recovered) == ["a", "b"] + assert store.get("a", flow_id=0) == 42 + assert store.get("b", flow_id=1) == "hello" + + def test_recover_skips_already_in_memory(self, persistence): + meta = {"name": "model", "node_id": 1, "type_name": "int", "module": "builtins"} + persistence.save("model", 42, meta, flow_id=0) + + store = ArtifactStore() + store.enable_persistence(persistence) + store.publish("model", 99, node_id=1, flow_id=0) + + recovered = store.recover_all() + assert recovered == [] # already in memory + assert store.get("model", flow_id=0) == 99 # original value preserved + + def test_recover_marks_recovered(self, persistence): + meta = {"name": "model", "node_id": 1, "type_name": "dict", "module": "builtins"} + persistence.save("model", {"w": 1}, meta, flow_id=0) + + store = ArtifactStore() + store.enable_persistence(persistence) + store.recover_all() + + listing = store.list_all() + assert listing["model"]["recovered"] is True + assert listing["model"]["persisted"] is True + + +class TestNoPersistence: + """When no persistence backend is attached, store behaves exactly as before.""" + + def test_no_persistence_publish_get(self): + store = ArtifactStore() + store.publish("item", 42, node_id=1) + assert store.get("item") == 42 + + def test_recover_all_returns_empty(self): + store = ArtifactStore() + assert store.recover_all() == [] + + def test_build_lazy_index_returns_zero(self): + store = ArtifactStore() + assert store.build_lazy_index() == 0 diff --git a/kernel_runtime/tests/test_flowfile_client.py b/kernel_runtime/tests/test_flowfile_client.py new file mode 100644 index 000000000..554f949af --- /dev/null +++ b/kernel_runtime/tests/test_flowfile_client.py @@ -0,0 +1,365 @@ +"""Tests for kernel_runtime.flowfile_client.""" + +from pathlib import Path + +import polars as pl +import pytest + +from kernel_runtime.artifact_store import ArtifactStore +from kernel_runtime import flowfile_client + + +@pytest.fixture(autouse=True) +def _reset_context(): + """Ensure context is cleared before and after each test.""" + flowfile_client._clear_context() + yield + flowfile_client._clear_context() + + +@pytest.fixture() +def ctx(tmp_dir: Path) -> dict: + """Set up a standard context and return its parameters.""" + store = ArtifactStore() + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + # Write a default input parquet + df = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + main_path = input_dir / "main.parquet" + df.write_parquet(str(main_path)) + + flowfile_client._set_context( + node_id=1, + input_paths={"main": [str(main_path)]}, + output_dir=str(output_dir), + artifact_store=store, + ) + return { + "store": store, + "input_dir": input_dir, + "output_dir": output_dir, + "main_path": main_path, + } + + +class TestContextManagement: + def test_missing_context_raises(self): + with pytest.raises(RuntimeError, match="context not initialized"): + flowfile_client.read_input() + + def test_set_and_clear(self, tmp_dir: Path): + store = ArtifactStore() + flowfile_client._set_context( + node_id=1, + input_paths={}, + output_dir=str(tmp_dir), + artifact_store=store, + ) + # Should not raise + flowfile_client._get_context_value("node_id") + + flowfile_client._clear_context() + with pytest.raises(RuntimeError): + flowfile_client._get_context_value("node_id") + + +class TestReadInput: + def test_read_main_input(self, ctx: dict): + lf = flowfile_client.read_input() + assert isinstance(lf, pl.LazyFrame) + df = lf.collect() + assert set(df.columns) == {"x", "y"} + assert len(df) == 3 + + def test_read_named_input(self, ctx: dict): + lf = flowfile_client.read_input("main") + df = lf.collect() + assert df["x"].to_list() == [1, 2, 3] + + def test_read_missing_input_raises(self, ctx: dict): + with pytest.raises(KeyError, match="not found"): + flowfile_client.read_input("nonexistent") + + def test_read_inputs_returns_dict(self, ctx: dict): + inputs = flowfile_client.read_inputs() + assert isinstance(inputs, dict) + assert "main" in inputs + assert isinstance(inputs["main"], pl.LazyFrame) + + +class TestReadMultipleInputs: + def test_multiple_named_inputs(self, tmp_dir: Path): + store = ArtifactStore() + input_dir = tmp_dir / "inputs" + input_dir.mkdir(exist_ok=True) + + left_path = input_dir / "left.parquet" + right_path = input_dir / "right.parquet" + pl.DataFrame({"id": [1, 2]}).write_parquet(str(left_path)) + pl.DataFrame({"id": [3, 4]}).write_parquet(str(right_path)) + + flowfile_client._set_context( + node_id=2, + input_paths={"left": [str(left_path)], "right": [str(right_path)]}, + output_dir=str(tmp_dir / "outputs"), + artifact_store=store, + ) + + inputs = flowfile_client.read_inputs() + assert set(inputs.keys()) == {"left", "right"} + assert inputs["left"].collect()["id"].to_list() == [1, 2] + assert inputs["right"].collect()["id"].to_list() == [3, 4] + + def test_read_input_concatenates_multiple_main_paths(self, tmp_dir: Path): + """When 'main' has multiple paths, read_input returns a union of all.""" + store = ArtifactStore() + input_dir = tmp_dir / "inputs" + input_dir.mkdir(exist_ok=True) + + path_a = input_dir / "main_0.parquet" + path_b = input_dir / "main_1.parquet" + pl.DataFrame({"val": [1, 2]}).write_parquet(str(path_a)) + pl.DataFrame({"val": [3, 4]}).write_parquet(str(path_b)) + + flowfile_client._set_context( + node_id=3, + input_paths={"main": [str(path_a), str(path_b)]}, + output_dir=str(tmp_dir / "outputs"), + artifact_store=store, + ) + + df = flowfile_client.read_input().collect() + assert sorted(df["val"].to_list()) == [1, 2, 3, 4] + + def test_read_first_returns_only_first(self, tmp_dir: Path): + """read_first returns only the first file, not the union.""" + store = ArtifactStore() + input_dir = tmp_dir / "inputs" + input_dir.mkdir(exist_ok=True) + + path_a = input_dir / "main_0.parquet" + path_b = input_dir / "main_1.parquet" + pl.DataFrame({"val": [1, 2]}).write_parquet(str(path_a)) + pl.DataFrame({"val": [3, 4]}).write_parquet(str(path_b)) + + flowfile_client._set_context( + node_id=4, + input_paths={"main": [str(path_a), str(path_b)]}, + output_dir=str(tmp_dir / "outputs"), + artifact_store=store, + ) + + df = flowfile_client.read_first().collect() + assert df["val"].to_list() == [1, 2] + + def test_read_first_missing_name_raises(self, ctx: dict): + with pytest.raises(KeyError, match="not found"): + flowfile_client.read_first("nonexistent") + + def test_read_inputs_with_multiple_main_paths(self, tmp_dir: Path): + """read_inputs should concatenate paths per name.""" + store = ArtifactStore() + input_dir = tmp_dir / "inputs" + input_dir.mkdir(exist_ok=True) + + path_0 = input_dir / "main_0.parquet" + path_1 = input_dir / "main_1.parquet" + path_2 = input_dir / "main_2.parquet" + pl.DataFrame({"x": [1]}).write_parquet(str(path_0)) + pl.DataFrame({"x": [2]}).write_parquet(str(path_1)) + pl.DataFrame({"x": [3]}).write_parquet(str(path_2)) + + flowfile_client._set_context( + node_id=5, + input_paths={"main": [str(path_0), str(path_1), str(path_2)]}, + output_dir=str(tmp_dir / "outputs"), + artifact_store=store, + ) + + inputs = flowfile_client.read_inputs() + df = inputs["main"].collect() + assert sorted(df["x"].to_list()) == [1, 2, 3] + + +class TestPublishOutput: + def test_publish_dataframe(self, ctx: dict): + df = pl.DataFrame({"a": [1, 2]}) + flowfile_client.publish_output(df) + out = Path(ctx["output_dir"]) / "main.parquet" + assert out.exists() + result = pl.read_parquet(str(out)) + assert result["a"].to_list() == [1, 2] + + def test_publish_lazyframe(self, ctx: dict): + lf = pl.LazyFrame({"b": [10, 20]}) + flowfile_client.publish_output(lf) + out = Path(ctx["output_dir"]) / "main.parquet" + assert out.exists() + result = pl.read_parquet(str(out)) + assert result["b"].to_list() == [10, 20] + + def test_publish_named_output(self, ctx: dict): + df = pl.DataFrame({"c": [5]}) + flowfile_client.publish_output(df, name="custom") + out = Path(ctx["output_dir"]) / "custom.parquet" + assert out.exists() + + def test_publish_creates_output_dir(self, tmp_dir: Path): + store = ArtifactStore() + new_output = tmp_dir / "new" / "nested" + flowfile_client._set_context( + node_id=1, + input_paths={}, + output_dir=str(new_output), + artifact_store=store, + ) + df = pl.DataFrame({"v": [1]}) + flowfile_client.publish_output(df) + assert (new_output / "main.parquet").exists() + + +class TestArtifacts: + def test_publish_and_read_artifact(self, ctx: dict): + flowfile_client.publish_artifact("my_dict", {"key": "value"}) + result = flowfile_client.read_artifact("my_dict") + assert result == {"key": "value"} + + def test_list_artifacts(self, ctx: dict): + flowfile_client.publish_artifact("a", 1) + flowfile_client.publish_artifact("b", [2, 3]) + listing = flowfile_client.list_artifacts() + assert set(listing.keys()) == {"a", "b"} + + def test_read_missing_artifact_raises(self, ctx: dict): + with pytest.raises(KeyError, match="not found"): + flowfile_client.read_artifact("missing") + + def test_publish_duplicate_artifact_raises(self, ctx: dict): + flowfile_client.publish_artifact("model", {"v": 1}) + with pytest.raises(ValueError, match="already exists"): + flowfile_client.publish_artifact("model", {"v": 2}) + + def test_delete_artifact(self, ctx: dict): + flowfile_client.publish_artifact("temp", 42) + flowfile_client.delete_artifact("temp") + with pytest.raises(KeyError, match="not found"): + flowfile_client.read_artifact("temp") + + def test_delete_missing_artifact_raises(self, ctx: dict): + with pytest.raises(KeyError, match="not found"): + flowfile_client.delete_artifact("nonexistent") + + def test_delete_then_republish(self, ctx: dict): + flowfile_client.publish_artifact("model", "v1") + flowfile_client.delete_artifact("model") + flowfile_client.publish_artifact("model", "v2") + assert flowfile_client.read_artifact("model") == "v2" + + +class TestDisplay: + def test_reset_displays(self): + flowfile_client._reset_displays() + assert flowfile_client._get_displays() == [] + + def test_display_plain_text(self): + flowfile_client._reset_displays() + flowfile_client.display("hello world") + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["mime_type"] == "text/plain" + assert displays[0]["data"] == "hello world" + assert displays[0]["title"] == "" + + def test_display_with_title(self): + flowfile_client._reset_displays() + flowfile_client.display("some data", title="My Title") + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["title"] == "My Title" + + def test_display_html_string(self): + flowfile_client._reset_displays() + html = "bold text" + flowfile_client.display(html) + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["mime_type"] == "text/html" + assert displays[0]["data"] == html + + def test_display_complex_html(self): + flowfile_client._reset_displays() + html = '

Hello

' + flowfile_client.display(html) + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["mime_type"] == "text/html" + + def test_display_multiple_outputs(self): + flowfile_client._reset_displays() + flowfile_client.display("first") + flowfile_client.display("second") + flowfile_client.display("third") + displays = flowfile_client._get_displays() + assert len(displays) == 3 + assert displays[0]["data"] == "first" + assert displays[1]["data"] == "second" + assert displays[2]["data"] == "third" + + def test_display_number_as_plain_text(self): + flowfile_client._reset_displays() + flowfile_client.display(42) + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["mime_type"] == "text/plain" + assert displays[0]["data"] == "42" + + def test_display_dict_as_plain_text(self): + flowfile_client._reset_displays() + flowfile_client.display({"key": "value"}) + displays = flowfile_client._get_displays() + assert len(displays) == 1 + assert displays[0]["mime_type"] == "text/plain" + assert "key" in displays[0]["data"] + + def test_get_displays_returns_copy(self): + """Ensure _get_displays returns the actual list that can be cleared.""" + flowfile_client._reset_displays() + flowfile_client.display("test") + displays1 = flowfile_client._get_displays() + assert len(displays1) == 1 + flowfile_client._reset_displays() + displays2 = flowfile_client._get_displays() + assert len(displays2) == 0 + + +class TestDisplayTypeDetection: + def test_is_html_string_true(self): + assert flowfile_client._is_html_string("test") is True + assert flowfile_client._is_html_string("
") is True + assert flowfile_client._is_html_string("Hello world!") is True + + def test_is_html_string_false(self): + assert flowfile_client._is_html_string("plain text") is False + assert flowfile_client._is_html_string("just text with math: 5 < 10") is False # only < + assert flowfile_client._is_html_string("x < 10 and y > 5") is False # comparison, not HTML + assert flowfile_client._is_html_string("a < b > c") is False # not actual HTML tags + assert flowfile_client._is_html_string(123) is False + assert flowfile_client._is_html_string(None) is False + + def test_is_matplotlib_figure_without_import(self): + """Without matplotlib installed, should return False.""" + result = flowfile_client._is_matplotlib_figure("not a figure") + assert result is False + + def test_is_plotly_figure_without_import(self): + """Without plotly installed, should return False.""" + result = flowfile_client._is_plotly_figure("not a figure") + assert result is False + + def test_is_pil_image_without_import(self): + """Without PIL installed, should return False.""" + result = flowfile_client._is_pil_image("not an image") + assert result is False diff --git a/kernel_runtime/tests/test_main.py b/kernel_runtime/tests/test_main.py new file mode 100644 index 000000000..84617ea55 --- /dev/null +++ b/kernel_runtime/tests/test_main.py @@ -0,0 +1,1047 @@ +"""Tests for kernel_runtime.main (FastAPI endpoints).""" + +import os +from pathlib import Path + +import polars as pl +import pytest +from fastapi.testclient import TestClient + + +class TestHealthEndpoint: + def test_health_returns_200(self, client: TestClient): + resp = client.get("/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "healthy" + assert data["artifact_count"] == 0 + + +class TestExecuteEndpoint: + def test_simple_print(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 1, + "code": 'print("hello")', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["success"] is True + assert "hello" in data["stdout"] + assert data["error"] is None + + def test_syntax_error(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 2, + "code": "def broken(", + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is False + assert data["error"] is not None + assert "SyntaxError" in data["error"] + + def test_runtime_error(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 3, + "code": "1 / 0", + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is False + assert "ZeroDivisionError" in data["error"] + + def test_stderr_captured(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 4, + "code": 'import sys; sys.stderr.write("warning\\n")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert "warning" in data["stderr"] + + def test_execution_time_tracked(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 5, + "code": "x = sum(range(1000))", + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert data["execution_time_ms"] > 0 + + def test_flowfile_module_available(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 6, + "code": "print(type(flowfile).__name__)", + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert "module" in data["stdout"] + + +class TestExecuteWithParquet: + def test_read_and_write_parquet(self, client: TestClient, tmp_dir: Path): + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + df_in = pl.DataFrame({"x": [1, 2, 3], "y": [10, 20, 30]}) + input_path = input_dir / "main.parquet" + df_in.write_parquet(str(input_path)) + + code = ( + "import polars as pl\n" + "df = flowfile.read_input()\n" + "df = df.collect().with_columns((pl.col('x') * pl.col('y')).alias('product'))\n" + "flowfile.publish_output(df)\n" + ) + + resp = client.post( + "/execute", + json={ + "node_id": 10, + "code": code, + "input_paths": {"main": [str(input_path)]}, + "output_dir": str(output_dir), + }, + ) + data = resp.json() + assert data["success"] is True, f"Execution failed: {data['error']}" + assert len(data["output_paths"]) > 0 + + out_path = output_dir / "main.parquet" + assert out_path.exists() + df_out = pl.read_parquet(str(out_path)) + assert "product" in df_out.columns + assert df_out["product"].to_list() == [10, 40, 90] + + def test_multiple_inputs(self, client: TestClient, tmp_dir: Path): + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + pl.DataFrame({"id": [1, 2], "name": ["a", "b"]}).write_parquet( + str(input_dir / "left.parquet") + ) + pl.DataFrame({"id": [1, 2], "score": [90, 80]}).write_parquet( + str(input_dir / "right.parquet") + ) + + code = ( + "inputs = flowfile.read_inputs()\n" + "left = inputs['left'].collect()\n" + "right = inputs['right'].collect()\n" + "merged = left.join(right, on='id')\n" + "flowfile.publish_output(merged)\n" + ) + + resp = client.post( + "/execute", + json={ + "node_id": 11, + "code": code, + "input_paths": { + "left": [str(input_dir / "left.parquet")], + "right": [str(input_dir / "right.parquet")], + }, + "output_dir": str(output_dir), + }, + ) + data = resp.json() + assert data["success"] is True, f"Execution failed: {data['error']}" + + df_out = pl.read_parquet(str(output_dir / "main.parquet")) + assert set(df_out.columns) == {"id", "name", "score"} + assert len(df_out) == 2 + + def test_multi_main_inputs_union(self, client: TestClient, tmp_dir: Path): + """Multiple paths under 'main' are concatenated (union) by read_input.""" + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + pl.DataFrame({"v": [1, 2]}).write_parquet(str(input_dir / "main_0.parquet")) + pl.DataFrame({"v": [3, 4]}).write_parquet(str(input_dir / "main_1.parquet")) + + code = ( + "df = flowfile.read_input().collect()\n" + "flowfile.publish_output(df)\n" + ) + + resp = client.post( + "/execute", + json={ + "node_id": 13, + "code": code, + "input_paths": { + "main": [ + str(input_dir / "main_0.parquet"), + str(input_dir / "main_1.parquet"), + ], + }, + "output_dir": str(output_dir), + }, + ) + data = resp.json() + assert data["success"] is True, f"Execution failed: {data['error']}" + + df_out = pl.read_parquet(str(output_dir / "main.parquet")) + assert sorted(df_out["v"].to_list()) == [1, 2, 3, 4] + + def test_read_first_via_execute(self, client: TestClient, tmp_dir: Path): + """read_first returns only the first input file.""" + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + pl.DataFrame({"v": [10, 20]}).write_parquet(str(input_dir / "a.parquet")) + pl.DataFrame({"v": [30, 40]}).write_parquet(str(input_dir / "b.parquet")) + + code = ( + "df = flowfile.read_first().collect()\n" + "flowfile.publish_output(df)\n" + ) + + resp = client.post( + "/execute", + json={ + "node_id": 14, + "code": code, + "input_paths": { + "main": [ + str(input_dir / "a.parquet"), + str(input_dir / "b.parquet"), + ], + }, + "output_dir": str(output_dir), + }, + ) + data = resp.json() + assert data["success"] is True, f"Execution failed: {data['error']}" + + df_out = pl.read_parquet(str(output_dir / "main.parquet")) + assert df_out["v"].to_list() == [10, 20] + + def test_publish_lazyframe_output(self, client: TestClient, tmp_dir: Path): + input_dir = tmp_dir / "inputs" + output_dir = tmp_dir / "outputs" + input_dir.mkdir() + output_dir.mkdir() + + pl.DataFrame({"v": [10, 20]}).write_parquet(str(input_dir / "main.parquet")) + + code = ( + "lf = flowfile.read_input()\n" + "flowfile.publish_output(lf)\n" + ) + + resp = client.post( + "/execute", + json={ + "node_id": 12, + "code": code, + "input_paths": {"main": [str(input_dir / "main.parquet")]}, + "output_dir": str(output_dir), + }, + ) + data = resp.json() + assert data["success"] is True + df_out = pl.read_parquet(str(output_dir / "main.parquet")) + assert df_out["v"].to_list() == [10, 20] + + +class TestArtifactEndpoints: + def test_publish_artifact_via_execute(self, client: TestClient): + resp = client.post( + "/execute", + json={ + "node_id": 20, + "code": 'flowfile.publish_artifact("my_dict", {"a": 1})', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert "my_dict" in data["artifacts_published"] + + def test_list_artifacts(self, client: TestClient): + # Publish via execute + client.post( + "/execute", + json={ + "node_id": 21, + "code": ( + 'flowfile.publish_artifact("item_a", [1, 2])\n' + 'flowfile.publish_artifact("item_b", "hello")\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + + resp = client.get("/artifacts") + assert resp.status_code == 200 + data = resp.json() + assert "item_a" in data + assert "item_b" in data + # The object itself should not be in the listing + assert "object" not in data["item_a"] + + def test_clear_artifacts(self, client: TestClient): + client.post( + "/execute", + json={ + "node_id": 22, + "code": 'flowfile.publish_artifact("tmp", 42)', + "input_paths": {}, + "output_dir": "", + }, + ) + + resp = client.post("/clear") + assert resp.status_code == 200 + assert resp.json()["status"] == "cleared" + + resp = client.get("/artifacts") + assert resp.json() == {} + + def test_health_shows_artifact_count(self, client: TestClient): + client.post( + "/execute", + json={ + "node_id": 23, + "code": 'flowfile.publish_artifact("x", 1)', + "input_paths": {}, + "output_dir": "", + }, + ) + resp = client.get("/health") + assert resp.json()["artifact_count"] == 1 + + def test_duplicate_publish_fails(self, client: TestClient): + """Publishing an artifact with the same name twice should fail.""" + resp = client.post( + "/execute", + json={ + "node_id": 24, + "code": 'flowfile.publish_artifact("model", 1)', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp.json()["success"] is True + + resp2 = client.post( + "/execute", + json={ + "node_id": 25, + "code": 'flowfile.publish_artifact("model", 2)', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp2.json() + assert data["success"] is False + assert "already exists" in data["error"] + + def test_delete_artifact_via_execute(self, client: TestClient): + """delete_artifact removes from the store and appears in artifacts_deleted.""" + client.post( + "/execute", + json={ + "node_id": 26, + "code": 'flowfile.publish_artifact("temp", 99)', + "input_paths": {}, + "output_dir": "", + }, + ) + resp = client.post( + "/execute", + json={ + "node_id": 27, + "code": 'flowfile.delete_artifact("temp")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert "temp" in data["artifacts_deleted"] + + # Verify artifact is gone + resp_list = client.get("/artifacts") + assert "temp" not in resp_list.json() + + def test_same_node_reexecution_clears_own_artifacts(self, client: TestClient): + """Re-executing the same node auto-clears its previous artifacts.""" + resp1 = client.post( + "/execute", + json={ + "node_id": 24, + "code": 'flowfile.publish_artifact("model", "v1")', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp1.json()["success"] is True + assert "model" in resp1.json()["artifacts_published"] + + # Same node re-executes — should NOT fail with "already exists" + resp2 = client.post( + "/execute", + json={ + "node_id": 24, + "code": 'flowfile.publish_artifact("model", "v2")', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp2.json()["success"] is True + assert "model" in resp2.json()["artifacts_published"] + + # Verify we get v2 + resp3 = client.post( + "/execute", + json={ + "node_id": 99, + "code": 'v = flowfile.read_artifact("model"); print(v)', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp3.json()["success"] is True + assert "v2" in resp3.json()["stdout"] + + def test_delete_then_republish_via_execute(self, client: TestClient): + """After deleting, a new artifact with the same name can be published.""" + client.post( + "/execute", + json={ + "node_id": 28, + "code": 'flowfile.publish_artifact("model", "v1")', + "input_paths": {}, + "output_dir": "", + }, + ) + resp = client.post( + "/execute", + json={ + "node_id": 29, + "code": ( + 'flowfile.delete_artifact("model")\n' + 'flowfile.publish_artifact("model", "v2")\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + # The artifact was deleted and re-published in the same call. + # Since the final state has "model" which didn't exist before the + # first publish in this request, it depends on whether it was in + # artifacts_before. Since it existed before this execute call, + # and still exists after, it's neither new nor deleted from the + # perspective of this single call. But the name was re-published + # so it shouldn't appear in artifacts_deleted. + # Let's just verify the artifact exists and has the new value. + resp_read = client.post( + "/execute", + json={ + "node_id": 30, + "code": ( + 'v = flowfile.read_artifact("model")\n' + 'print(v)\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp_read.json()["success"] is True + assert "v2" in resp_read.json()["stdout"] + + +class TestClearNodeArtifactsEndpoint: + def test_clear_node_artifacts_selective(self, client: TestClient): + """Only artifacts from specified node IDs should be removed.""" + # Publish artifacts from two different nodes + client.post( + "/execute", + json={ + "node_id": 40, + "code": 'flowfile.publish_artifact("model", {"v": 1})', + "input_paths": {}, + "output_dir": "", + }, + ) + client.post( + "/execute", + json={ + "node_id": 41, + "code": 'flowfile.publish_artifact("scaler", {"v": 2})', + "input_paths": {}, + "output_dir": "", + }, + ) + + # Clear only node 40's artifacts + resp = client.post("/clear_node_artifacts", json={"node_ids": [40]}) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "cleared" + assert "model" in data["removed"] + + # "scaler" from node 41 should still exist + artifacts = client.get("/artifacts").json() + assert "model" not in artifacts + assert "scaler" in artifacts + + def test_clear_node_artifacts_empty_list(self, client: TestClient): + """Passing empty list should not remove anything.""" + client.post( + "/execute", + json={ + "node_id": 42, + "code": 'flowfile.publish_artifact("keep_me", 42)', + "input_paths": {}, + "output_dir": "", + }, + ) + resp = client.post("/clear_node_artifacts", json={"node_ids": []}) + assert resp.status_code == 200 + assert resp.json()["removed"] == [] + assert "keep_me" in client.get("/artifacts").json() + + def test_clear_node_artifacts_allows_republish(self, client: TestClient): + """After clearing, the same artifact name can be re-published.""" + client.post( + "/execute", + json={ + "node_id": 43, + "code": 'flowfile.publish_artifact("reuse", "v1")', + "input_paths": {}, + "output_dir": "", + }, + ) + client.post("/clear_node_artifacts", json={"node_ids": [43]}) + resp = client.post( + "/execute", + json={ + "node_id": 43, + "code": 'flowfile.publish_artifact("reuse", "v2")', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp.json()["success"] is True + + +class TestNodeArtifactsEndpoint: + def test_list_node_artifacts(self, client: TestClient): + """Should return only artifacts for the specified node.""" + client.post( + "/execute", + json={ + "node_id": 50, + "code": ( + 'flowfile.publish_artifact("a", 1)\n' + 'flowfile.publish_artifact("b", 2)\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + client.post( + "/execute", + json={ + "node_id": 51, + "code": 'flowfile.publish_artifact("c", 3)', + "input_paths": {}, + "output_dir": "", + }, + ) + + resp = client.get("/artifacts/node/50") + assert resp.status_code == 200 + data = resp.json() + assert set(data.keys()) == {"a", "b"} + + resp2 = client.get("/artifacts/node/51") + assert set(resp2.json().keys()) == {"c"} + + def test_list_node_artifacts_empty(self, client: TestClient): + resp = client.get("/artifacts/node/999") + assert resp.status_code == 200 + assert resp.json() == {} + + +class TestDisplayOutputs: + def test_display_outputs_empty_by_default(self, client: TestClient): + """Execute code without displays should return empty display_outputs.""" + resp = client.post( + "/execute", + json={ + "node_id": 60, + "code": 'print("hello")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert data["display_outputs"] == [] + + def test_display_output_explicit(self, client: TestClient): + """Execute flowfile.display() should return a display output.""" + resp = client.post( + "/execute", + json={ + "node_id": 61, + "code": 'flowfile.display("hello")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["mime_type"] == "text/plain" + assert data["display_outputs"][0]["data"] == "hello" + + def test_display_output_html(self, client: TestClient): + """Execute flowfile.display() with HTML should return HTML mime type.""" + resp = client.post( + "/execute", + json={ + "node_id": 62, + "code": 'flowfile.display("bold")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["mime_type"] == "text/html" + assert data["display_outputs"][0]["data"] == "bold" + + def test_display_output_with_title(self, client: TestClient): + """Display with title should preserve the title.""" + resp = client.post( + "/execute", + json={ + "node_id": 63, + "code": 'flowfile.display("data", title="My Chart")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["title"] == "My Chart" + + def test_multiple_display_outputs(self, client: TestClient): + """Multiple display calls should return multiple outputs.""" + resp = client.post( + "/execute", + json={ + "node_id": 64, + "code": ( + 'flowfile.display("first")\n' + 'flowfile.display("second")\n' + 'flowfile.display("third")\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 3 + assert data["display_outputs"][0]["data"] == "first" + assert data["display_outputs"][1]["data"] == "second" + assert data["display_outputs"][2]["data"] == "third" + + def test_display_outputs_cleared_between_executions(self, client: TestClient): + """Display outputs should not persist between execution calls.""" + # First execution + client.post( + "/execute", + json={ + "node_id": 65, + "code": 'flowfile.display("from first call")', + "input_paths": {}, + "output_dir": "", + }, + ) + + # Second execution should not include first call's displays + resp = client.post( + "/execute", + json={ + "node_id": 66, + "code": 'flowfile.display("from second call")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["data"] == "from second call" + + def test_display_output_on_error_still_collected(self, client: TestClient): + """Display outputs generated before an error should still be returned.""" + resp = client.post( + "/execute", + json={ + "node_id": 67, + "code": ( + 'flowfile.display("before error")\n' + 'raise ValueError("oops")\n' + ), + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is False + assert "ValueError" in data["error"] + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["data"] == "before error" + + def test_interactive_mode_auto_display_last_expression(self, client: TestClient): + """Interactive mode should auto-display the last expression.""" + resp = client.post( + "/execute", + json={ + "node_id": 68, + "code": "1 + 2 + 3", + "input_paths": {}, + "output_dir": "", + "interactive": True, + }, + ) + data = resp.json() + assert data["success"] is True + assert len(data["display_outputs"]) == 1 + assert data["display_outputs"][0]["data"] == "6" + + def test_non_interactive_mode_no_auto_display(self, client: TestClient): + """Non-interactive mode should not auto-display the last expression.""" + resp = client.post( + "/execute", + json={ + "node_id": 69, + "code": "1 + 2 + 3", + "input_paths": {}, + "output_dir": "", + "interactive": False, + }, + ) + data = resp.json() + assert data["success"] is True + assert data["display_outputs"] == [] + + def test_interactive_mode_with_print_no_double_display(self, client: TestClient): + """Print statements should not trigger auto-display.""" + resp = client.post( + "/execute", + json={ + "node_id": 70, + "code": 'print("hello")', + "input_paths": {}, + "output_dir": "", + "interactive": True, + }, + ) + data = resp.json() + assert data["success"] is True + # print doesn't return a value worth displaying + assert data["display_outputs"] == [] + + +class TestContextCleanup: + def test_context_cleared_after_success(self, client: TestClient): + """After a successful /execute, the flowfile context should be cleared.""" + client.post( + "/execute", + json={ + "node_id": 30, + "code": "x = 1", + "input_paths": {}, + "output_dir": "", + }, + ) + # A second call that tries to use context should still work + # (context is re-set for each request) + resp = client.post( + "/execute", + json={ + "node_id": 31, + "code": 'print("ok")', + "input_paths": {}, + "output_dir": "", + }, + ) + assert resp.json()["success"] is True + + def test_context_cleared_after_error(self, client: TestClient): + """After a failed /execute, the flowfile context should still be cleared.""" + client.post( + "/execute", + json={ + "node_id": 32, + "code": "raise ValueError('boom')", + "input_paths": {}, + "output_dir": "", + }, + ) + resp = client.post( + "/execute", + json={ + "node_id": 33, + "code": 'print("still works")', + "input_paths": {}, + "output_dir": "", + }, + ) + data = resp.json() + assert data["success"] is True + assert "still works" in data["stdout"] + + +class TestFlowIsolation: + """Artifacts published by different flows don't interfere with each other.""" + + def test_same_artifact_name_different_flows(self, client: TestClient): + """Two flows can each publish an artifact called 'model' independently.""" + resp1 = client.post( + "/execute", + json={ + "node_id": 1, + "code": 'flowfile.publish_artifact("model", "flow1_model")', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + assert resp1.json()["success"] is True + + resp2 = client.post( + "/execute", + json={ + "node_id": 1, + "code": 'flowfile.publish_artifact("model", "flow2_model")', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + assert resp2.json()["success"] is True + + # Each flow reads its own artifact + resp_read1 = client.post( + "/execute", + json={ + "node_id": 99, + "code": 'v = flowfile.read_artifact("model"); print(v)', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + assert resp_read1.json()["success"] is True + assert "flow1_model" in resp_read1.json()["stdout"] + + resp_read2 = client.post( + "/execute", + json={ + "node_id": 99, + "code": 'v = flowfile.read_artifact("model"); print(v)', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + assert resp_read2.json()["success"] is True + assert "flow2_model" in resp_read2.json()["stdout"] + + def test_flow_cannot_read_other_flows_artifact(self, client: TestClient): + """Flow 1 publishes 'secret'; flow 2 should not see it.""" + client.post( + "/execute", + json={ + "node_id": 1, + "code": 'flowfile.publish_artifact("secret", "hidden")', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + + resp = client.post( + "/execute", + json={ + "node_id": 2, + "code": 'flowfile.read_artifact("secret")', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + data = resp.json() + assert data["success"] is False + assert "not found" in data["error"] + + def test_reexecution_only_clears_own_flow(self, client: TestClient): + """Re-executing a node in flow 1 doesn't clear flow 2's artifacts.""" + # Flow 1, node 5 publishes "model" + client.post( + "/execute", + json={ + "node_id": 5, + "code": 'flowfile.publish_artifact("model", "f1v1")', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + # Flow 2, node 5 publishes "model" + client.post( + "/execute", + json={ + "node_id": 5, + "code": 'flowfile.publish_artifact("model", "f2v1")', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + + # Re-execute node 5 in flow 1 — auto-clear only affects flow 1 + resp = client.post( + "/execute", + json={ + "node_id": 5, + "code": 'flowfile.publish_artifact("model", "f1v2")', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + assert resp.json()["success"] is True + + # Flow 2's artifact should be untouched + resp_f2 = client.post( + "/execute", + json={ + "node_id": 99, + "code": 'v = flowfile.read_artifact("model"); print(v)', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + assert resp_f2.json()["success"] is True + assert "f2v1" in resp_f2.json()["stdout"] + + def test_list_artifacts_filtered_by_flow(self, client: TestClient): + """GET /artifacts?flow_id=X returns only that flow's artifacts.""" + client.post( + "/execute", + json={ + "node_id": 1, + "code": 'flowfile.publish_artifact("a", 1)', + "input_paths": {}, + "output_dir": "", + "flow_id": 10, + }, + ) + client.post( + "/execute", + json={ + "node_id": 2, + "code": 'flowfile.publish_artifact("b", 2)', + "input_paths": {}, + "output_dir": "", + "flow_id": 20, + }, + ) + + resp10 = client.get("/artifacts", params={"flow_id": 10}) + assert set(resp10.json().keys()) == {"a"} + + resp20 = client.get("/artifacts", params={"flow_id": 20}) + assert set(resp20.json().keys()) == {"b"} + + # No filter returns both + resp_all = client.get("/artifacts") + assert set(resp_all.json().keys()) == {"a", "b"} + + def test_clear_node_artifacts_scoped_to_flow(self, client: TestClient): + """POST /clear_node_artifacts with flow_id only clears that flow.""" + client.post( + "/execute", + json={ + "node_id": 5, + "code": 'flowfile.publish_artifact("model", "f1")', + "input_paths": {}, + "output_dir": "", + "flow_id": 1, + }, + ) + client.post( + "/execute", + json={ + "node_id": 5, + "code": 'flowfile.publish_artifact("model", "f2")', + "input_paths": {}, + "output_dir": "", + "flow_id": 2, + }, + ) + + resp = client.post( + "/clear_node_artifacts", + json={"node_ids": [5], "flow_id": 1}, + ) + assert resp.json()["status"] == "cleared" + assert "model" in resp.json()["removed"] + + # Flow 2's artifact survives + artifacts_f2 = client.get("/artifacts", params={"flow_id": 2}).json() + assert "model" in artifacts_f2 diff --git a/kernel_runtime/tests/test_persistence_api.py b/kernel_runtime/tests/test_persistence_api.py new file mode 100644 index 000000000..7b57863a4 --- /dev/null +++ b/kernel_runtime/tests/test_persistence_api.py @@ -0,0 +1,53 @@ +"""Tests for the persistence-related API endpoints in the kernel runtime.""" + +import os +from unittest.mock import patch + +import pytest +from fastapi.testclient import TestClient + + +class TestHealthEndpoint: + def test_health_includes_persistence_info(self, client: TestClient): + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert "persistence" in data + assert "recovery_mode" in data + + +class TestRecoveryStatusEndpoint: + def test_recovery_status(self, client: TestClient): + response = client.get("/recovery-status") + assert response.status_code == 200 + data = response.json() + assert "status" in data + + +class TestPersistenceEndpoint: + def test_persistence_info(self, client: TestClient): + response = client.get("/persistence") + assert response.status_code == 200 + data = response.json() + assert "enabled" in data + assert "recovery_mode" in data + + +class TestRecoverEndpoint: + def test_recover_returns_status(self, client: TestClient): + response = client.post("/recover") + assert response.status_code == 200 + data = response.json() + assert "status" in data + + +class TestCleanupEndpoint: + def test_cleanup_with_max_age(self, client: TestClient): + response = client.post("/cleanup", json={"max_age_hours": 24}) + assert response.status_code == 200 + data = response.json() + assert "status" in data + + def test_cleanup_with_empty_request(self, client: TestClient): + response = client.post("/cleanup", json={}) + assert response.status_code == 200 diff --git a/poetry.lock b/poetry.lock index 311677d12..c328c90e4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiobotocore" @@ -2849,10 +2849,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c47676e5b485393f069b4d7a811267d3168ce46f988fa602658b8bb901e9e64d"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:a28d8c01a7b27a1e3265b11250ba7557e5f72b5ee9e5f3a2fa8d2949c29bf5d2"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5f3f2732cf504a1aa9e9609d02f79bea1067d99edf844ab92c247bbca143303b"}, - {file = "psycopg2_binary-2.9.11-cp310-cp310-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:865f9945ed1b3950d968ec4690ce68c55019d79e4497366d36e090327ce7db14"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:91537a8df2bde69b1c1db01d6d944c831ca793952e4f57892600e96cee95f2cd"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:4dca1f356a67ecb68c81a7bc7809f1569ad9e152ce7fd02c2f2036862ca9f66b"}, - {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:0da4de5c1ac69d94ed4364b6cbe7190c1a70d325f112ba783d83f8440285f152"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:37d8412565a7267f7d79e29ab66876e55cb5e8e7b3bbf94f8206f6795f8f7e7e"}, {file = "psycopg2_binary-2.9.11-cp310-cp310-win_amd64.whl", hash = "sha256:c665f01ec8ab273a61c62beeb8cce3014c214429ced8a308ca1fc410ecac3a39"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0e8480afd62362d0a6a27dd09e4ca2def6fa50ed3a4e7c09165266106b2ffa10"}, @@ -2860,10 +2858,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee"}, - {file = "psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94"}, - {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908"}, {file = "psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4"}, @@ -2871,10 +2867,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db"}, - {file = "psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a"}, - {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d"}, {file = "psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c"}, @@ -2882,10 +2876,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3"}, - {file = "psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c"}, - {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1"}, {file = "psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:92e3b669236327083a2e33ccfa0d320dd01b9803b3e14dd986a4fc54aa00f4e1"}, @@ -2893,10 +2885,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b52a3f9bb540a3e4ec0f6ba6d31339727b2950c9772850d6545b7eae0b9d7c5"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:db4fd476874ccfdbb630a54426964959e58da4c61c9feba73e6094d51303d7d8"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47f212c1d3be608a12937cc131bd85502954398aaa1320cb4c14421a0ffccf4c"}, - {file = "psycopg2_binary-2.9.11-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e35b7abae2b0adab776add56111df1735ccc71406e56203515e228a8dc07089f"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fcf21be3ce5f5659daefd2b3b3b6e4727b028221ddc94e6c1523425579664747"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:9bd81e64e8de111237737b29d68039b9c813bdf520156af36d26819c9a979e5f"}, - {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:32770a4d666fbdafab017086655bcddab791d7cb260a16679cc5a7338b64343b"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3cb3a676873d7506825221045bd70e0427c905b9c8ee8d6acd70cfcbd6e576d"}, {file = "psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:20e7fb94e20b03dcc783f76c0865f9da39559dcc0c28dd1a3fce0d01902a6b9c"}, @@ -2904,10 +2894,8 @@ files = [ {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9d3a9edcfbe77a3ed4bc72836d466dfce4174beb79eda79ea155cc77237ed9e8"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:44fc5c2b8fa871ce7f0023f619f1349a0aa03a0857f2c96fbc01c657dcbbdb49"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9c55460033867b4622cda1b6872edf445809535144152e5d14941ef591980edf"}, - {file = "psycopg2_binary-2.9.11-cp39-cp39-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2d11098a83cca92deaeaed3d58cfd150d49b3b06ee0d0852be466bf87596899e"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:691c807d94aecfbc76a14e1408847d59ff5b5906a04a23e12a89007672b9e819"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:8b81627b691f29c4c30a8f322546ad039c40c328373b11dff7490a3e1b517855"}, - {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_riscv64.whl", hash = "sha256:b637d6d941209e8d96a072d7977238eea128046effbf37d1d8b2c0764750017d"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:41360b01c140c2a03d346cec3280cf8a71aa07d94f3b1509fa0161c366af66b4"}, {file = "psycopg2_binary-2.9.11-cp39-cp39-win_amd64.whl", hash = "sha256:875039274f8a2361e5207857899706da840768e2a775bf8c65e82f60b197df02"}, ] @@ -3950,13 +3938,13 @@ files = [ [[package]] name = "tqdm" -version = "4.67.1" +version = "4.67.2" description = "Fast, Extensible Progress Meter" optional = false python-versions = ">=3.7" files = [ - {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, - {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, + {file = "tqdm-4.67.2-py3-none-any.whl", hash = "sha256:9a12abcbbff58b6036b2167d9d3853042b9d436fe7330f06ae047867f2f8e0a7"}, + {file = "tqdm-4.67.2.tar.gz", hash = "sha256:649aac53964b2cb8dec76a14b405a4c0d13612cb8933aae547dd144eacc99653"}, ] [package.dependencies] @@ -4405,4 +4393,4 @@ propcache = ">=0.2.1" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "0368b7bb3231134e2c9d78e4d79c30da6abd199094c6e69c1a02102188509de8" +content-hash = "b9627d3d6426127ba47aea057bd8e6878ef7cd1f96d4bae0171ebe69f60b94ff" diff --git a/pyproject.toml b/pyproject.toml index a046d5da0..5d4f42579 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ pyiceberg = {extras = ["hadoop"], version = "^0.9.1"} boto3 = ">=1.38.40,<1.38.47" cryptography = "^45.0.5" httpx = "^0.28.1" +docker = ">=7.0.0" tqdm = "^4.67.1" s3fs = "^2025.7.0" pl-fuzzy-frame-match = ">=0.4.0" @@ -102,7 +103,8 @@ build-backend = "poetry.core.masonry.api" [tool.pytest.ini_options] markers = [ "worker: Tests for the flowfile_worker package", - "core: Tests for the flowfile_core package" + "core: Tests for the flowfile_core package", + "kernel: Integration tests requiring Docker kernel containers" ] [tool.coverage.run]