From 0d1d50bb804c31c2d906fa7bacc019a0b1bc85ad Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 5 Apr 2026 14:40:28 +0000 Subject: [PATCH 1/2] Add data contract system and validation node for catalog tables Introduces a typed data contract model (1:1 per catalog table) with column-level rules (not_null, unique, dtype, value_range, allowed_values, regex, custom_expression) and general rules (Flowfile expressions). Each rule provides a to_expression() method returning a Flowfile expression string, keeping the validation engine thin and rule logic self-contained. Includes: - DataContract DB model, repository, service CRUD, and API endpoints - Validation engine shared between catalog contracts and the new data_validation flow node (pass-through with _is_valid/_violations columns) - Auto-generate contract from table profile - mark-validated endpoint for tier-2 validation flows - ContractSummary on CatalogTableOut for catalog UI status badges https://claude.ai/code/session_01RJ6iAh4f3MN7UdzttruKho --- .../flowfile_core/catalog/__init__.py | 4 + .../flowfile_core/catalog/exceptions.py | 16 + .../flowfile_core/catalog/repository.py | 54 +++- .../flowfile_core/catalog/service.py | 174 ++++++++++ .../flowfile_core/configs/node_store/nodes.py | 12 + .../flowfile_core/database/models.py | 35 ++- .../flowfile/flow_data_engine/validation.py | 244 ++++++++++++++ .../flowfile_core/flowfile/flow_graph.py | 48 ++- flowfile_core/flowfile_core/routes/catalog.py | 186 +++++++++++ .../flowfile_core/schemas/catalog_schema.py | 1 + .../flowfile_core/schemas/contract_schema.py | 236 ++++++++++++++ .../flowfile_core/schemas/input_schema.py | 19 +- .../flowfile_core/schemas/schemas.py | 1 + flowfile_core/tests/test_data_contracts.py | 297 ++++++++++++++++++ 14 files changed, 1305 insertions(+), 22 deletions(-) create mode 100644 flowfile_core/flowfile_core/flowfile/flow_data_engine/validation.py create mode 100644 flowfile_core/flowfile_core/schemas/contract_schema.py create mode 100644 flowfile_core/tests/test_data_contracts.py diff --git a/flowfile_core/flowfile_core/catalog/__init__.py b/flowfile_core/flowfile_core/catalog/__init__.py index a0486cf89..3ea5d81c1 100644 --- a/flowfile_core/flowfile_core/catalog/__init__.py +++ b/flowfile_core/flowfile_core/catalog/__init__.py @@ -10,6 +10,8 @@ from .exceptions import ( CatalogError, + ContractExistsError, + ContractNotFoundError, FavoriteNotFoundError, FlowAlreadyRunningError, FlowExistsError, @@ -55,4 +57,6 @@ "TableNotFoundError", "TableExistsError", "TableFavoriteNotFoundError", + "ContractNotFoundError", + "ContractExistsError", ] diff --git a/flowfile_core/flowfile_core/catalog/exceptions.py b/flowfile_core/flowfile_core/catalog/exceptions.py index 460693299..3dfb6b3c1 100644 --- a/flowfile_core/flowfile_core/catalog/exceptions.py +++ b/flowfile_core/flowfile_core/catalog/exceptions.py @@ -196,3 +196,19 @@ def __init__(self, registration_id: int, schedule_type: str): self.registration_id = registration_id self.schedule_type = schedule_type super().__init__(f"A '{schedule_type}' schedule already exists for flow {registration_id}") + + +class ContractNotFoundError(CatalogError): + """Raised when a data contract lookup fails.""" + + def __init__(self, table_id: int): + self.table_id = table_id + super().__init__(f"No data contract found for table_id={table_id}") + + +class ContractExistsError(CatalogError): + """Raised when attempting to create a duplicate data contract.""" + + def __init__(self, table_id: int): + self.table_id = table_id + super().__init__(f"A data contract already exists for table_id={table_id}") diff --git a/flowfile_core/flowfile_core/catalog/repository.py b/flowfile_core/flowfile_core/catalog/repository.py index 87d353afa..6fcba1154 100644 --- a/flowfile_core/flowfile_core/catalog/repository.py +++ b/flowfile_core/flowfile_core/catalog/repository.py @@ -15,6 +15,7 @@ CatalogNamespace, CatalogTable, CatalogTableReadLink, + DataContract, FlowFavorite, FlowFollow, FlowRegistration, @@ -245,6 +246,18 @@ def set_trigger_table_ids(self, schedule_id: int, table_ids: list[int]) -> None: def delete_trigger_table_ids(self, schedule_id: int) -> None: ... + # -- Data Contracts ------------------------------------------------------ + + def get_contract_by_table(self, table_id: int) -> DataContract | None: ... + + def create_contract(self, contract: DataContract) -> DataContract: ... + + def update_contract(self, contract: DataContract) -> DataContract: ... + + def delete_contract(self, table_id: int) -> None: ... + + def list_contracts_for_tables(self, table_ids: list[int]) -> dict[int, DataContract]: ... + # --------------------------------------------------------------------------- # SQLAlchemy implementation @@ -603,9 +616,7 @@ def remove_table_favorite(self, user_id: int, table_id: int) -> None: self._db.commit() def list_table_favorites(self, user_id: int) -> list[TableFavorite]: - return ( - self._db.query(TableFavorite).filter_by(user_id=user_id).order_by(TableFavorite.created_at.desc()).all() - ) + return self._db.query(TableFavorite).filter_by(user_id=user_id).order_by(TableFavorite.created_at.desc()).all() def count_table_favorites(self, user_id: int) -> int: return self._db.query(TableFavorite).filter_by(user_id=user_id).count() @@ -826,12 +837,7 @@ def count_schedules(self) -> int: def list_active_runs(self) -> list[FlowRun]: """Return runs that have not yet ended (ended_at IS NULL).""" - return ( - self._db.query(FlowRun) - .filter(FlowRun.ended_at.is_(None)) - .order_by(FlowRun.started_at.desc()) - .all() - ) + return self._db.query(FlowRun).filter(FlowRun.ended_at.is_(None)).order_by(FlowRun.started_at.desc()).all() def has_active_run(self, registration_id: int) -> bool: """Check if a flow already has an active (unfinished) run.""" @@ -873,9 +879,7 @@ def list_table_trigger_schedules_for_table(self, table_id: int) -> list[FlowSche def get_trigger_table_ids(self, schedule_id: int) -> list[int]: """Return table IDs linked to a table_set_trigger schedule.""" rows = ( - self._db.query(ScheduleTriggerTable.table_id) - .filter(ScheduleTriggerTable.schedule_id == schedule_id) - .all() + self._db.query(ScheduleTriggerTable.table_id).filter(ScheduleTriggerTable.schedule_id == schedule_id).all() ) return [r[0] for r in rows] @@ -890,3 +894,29 @@ def delete_trigger_table_ids(self, schedule_id: int) -> None: """Remove all trigger table links for a schedule.""" self._db.query(ScheduleTriggerTable).filter_by(schedule_id=schedule_id).delete() self._db.commit() + + # -- Data Contracts ------------------------------------------------------ + + def get_contract_by_table(self, table_id: int) -> DataContract | None: + return self._db.query(DataContract).filter_by(table_id=table_id).first() + + def create_contract(self, contract: DataContract) -> DataContract: + self._db.add(contract) + self._db.commit() + self._db.refresh(contract) + return contract + + def update_contract(self, contract: DataContract) -> DataContract: + self._db.commit() + self._db.refresh(contract) + return contract + + def delete_contract(self, table_id: int) -> None: + self._db.query(DataContract).filter_by(table_id=table_id).delete() + self._db.commit() + + def list_contracts_for_tables(self, table_ids: list[int]) -> dict[int, DataContract]: + if not table_ids: + return {} + rows = self._db.query(DataContract).filter(DataContract.table_id.in_(table_ids)).all() + return {c.table_id: c for c in rows} diff --git a/flowfile_core/flowfile_core/catalog/service.py b/flowfile_core/flowfile_core/catalog/service.py index a6aec3c83..f70ce0c31 100644 --- a/flowfile_core/flowfile_core/catalog/service.py +++ b/flowfile_core/flowfile_core/catalog/service.py @@ -27,6 +27,8 @@ table_exists, ) from flowfile_core.catalog.exceptions import ( + ContractExistsError, + ContractNotFoundError, FavoriteNotFoundError, FlowAlreadyRunningError, FlowHasArtifactsError, @@ -47,6 +49,7 @@ from flowfile_core.database.models import ( CatalogNamespace, CatalogTable, + DataContract, FlowFavorite, FlowFollow, FlowRegistration, @@ -1978,3 +1981,174 @@ def get_catalog_stats(self, user_id: int) -> CatalogStats: favorite_tables=fav_tables, active_runs=active_runs, ) + + # ======================================================================== + # Data Contracts + # ======================================================================== + + def get_contract(self, table_id: int) -> DataContract | None: + """Return the contract for *table_id*, or ``None``.""" + return self._repo.get_contract_by_table(table_id) + + def create_contract( + self, + table_id: int, + name: str, + owner_id: int, + definition_json: str, + description: str | None = None, + status: str = "draft", + ) -> DataContract: + """Create a new contract for a catalog table.""" + table = self._repo.get_table(table_id) + if table is None: + raise TableNotFoundError(table_id) + existing = self._repo.get_contract_by_table(table_id) + if existing is not None: + raise ContractExistsError(table_id) + contract = DataContract( + table_id=table_id, + name=name, + description=description, + definition_json=definition_json, + owner_id=owner_id, + status=status, + version=1, + ) + return self._repo.create_contract(contract) + + def update_contract( + self, + table_id: int, + name: str | None = None, + description: str | None = None, + definition_json: str | None = None, + status: str | None = None, + ) -> DataContract: + """Update an existing contract. Bumps the version number.""" + contract = self._repo.get_contract_by_table(table_id) + if contract is None: + raise ContractNotFoundError(table_id) + if name is not None: + contract.name = name + if description is not None: + contract.description = description + if definition_json is not None: + contract.definition_json = definition_json + if status is not None: + contract.status = status + contract.version = (contract.version or 1) + 1 + return self._repo.update_contract(contract) + + def delete_contract(self, table_id: int) -> None: + """Delete the contract attached to *table_id*.""" + contract = self._repo.get_contract_by_table(table_id) + if contract is None: + raise ContractNotFoundError(table_id) + self._repo.delete_contract(table_id) + + def mark_contract_validated( + self, + table_id: int, + passed: bool, + version: int | None = None, + ) -> DataContract: + """Set the validation state on the contract. + + If *version* is not supplied, the current Delta version is read + from the table storage. + """ + contract = self._repo.get_contract_by_table(table_id) + if contract is None: + raise ContractNotFoundError(table_id) + + if version is None: + table = self._repo.get_table(table_id) + if table is not None and is_delta_table(table.file_path): + dt = DeltaTable(table.file_path) + version = dt.version() + + contract.last_validated_version = version + contract.last_validated_at = datetime.now(timezone.utc) + contract.last_validation_passed = passed + return self._repo.update_contract(contract) + + def generate_contract_from_table(self, table_id: int) -> str: + """Profile a catalog table and return a proposed ``DataContractDefinition`` as JSON. + + Uses column metadata (dtype, null count, unique count) to infer rules. + """ + from flowfile_core.schemas.contract_schema import ( + ColumnContract, + DataContractDefinition, + DtypeRule, + NotNullRule, + ) + + table = self._repo.get_table(table_id) + if table is None: + raise TableNotFoundError(table_id) + + columns: list[ColumnContract] = [] + if table.schema_json: + schema = json.loads(table.schema_json) + for col_info in schema: + rules = [] + col_name = col_info.get("name", "") + col_dtype = col_info.get("dtype") + if col_dtype: + rules.append(DtypeRule(expected_dtype=col_dtype)) + rules.append(NotNullRule()) + columns.append(ColumnContract(name=col_name, rules=rules)) + + definition = DataContractDefinition(columns=columns, allow_extra_columns=False) + return definition.model_dump_json() + + def get_contract_summary(self, table_id: int) -> dict | None: + """Build a lightweight contract summary for a catalog table. + + Returns ``None`` if no contract exists. + """ + from flowfile_core.schemas.contract_schema import ContractSummary + + contract = self._repo.get_contract_by_table(table_id) + if contract is None: + return None + + # Count rules + from flowfile_core.schemas.contract_schema import DataContractDefinition + + try: + definition = DataContractDefinition.model_validate_json(contract.definition_json) + except Exception: + definition = DataContractDefinition() + rule_count = sum(len(c.rules) for c in definition.columns) + len(definition.general_rules) + + # Determine status + if contract.status == "draft": + status = "draft" + elif contract.last_validation_passed is False: + status = "failed" + elif contract.last_validated_version is not None: + # Try to compare with current Delta version + table = self._repo.get_table(table_id) + current_version = None + if table is not None and is_delta_table(table.file_path): + try: + dt = DeltaTable(table.file_path) + current_version = dt.version() + except Exception: + pass + if current_version is not None and contract.last_validated_version >= current_version: + status = "validated" + else: + status = "stale" + else: + status = "stale" + + return ContractSummary( + status=status, + last_validated_version=contract.last_validated_version, + current_version=None, # populated by caller if needed + rule_count=rule_count, + ).model_dump() diff --git a/flowfile_core/flowfile_core/configs/node_store/nodes.py b/flowfile_core/flowfile_core/configs/node_store/nodes.py index c339a391a..ae15af5dc 100644 --- a/flowfile_core/flowfile_core/configs/node_store/nodes.py +++ b/flowfile_core/flowfile_core/configs/node_store/nodes.py @@ -248,6 +248,18 @@ def get_all_standard_nodes() -> tuple[list[NodeTemplate], dict[str, NodeTemplate drawer_title="Count Records", drawer_intro="Calculate the total number of rows", ), + NodeTemplate( + name="Data validation", + item="data_validation", + input=1, + output=1, + transform_type="narrow", + node_type="process", + image="data_validation.png", + node_group="quality", + drawer_title="Validate Data", + drawer_intro="Check data against validation rules and add pass/fail columns", + ), NodeTemplate( name="Cross join", item="cross_join", diff --git a/flowfile_core/flowfile_core/database/models.py b/flowfile_core/flowfile_core/database/models.py index cf2b5058b..cf5546533 100644 --- a/flowfile_core/flowfile_core/database/models.py +++ b/flowfile_core/flowfile_core/database/models.py @@ -345,6 +345,39 @@ class CatalogTable(Base): __table_args__ = (UniqueConstraint("name", "namespace_id", name="uq_catalog_table_name_ns"),) +class DataContract(Base): + """A data contract attached to a catalog table (1:1). + + Stores typed validation rules as a JSON blob (serialized + ``DataContractDefinition``). The validation state tracks whether the + table's current Delta version has been validated successfully. + """ + + __tablename__ = "data_contracts" + + id = Column(Integer, primary_key=True, index=True) + table_id = Column(Integer, ForeignKey("catalog_tables.id"), nullable=False, unique=True) + name = Column(String, nullable=False) + description = Column(Text, nullable=True) + + # Serialized DataContractDefinition (JSON) + definition_json = Column(Text, nullable=False) + + # Validation state + last_validated_version = Column(Integer, nullable=True) + last_validated_at = Column(DateTime, nullable=True) + last_validation_passed = Column(Boolean, nullable=True) + + # Metadata + owner_id = Column(Integer, ForeignKey("users.id"), nullable=False) + status = Column(String, nullable=False, default="draft") # "draft" | "active" | "archived" + version = Column(Integer, nullable=False, default=1) + created_at = Column(DateTime, default=func.now(), nullable=False) + updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), nullable=False) + + __table_args__ = (UniqueConstraint("table_id", name="uq_data_contract_table"),) + + class CatalogTableReadLink(Base): """Tracks which registered flows read from which catalog tables.""" @@ -396,5 +429,3 @@ class KafkaConnection(Base): sasl_password = relationship("Secret", foreign_keys=[sasl_password_id], lazy="joined") ssl_key = relationship("Secret", foreign_keys=[ssl_key_id], lazy="joined") - - diff --git a/flowfile_core/flowfile_core/flowfile/flow_data_engine/validation.py b/flowfile_core/flowfile_core/flowfile/flow_data_engine/validation.py new file mode 100644 index 000000000..eb0ea06d7 --- /dev/null +++ b/flowfile_core/flowfile_core/flowfile/flow_data_engine/validation.py @@ -0,0 +1,244 @@ +"""Validation engine for data contracts and the data_validation node. + +Each rule provides its own Flowfile expression via ``to_expression()``. +This module collects those expressions, converts them to Polars expressions +via ``to_expr()`` (``polars_expr_transformer.simple_function_to_expr``), and +evaluates them against a ``LazyFrame``. + +Shared between: +- The ``data_validation`` flow node (appends ``_is_valid`` / ``_violations`` columns) +- The catalog contract validation endpoint (returns a ``ValidationResult``) +""" + +from __future__ import annotations + +import logging + +import polars as pl +from polars_expr_transformer import simple_function_to_expr as to_expr + +from flowfile_core.schemas.contract_schema import ( + DataContractDefinition, + DtypeRule, + RuleResult, + UniqueRule, + ValidationResult, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Expression collection +# --------------------------------------------------------------------------- + + +def _build_column_check(col_name: str, rule) -> tuple[str, pl.Expr | None]: + """Return ``(rule_name, pl.Expr | None)`` for a single column rule. + + Most rules produce a Flowfile expression string that is parsed via + ``to_expr()``. Special cases (``DtypeRule``, ``UniqueRule``) are + handled with native Polars operations. + """ + rule_name = f"{col_name}:{rule.rule_type}" + + # DtypeRule is a schema-level check — cannot be a row-level expression + if isinstance(rule, DtypeRule): + return (rule_name, None) + + # UniqueRule needs native Polars (is_unique is not in polars_expr_transformer) + if isinstance(rule, UniqueRule): + return (rule_name, pl.col(col_name).is_unique()) + + # All other rules produce a Flowfile expression string + ff_expr = rule.to_expression(col_name) + try: + return (rule_name, to_expr(ff_expr)) + except Exception: + logger.warning("Failed to parse Flowfile expression: %s", ff_expr) + return (rule_name, None) + + +def collect_checks(definition: DataContractDefinition) -> list[tuple[str, pl.Expr]]: + """Collect ``(rule_name, pl.Expr)`` pairs from the contract definition. + + Skips rules that cannot be represented as row-level expressions (e.g. + ``DtypeRule``). + """ + checks: list[tuple[str, pl.Expr]] = [] + for col in definition.columns: + for rule in col.rules: + name, expr = _build_column_check(col.name, rule) + if expr is not None: + checks.append((name, expr)) + + for rule in definition.general_rules: + try: + checks.append((rule.name, to_expr(rule.to_expression()))) + except Exception: + logger.warning("Failed to parse general rule expression: %s", rule.expression) + + return checks + + +# --------------------------------------------------------------------------- +# Node-level: append validation columns to a LazyFrame +# --------------------------------------------------------------------------- + + +def apply_validation( + lf: pl.LazyFrame, + definition: DataContractDefinition, + add_validity: bool = True, + add_details: bool = False, +) -> pl.LazyFrame: + """Append ``_is_valid`` and optionally ``_violations`` columns to *lf*.""" + checks = collect_checks(definition) + + if not checks: + if add_validity: + lf = lf.with_columns(pl.lit(True).alias("_is_valid")) + if add_details: + lf = lf.with_columns(pl.lit("").alias("_violations")) + return lf + + if add_validity: + validity_expr = pl.lit(True) + for _, expr in checks: + validity_expr = validity_expr & expr + lf = lf.with_columns(validity_expr.alias("_is_valid")) + + if add_details: + violation_parts = [pl.when(~expr).then(pl.lit(name)).otherwise(pl.lit("")) for name, expr in checks] + if violation_parts: + lf = lf.with_columns(pl.concat_str(violation_parts, separator=",").alias("_violations")) + else: + lf = lf.with_columns(pl.lit("").alias("_violations")) + + return lf + + +# --------------------------------------------------------------------------- +# Contract-level: full validation returning a ValidationResult +# --------------------------------------------------------------------------- + + +def _validate_schema(lf: pl.LazyFrame, definition: DataContractDefinition) -> list[RuleResult]: + """Check column existence, dtype, and extra-column rules against the schema.""" + results: list[RuleResult] = [] + actual_schema = dict(lf.collect_schema()) + + for col in definition.columns: + if col.name not in actual_schema: + results.append( + RuleResult( + rule_name=f"{col.name}:exists", + column=col.name, + passed=False, + message=f"Column '{col.name}' missing", + ) + ) + continue + + # Check DtypeRule(s) on this column + for rule in col.rules: + if isinstance(rule, DtypeRule): + actual_dtype = str(actual_schema[col.name]) + passed = actual_dtype == rule.expected_dtype + results.append( + RuleResult( + rule_name=f"{col.name}:dtype", + column=col.name, + passed=passed, + message=f"Expected {rule.expected_dtype}, got {actual_dtype}" if not passed else "", + ) + ) + + if not definition.allow_extra_columns: + expected_names = {c.name for c in definition.columns} + for name in actual_schema: + if name not in expected_names: + results.append( + RuleResult( + rule_name=f"{name}:unexpected", + column=name, + passed=False, + message=f"Unexpected column '{name}'", + ) + ) + + return results + + +def validate_contract(lf: pl.LazyFrame, definition: DataContractDefinition) -> ValidationResult: + """Run full contract validation: schema checks + row-level rule evaluation.""" + all_results: list[RuleResult] = [] + + # 1. Schema-level checks + all_results.extend(_validate_schema(lf, definition)) + + # 2. Row-level checks + checks = collect_checks(definition) + if checks: + # Build a frame that evaluates each check and counts violations + check_exprs = [expr.alias(name) for name, expr in checks] + try: + evaluated = lf.select([pl.len().alias("_total")] + check_exprs).collect() + total_rows = evaluated["_total"][0] + + for name, _ in checks: + col_data = evaluated[name] + violation_count = int(col_data.not_().sum()) + passed = violation_count == 0 + col_name = name.split(":")[0] if ":" in name else None + all_results.append( + RuleResult( + rule_name=name, + column=col_name, + passed=passed, + violation_count=violation_count, + message=f"{violation_count} violation(s)" if not passed else "", + ) + ) + except Exception as e: + logger.warning("Error evaluating row-level checks: %s", e) + total_rows = 0 + else: + try: + total_rows = lf.select(pl.len()).collect().item() + except Exception: + total_rows = 0 + + # 3. Row count checks + if definition.row_count: + rc = definition.row_count + if rc.min_rows is not None: + passed = total_rows >= rc.min_rows + all_results.append( + RuleResult( + rule_name="table:min_row_count", + passed=passed, + violation_count=0 if passed else 1, + message="" if passed else f"Expected >= {rc.min_rows} rows, got {total_rows}", + ) + ) + if rc.max_rows is not None: + passed = total_rows <= rc.max_rows + all_results.append( + RuleResult( + rule_name="table:max_row_count", + passed=passed, + violation_count=0 if passed else 1, + message="" if passed else f"Expected <= {rc.max_rows} rows, got {total_rows}", + ) + ) + + overall_passed = all(r.passed for r in all_results) + valid_rows = total_rows - max((r.violation_count for r in all_results), default=0) + + return ValidationResult( + passed=overall_passed, + rule_results=all_results, + total_rows=total_rows, + valid_rows=max(valid_rows, 0), + ) diff --git a/flowfile_core/flowfile_core/flowfile/flow_graph.py b/flowfile_core/flowfile_core/flowfile/flow_graph.py index 769f496e1..067db9ff2 100644 --- a/flowfile_core/flowfile_core/flowfile/flow_graph.py +++ b/flowfile_core/flowfile_core/flowfile/flow_graph.py @@ -1280,6 +1280,36 @@ def _func(fl: FlowDataEngine) -> FlowDataEngine: input_node_ids=[node_number_of_records.depending_on_id], ) + def add_data_validation(self, settings: input_schema.NodeDataValidation): + """Adds a validation node that checks data against typed rules. + + Passes data through and appends ``_is_valid`` / ``_violations`` columns. + """ + + def _func(fl: FlowDataEngine) -> FlowDataEngine: + from flowfile_core.flowfile.flow_data_engine.validation import apply_validation + from flowfile_core.schemas.contract_schema import DataContractDefinition + + definition = settings.definition + if not isinstance(definition, DataContractDefinition): + definition = DataContractDefinition.model_validate(definition) + result_lf = apply_validation( + fl.data_frame, + definition, + add_validity=settings.add_validity_column, + add_details=settings.add_violation_details, + ) + return FlowDataEngine(result_lf) + + self.add_node_step( + node_id=settings.node_id, + function=_func, + node_type="data_validation", + renew_schema=True, + setting_input=settings, + input_node_ids=[settings.depending_on_id], + ) + @with_history_capture(HistoryActionType.UPDATE_SETTINGS) def add_polars_code(self, node_polars_code: input_schema.NodePolarsCode): """Adds a node that executes custom Polars code. @@ -2411,8 +2441,11 @@ def _func(): # Store deferred commit info for post-stage commit if kafka_result.messages_consumed > 0: node._on_flow_complete = make_kafka_commit_callback( - kafka_read_settings, kafka_result.new_offsets, node_kafka_source.node_id, - self.flow_logger, _decrypt_fn, + kafka_read_settings, + kafka_result.new_offsets, + node_kafka_source.node_id, + self.flow_logger, + _decrypt_fn, ) else: # Remote execution — offload to worker (worker uses commit=False + sidecar) @@ -2423,8 +2456,11 @@ def _func(): offsets_data = fetch_kafka_offsets(external_kafka_fetcher.file_ref) if offsets_data and offsets_data.get("messages_consumed", 0) > 0: node._on_flow_complete = make_kafka_commit_callback( - kafka_read_settings, offsets_data["new_offsets"], node_kafka_source.node_id, - self.flow_logger, _decrypt_fn, + kafka_read_settings, + offsets_data["new_offsets"], + node_kafka_source.node_id, + self.flow_logger, + _decrypt_fn, ) # The worker DataFrame may have fewer columns than the inferred # schema (e.g. empty topic or starting at "latest"). Align to @@ -3265,9 +3301,7 @@ def _run_post_execution_callbacks( try: callback(success) except Exception as e: - self.flow_logger.error( - f"Post-execution callback failed for node {n.node_id}: {e}" - ) + self.flow_logger.error(f"Post-execution callback failed for node {n.node_id}: {e}") n._on_flow_complete = None def run_graph(self) -> RunInformation | None: diff --git a/flowfile_core/flowfile_core/routes/catalog.py b/flowfile_core/flowfile_core/routes/catalog.py index 6a1998f4e..7d360e7b9 100644 --- a/flowfile_core/flowfile_core/routes/catalog.py +++ b/flowfile_core/flowfile_core/routes/catalog.py @@ -21,6 +21,8 @@ from flowfile_core.auth.jwt import get_current_active_user from flowfile_core.catalog import ( CatalogService, + ContractExistsError, + ContractNotFoundError, FavoriteNotFoundError, FlowAlreadyRunningError, FlowHasArtifactsError, @@ -849,3 +851,187 @@ async def scheduler_stop(current_user=Depends(get_current_active_user)): await scheduler.stop() set_scheduler(None) return {"message": "Scheduler stopped"} + + +# --------------------------------------------------------------------------- +# Data Contracts +# --------------------------------------------------------------------------- + + +@router.get("/tables/{table_id}/contract") +def get_contract( + table_id: int, + service: CatalogService = Depends(get_catalog_service), +): + """Get the data contract for a catalog table.""" + from flowfile_core.schemas.contract_schema import DataContractDefinition, DataContractOut + + contract = service.get_contract(table_id) + if contract is None: + raise HTTPException(404, "No contract found for this table") + definition = DataContractDefinition.model_validate_json(contract.definition_json) + return DataContractOut( + id=contract.id, + table_id=contract.table_id, + name=contract.name, + description=contract.description, + definition=definition, + last_validated_version=contract.last_validated_version, + last_validated_at=contract.last_validated_at, + last_validation_passed=contract.last_validation_passed, + status=contract.status, + version=contract.version, + owner_id=contract.owner_id, + created_at=contract.created_at, + updated_at=contract.updated_at, + ) + + +@router.post("/tables/{table_id}/contract", status_code=201) +def create_contract( + table_id: int, + current_user=Depends(get_current_active_user), + service: CatalogService = Depends(get_catalog_service), +): + """Create a data contract for a catalog table.""" + from flowfile_core.schemas.contract_schema import DataContractCreate, DataContractDefinition, DataContractOut + + body = DataContractCreate(table_id=table_id, name=f"Contract for table {table_id}") + try: + contract = service.create_contract( + table_id=table_id, + name=body.name, + owner_id=current_user.id, + definition_json=body.definition.model_dump_json(), + description=body.description, + status=body.status, + ) + except TableNotFoundError: + raise HTTPException(404, "Catalog table not found") from None + except ContractExistsError: + raise HTTPException(409, "A contract already exists for this table") from None + definition = DataContractDefinition.model_validate_json(contract.definition_json) + return DataContractOut( + id=contract.id, + table_id=contract.table_id, + name=contract.name, + description=contract.description, + definition=definition, + last_validated_version=contract.last_validated_version, + last_validated_at=contract.last_validated_at, + last_validation_passed=contract.last_validation_passed, + status=contract.status, + version=contract.version, + owner_id=contract.owner_id, + created_at=contract.created_at, + updated_at=contract.updated_at, + ) + + +@router.put("/tables/{table_id}/contract") +def update_contract( + table_id: int, + service: CatalogService = Depends(get_catalog_service), +): + """Update the data contract for a catalog table.""" + from flowfile_core.schemas.contract_schema import DataContractDefinition, DataContractOut, DataContractUpdate + + body = DataContractUpdate() + try: + contract = service.update_contract( + table_id=table_id, + name=body.name, + description=body.description, + definition_json=body.definition.model_dump_json() if body.definition else None, + status=body.status, + ) + except ContractNotFoundError: + raise HTTPException(404, "No contract found for this table") from None + definition = DataContractDefinition.model_validate_json(contract.definition_json) + return DataContractOut( + id=contract.id, + table_id=contract.table_id, + name=contract.name, + description=contract.description, + definition=definition, + last_validated_version=contract.last_validated_version, + last_validated_at=contract.last_validated_at, + last_validation_passed=contract.last_validation_passed, + status=contract.status, + version=contract.version, + owner_id=contract.owner_id, + created_at=contract.created_at, + updated_at=contract.updated_at, + ) + + +@router.delete("/tables/{table_id}/contract", status_code=204) +def delete_contract( + table_id: int, + service: CatalogService = Depends(get_catalog_service), +): + """Delete the data contract for a catalog table.""" + try: + service.delete_contract(table_id) + except ContractNotFoundError: + raise HTTPException(404, "No contract found for this table") from None + + +@router.post("/tables/{table_id}/contract/generate") +def generate_contract( + table_id: int, + service: CatalogService = Depends(get_catalog_service), +): + """Auto-generate a contract definition from the table's current data profile.""" + from flowfile_core.schemas.contract_schema import DataContractDefinition + + try: + definition_json = service.generate_contract_from_table(table_id) + except TableNotFoundError: + raise HTTPException(404, "Catalog table not found") from None + return DataContractDefinition.model_validate_json(definition_json) + + +@router.post("/tables/{table_id}/contract/validate") +def validate_contract( + table_id: int, + service: CatalogService = Depends(get_catalog_service), +): + """Run validation of the table against its contract. Returns a ValidationResult.""" + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract as run_validation + from flowfile_core.schemas.contract_schema import DataContractDefinition + + contract = service.get_contract(table_id) + if contract is None: + raise HTTPException(404, "No contract found for this table") + + # Get the raw table record for file_path + table_record = service._repo.get_table(table_id) + if table_record is None: + raise HTTPException(404, "Catalog table not found") + + definition = DataContractDefinition.model_validate_json(contract.definition_json) + + import polars as pl + + try: + lf = pl.scan_delta(table_record.file_path) + except Exception: + lf = pl.scan_parquet(table_record.file_path) + + result = run_validation(lf, definition) + return result + + +@router.put("/tables/{table_id}/contract/mark-validated") +def mark_contract_validated( + table_id: int, + passed: bool = Query(...), + service: CatalogService = Depends(get_catalog_service), +): + """Mark the contract as validated (or failed) at the current Delta version.""" + try: + service.mark_contract_validated(table_id=table_id, passed=passed) + except ContractNotFoundError: + raise HTTPException(404, "No contract found for this table") from None + return {"status": "ok"} diff --git a/flowfile_core/flowfile_core/schemas/catalog_schema.py b/flowfile_core/flowfile_core/schemas/catalog_schema.py index e9f27766d..a49ab9916 100644 --- a/flowfile_core/flowfile_core/schemas/catalog_schema.py +++ b/flowfile_core/flowfile_core/schemas/catalog_schema.py @@ -228,6 +228,7 @@ class CatalogTableOut(BaseModel): source_registration_name: str | None = None source_run_id: int | None = None read_by_flows: list["FlowSummary"] = Field(default_factory=list) + contract: dict | None = None created_at: datetime updated_at: datetime diff --git a/flowfile_core/flowfile_core/schemas/contract_schema.py b/flowfile_core/flowfile_core/schemas/contract_schema.py new file mode 100644 index 000000000..efd09e389 --- /dev/null +++ b/flowfile_core/flowfile_core/schemas/contract_schema.py @@ -0,0 +1,236 @@ +"""Pydantic schemas for the Data Contract system. + +Defines typed rule models for column-level and table-level validation, +the contract definition, and API request/response schemas. + +Each column rule provides a `to_expression(column)` method that returns +a Flowfile expression string. The validation engine collects these +expressions and evaluates them via `to_expr()` (polars_expr_transformer). +""" + +from __future__ import annotations + +from abc import abstractmethod +from datetime import datetime +from typing import Annotated, Literal + +from pydantic import BaseModel, Field + +from flowfile_core.types import DataTypeStr + +# ==================== Column-Level Rule Types ==================== + + +class BaseColumnRule(BaseModel): + """Base for all column-level rules. + + Each subclass must implement ``to_expression(column)`` which returns a + Flowfile expression string for the given column name. + """ + + @abstractmethod + def to_expression(self, column: str) -> str: + """Return a Flowfile expression string for this rule applied to *column*. + + Examples: ``'is_not_empty([age])'``, ``'([age]>=0) & ([age]<=150)'`` + """ + ... + + +class NotNullRule(BaseColumnRule): + rule_type: Literal["not_null"] = "not_null" + + def to_expression(self, column: str) -> str: + return f"is_not_empty([{column}])" + + +class UniqueRule(BaseColumnRule): + rule_type: Literal["unique"] = "unique" + + def to_expression(self, column: str) -> str: + # is_unique is not a standard Flowfile expression function; + # the validation engine handles this rule via Polars directly. + return f"is_unique([{column}])" + + +class DtypeRule(BaseColumnRule): + rule_type: Literal["dtype"] = "dtype" + expected_dtype: DataTypeStr + + def to_expression(self, column: str) -> str: + # dtype checks are schema-level, not row-level. + # The validation engine handles this separately. + return f'dtype([{column}]) = "{self.expected_dtype}"' + + +class ValueRangeRule(BaseColumnRule): + rule_type: Literal["value_range"] = "value_range" + min_value: float | None = None + max_value: float | None = None + + def to_expression(self, column: str) -> str: + parts: list[str] = [] + if self.min_value is not None: + parts.append(f"([{column}]>={self.min_value})") + if self.max_value is not None: + parts.append(f"([{column}]<={self.max_value})") + return " & ".join(parts) if parts else "true" + + +class AllowedValuesRule(BaseColumnRule): + rule_type: Literal["allowed_values"] = "allowed_values" + values: list[str] + + def to_expression(self, column: str) -> str: + if len(self.values) == 1: + return f'[{column}]="{self.values[0]}"' + conditions = [f'([{column}]="{v}")' for v in self.values] + return " | ".join(conditions) + + +class RegexRule(BaseColumnRule): + rule_type: Literal["regex"] = "regex" + pattern: str + + def to_expression(self, column: str) -> str: + return f'contains([{column}], "{self.pattern}")' + + +class CustomExpressionRule(BaseColumnRule): + """Free-form Flowfile expression. + + The user writes the expression directly using Flowfile syntax. + """ + + rule_type: Literal["custom_expression"] = "custom_expression" + expression: str + + def to_expression(self, column: str) -> str: + return self.expression + + +# Discriminated union — Pydantic picks the right model by ``rule_type`` +ColumnRule = Annotated[ + NotNullRule | UniqueRule | DtypeRule | ValueRangeRule | AllowedValuesRule | RegexRule | CustomExpressionRule, + Field(discriminator="rule_type"), +] + + +# ==================== Column Contract ==================== + + +class ColumnContract(BaseModel): + """A column in the contract: its identity plus a list of typed rules.""" + + name: str + rules: list[ColumnRule] = Field(default_factory=list) + + def get_expressions(self) -> list[tuple[str, str]]: + """Return ``(rule_name, flowfile_expression)`` for each rule on this column.""" + return [(f"{self.name}:{rule.rule_type}", rule.to_expression(self.name)) for rule in self.rules] + + +# ==================== General (Table-Level) Rules ==================== + + +class GeneralRule(BaseModel): + """A table-level / cross-column rule written as a Flowfile expression. + + Examples:: + + GeneralRule(name="start_before_end", expression="[start_date]<=[end_date]") + GeneralRule(name="has_contact", expression="is_not_empty([email]) | is_not_empty([phone])") + """ + + name: str + expression: str + description: str | None = None + + def to_expression(self) -> str: + return self.expression + + +class RowCountRule(BaseModel): + """Convenience typed rule for row count bounds.""" + + min_rows: int | None = None + max_rows: int | None = None + + +# ==================== Contract Definition ==================== + + +class DataContractDefinition(BaseModel): + """The full contract definition: columns with rules + general rules.""" + + columns: list[ColumnContract] = Field(default_factory=list) + allow_extra_columns: bool = False + general_rules: list[GeneralRule] = Field(default_factory=list) + row_count: RowCountRule | None = None + + +# ==================== Validation Result Models ==================== + + +class RuleResult(BaseModel): + """Result of evaluating a single validation rule.""" + + rule_name: str + column: str | None = None + passed: bool + violation_count: int = 0 + message: str = "" + + +class ValidationResult(BaseModel): + """Aggregated result of all contract rules.""" + + passed: bool + rule_results: list[RuleResult] = Field(default_factory=list) + total_rows: int = 0 + valid_rows: int = 0 + + +# ==================== API Request / Response Schemas ==================== + + +class DataContractCreate(BaseModel): + table_id: int + name: str + description: str | None = None + definition: DataContractDefinition = Field(default_factory=DataContractDefinition) + status: Literal["draft", "active"] = "draft" + + +class DataContractUpdate(BaseModel): + name: str | None = None + description: str | None = None + definition: DataContractDefinition | None = None + status: Literal["draft", "active", "archived"] | None = None + + +class DataContractOut(BaseModel): + id: int + table_id: int + name: str + description: str | None + definition: DataContractDefinition + last_validated_version: int | None + last_validated_at: datetime | None + last_validation_passed: bool | None + status: str + version: int + owner_id: int + created_at: datetime + updated_at: datetime + + model_config = {"from_attributes": True} + + +class ContractSummary(BaseModel): + """Lightweight contract info attached to CatalogTableOut.""" + + status: Literal["validated", "stale", "failed", "draft", "none"] + last_validated_version: int | None = None + current_version: int | None = None + rule_count: int = 0 diff --git a/flowfile_core/flowfile_core/schemas/input_schema.py b/flowfile_core/flowfile_core/schemas/input_schema.py index 725485fd8..f972b1f25 100644 --- a/flowfile_core/flowfile_core/schemas/input_schema.py +++ b/flowfile_core/flowfile_core/schemas/input_schema.py @@ -15,7 +15,7 @@ model_validator, ) -from flowfile_core.schemas import transform_schema +from flowfile_core.schemas import contract_schema, transform_schema from flowfile_core.schemas.analysis_schemas import graphic_walker_schemas as gs_schemas from flowfile_core.schemas.cloud_storage_schemas import CloudStorageReadSettings, CloudStorageWriteSettings from flowfile_core.schemas.yaml_types import ( @@ -1284,6 +1284,23 @@ def validate_output_names(cls, v: list[str]) -> list[str]: return _validate_output_names(v) +class NodeDataValidation(NodeSingleInput): + """Settings for a validation node that checks data against typed rules. + + Uses the same ``DataContractDefinition`` model as catalog data contracts. + Passes data through and appends ``_is_valid`` / ``_violations`` columns. + """ + + definition: "contract_schema.DataContractDefinition" + add_validity_column: bool = True + add_violation_details: bool = False + + def get_default_description(self) -> str: + n_col = len(self.definition.columns) + n_gen = len(self.definition.general_rules) + return f"Validate {n_col} column(s), {n_gen} rule(s)" + + class UserDefinedNode(NodeMultiInput): """Settings for a node that contains the user defined node information""" diff --git a/flowfile_core/flowfile_core/schemas/schemas.py b/flowfile_core/flowfile_core/schemas/schemas.py index d7531d577..f0c2309d8 100644 --- a/flowfile_core/flowfile_core/schemas/schemas.py +++ b/flowfile_core/flowfile_core/schemas/schemas.py @@ -48,6 +48,7 @@ "external_source": input_schema.NodeExternalSource, "promise": input_schema.NodePromise, "user_defined": input_schema.UserDefinedNode, + "data_validation": input_schema.NodeDataValidation, } diff --git a/flowfile_core/tests/test_data_contracts.py b/flowfile_core/tests/test_data_contracts.py new file mode 100644 index 000000000..489853468 --- /dev/null +++ b/flowfile_core/tests/test_data_contracts.py @@ -0,0 +1,297 @@ +"""Tests for data contract models and validation engine.""" + +import polars as pl +import pytest + +from flowfile_core.schemas.contract_schema import ( + AllowedValuesRule, + ColumnContract, + CustomExpressionRule, + DataContractDefinition, + DtypeRule, + GeneralRule, + NotNullRule, + RegexRule, + RowCountRule, + UniqueRule, + ValidationResult, + ValueRangeRule, +) + + +# --------------------------------------------------------------------------- +# Rule to_expression tests +# --------------------------------------------------------------------------- + + +class TestRuleExpressions: + def test_not_null_expression(self): + rule = NotNullRule() + assert rule.to_expression("age") == "is_not_empty([age])" + + def test_unique_expression(self): + rule = UniqueRule() + assert rule.to_expression("id") == "is_unique([id])" + + def test_dtype_expression(self): + rule = DtypeRule(expected_dtype="Int64") + expr = rule.to_expression("age") + assert "Int64" in expr + assert "[age]" in expr + + def test_value_range_both_bounds(self): + rule = ValueRangeRule(min_value=0, max_value=150) + expr = rule.to_expression("age") + assert "([age]>=0" in expr + assert "([age]<=150" in expr + assert "&" in expr + + def test_value_range_min_only(self): + rule = ValueRangeRule(min_value=0) + expr = rule.to_expression("age") + assert "([age]>=0" in expr + assert "<=" not in expr + + def test_value_range_max_only(self): + rule = ValueRangeRule(max_value=100) + expr = rule.to_expression("count") + assert "([count]<=100" in expr + + def test_value_range_no_bounds(self): + rule = ValueRangeRule() + assert rule.to_expression("x") == "true" + + def test_allowed_values_multiple(self): + rule = AllowedValuesRule(values=["A", "B", "C"]) + expr = rule.to_expression("status") + assert '([status]="A")' in expr + assert '([status]="B")' in expr + assert "|" in expr + + def test_allowed_values_single(self): + rule = AllowedValuesRule(values=["active"]) + expr = rule.to_expression("status") + assert expr == '[status]="active"' + + def test_regex_expression(self): + rule = RegexRule(pattern=r"^\d+$") + expr = rule.to_expression("code") + assert "contains([code]" in expr + + def test_custom_expression(self): + rule = CustomExpressionRule(expression="[age] > 0") + assert rule.to_expression("age") == "[age] > 0" + + +# --------------------------------------------------------------------------- +# ColumnContract tests +# --------------------------------------------------------------------------- + + +class TestColumnContract: + def test_get_expressions(self): + col = ColumnContract( + name="age", + rules=[NotNullRule(), ValueRangeRule(min_value=0, max_value=150)], + ) + expressions = col.get_expressions() + assert len(expressions) == 2 + assert expressions[0][0] == "age:not_null" + assert expressions[1][0] == "age:value_range" + + def test_empty_rules(self): + col = ColumnContract(name="x") + assert col.get_expressions() == [] + + +# --------------------------------------------------------------------------- +# DataContractDefinition serialization tests +# --------------------------------------------------------------------------- + + +class TestContractSerialization: + def test_roundtrip_json(self): + definition = DataContractDefinition( + columns=[ + ColumnContract( + name="age", + rules=[DtypeRule(expected_dtype="Int64"), NotNullRule()], + ), + ColumnContract( + name="name", + rules=[NotNullRule(), RegexRule(pattern="^[A-Z]")], + ), + ], + allow_extra_columns=False, + general_rules=[ + GeneralRule(name="positive_age", expression="[age] > 0"), + ], + row_count=RowCountRule(min_rows=1), + ) + json_str = definition.model_dump_json() + restored = DataContractDefinition.model_validate_json(json_str) + + assert len(restored.columns) == 2 + assert len(restored.columns[0].rules) == 2 + assert isinstance(restored.columns[0].rules[0], DtypeRule) + assert isinstance(restored.columns[0].rules[1], NotNullRule) + assert isinstance(restored.columns[1].rules[1], RegexRule) + assert len(restored.general_rules) == 1 + assert restored.row_count.min_rows == 1 + + def test_discriminated_union(self): + """Pydantic correctly picks rule type from JSON based on rule_type.""" + data = { + "columns": [ + { + "name": "id", + "rules": [ + {"rule_type": "not_null"}, + {"rule_type": "unique"}, + {"rule_type": "value_range", "min_value": 1}, + {"rule_type": "allowed_values", "values": ["a", "b"]}, + {"rule_type": "regex", "pattern": "^[a-z]+$"}, + {"rule_type": "dtype", "expected_dtype": "Int64"}, + {"rule_type": "custom_expression", "expression": "[id] > 0"}, + ], + } + ] + } + definition = DataContractDefinition.model_validate(data) + rules = definition.columns[0].rules + assert isinstance(rules[0], NotNullRule) + assert isinstance(rules[1], UniqueRule) + assert isinstance(rules[2], ValueRangeRule) + assert isinstance(rules[3], AllowedValuesRule) + assert isinstance(rules[4], RegexRule) + assert isinstance(rules[5], DtypeRule) + assert isinstance(rules[6], CustomExpressionRule) + + +# --------------------------------------------------------------------------- +# Validation engine tests +# --------------------------------------------------------------------------- + + +class TestValidationEngine: + @pytest.fixture + def sample_lf(self): + return pl.LazyFrame( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", None, "Diana", "Eve"], + "age": [25, 30, -5, 45, 200], + "status": ["active", "active", "inactive", "active", "unknown"], + } + ) + + def test_apply_validation_adds_is_valid(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import apply_validation + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="name", rules=[NotNullRule()]), + ] + ) + result = apply_validation(sample_lf, definition, add_validity=True) + df = result.collect() + assert "_is_valid" in df.columns + # Row 3 (index 2) has null name + assert df["_is_valid"].to_list() == [True, True, False, True, True] + + def test_apply_validation_value_range(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import apply_validation + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="age", rules=[ValueRangeRule(min_value=0, max_value=150)]), + ] + ) + result = apply_validation(sample_lf, definition, add_validity=True) + df = result.collect() + # Row 3 has age=-5, row 5 has age=200 + valid = df["_is_valid"].to_list() + assert valid[0] is True # age=25 + assert valid[2] is False # age=-5 + assert valid[4] is False # age=200 + + def test_apply_validation_empty_definition(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import apply_validation + + definition = DataContractDefinition() + result = apply_validation(sample_lf, definition, add_validity=True) + df = result.collect() + assert all(df["_is_valid"].to_list()) + + def test_validate_contract_schema_check(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[DtypeRule(expected_dtype="Int64")]), + ColumnContract(name="missing_col", rules=[]), + ], + allow_extra_columns=True, + ) + result = validate_contract(sample_lf, definition) + assert isinstance(result, ValidationResult) + # missing_col should fail existence check + missing_results = [r for r in result.rule_results if r.rule_name == "missing_col:exists"] + assert len(missing_results) == 1 + assert missing_results[0].passed is False + + def test_validate_contract_row_count(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract + + definition = DataContractDefinition( + row_count=RowCountRule(min_rows=10), + allow_extra_columns=True, + ) + result = validate_contract(sample_lf, definition) + assert result.passed is False + row_check = [r for r in result.rule_results if r.rule_name == "table:min_row_count"] + assert len(row_check) == 1 + assert row_check[0].passed is False + + def test_validate_contract_unique_rule(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="status", rules=[UniqueRule()]), + ], + allow_extra_columns=True, + ) + result = validate_contract(sample_lf, definition) + # status has duplicates ("active" appears 3 times) + unique_check = [r for r in result.rule_results if r.rule_name == "status:unique"] + assert len(unique_check) == 1 + assert unique_check[0].passed is False + assert unique_check[0].violation_count > 0 + + def test_validate_contract_extra_columns_rejected(self, sample_lf): + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract + + definition = DataContractDefinition( + columns=[ColumnContract(name="id", rules=[])], + allow_extra_columns=False, + ) + result = validate_contract(sample_lf, definition) + unexpected = [r for r in result.rule_results if "unexpected" in r.rule_name] + # name, age, status are unexpected + assert len(unexpected) == 3 + + def test_validate_contract_all_pass(self): + from flowfile_core.flowfile.flow_data_engine.validation import validate_contract + + lf = pl.LazyFrame({"id": [1, 2, 3], "name": ["a", "b", "c"]}) + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ], + allow_extra_columns=False, + ) + result = validate_contract(lf, definition) + assert result.passed is True + assert result.total_rows == 3 From 01a562453bd9ef3f2d905d715051742d97e0828d Mon Sep 17 00:00:00 2001 From: Claude Date: Mon, 6 Apr 2026 09:01:53 +0000 Subject: [PATCH 2/2] Move downstream triggers to validation path and add integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Validation now auto-marks the contract at the current table version and fires downstream table_trigger schedules on success. Overwrite no longer triggers downstream flows — the contract must pass first. Also fixes self._repo → self.repo bug in service contract methods and improves storage format detection in the validate endpoint. https://claude.ai/code/session_01RJ6iAh4f3MN7UdzttruKho --- .../flowfile_core/catalog/service.py | 44 +- flowfile_core/flowfile_core/routes/catalog.py | 13 +- flowfile_core/tests/test_catalog.py | 72 +- flowfile_core/tests/test_catalog_delta.py | 4 - .../tests/test_data_contracts_integration.py | 797 ++++++++++++++++++ 5 files changed, 839 insertions(+), 91 deletions(-) create mode 100644 flowfile_core/tests/test_data_contracts_integration.py diff --git a/flowfile_core/flowfile_core/catalog/service.py b/flowfile_core/flowfile_core/catalog/service.py index f70ce0c31..3cb05a4e2 100644 --- a/flowfile_core/flowfile_core/catalog/service.py +++ b/flowfile_core/flowfile_core/catalog/service.py @@ -1217,11 +1217,6 @@ def overwrite_table_data( table = self.repo.update_table(table) - try: - self._fire_table_trigger_schedules(table.id, table.updated_at) - except Exception: - logger.exception("Failed to fire push triggers for table %s", table.id) - return self._table_to_out(table) def _fire_table_trigger_schedules(self, table_id: int, table_updated_at: datetime) -> int: @@ -1988,7 +1983,7 @@ def get_catalog_stats(self, user_id: int) -> CatalogStats: def get_contract(self, table_id: int) -> DataContract | None: """Return the contract for *table_id*, or ``None``.""" - return self._repo.get_contract_by_table(table_id) + return self.repo.get_contract_by_table(table_id) def create_contract( self, @@ -2000,10 +1995,10 @@ def create_contract( status: str = "draft", ) -> DataContract: """Create a new contract for a catalog table.""" - table = self._repo.get_table(table_id) + table = self.repo.get_table(table_id) if table is None: raise TableNotFoundError(table_id) - existing = self._repo.get_contract_by_table(table_id) + existing = self.repo.get_contract_by_table(table_id) if existing is not None: raise ContractExistsError(table_id) contract = DataContract( @@ -2015,7 +2010,7 @@ def create_contract( status=status, version=1, ) - return self._repo.create_contract(contract) + return self.repo.create_contract(contract) def update_contract( self, @@ -2026,7 +2021,7 @@ def update_contract( status: str | None = None, ) -> DataContract: """Update an existing contract. Bumps the version number.""" - contract = self._repo.get_contract_by_table(table_id) + contract = self.repo.get_contract_by_table(table_id) if contract is None: raise ContractNotFoundError(table_id) if name is not None: @@ -2038,14 +2033,14 @@ def update_contract( if status is not None: contract.status = status contract.version = (contract.version or 1) + 1 - return self._repo.update_contract(contract) + return self.repo.update_contract(contract) def delete_contract(self, table_id: int) -> None: """Delete the contract attached to *table_id*.""" - contract = self._repo.get_contract_by_table(table_id) + contract = self.repo.get_contract_by_table(table_id) if contract is None: raise ContractNotFoundError(table_id) - self._repo.delete_contract(table_id) + self.repo.delete_contract(table_id) def mark_contract_validated( self, @@ -2058,12 +2053,12 @@ def mark_contract_validated( If *version* is not supplied, the current Delta version is read from the table storage. """ - contract = self._repo.get_contract_by_table(table_id) + contract = self.repo.get_contract_by_table(table_id) if contract is None: raise ContractNotFoundError(table_id) if version is None: - table = self._repo.get_table(table_id) + table = self.repo.get_table(table_id) if table is not None and is_delta_table(table.file_path): dt = DeltaTable(table.file_path) version = dt.version() @@ -2071,7 +2066,18 @@ def mark_contract_validated( contract.last_validated_version = version contract.last_validated_at = datetime.now(timezone.utc) contract.last_validation_passed = passed - return self._repo.update_contract(contract) + contract = self.repo.update_contract(contract) + + # Fire downstream table triggers when validation passes + if passed: + try: + table = self.repo.get_table(table_id) + if table is not None: + self._fire_table_trigger_schedules(table.id, table.updated_at) + except Exception: + logger.exception("Failed to fire push triggers after validation for table %s", table_id) + + return contract def generate_contract_from_table(self, table_id: int) -> str: """Profile a catalog table and return a proposed ``DataContractDefinition`` as JSON. @@ -2085,7 +2091,7 @@ def generate_contract_from_table(self, table_id: int) -> str: NotNullRule, ) - table = self._repo.get_table(table_id) + table = self.repo.get_table(table_id) if table is None: raise TableNotFoundError(table_id) @@ -2111,7 +2117,7 @@ def get_contract_summary(self, table_id: int) -> dict | None: """ from flowfile_core.schemas.contract_schema import ContractSummary - contract = self._repo.get_contract_by_table(table_id) + contract = self.repo.get_contract_by_table(table_id) if contract is None: return None @@ -2131,7 +2137,7 @@ def get_contract_summary(self, table_id: int) -> dict | None: status = "failed" elif contract.last_validated_version is not None: # Try to compare with current Delta version - table = self._repo.get_table(table_id) + table = self.repo.get_table(table_id) current_version = None if table is not None and is_delta_table(table.file_path): try: diff --git a/flowfile_core/flowfile_core/routes/catalog.py b/flowfile_core/flowfile_core/routes/catalog.py index 7d360e7b9..f46ca6040 100644 --- a/flowfile_core/flowfile_core/routes/catalog.py +++ b/flowfile_core/flowfile_core/routes/catalog.py @@ -1006,7 +1006,7 @@ def validate_contract( raise HTTPException(404, "No contract found for this table") # Get the raw table record for file_path - table_record = service._repo.get_table(table_id) + table_record = service.repo.get_table(table_id) if table_record is None: raise HTTPException(404, "Catalog table not found") @@ -1014,12 +1014,19 @@ def validate_contract( import polars as pl - try: + if getattr(table_record, "storage_format", None) == "delta": lf = pl.scan_delta(table_record.file_path) - except Exception: + else: lf = pl.scan_parquet(table_record.file_path) result = run_validation(lf, definition) + + # Auto-mark the contract validated at the current table version + try: + service.mark_contract_validated(table_id=table_id, passed=result.passed) + except Exception: + pass # validation result is still returned even if marking fails + return result diff --git a/flowfile_core/tests/test_catalog.py b/flowfile_core/tests/test_catalog.py index 0b8aee9a1..77aded94b 100644 --- a/flowfile_core/tests/test_catalog.py +++ b/flowfile_core/tests/test_catalog.py @@ -943,7 +943,11 @@ def test_flow_out_includes_tables_produced(self): class TestPushTableTrigger: - """Tests for push-driven table_trigger schedule firing on overwrite_table_data.""" + """Tests verifying that overwrite_table_data does NOT fire downstream triggers. + + Downstream triggers are now fired by contract validation, not by overwrites. + See test_data_contracts_integration.py for validation-based trigger tests. + """ @staticmethod def _setup_table_and_schedule(schedule_type="table_trigger"): @@ -1000,74 +1004,12 @@ def _setup_table_and_schedule(schedule_type="table_trigger"): return schema.id, flow.id, table.id, schedule.id, tmp.name - def test_overwrite_fires_push_trigger(self, monkeypatch): - """Overwriting a table should create a FlowRun for its table_trigger schedule.""" - schema_id, flow_id, table_id, schedule_id, parquet_path = self._setup_table_and_schedule() - - monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) - - with get_db_context() as db: - repo = SQLAlchemyCatalogRepository(db) - svc = CatalogService(repo) - svc.overwrite_table_data(table_id, parquet_path) - - runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() - assert len(runs) == 1 - - def test_push_trigger_skips_active_run(self, monkeypatch): - """If the flow already has an active run, no new FlowRun should be created.""" - from datetime import datetime, timezone - - schema_id, flow_id, table_id, schedule_id, parquet_path = self._setup_table_and_schedule() - - monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) - - # Create an active (unfinished) run - with get_db_context() as db: - active_run = FlowRun( - registration_id=flow_id, - flow_name="triggered_flow", - flow_path="/tmp/triggered.yaml", - user_id=1, - started_at=datetime.now(timezone.utc), - number_of_nodes=0, - run_type="scheduled", - ) - db.add(active_run) - db.commit() - - with get_db_context() as db: - repo = SQLAlchemyCatalogRepository(db) - svc = CatalogService(repo) - svc.overwrite_table_data(table_id, parquet_path) - - runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() - # Only the pre-existing active run, no new one - assert len(runs) == 1 - - def test_push_trigger_updates_schedule_timestamps(self, monkeypatch): - """After firing, the schedule's last_triggered_at and last_trigger_table_updated_at should be set.""" + def test_overwrite_does_not_fire_trigger(self, monkeypatch): + """Overwriting a table should NOT create a FlowRun — triggers are validation-driven.""" schema_id, flow_id, table_id, schedule_id, parquet_path = self._setup_table_and_schedule() monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) - with get_db_context() as db: - repo = SQLAlchemyCatalogRepository(db) - svc = CatalogService(repo) - svc.overwrite_table_data(table_id, parquet_path) - - schedule = db.query(FlowSchedule).get(schedule_id) - assert schedule.last_triggered_at is not None - assert schedule.last_trigger_table_updated_at is not None - - def test_push_trigger_ignores_table_set_triggers(self, monkeypatch): - """table_set_trigger schedules should NOT be fired by the push mechanism.""" - schema_id, flow_id, table_id, schedule_id, parquet_path = self._setup_table_and_schedule( - schedule_type="table_set_trigger" - ) - - monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) - with get_db_context() as db: repo = SQLAlchemyCatalogRepository(db) svc = CatalogService(repo) diff --git a/flowfile_core/tests/test_catalog_delta.py b/flowfile_core/tests/test_catalog_delta.py index 44f686e3c..c2ea36bf2 100644 --- a/flowfile_core/tests/test_catalog_delta.py +++ b/flowfile_core/tests/test_catalog_delta.py @@ -323,9 +323,6 @@ def test_overwrite_with_precomputed_metadata(self, tmp_path, monkeypatch): """When all metadata is provided, overwrite should not read the file.""" _, schema_id = _make_namespace() - # Disable push triggers for simplicity - monkeypatch.setattr(CatalogService, "_fire_table_trigger_schedules", lambda *a, **kw: 0) - delta_dir = tmp_path / "overwrite_delta" delta_dir.mkdir() @@ -369,7 +366,6 @@ def test_overwrite_with_precomputed_metadata(self, tmp_path, monkeypatch): def test_overwrite_updates_storage_format(self, tmp_path, monkeypatch): """overwrite_table_data should update storage_format on the DB record.""" _, schema_id = _make_namespace() - monkeypatch.setattr(CatalogService, "_fire_table_trigger_schedules", lambda *a, **kw: 0) with get_db_context() as db: table = CatalogTable( diff --git a/flowfile_core/tests/test_data_contracts_integration.py b/flowfile_core/tests/test_data_contracts_integration.py new file mode 100644 index 000000000..9cd9e48ca --- /dev/null +++ b/flowfile_core/tests/test_data_contracts_integration.py @@ -0,0 +1,797 @@ +"""Integration tests for data contracts via API and service layer. + +Covers: +- Contract CRUD through the catalog API +- Validation auto-marks the contract at the latest version +- Successful validation fires downstream table_trigger schedules +- Failed validation does NOT fire downstream triggers +- Contract generation from table profile +- Contract summary status computation +- Update endpoint with definition changes +- Overwrite does NOT fire triggers (moved to validation) +""" + +import json +import tempfile + +import polars as pl +import pytest +from fastapi.testclient import TestClient + +from flowfile_core import main +from flowfile_core.catalog import CatalogService +from flowfile_core.catalog.repository import SQLAlchemyCatalogRepository +from flowfile_core.database.connection import get_db_context +from flowfile_core.database.models import ( + CatalogNamespace, + CatalogTable, + CatalogTableReadLink, + DataContract, + FlowFavorite, + FlowFollow, + FlowRegistration, + FlowRun, + FlowSchedule, + ScheduleTriggerTable, +) +from flowfile_core.schemas.contract_schema import ( + ColumnContract, + DataContractDefinition, + NotNullRule, + RowCountRule, + UniqueRule, + ValueRangeRule, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_auth_token() -> str: + with TestClient(main.app) as client: + response = client.post("/auth/token") + return response.json()["access_token"] + + +def _get_test_client() -> TestClient: + token = _get_auth_token() + c = TestClient(main.app) + c.headers = {"Authorization": f"Bearer {token}"} + return c + + +client = _get_test_client() + + +def _cleanup_catalog(): + with get_db_context() as db: + db.query(ScheduleTriggerTable).delete() + db.query(FlowSchedule).delete() + db.query(DataContract).delete() + db.query(CatalogTableReadLink).delete() + db.query(CatalogTable).delete() + db.query(FlowFollow).delete() + db.query(FlowFavorite).delete() + db.query(FlowRun).delete() + db.query(FlowRegistration).delete() + db.query(CatalogNamespace).delete() + db.commit() + + +def _make_namespace() -> tuple[int, int]: + """Create a catalog + schema and return (catalog_id, schema_id).""" + with get_db_context() as db: + cat = CatalogNamespace(name="ContractCat", level=0, owner_id=1) + db.add(cat) + db.commit() + db.refresh(cat) + schema = CatalogNamespace(name="ContractSch", level=1, parent_id=cat.id, owner_id=1) + db.add(schema) + db.commit() + db.refresh(schema) + return cat.id, schema.id + + +def _make_table_with_parquet(schema_id: int, data: dict | None = None) -> tuple[int, str]: + """Create a parquet file, register it in the catalog, return (table_id, path).""" + if data is None: + data = {"id": [1, 2, 3], "name": ["Alice", "Bob", "Charlie"], "age": [25, 30, 35]} + + tmp = tempfile.NamedTemporaryFile(suffix=".parquet", delete=False) + df = pl.DataFrame(data) + df.write_parquet(tmp.name) + + schema_json = json.dumps([{"name": c, "dtype": str(df.dtypes[i])} for i, c in enumerate(df.columns)]) + + with get_db_context() as db: + table = CatalogTable( + name="test_table", + namespace_id=schema_id, + owner_id=1, + file_path=tmp.name, + storage_format="parquet", + schema_json=schema_json, + row_count=len(df), + column_count=len(df.columns), + size_bytes=100, + ) + db.add(table) + db.commit() + db.refresh(table) + return table.id, tmp.name + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def clean_catalog(): + _cleanup_catalog() + yield + _cleanup_catalog() + + +# --------------------------------------------------------------------------- +# Contract CRUD via API +# --------------------------------------------------------------------------- + + +class TestContractCrudApi: + def test_create_and_get_contract(self): + """POST creates a contract, GET retrieves it.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + resp = client.post(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 201 + data = resp.json() + assert data["table_id"] == table_id + assert data["status"] == "draft" + assert data["version"] == 1 + + resp = client.get(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 200 + assert resp.json()["table_id"] == table_id + + def test_create_contract_409_on_duplicate(self): + """Creating a second contract for the same table returns 409.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + resp = client.post(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 201 + + resp = client.post(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 409 + + def test_create_contract_404_for_missing_table(self): + """Creating a contract for a non-existent table returns 404.""" + resp = client.post("/catalog/tables/999999/contract") + assert resp.status_code == 404 + + def test_get_contract_404(self): + """GET on a table with no contract returns 404.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + resp = client.get(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 404 + + def test_delete_contract(self): + """DELETE removes the contract; subsequent GET returns 404.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + client.post(f"/catalog/tables/{table_id}/contract") + resp = client.delete(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 204 + + resp = client.get(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 404 + + def test_delete_contract_404(self): + """DELETE on a non-existent contract returns 404.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + resp = client.delete(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 404 + + def test_update_contract_increments_version(self): + """PUT updates the contract and bumps the version.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + client.post(f"/catalog/tables/{table_id}/contract") + + resp = client.put(f"/catalog/tables/{table_id}/contract") + assert resp.status_code == 200 + data = resp.json() + assert data["version"] == 2 + + +# --------------------------------------------------------------------------- +# Validation + auto-mark +# --------------------------------------------------------------------------- + + +class TestValidationAutoMark: + def test_validate_passes_and_marks_contract(self): + """When all rules pass, the contract is marked as validated.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + # Create contract with matching rules + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ColumnContract(name="age", rules=[NotNullRule(), ValueRangeRule(min_value=0, max_value=150)]), + ], + allow_extra_columns=False, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="test_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.status_code == 200 + result = resp.json() + assert result["passed"] is True + assert result["total_rows"] == 3 + + # Check contract was auto-marked + resp = client.get(f"/catalog/tables/{table_id}/contract") + contract = resp.json() + assert contract["last_validation_passed"] is True + assert contract["last_validated_at"] is not None + + def test_validate_fails_and_marks_contract(self): + """When rules fail, the contract is marked as failed.""" + _, schema_id = _make_namespace() + data = {"id": [1, 2, 3], "name": ["Alice", None, "Charlie"], "age": [25, -5, 35]} + table_id, _ = _make_table_with_parquet(schema_id, data) + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="name", rules=[NotNullRule()]), + ColumnContract(name="age", rules=[ValueRangeRule(min_value=0, max_value=150)]), + ], + allow_extra_columns=True, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="fail_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.status_code == 200 + result = resp.json() + assert result["passed"] is False + + # Check contract was auto-marked as failed + resp = client.get(f"/catalog/tables/{table_id}/contract") + contract = resp.json() + assert contract["last_validation_passed"] is False + + def test_validate_without_contract_returns_404(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.status_code == 404 + + def test_validate_row_count_rule(self): + """Row count rule is enforced during validation.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + definition = DataContractDefinition( + row_count=RowCountRule(min_rows=100), + allow_extra_columns=True, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="row_count_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + result = resp.json() + assert result["passed"] is False + min_row_results = [r for r in result["rule_results"] if r["rule_name"] == "table:min_row_count"] + assert len(min_row_results) == 1 + assert min_row_results[0]["passed"] is False + + def test_validate_unique_rule_via_api(self): + """UniqueRule detects duplicate values through the API.""" + _, schema_id = _make_namespace() + data = {"id": [1, 1, 3], "value": ["a", "b", "c"]} + table_id, _ = _make_table_with_parquet(schema_id, data) + + definition = DataContractDefinition( + columns=[ColumnContract(name="id", rules=[UniqueRule()])], + allow_extra_columns=True, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="unique_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + result = resp.json() + assert result["passed"] is False + unique_results = [r for r in result["rule_results"] if r["rule_name"] == "id:unique"] + assert len(unique_results) == 1 + assert unique_results[0]["violation_count"] == 2 + + def test_validate_extra_columns_rejected(self): + """Extra columns are flagged when allow_extra_columns is False.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + definition = DataContractDefinition( + columns=[ColumnContract(name="id", rules=[])], + allow_extra_columns=False, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="extra_cols_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + result = resp.json() + assert result["passed"] is False + unexpected = [r for r in result["rule_results"] if "unexpected" in r["rule_name"]] + # name, age are unexpected + assert len(unexpected) == 2 + + +# --------------------------------------------------------------------------- +# Validation triggers downstream +# --------------------------------------------------------------------------- + + +class TestValidationTriggersDownstream: + @staticmethod + def _setup_table_with_trigger(schema_id: int, schedule_type="table_trigger"): + """Create a table, flow, and schedule. Returns (table_id, flow_id, schedule_id, path).""" + data = {"id": [1, 2, 3], "name": ["a", "b", "c"]} + table_id, path = _make_table_with_parquet(schema_id, data) + + with get_db_context() as db: + flow = FlowRegistration( + name="downstream_flow", + flow_path="/tmp/downstream.yaml", + namespace_id=schema_id, + owner_id=1, + ) + db.add(flow) + db.commit() + db.refresh(flow) + + schedule = FlowSchedule( + registration_id=flow.id, + owner_id=1, + enabled=True, + schedule_type=schedule_type, + trigger_table_id=table_id if schedule_type == "table_trigger" else None, + ) + db.add(schedule) + db.commit() + db.refresh(schedule) + + return table_id, flow.id, schedule.id, path + + def test_successful_validation_fires_downstream_trigger(self, monkeypatch): + """When validation passes, downstream table_trigger schedules fire.""" + _, schema_id = _make_namespace() + table_id, flow_id, schedule_id, path = self._setup_table_with_trigger(schema_id) + + monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) + + # Create a contract that will pass + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ], + allow_extra_columns=False, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="trigger_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.status_code == 200 + assert resp.json()["passed"] is True + + # Verify downstream flow was triggered + with get_db_context() as db: + runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() + assert len(runs) == 1 + + def test_failed_validation_does_not_fire_downstream(self, monkeypatch): + """When validation fails, downstream triggers should NOT fire.""" + _, schema_id = _make_namespace() + table_id, flow_id, schedule_id, path = self._setup_table_with_trigger(schema_id) + + monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) + + # Create a contract that will fail (require min 100 rows) + definition = DataContractDefinition( + row_count=RowCountRule(min_rows=100), + allow_extra_columns=True, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="failing_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.status_code == 200 + assert resp.json()["passed"] is False + + # Verify no downstream runs were created + with get_db_context() as db: + runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() + assert len(runs) == 0 + + def test_overwrite_does_not_fire_triggers(self, monkeypatch): + """After moving triggers to validation, overwrite_table_data should NOT fire them.""" + _, schema_id = _make_namespace() + table_id, flow_id, schedule_id, path = self._setup_table_with_trigger(schema_id) + + monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.overwrite_table_data(table_id, path) + + runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() + assert len(runs) == 0 + + def test_validation_trigger_skips_active_run(self, monkeypatch): + """If the downstream flow already has an active run, validation should skip it.""" + from datetime import datetime, timezone + + _, schema_id = _make_namespace() + table_id, flow_id, schedule_id, path = self._setup_table_with_trigger(schema_id) + + monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) + + # Create an active (unfinished) run + with get_db_context() as db: + active_run = FlowRun( + registration_id=flow_id, + flow_name="downstream_flow", + flow_path="/tmp/downstream.yaml", + user_id=1, + started_at=datetime.now(timezone.utc), + number_of_nodes=0, + run_type="scheduled", + ) + db.add(active_run) + db.commit() + + # Create a passing contract + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ], + allow_extra_columns=False, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="skip_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.json()["passed"] is True + + # Only the pre-existing active run, no new one + with get_db_context() as db: + runs = db.query(FlowRun).filter_by(registration_id=flow_id, run_type="scheduled").all() + assert len(runs) == 1 + + def test_validation_trigger_updates_schedule_timestamps(self, monkeypatch): + """After validation fires a trigger, schedule timestamps should be updated.""" + _, schema_id = _make_namespace() + table_id, flow_id, schedule_id, path = self._setup_table_with_trigger(schema_id) + + monkeypatch.setattr(CatalogService, "_spawn_flow_subprocess", staticmethod(lambda *a, **kw: None)) + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ], + allow_extra_columns=False, + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="ts_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="active", + ) + + resp = client.post(f"/catalog/tables/{table_id}/contract/validate") + assert resp.json()["passed"] is True + + with get_db_context() as db: + schedule = db.query(FlowSchedule).get(schedule_id) + assert schedule.last_triggered_at is not None + assert schedule.last_trigger_table_updated_at is not None + + +# --------------------------------------------------------------------------- +# Contract generation +# --------------------------------------------------------------------------- + + +class TestContractGeneration: + def test_generate_contract_from_table(self): + """Auto-generates a contract definition based on table schema.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + resp = client.post(f"/catalog/tables/{table_id}/contract/generate") + assert resp.status_code == 200 + definition = resp.json() + assert "columns" in definition + col_names = [c["name"] for c in definition["columns"]] + assert "id" in col_names + assert "name" in col_names + assert "age" in col_names + + def test_generate_contract_for_missing_table(self): + resp = client.post("/catalog/tables/999999/contract/generate") + assert resp.status_code == 404 + + +# --------------------------------------------------------------------------- +# Contract summary +# --------------------------------------------------------------------------- + + +class TestContractSummary: + def test_summary_draft_status(self): + """A newly created contract should have 'draft' summary status.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="draft_contract", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + status="draft", + ) + summary = svc.get_contract_summary(table_id) + + assert summary is not None + assert summary["status"] == "draft" + + def test_summary_failed_status(self): + """A contract with last_validation_passed=False has 'failed' status.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="fail_contract", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + status="active", + ) + svc.mark_contract_validated(table_id=table_id, passed=False) + summary = svc.get_contract_summary(table_id) + + assert summary["status"] == "failed" + + def test_summary_returns_none_without_contract(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + summary = svc.get_contract_summary(table_id) + + assert summary is None + + def test_summary_rule_count(self): + """Summary should count rules across columns and general rules.""" + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + definition = DataContractDefinition( + columns=[ + ColumnContract(name="id", rules=[NotNullRule(), UniqueRule()]), + ColumnContract(name="name", rules=[NotNullRule()]), + ], + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="count_contract", + owner_id=1, + definition_json=definition.model_dump_json(), + status="draft", + ) + summary = svc.get_contract_summary(table_id) + + assert summary["rule_count"] == 3 + + +# --------------------------------------------------------------------------- +# Service-level contract operations +# --------------------------------------------------------------------------- + + +class TestContractServiceOps: + def test_create_and_get_via_service(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + contract = svc.create_contract( + table_id=table_id, + name="svc_contract", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + ) + assert contract.table_id == table_id + assert contract.version == 1 + + fetched = svc.get_contract(table_id) + assert fetched is not None + assert fetched.id == contract.id + + def test_update_contract_via_service(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + definition = DataContractDefinition( + columns=[ColumnContract(name="id", rules=[NotNullRule()])], + ) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="update_me", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + ) + updated = svc.update_contract( + table_id=table_id, + name="updated_name", + definition_json=definition.model_dump_json(), + ) + assert updated.name == "updated_name" + assert updated.version == 2 + + def test_delete_contract_via_service(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="delete_me", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + ) + svc.delete_contract(table_id) + assert svc.get_contract(table_id) is None + + def test_mark_contract_validated_via_service(self): + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + svc.create_contract( + table_id=table_id, + name="mark_me", + owner_id=1, + definition_json=DataContractDefinition().model_dump_json(), + status="active", + ) + contract = svc.mark_contract_validated(table_id=table_id, passed=True, version=5) + assert contract.last_validation_passed is True + assert contract.last_validated_version == 5 + assert contract.last_validated_at is not None + + def test_mark_validated_not_found(self): + from flowfile_core.catalog.exceptions import ContractNotFoundError + + _, schema_id = _make_namespace() + table_id, _ = _make_table_with_parquet(schema_id) + + with get_db_context() as db: + repo = SQLAlchemyCatalogRepository(db) + svc = CatalogService(repo) + with pytest.raises(ContractNotFoundError): + svc.mark_contract_validated(table_id=table_id, passed=True)