Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flowfile_core/flowfile_core/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from .exceptions import (
CatalogError,
ContractExistsError,
ContractNotFoundError,
FavoriteNotFoundError,
FlowAlreadyRunningError,
FlowExistsError,
Expand Down Expand Up @@ -55,4 +57,6 @@
"TableNotFoundError",
"TableExistsError",
"TableFavoriteNotFoundError",
"ContractNotFoundError",
"ContractExistsError",
]
16 changes: 16 additions & 0 deletions flowfile_core/flowfile_core/catalog/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,19 @@ def __init__(self, registration_id: int, schedule_type: str):
self.registration_id = registration_id
self.schedule_type = schedule_type
super().__init__(f"A '{schedule_type}' schedule already exists for flow {registration_id}")


class ContractNotFoundError(CatalogError):
"""Raised when a data contract lookup fails."""

def __init__(self, table_id: int):
self.table_id = table_id
super().__init__(f"No data contract found for table_id={table_id}")


class ContractExistsError(CatalogError):
"""Raised when attempting to create a duplicate data contract."""

def __init__(self, table_id: int):
self.table_id = table_id
super().__init__(f"A data contract already exists for table_id={table_id}")
54 changes: 42 additions & 12 deletions flowfile_core/flowfile_core/catalog/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CatalogNamespace,
CatalogTable,
CatalogTableReadLink,
DataContract,
FlowFavorite,
FlowFollow,
FlowRegistration,
Expand Down Expand Up @@ -245,6 +246,18 @@ def set_trigger_table_ids(self, schedule_id: int, table_ids: list[int]) -> None:

def delete_trigger_table_ids(self, schedule_id: int) -> None: ...

# -- Data Contracts ------------------------------------------------------

def get_contract_by_table(self, table_id: int) -> DataContract | None: ...

def create_contract(self, contract: DataContract) -> DataContract: ...

def update_contract(self, contract: DataContract) -> DataContract: ...

def delete_contract(self, table_id: int) -> None: ...

def list_contracts_for_tables(self, table_ids: list[int]) -> dict[int, DataContract]: ...


# ---------------------------------------------------------------------------
# SQLAlchemy implementation
Expand Down Expand Up @@ -603,9 +616,7 @@ def remove_table_favorite(self, user_id: int, table_id: int) -> None:
self._db.commit()

def list_table_favorites(self, user_id: int) -> list[TableFavorite]:
return (
self._db.query(TableFavorite).filter_by(user_id=user_id).order_by(TableFavorite.created_at.desc()).all()
)
return self._db.query(TableFavorite).filter_by(user_id=user_id).order_by(TableFavorite.created_at.desc()).all()

def count_table_favorites(self, user_id: int) -> int:
return self._db.query(TableFavorite).filter_by(user_id=user_id).count()
Expand Down Expand Up @@ -826,12 +837,7 @@ def count_schedules(self) -> int:

def list_active_runs(self) -> list[FlowRun]:
"""Return runs that have not yet ended (ended_at IS NULL)."""
return (
self._db.query(FlowRun)
.filter(FlowRun.ended_at.is_(None))
.order_by(FlowRun.started_at.desc())
.all()
)
return self._db.query(FlowRun).filter(FlowRun.ended_at.is_(None)).order_by(FlowRun.started_at.desc()).all()

def has_active_run(self, registration_id: int) -> bool:
"""Check if a flow already has an active (unfinished) run."""
Expand Down Expand Up @@ -873,9 +879,7 @@ def list_table_trigger_schedules_for_table(self, table_id: int) -> list[FlowSche
def get_trigger_table_ids(self, schedule_id: int) -> list[int]:
"""Return table IDs linked to a table_set_trigger schedule."""
rows = (
self._db.query(ScheduleTriggerTable.table_id)
.filter(ScheduleTriggerTable.schedule_id == schedule_id)
.all()
self._db.query(ScheduleTriggerTable.table_id).filter(ScheduleTriggerTable.schedule_id == schedule_id).all()
)
return [r[0] for r in rows]

Expand All @@ -890,3 +894,29 @@ def delete_trigger_table_ids(self, schedule_id: int) -> None:
"""Remove all trigger table links for a schedule."""
self._db.query(ScheduleTriggerTable).filter_by(schedule_id=schedule_id).delete()
self._db.commit()

# -- Data Contracts ------------------------------------------------------

def get_contract_by_table(self, table_id: int) -> DataContract | None:
return self._db.query(DataContract).filter_by(table_id=table_id).first()

def create_contract(self, contract: DataContract) -> DataContract:
self._db.add(contract)
self._db.commit()
self._db.refresh(contract)
return contract

def update_contract(self, contract: DataContract) -> DataContract:
self._db.commit()
self._db.refresh(contract)
return contract

def delete_contract(self, table_id: int) -> None:
self._db.query(DataContract).filter_by(table_id=table_id).delete()
self._db.commit()

def list_contracts_for_tables(self, table_ids: list[int]) -> dict[int, DataContract]:
if not table_ids:
return {}
rows = self._db.query(DataContract).filter(DataContract.table_id.in_(table_ids)).all()
return {c.table_id: c for c in rows}
190 changes: 185 additions & 5 deletions flowfile_core/flowfile_core/catalog/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
table_exists,
)
from flowfile_core.catalog.exceptions import (
ContractExistsError,
ContractNotFoundError,
FavoriteNotFoundError,
FlowAlreadyRunningError,
FlowHasArtifactsError,
Expand All @@ -47,6 +49,7 @@
from flowfile_core.database.models import (
CatalogNamespace,
CatalogTable,
DataContract,
FlowFavorite,
FlowFollow,
FlowRegistration,
Expand Down Expand Up @@ -1214,11 +1217,6 @@ def overwrite_table_data(

table = self.repo.update_table(table)

try:
self._fire_table_trigger_schedules(table.id, table.updated_at)
except Exception:
logger.exception("Failed to fire push triggers for table %s", table.id)

return self._table_to_out(table)

def _fire_table_trigger_schedules(self, table_id: int, table_updated_at: datetime) -> int:
Expand Down Expand Up @@ -1978,3 +1976,185 @@ def get_catalog_stats(self, user_id: int) -> CatalogStats:
favorite_tables=fav_tables,
active_runs=active_runs,
)

# ========================================================================
# Data Contracts
# ========================================================================

def get_contract(self, table_id: int) -> DataContract | None:
"""Return the contract for *table_id*, or ``None``."""
return self.repo.get_contract_by_table(table_id)

def create_contract(
self,
table_id: int,
name: str,
owner_id: int,
definition_json: str,
description: str | None = None,
status: str = "draft",
) -> DataContract:
"""Create a new contract for a catalog table."""
table = self.repo.get_table(table_id)
if table is None:
raise TableNotFoundError(table_id)
existing = self.repo.get_contract_by_table(table_id)
if existing is not None:
raise ContractExistsError(table_id)
contract = DataContract(
table_id=table_id,
name=name,
description=description,
definition_json=definition_json,
owner_id=owner_id,
status=status,
version=1,
)
return self.repo.create_contract(contract)

def update_contract(
self,
table_id: int,
name: str | None = None,
description: str | None = None,
definition_json: str | None = None,
status: str | None = None,
) -> DataContract:
"""Update an existing contract. Bumps the version number."""
contract = self.repo.get_contract_by_table(table_id)
if contract is None:
raise ContractNotFoundError(table_id)
if name is not None:
contract.name = name
if description is not None:
contract.description = description
if definition_json is not None:
contract.definition_json = definition_json
if status is not None:
contract.status = status
contract.version = (contract.version or 1) + 1
return self.repo.update_contract(contract)

def delete_contract(self, table_id: int) -> None:
"""Delete the contract attached to *table_id*."""
contract = self.repo.get_contract_by_table(table_id)
if contract is None:
raise ContractNotFoundError(table_id)
self.repo.delete_contract(table_id)

def mark_contract_validated(
self,
table_id: int,
passed: bool,
version: int | None = None,
) -> DataContract:
"""Set the validation state on the contract.

If *version* is not supplied, the current Delta version is read
from the table storage.
"""
contract = self.repo.get_contract_by_table(table_id)
if contract is None:
raise ContractNotFoundError(table_id)

if version is None:
table = self.repo.get_table(table_id)
if table is not None and is_delta_table(table.file_path):
dt = DeltaTable(table.file_path)
version = dt.version()

contract.last_validated_version = version
contract.last_validated_at = datetime.now(timezone.utc)
contract.last_validation_passed = passed
contract = self.repo.update_contract(contract)

# Fire downstream table triggers when validation passes
if passed:
try:
table = self.repo.get_table(table_id)
if table is not None:
self._fire_table_trigger_schedules(table.id, table.updated_at)
except Exception:
logger.exception("Failed to fire push triggers after validation for table %s", table_id)

return contract

def generate_contract_from_table(self, table_id: int) -> str:
"""Profile a catalog table and return a proposed ``DataContractDefinition`` as JSON.

Uses column metadata (dtype, null count, unique count) to infer rules.
"""
from flowfile_core.schemas.contract_schema import (
ColumnContract,
DataContractDefinition,
DtypeRule,
NotNullRule,
)

table = self.repo.get_table(table_id)
if table is None:
raise TableNotFoundError(table_id)

columns: list[ColumnContract] = []
if table.schema_json:
schema = json.loads(table.schema_json)
for col_info in schema:
rules = []
col_name = col_info.get("name", "")
col_dtype = col_info.get("dtype")
if col_dtype:
rules.append(DtypeRule(expected_dtype=col_dtype))
rules.append(NotNullRule())
columns.append(ColumnContract(name=col_name, rules=rules))

definition = DataContractDefinition(columns=columns, allow_extra_columns=False)
return definition.model_dump_json()

def get_contract_summary(self, table_id: int) -> dict | None:
"""Build a lightweight contract summary for a catalog table.

Returns ``None`` if no contract exists.
"""
from flowfile_core.schemas.contract_schema import ContractSummary

contract = self.repo.get_contract_by_table(table_id)
if contract is None:
return None

# Count rules
from flowfile_core.schemas.contract_schema import DataContractDefinition

try:
definition = DataContractDefinition.model_validate_json(contract.definition_json)
except Exception:
definition = DataContractDefinition()
rule_count = sum(len(c.rules) for c in definition.columns) + len(definition.general_rules)

# Determine status
if contract.status == "draft":
status = "draft"
elif contract.last_validation_passed is False:
status = "failed"
elif contract.last_validated_version is not None:
# Try to compare with current Delta version
table = self.repo.get_table(table_id)
current_version = None
if table is not None and is_delta_table(table.file_path):
try:
dt = DeltaTable(table.file_path)
current_version = dt.version()
except Exception:
pass
if current_version is not None and contract.last_validated_version >= current_version:
status = "validated"
else:
status = "stale"
else:
status = "stale"

return ContractSummary(
status=status,
last_validated_version=contract.last_validated_version,
current_version=None, # populated by caller if needed
rule_count=rule_count,
).model_dump()
12 changes: 12 additions & 0 deletions flowfile_core/flowfile_core/configs/node_store/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,18 @@ def get_all_standard_nodes() -> tuple[list[NodeTemplate], dict[str, NodeTemplate
drawer_title="Count Records",
drawer_intro="Calculate the total number of rows",
),
NodeTemplate(
name="Data validation",
item="data_validation",
input=1,
output=1,
transform_type="narrow",
node_type="process",
image="data_validation.png",
node_group="quality",
drawer_title="Validate Data",
drawer_intro="Check data against validation rules and add pass/fail columns",
),
NodeTemplate(
name="Cross join",
item="cross_join",
Expand Down
Loading
Loading