diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..6e035ff --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,17 @@ +{ + "permissions": { + "allow": [ + "Bash(git add:*)", + "Bash(git commit:*)", + "Bash(git push:*)", + "Bash(python:*)", + "Bash(make lint:*)", + "Bash(pip install:*)", + "Bash(ruff check:*)", + "Bash(ruff format:*)", + "Bash(python3:*)", + "Bash(pip3 install:*)", + "Bash(source:*)" + ] + } +} diff --git a/.env.example b/.env.example index 449f921..32380cb 100644 --- a/.env.example +++ b/.env.example @@ -15,7 +15,7 @@ CORS_ORIGINS=http://localhost:3000,http://localhost:8000 METADATA_DB_HOST=postgres METADATA_DB_PORT=5432 METADATA_DB_USER=observakit -METADATA_DB_PASSWORD=changeme +METADATA_DB_PASSWORD=your_secure_password_here METADATA_DB_NAME=observakit # --- Target Warehouse (the warehouse you're monitoring) --- @@ -24,7 +24,7 @@ WAREHOUSE_TYPE=postgres WAREHOUSE_HOST=host.docker.internal WAREHOUSE_PORT=5432 WAREHOUSE_USER=your_user -WAREHOUSE_PASSWORD=your_password +WAREHOUSE_PASSWORD=your_warehouse_password WAREHOUSE_DB=your_database WAREHOUSE_SCHEMA=public @@ -44,7 +44,7 @@ WAREHOUSE_SCHEMA=public # --- Airflow --- AIRFLOW_BASE_URL=http://localhost:8080 AIRFLOW_USERNAME=admin -AIRFLOW_PASSWORD=admin +AIRFLOW_PASSWORD=your_airflow_password # --- Prefect (uncomment if using Prefect) --- # PREFECT_API_URL=http://localhost:4200/api diff --git a/alembic/env.py b/alembic/env.py index ded1230..3cc8a6e 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -50,9 +50,7 @@ def run_migrations_online() -> None: ) with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata - ) + context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() diff --git a/alembic/versions/81c7c80e0f60_add_project_and_apikey_models.py b/alembic/versions/81c7c80e0f60_add_project_and_apikey_models.py index 02a1410..60ae941 100644 --- a/alembic/versions/81c7c80e0f60_add_project_and_apikey_models.py +++ b/alembic/versions/81c7c80e0f60_add_project_and_apikey_models.py @@ -5,15 +5,16 @@ Create Date: 2026-04-05 13:43:13.652164 """ + from typing import Sequence, Union -from alembic import op import sqlalchemy as sa -from sqlalchemy.dialects import postgresql + +from alembic import op # revision identifiers, used by Alembic. -revision: str = '81c7c80e0f60' -down_revision: Union[str, Sequence[str], None] = '002_biginteger_numeric_columns' +revision: str = "81c7c80e0f60" +down_revision: Union[str, Sequence[str], None] = "002_biginteger_numeric_columns" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -21,51 +22,57 @@ def upgrade() -> None: """Upgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.create_table('projects', - sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), - sa.Column('name', sa.String(length=100), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('name') + op.create_table( + "projects", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name"), ) - op.create_table('api_keys', - sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), - sa.Column('project_id', sa.Integer(), nullable=False), - sa.Column('hashed_key', sa.String(length=255), nullable=False), - sa.Column('role', sa.String(length=20), nullable=False), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), - sa.Column('is_active', sa.Boolean(), nullable=True), - sa.ForeignKeyConstraint(['project_id'], ['projects.id'], ), - sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('hashed_key') + op.create_table( + "api_keys", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("hashed_key", sa.String(length=255), nullable=False), + sa.Column("role", sa.String(length=20), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_active", sa.Boolean(), nullable=True), + sa.ForeignKeyConstraint( + ["project_id"], + ["projects.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("hashed_key"), ) # Insert an initial default project and admin key import hashlib import secrets - import os - + # Generate a random 32-character hex key initial_key = secrets.token_hex(16) hashed_key = hashlib.sha256(initial_key.encode()).hexdigest() - + # Insert project op.execute("INSERT INTO projects (name, created_at) VALUES ('default', CURRENT_TIMESTAMP)") # Insert API Key (project 1 is the default since it's autoincrementing and the only row we inserted) - op.execute(f"INSERT INTO api_keys (project_id, hashed_key, role, is_active, created_at) VALUES (1, '{hashed_key}', 'admin', true, CURRENT_TIMESTAMP)") - + op.execute( + f"INSERT INTO api_keys (project_id, hashed_key, role, is_active, created_at) VALUES (1, '{hashed_key}', 'admin', true, CURRENT_TIMESTAMP)" + ) + # Print the key to the console so the user gets it during init/upgrade! - print("\\n" + "="*60) + print("\\n" + "=" * 60) print("πŸš€ OBSERVAKIT RBAC ENABLED") print(f"πŸ”‘ Your initial Admin API Key is: {initial_key}") print("Store this safely! It will not be shown again.") - print("="*60 + "\\n") + print("=" * 60 + "\\n") # ### end Alembic commands ### def downgrade() -> None: """Downgrade schema.""" # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('api_keys') - op.drop_table('projects') + op.drop_table("api_keys") + op.drop_table("projects") # ### end Alembic commands ### diff --git a/alerts/base.py b/alerts/base.py index 5f72a96..5bfa30e 100644 --- a/alerts/base.py +++ b/alerts/base.py @@ -10,8 +10,14 @@ class AlertDispatcher(ABC): """Abstract base class for alert dispatchers.""" @abstractmethod - def send(self, message: str, subject: str = None, alert_type: str = None, - table_name: str = None, **kwargs) -> bool: + def send( + self, + message: str, + subject: str = None, + alert_type: str = None, + table_name: str = None, + **kwargs, + ) -> bool: """ Send an alert message. Returns True if successful, False otherwise. @@ -23,21 +29,27 @@ def get_alert_dispatcher(channel: str, **kwargs) -> AlertDispatcher: """Factory: return the appropriate alert dispatcher.""" if channel == "slack": from alerts.slack import SlackDispatcher + return SlackDispatcher(**kwargs) elif channel == "email": from alerts.email import EmailDispatcher + return EmailDispatcher(**kwargs) elif channel == "discord": from alerts.discord import DiscordDispatcher + return DiscordDispatcher(**kwargs) elif channel == "webhook": from alerts.webhook import WebhookDispatcher + return WebhookDispatcher(**kwargs) elif channel == "teams": from alerts.teams import TeamsDispatcher + return TeamsDispatcher(**kwargs) elif channel == "pagerduty": from alerts.pagerduty import PagerDutyDispatcher + return PagerDutyDispatcher(**kwargs) else: raise ValueError( @@ -46,7 +58,14 @@ def get_alert_dispatcher(channel: str, **kwargs) -> AlertDispatcher: ) -def dispatch_alert(alert_type: str, message: str, table_name: str = None, subject: str = None, db=None, severity: str = "fail"): +def dispatch_alert( + alert_type: str, + message: str, + table_name: str = None, + subject: str = None, + db=None, + severity: str = "fail", +): """ Dispatch an alert using routing rules from kit.yml. Uses load_config() so that ${VAR:-default} env vars are properly expanded. @@ -90,7 +109,9 @@ def dispatch_alert(alert_type: str, message: str, table_name: str = None, subjec kwargs = {k: v for k, v in rule.items() if k not in ["match", "channel"]} try: dispatcher = get_alert_dispatcher(channel, **kwargs) - if dispatcher.send(formatted_message, subject, alert_type=alert_type, table_name=table_name): + if dispatcher.send( + formatted_message, subject, alert_type=alert_type, table_name=table_name + ): dispatched = True used_channel = channel except Exception as e: @@ -101,21 +122,26 @@ def dispatch_alert(alert_type: str, message: str, table_name: str = None, subjec default_channel = config.get("alerts", {}).get("default_channel", "slack") try: dispatcher = get_alert_dispatcher(default_channel) - if dispatcher.send(formatted_message, subject, alert_type=alert_type, table_name=table_name): + if dispatcher.send( + formatted_message, subject, alert_type=alert_type, table_name=table_name + ): dispatched = True used_channel = default_channel except Exception as e: - logging.getLogger(__name__).error(f"Failed to send default alert via {default_channel}: {e}") + logging.getLogger(__name__).error( + f"Failed to send default alert via {default_channel}: {e}" + ) if dispatched and db: from backend.models import AlertLog + try: log = AlertLog( alert_type=alert_type, channel=used_channel, table_name=table_name, message=formatted_message, - success=True + success=True, ) db.add(log) db.commit() @@ -133,12 +159,17 @@ def is_alert_suppressed(db, table_name: str) -> bool: from backend.models import CheckSuppression - suppression = db.query(CheckSuppression).filter( - CheckSuppression.table_name == table_name, - CheckSuppression.suppressed_until >= datetime.now(timezone.utc), - ).first() + suppression = ( + db.query(CheckSuppression) + .filter( + CheckSuppression.table_name == table_name, + CheckSuppression.suppressed_until >= datetime.now(timezone.utc), + ) + .first() + ) if suppression: import logging + logging.getLogger(__name__).info( f"Alert suppressed for {table_name} until {suppression.suppressed_until} " f"β€” reason: {suppression.reason}" @@ -156,16 +187,20 @@ def is_alert_deduped(db, table_name: str, alert_type: str, window_minutes: int = from backend.models import AlertLog - recent = db.query(AlertLog).filter( - AlertLog.table_name == table_name, - AlertLog.alert_type == alert_type, - AlertLog.sent_at >= datetime.now(timezone.utc) - timedelta(minutes=window_minutes), - ).first() + recent = ( + db.query(AlertLog) + .filter( + AlertLog.table_name == table_name, + AlertLog.alert_type == alert_type, + AlertLog.sent_at >= datetime.now(timezone.utc) - timedelta(minutes=window_minutes), + ) + .first() + ) if recent: import logging + logging.getLogger(__name__).info( - f"Skipping duplicate {alert_type} alert for {table_name} " - f"(last sent {recent.sent_at})" + f"Skipping duplicate {alert_type} alert for {table_name} (last sent {recent.sent_at})" ) return True return False diff --git a/alerts/discord.py b/alerts/discord.py index 1091155..2189785 100644 --- a/alerts/discord.py +++ b/alerts/discord.py @@ -25,10 +25,10 @@ logger = logging.getLogger(__name__) # Discord message colour codes -COLOUR_OK = 0x57F287 # green -COLOUR_WARN = 0xFEE75C # yellow -COLOUR_FAIL = 0xED4245 # red -COLOUR_INFO = 0x5865F2 # blurple (default) +COLOUR_OK = 0x57F287 # green +COLOUR_WARN = 0xFEE75C # yellow +COLOUR_FAIL = 0xED4245 # red +COLOUR_INFO = 0x5865F2 # blurple (default) # Alert type β†’ colour mapping _ALERT_COLOURS = { @@ -47,7 +47,7 @@ class DiscordDispatcher(AlertDispatcher): def __init__(self, **kwargs): self._webhook_url = os.getenv("DISCORD_WEBHOOK_URL", "") - self._mention = os.getenv("DISCORD_MENTION", "") # e.g. "@here" or "<@&ROLE_ID>" + self._mention = os.getenv("DISCORD_MENTION", "") # e.g. "@here" or "<@&ROLE_ID>" def send(self, message: str, subject: str = None, alert_type: str = None, **kwargs) -> bool: if not self._webhook_url: @@ -94,4 +94,5 @@ def send(self, message: str, subject: str = None, alert_type: str = None, **kwar def _utc_now_iso() -> str: from datetime import datetime, timezone + return datetime.now(timezone.utc).isoformat() diff --git a/alerts/pagerduty.py b/alerts/pagerduty.py index 1b7555f..6f176c7 100644 --- a/alerts/pagerduty.py +++ b/alerts/pagerduty.py @@ -48,7 +48,9 @@ class PagerDutyDispatcher(AlertDispatcher): """Sends alerts using the PagerDuty Events API v2.""" def __init__(self, **kwargs): - self._routing_key = kwargs.get("pagerduty_routing_key") or os.getenv("PAGERDUTY_ROUTING_KEY", "") + self._routing_key = kwargs.get("pagerduty_routing_key") or os.getenv( + "PAGERDUTY_ROUTING_KEY", "" + ) # Allow overriding severity map via env var (JSON string) custom_map = os.getenv("PAGERDUTY_SEVERITY_MAP") if custom_map: @@ -103,9 +105,7 @@ def send( ) return True else: - logger.error( - "PagerDuty API returned %d: %s", resp.status_code, resp.text - ) + logger.error("PagerDuty API returned %d: %s", resp.status_code, resp.text) return False except Exception as exc: logger.error("Failed to send PagerDuty alert: %s", exc) diff --git a/alerts/slack.py b/alerts/slack.py index 81d46af..e0f832e 100644 --- a/alerts/slack.py +++ b/alerts/slack.py @@ -25,10 +25,10 @@ logger = logging.getLogger(__name__) _SEVERITY_COLOUR = { - "fail": "#d63031", # red - "warn": "#fdcb6e", # yellow - "info": "#74b9ff", # blue - "ok": "#00b894", # green + "fail": "#d63031", # red + "warn": "#fdcb6e", # yellow + "info": "#74b9ff", # blue + "ok": "#00b894", # green } @@ -50,7 +50,9 @@ def send( blocks: list = None, **kwargs, ) -> bool: - if not self._webhook_url or self._webhook_url.startswith("https://hooks.slack.com/services/YOUR"): + if not self._webhook_url or self._webhook_url.startswith( + "https://hooks.slack.com/services/YOUR" + ): logger.warning("Slack webhook URL not configured β€” skipping alert") return False @@ -145,19 +147,24 @@ def _post_with_retry(self, payload: dict, max_attempts: int = 3) -> bool: return True if resp.status_code == 429: - retry_after = int(resp.headers.get("Retry-After", 2 ** attempt)) + retry_after = int(resp.headers.get("Retry-After", 2**attempt)) logger.warning( "Slack rate limited (429) β€” retrying in %ds (attempt %d/%d)", - retry_after, attempt, max_attempts, + retry_after, + attempt, + max_attempts, ) time.sleep(retry_after) continue if resp.status_code >= 500: - backoff = 2 ** attempt + backoff = 2**attempt logger.warning( "Slack 5xx (%d) β€” retrying in %ds (attempt %d/%d)", - resp.status_code, backoff, attempt, max_attempts, + resp.status_code, + backoff, + attempt, + max_attempts, ) time.sleep(backoff) continue @@ -167,10 +174,13 @@ def _post_with_retry(self, payload: dict, max_attempts: int = 3) -> bool: return False except Exception as exc: - backoff = 2 ** attempt + backoff = 2**attempt logger.warning( "Slack request error: %s β€” retrying in %ds (attempt %d/%d)", - exc, backoff, attempt, max_attempts, + exc, + backoff, + attempt, + max_attempts, ) if attempt < max_attempts: time.sleep(backoff) diff --git a/alerts/teams.py b/alerts/teams.py index 2d5018c..56cdfe4 100644 --- a/alerts/teams.py +++ b/alerts/teams.py @@ -72,7 +72,9 @@ def send( { "type": "Action.OpenUrl", "title": "View in ObservaKit", - "url": os.getenv("OBSERVAKIT_DASHBOARD_URL", "http://localhost:8000/ui"), + "url": os.getenv( + "OBSERVAKIT_DASHBOARD_URL", "http://localhost:8000/ui" + ), } ] if os.getenv("OBSERVAKIT_DASHBOARD_URL") diff --git a/alerts/webhook.py b/alerts/webhook.py index c380ec2..c0faaa2 100644 --- a/alerts/webhook.py +++ b/alerts/webhook.py @@ -75,7 +75,7 @@ def __init__(self, **kwargs): self._severity_map = kwargs.get("webhook_severity_map", {}) # Extra headers to pass (e.g. Authorization for internal services) self._extra_headers: dict = {} - auth_header = os.getenv("WEBHOOK_AUTH_HEADER", "") # e.g. "Bearer my-token" + auth_header = os.getenv("WEBHOOK_AUTH_HEADER", "") # e.g. "Bearer my-token" if auth_header: self._extra_headers["Authorization"] = auth_header @@ -91,10 +91,7 @@ def send( logger.warning("Webhook URL not configured β€” skipping webhook alert") return False - severity = ( - self._severity_map.get(alert_type) - or _SEVERITY_MAP.get(alert_type, "info") - ) + severity = self._severity_map.get(alert_type) or _SEVERITY_MAP.get(alert_type, "info") payload = { "source": "observakit", diff --git a/backend/auth.py b/backend/auth.py index 3d3b26e..b1799e8 100644 --- a/backend/auth.py +++ b/backend/auth.py @@ -22,7 +22,7 @@ async def verify_api_key( request: Request, api_key: Optional[str] = Security(API_KEY_HEADER), - db: Session = Depends(get_db) + db: Session = Depends(get_db), ): """ Verify the API key: @@ -47,7 +47,9 @@ async def verify_api_key( # Dev mode fallback if not expected_legacy_key and db.query(ApiKey).count() == 0: - logger.warning("No API keys found and OBSERVAKIT_API_KEY is not set. Allowing access (DEV MODE).") + logger.warning( + "No API keys found and OBSERVAKIT_API_KEY is not set. Allowing access (DEV MODE)." + ) request.state.user_role = "super_admin" request.state.project_id = None return None diff --git a/backend/main.py b/backend/main.py index 4990508..083ada9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -45,6 +45,7 @@ async def lifespan(app: FastAPI): except Exception: logger.exception("Failed to run Alembic migrations β€” falling back to create_all.") from backend.models import Base, engine + Base.metadata.create_all(bind=engine) logger.info( "\n" @@ -115,9 +116,21 @@ async def lifespan(app: FastAPI): tags=["FinOps"], dependencies=[Depends(verify_api_key)], ) -app.include_router(profiling.router, prefix="/profiling", tags=["Column Profiling"], dependencies=[Depends(verify_api_key)]) -app.include_router(webhooks.router, prefix="/webhooks", tags=["Webhooks"], dependencies=[Depends(verify_api_key)]) -app.include_router(suppressions.router, prefix="/suppress", tags=["Suppressions"], dependencies=[Depends(verify_api_key)]) +app.include_router( + profiling.router, + prefix="/profiling", + tags=["Column Profiling"], + dependencies=[Depends(verify_api_key)], +) +app.include_router( + webhooks.router, prefix="/webhooks", tags=["Webhooks"], dependencies=[Depends(verify_api_key)] +) +app.include_router( + suppressions.router, + prefix="/suppress", + tags=["Suppressions"], + dependencies=[Depends(verify_api_key)], +) app.include_router( distribution.router, prefix="/distribution", @@ -144,6 +157,7 @@ async def serve_ui(full_path: str): return FileResponse(str(index)) return {"detail": "Dashboard not built. Run: make ui-build"} + @app.get("/", tags=["Health"]) async def root(): return { @@ -233,34 +247,38 @@ async def get_status(): # Freshness: latest status per table freshness_subq = ( db.query( - FreshnessRecord.table_name, - sqlfunc.max(FreshnessRecord.checked_at).label("latest") + FreshnessRecord.table_name, sqlfunc.max(FreshnessRecord.checked_at).label("latest") ) .filter(FreshnessRecord.checked_at >= cutoff_24h) .group_by(FreshnessRecord.table_name) .subquery() ) - freshness_latest = db.query(FreshnessRecord).join( - freshness_subq, - (FreshnessRecord.table_name == freshness_subq.c.table_name) & - (FreshnessRecord.checked_at == freshness_subq.c.latest) - ).all() + freshness_latest = ( + db.query(FreshnessRecord) + .join( + freshness_subq, + (FreshnessRecord.table_name == freshness_subq.c.table_name) + & (FreshnessRecord.checked_at == freshness_subq.c.latest), + ) + .all() + ) # Volume: latest anomaly status per table volume_subq = ( - db.query( - VolumeRecord.table_name, - sqlfunc.max(VolumeRecord.recorded_at).label("latest") - ) + db.query(VolumeRecord.table_name, sqlfunc.max(VolumeRecord.recorded_at).label("latest")) .filter(VolumeRecord.recorded_at >= cutoff_24h) .group_by(VolumeRecord.table_name) .subquery() ) - volume_latest = db.query(VolumeRecord).join( - volume_subq, - (VolumeRecord.table_name == volume_subq.c.table_name) & - (VolumeRecord.recorded_at == volume_subq.c.latest) - ).all() + volume_latest = ( + db.query(VolumeRecord) + .join( + volume_subq, + (VolumeRecord.table_name == volume_subq.c.table_name) + & (VolumeRecord.recorded_at == volume_subq.c.latest), + ) + .all() + ) # Quality: pass/fail counts per table in last 24h quality_rows = ( @@ -291,21 +309,48 @@ async def get_status(): tables: dict = {} for f in freshness_latest: - t = tables.setdefault(f.table_name, {"name": f.table_name, "freshness": "ok", "volume": "ok", - "quality": "ok", "schema": "ok", "last_checked": None}) + t = tables.setdefault( + f.table_name, + { + "name": f.table_name, + "freshness": "ok", + "volume": "ok", + "quality": "ok", + "schema": "ok", + "last_checked": None, + }, + ) t["freshness"] = f.status # ok | warn | fail t["last_checked"] = f.checked_at.isoformat() for v in volume_latest: - t = tables.setdefault(v.table_name, {"name": v.table_name, "freshness": "ok", "volume": "ok", - "quality": "ok", "schema": "ok", "last_checked": None}) + t = tables.setdefault( + v.table_name, + { + "name": v.table_name, + "freshness": "ok", + "volume": "ok", + "quality": "ok", + "schema": "ok", + "last_checked": None, + }, + ) t["volume"] = "fail" if v.is_anomaly else "ok" if not t["last_checked"] or v.recorded_at.isoformat() > t["last_checked"]: t["last_checked"] = v.recorded_at.isoformat() for q in quality_rows: - t = tables.setdefault(q.table_name, {"name": q.table_name, "freshness": "ok", "volume": "ok", - "quality": "ok", "schema": "ok", "last_checked": None}) + t = tables.setdefault( + q.table_name, + { + "name": q.table_name, + "freshness": "ok", + "volume": "ok", + "quality": "ok", + "schema": "ok", + "last_checked": None, + }, + ) passed = int(q.passed or 0) total = int(q.total or 0) rate = (passed / total) if total > 0 else 1.0 @@ -315,8 +360,17 @@ async def get_status(): t["last_checked"] = q.latest.isoformat() for s in schema_rows: - t = tables.setdefault(s.table_name, {"name": s.table_name, "freshness": "ok", "volume": "ok", - "quality": "ok", "schema": "ok", "last_checked": None}) + t = tables.setdefault( + s.table_name, + { + "name": s.table_name, + "freshness": "ok", + "volume": "ok", + "quality": "ok", + "schema": "ok", + "last_checked": None, + }, + ) t["schema"] = "fail" if int(s.drifts) > 0 else "ok" if not t["last_checked"] or s.latest.isoformat() > t["last_checked"]: t["last_checked"] = s.latest.isoformat() @@ -340,9 +394,9 @@ def _worst(row): } # Active suppressions - active_suppressions = db.query(CheckSuppression).filter( - CheckSuppression.suppressed_until >= now - ).count() + active_suppressions = ( + db.query(CheckSuppression).filter(CheckSuppression.suppressed_until >= now).count() + ) # Last run timestamps per pillar (global) last_freshness = db.query(sqlfunc.max(FreshnessRecord.checked_at)).scalar() @@ -356,7 +410,9 @@ def _worst(row): "summary": summary, "tables": table_list, "pillars": { - "freshness": {"last_checked": last_freshness.isoformat() if last_freshness else None}, + "freshness": { + "last_checked": last_freshness.isoformat() if last_freshness else None + }, "volume": {"last_checked": last_volume.isoformat() if last_volume else None}, "quality": {"last_checked": last_quality.isoformat() if last_quality else None}, "schema": {"last_detected": last_schema.isoformat() if last_schema else None}, @@ -367,7 +423,6 @@ def _worst(row): db.close() - @app.get("/scheduler/jobs", tags=["Scheduler"], dependencies=[Depends(verify_api_key)]) async def scheduler_jobs(): """ diff --git a/backend/models.py b/backend/models.py index 212ce7f..083f158 100644 --- a/backend/models.py +++ b/backend/models.py @@ -51,11 +51,11 @@ def get_db(): db.close() - # ============================================================================= # Models # ============================================================================= + class Project(Base): """A logical grouping of checks, integrations, and alerts.""" @@ -106,7 +106,9 @@ class VolumeRecord(Base): id = Column(Integer, primary_key=True, index=True) table_name = Column(String(255), nullable=False, index=True) dag_id = Column(String(255), nullable=True) - row_count = Column(BigInteger, nullable=False) # BigInteger supports tables with billions of rows + row_count = Column( + BigInteger, nullable=False + ) # BigInteger supports tables with billions of rows rolling_avg = Column(Float, nullable=True) deviation_pct = Column(Float, nullable=True) is_anomaly = Column(Boolean, default=False) @@ -123,7 +125,9 @@ class CheckResult(Base): table_name = Column(String(255), nullable=False, index=True) check_type = Column(String(100), nullable=False) # soda | great_expectations | custom_sql passed = Column(Boolean, nullable=False) - metric_value = Column(Numeric(precision=20, scale=6), nullable=True) # Numeric avoids float precision loss + metric_value = Column( + Numeric(precision=20, scale=6), nullable=True + ) # Numeric avoids float precision loss details = Column(Text, nullable=True) executed_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), nullable=False) @@ -236,8 +240,8 @@ class DistributionDrift(Base): column_name = Column(String(255), nullable=False) # null_pct_change | value_share_shift | mean_shift drift_type = Column(String(100), nullable=False) - previous_value = Column(Text, nullable=True) # Human-readable description of old state - current_value = Column(Text, nullable=True) # Human-readable description of new state + previous_value = Column(Text, nullable=True) # Human-readable description of old state + current_value = Column(Text, nullable=True) # Human-readable description of new state change_magnitude = Column(Float, nullable=True) # Fraction (0-1); multiply by 100 for % detected_at = Column(DateTime, default=lambda: datetime.now(timezone.utc), nullable=False) diff --git a/backend/requirements.txt b/backend/requirements.txt index 029c77c..c70c1f0 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,5 +9,6 @@ httpx>=0.27.0 pyyaml>=6.0.1 python-dotenv>=1.0.1 alembic>=1.13.0 +tenacity>=8.2.0 soda-core-postgres>=3.0.0 great-expectations>=0.18.0 diff --git a/backend/routers/checks.py b/backend/routers/checks.py index ede48c4..af96826 100644 --- a/backend/routers/checks.py +++ b/backend/routers/checks.py @@ -74,9 +74,7 @@ def _safe_eval_assertion(expression: str, result_value) -> bool: f"Assertion must be a simple comparison (e.g. 'result == 0'), got: {expression!r}" ) if len(node.ops) != 1 or len(node.comparators) != 1: - raise ValueError( - f"Only single comparisons are supported, got: {expression!r}" - ) + raise ValueError(f"Only single comparisons are supported, got: {expression!r}") left = node.left op = node.ops[0] @@ -84,9 +82,7 @@ def _safe_eval_assertion(expression: str, result_value) -> bool: # Left side must be the name 'result' if not (isinstance(left, ast.Name) and left.id == "result"): - raise ValueError( - f"Left side of assertion must be 'result', got: {ast.dump(left)!r}" - ) + raise ValueError(f"Left side of assertion must be 'result', got: {ast.dump(left)!r}") # Right side must be a numeric constant (int or float) if isinstance(right, ast.Constant) and isinstance(right.value, (int, float)): @@ -145,10 +141,12 @@ def get_check_trends( """ cutoff = datetime.now(timezone.utc) - timedelta(days=days) - results = db.query(CheckResult).filter( - CheckResult.table_name == table_name, - CheckResult.executed_at >= cutoff - ).order_by(CheckResult.executed_at.asc()).all() + results = ( + db.query(CheckResult) + .filter(CheckResult.table_name == table_name, CheckResult.executed_at >= cutoff) + .order_by(CheckResult.executed_at.asc()) + .all() + ) if not results: return {"table": table_name, "message": "No data for selected period"} @@ -186,7 +184,7 @@ def get_check_trends( "period_days": days, "overall_pass_rate": round(pass_rate, 2), "current_failure_streak": current_streak, - "history": history + "history": history, } @@ -229,7 +227,7 @@ def run_quality_checks(dry_run: bool = False, db: Session = Depends(get_db)): passed=res["passed"], metric_value=res.get("metric_value"), details=res.get("details"), - executed_at=datetime.now(timezone.utc) + executed_at=datetime.now(timezone.utc), ) db.add(record) @@ -267,18 +265,17 @@ def run_quality_checks(dry_run: bool = False, db: Session = Depends(get_db)): table_name=table, check_type="custom_sql", passed=passed, - metric_value=float(result_value) if isinstance(result_value, (int, float)) else None, + metric_value=float(result_value) + if isinstance(result_value, (int, float)) + else None, details=f"Query: {query.strip()[:100]}... | Result: {result_value}", - executed_at=datetime.now(timezone.utc) + executed_at=datetime.now(timezone.utc), ) db.add(record) - all_results.append({ - "check_name": name, - "table_name": table, - "passed": passed, - "engine": "custom_sql" - }) + all_results.append( + {"check_name": name, "table_name": table, "passed": passed, "engine": "custom_sql"} + ) if not passed and not dry_run: # Trigger lineage-aware alert @@ -291,7 +288,7 @@ def run_quality_checks(dry_run: bool = False, db: Session = Depends(get_db)): subject=f"❌ Quality Check Failed: {name}", message=f"Table: {table}\nCheck: {name}\nResult: {result_value}\nAssertion: {assertion}{impact_msg}", db=db, - severity="fail" + severity="fail", ) except Exception as e: @@ -304,7 +301,12 @@ def run_quality_checks(dry_run: bool = False, db: Session = Depends(get_db)): if not dry_run: db.commit() - return {"engine": engine_name, "checks_run": len(all_results), "results": all_results, "dry_run": dry_run} + return { + "engine": engine_name, + "checks_run": len(all_results), + "results": all_results, + "dry_run": dry_run, + } def _run_soda_check(check_file: str, connector) -> list[dict]: @@ -332,13 +334,27 @@ def _run_soda_check(check_file: str, connector) -> list[dict]: return _parse_soda_json_output(result.stdout, result.returncode, check_file) except FileNotFoundError: - logger.error("'soda' CLI not found. Install soda-core-postgres with: pip install soda-core-postgres") - return [{"check_name": "soda_init", "table_name": "unknown", "passed": False, - "details": "soda CLI not found β€” install soda-core-postgres"}] + logger.error( + "'soda' CLI not found. Install soda-core-postgres with: pip install soda-core-postgres" + ) + return [ + { + "check_name": "soda_init", + "table_name": "unknown", + "passed": False, + "details": "soda CLI not found β€” install soda-core-postgres", + } + ] except subprocess.TimeoutExpired: logger.error(f"Soda scan timed out for {check_file}") - return [{"check_name": "soda_timeout", "table_name": "unknown", "passed": False, - "details": f"Scan timed out after 300s: {check_file}"}] + return [ + { + "check_name": "soda_timeout", + "table_name": "unknown", + "passed": False, + "details": f"Scan timed out after 300s: {check_file}", + } + ] finally: os.unlink(cfg_path) @@ -348,12 +364,22 @@ def _parse_soda_json_output(stdout: str, returncode: int, check_file: str) -> li results = [] try: # Soda prints one JSON object per line in some versions, or a single array - lines = [line.strip() for line in stdout.splitlines() if line.strip().startswith("{") or line.strip().startswith("[")] + lines = [ + line.strip() + for line in stdout.splitlines() + if line.strip().startswith("{") or line.strip().startswith("[") + ] if not lines: # No JSON output β€” treat as failed scan logger.warning(f"No JSON output from soda scan of {check_file} (rc={returncode})") - return [{"check_name": "soda_scan", "table_name": "unknown", - "passed": returncode == 0, "details": "No structured output from soda"}] + return [ + { + "check_name": "soda_scan", + "table_name": "unknown", + "passed": returncode == 0, + "details": "No structured output from soda", + } + ] raw = json.loads("\n".join(lines)) if len(lines) == 1 else json.loads(lines[0]) @@ -361,17 +387,27 @@ def _parse_soda_json_output(stdout: str, returncode: int, check_file: str) -> li checks = raw.get("checks", []) if isinstance(raw, dict) else raw for check in checks: outcome = check.get("outcome", "fail").lower() - results.append({ - "check_name": check.get("name", "unnamed"), - "table_name": check.get("table", "unknown"), - "passed": outcome == "pass", - "metric_value": check.get("measured_value"), - "details": check.get("definition", ""), - }) + results.append( + { + "check_name": check.get("name", "unnamed"), + "table_name": check.get("table", "unknown"), + "passed": outcome == "pass", + "metric_value": check.get("measured_value"), + "details": check.get("definition", ""), + } + ) except (json.JSONDecodeError, KeyError) as e: - logger.warning(f"Could not parse Soda JSON output: {e} β€” treating as {'pass' if returncode == 0 else 'fail'}") - results.append({"check_name": "soda_scan", "table_name": "unknown", - "passed": returncode == 0, "details": stdout[:500]}) + logger.warning( + f"Could not parse Soda JSON output: {e} β€” treating as {'pass' if returncode == 0 else 'fail'}" + ) + results.append( + { + "check_name": "soda_scan", + "table_name": "unknown", + "passed": returncode == 0, + "details": stdout[:500], + } + ) return results @@ -386,15 +422,17 @@ def _run_gx_check(check_file: str, connector) -> list[dict]: f"Great Expectations engine selected for {check_file} but is not yet implemented. " "Switch to engine: soda in kit.yml or implement GX execution logic." ) - return [{ - "check_name": "gx_not_implemented", - "table_name": "unknown", - "passed": False, - "details": ( - "Great Expectations execution is not implemented. " - "Use engine: soda in kit.yml or add a GX runner." - ), - }] + return [ + { + "check_name": "gx_not_implemented", + "table_name": "unknown", + "passed": False, + "details": ( + "Great Expectations execution is not implemented. " + "Use engine: soda in kit.yml or add a GX runner." + ), + } + ] @router.post("/volume") @@ -494,7 +532,12 @@ def run_volume_checks(db: Session = Depends(get_db), connector=None): # Alert on anomaly if is_anomaly: _trigger_volume_alert( - table_name, current_count, rolling_avg, deviation, table_cfg.get("alert", "slack"), db + table_name, + current_count, + rolling_avg, + deviation, + table_cfg.get("alert", "slack"), + db, ) except Exception as e: @@ -505,7 +548,9 @@ def run_volume_checks(db: Session = Depends(get_db), connector=None): return {"checked": len(results), "results": results} -def _trigger_volume_alert(table: str, count: int, avg: float, deviation: float, channel: str, db: Session): +def _trigger_volume_alert( + table: str, count: int, avg: float, deviation: float, channel: str, db: Session +): """Dispatch a volume anomaly alert.""" downstream = get_lineage_impact(table) impact_msg = f"\n⚠️ Downstream impact: {', '.join(downstream)}" if downstream else "" @@ -525,7 +570,7 @@ def _trigger_volume_alert(table: str, count: int, avg: float, deviation: float, subject=f"πŸ”΄ Volume Anomaly: {table}", message=message, db=db, - severity="warn" + severity="warn", ) @@ -587,8 +632,12 @@ def run_consistency_checks(connector, db: Session) -> list[dict]: col_b = check_cfg["column_b"] tolerance_pct = float(check_cfg.get("tolerance_pct", 0.0)) - sum_a_rows = connector.execute_query(f"SELECT COALESCE(SUM({col_a}), 0) as total FROM {table_a}") - sum_b_rows = connector.execute_query(f"SELECT COALESCE(SUM({col_b}), 0) as total FROM {table_b}") + sum_a_rows = connector.execute_query( + f"SELECT COALESCE(SUM({col_a}), 0) as total FROM {table_a}" + ) + sum_b_rows = connector.execute_query( + f"SELECT COALESCE(SUM({col_b}), 0) as total FROM {table_b}" + ) sum_a = float(sum_a_rows[0]["total"]) if sum_a_rows else 0.0 sum_b = float(sum_b_rows[0]["total"]) if sum_b_rows else 0.0 max_val = max(abs(sum_a), abs(sum_b), 1.0) @@ -612,7 +661,9 @@ def run_consistency_checks(connector, db: Session) -> list[dict]: executed_at=datetime.now(timezone.utc), ) db.add(record) - results.append({"check_name": name, "check_type": check_type, "passed": passed, "details": details}) + results.append( + {"check_name": name, "check_type": check_type, "passed": passed, "details": details} + ) if not passed: dispatch_alert( @@ -621,7 +672,7 @@ def run_consistency_checks(connector, db: Session) -> list[dict]: subject=f"❌ Consistency Check Failed: {name}", message=f"Consistency check failed: {name}\n{details}", db=db, - severity="fail" + severity="fail", ) except Exception as e: diff --git a/backend/routers/contracts.py b/backend/routers/contracts.py index d276192..bfa6805 100644 --- a/backend/routers/contracts.py +++ b/backend/routers/contracts.py @@ -66,6 +66,7 @@ from alerts.base import dispatch_alert from backend.models import ContractValidationResult, get_db from backend.routers.checks import _safe_eval_assertion +from backend.security import is_safe_identifier, is_safe_table_reference logger = logging.getLogger(__name__) @@ -78,6 +79,7 @@ # API Endpoints # --------------------------------------------------------------------------- + @router.post("/validate") def validate_contracts( contract_id: Optional[str] = None, @@ -92,6 +94,7 @@ def validate_contracts( return {"message": f"No contract files found in {CONTRACTS_DIR}"} from connectors.base import get_warehouse_connector + connector = get_warehouse_connector() all_results = [] @@ -116,20 +119,47 @@ def validate_contracts( violations = [] + # --- Validate table reference before use --- + if not is_safe_table_reference(table): + logger.error(f"Contract {cid}: invalid table reference '{table}' β€” skipping") + violations.append( + { + "rule": "schema_fetch", + "passed": False, + "detail": f"Invalid table reference: {table}", + } + ) + continue + # --- Column presence and type checks --- try: live_schema = {col["name"]: col for col in connector.get_schema(table)} except Exception as e: logger.error(f"Could not fetch schema for {table}: {e}") - violations.append({ - "rule": "schema_fetch", - "passed": False, - "detail": f"Could not connect to table: {e}", - }) + violations.append( + { + "rule": "schema_fetch", + "passed": False, + "detail": f"Could not connect to table: {e}", + } + ) live_schema = {} for col_spec in contract.get("columns", []): col_name = col_spec["name"] + # Validate column name before use in SQL + if not col_name or not is_safe_identifier(col_name): + logger.error( + f"Contract {cid}: invalid column name '{col_name}' β€” skipping column checks" + ) + violations.append( + { + "rule": f"column_check:{col_name}", + "passed": False, + "detail": f"Invalid column name: {col_name}", + } + ) + continue expected_type = col_spec.get("type", "").lower() nullable = col_spec.get("nullable", True) unique = col_spec.get("unique", False) @@ -139,25 +169,31 @@ def validate_contracts( # Column existence if col_name not in live_schema: - violations.append({ - "rule": f"column_exists:{col_name}", - "passed": False, - "detail": f"Column '{col_name}' is missing from {table}", - }) + violations.append( + { + "rule": f"column_exists:{col_name}", + "passed": False, + "detail": f"Column '{col_name}' is missing from {table}", + } + ) continue live_col = live_schema[col_name] # Type check (case-insensitive prefix match β€” e.g. "integer" matches "integer4") - if expected_type and not live_col.get("type", "").lower().startswith(expected_type.replace(" ", "_")): - violations.append({ - "rule": f"column_type:{col_name}", - "passed": False, - "detail": ( - f"Column '{col_name}' expected type '{expected_type}', " - f"got '{live_col.get('type')}'" - ), - }) + if expected_type and not live_col.get("type", "").lower().startswith( + expected_type.replace(" ", "_") + ): + violations.append( + { + "rule": f"column_type:{col_name}", + "passed": False, + "detail": ( + f"Column '{col_name}' expected type '{expected_type}', " + f"got '{live_col.get('type')}'" + ), + } + ) # Nullable constraint if not nullable: @@ -166,13 +202,17 @@ def validate_contracts( f"SELECT COUNT(*) AS cnt FROM {table} WHERE {col_name} IS NULL" ) null_count = int(null_result[0]["cnt"]) if null_result else 0 - violations.append({ - "rule": f"not_null:{col_name}", - "passed": null_count == 0, - "detail": f"{null_count} null values found in '{col_name}' (must be 0)", - }) + violations.append( + { + "rule": f"not_null:{col_name}", + "passed": null_count == 0, + "detail": f"{null_count} null values found in '{col_name}' (must be 0)", + } + ) except Exception as e: - violations.append({"rule": f"not_null:{col_name}", "passed": False, "detail": str(e)}) + violations.append( + {"rule": f"not_null:{col_name}", "passed": False, "detail": str(e)} + ) # Uniqueness constraint if unique: @@ -186,13 +226,17 @@ def validate_contracts( """ ) dup_count = int(dup_result[0]["cnt"]) if dup_result else 0 - violations.append({ - "rule": f"unique:{col_name}", - "passed": dup_count == 0, - "detail": f"{dup_count} duplicate values found in '{col_name}'", - }) + violations.append( + { + "rule": f"unique:{col_name}", + "passed": dup_count == 0, + "detail": f"{dup_count} duplicate values found in '{col_name}'", + } + ) except Exception as e: - violations.append({"rule": f"unique:{col_name}", "passed": False, "detail": str(e)}) + violations.append( + {"rule": f"unique:{col_name}", "passed": False, "detail": str(e)} + ) # Allowed values check if allowed_values: @@ -209,16 +253,20 @@ def validate_contracts( """ ) invalid_count = int(invalid_result[0]["cnt"]) if invalid_result else 0 - violations.append({ - "rule": f"allowed_values:{col_name}", - "passed": invalid_count == 0, - "detail": ( - f"{invalid_count} rows have values outside allowed set " - f"{allowed_values} in '{col_name}'" - ), - }) + violations.append( + { + "rule": f"allowed_values:{col_name}", + "passed": invalid_count == 0, + "detail": ( + f"{invalid_count} rows have values outside allowed set " + f"{allowed_values} in '{col_name}'" + ), + } + ) except Exception as e: - violations.append({"rule": f"allowed_values:{col_name}", "passed": False, "detail": str(e)}) + violations.append( + {"rule": f"allowed_values:{col_name}", "passed": False, "detail": str(e)} + ) # Min / max range if min_val is not None: @@ -227,13 +275,17 @@ def validate_contracts( f"SELECT COUNT(*) AS cnt FROM {table} WHERE {col_name} < {min_val}" ) below_count = int(below_result[0]["cnt"]) if below_result else 0 - violations.append({ - "rule": f"min_value:{col_name}", - "passed": below_count == 0, - "detail": f"{below_count} rows have {col_name} < {min_val}", - }) + violations.append( + { + "rule": f"min_value:{col_name}", + "passed": below_count == 0, + "detail": f"{below_count} rows have {col_name} < {min_val}", + } + ) except Exception as e: - violations.append({"rule": f"min_value:{col_name}", "passed": False, "detail": str(e)}) + violations.append( + {"rule": f"min_value:{col_name}", "passed": False, "detail": str(e)} + ) if max_val is not None: try: @@ -241,13 +293,17 @@ def validate_contracts( f"SELECT COUNT(*) AS cnt FROM {table} WHERE {col_name} > {max_val}" ) above_count = int(above_result[0]["cnt"]) if above_result else 0 - violations.append({ - "rule": f"max_value:{col_name}", - "passed": above_count == 0, - "detail": f"{above_count} rows have {col_name} > {max_val}", - }) + violations.append( + { + "rule": f"max_value:{col_name}", + "passed": above_count == 0, + "detail": f"{above_count} rows have {col_name} > {max_val}", + } + ) except Exception as e: - violations.append({"rule": f"max_value:{col_name}", "passed": False, "detail": str(e)}) + violations.append( + {"rule": f"max_value:{col_name}", "passed": False, "detail": str(e)} + ) # --- Volume check --- volume_spec = contract.get("volume", {}) @@ -256,19 +312,36 @@ def validate_contracts( count_result = connector.execute_query(f"SELECT COUNT(*) AS cnt FROM {table}") row_count = int(count_result[0]["cnt"]) if count_result else 0 min_rows = int(volume_spec["min_rows"]) - violations.append({ - "rule": "min_rows", - "passed": row_count >= min_rows, - "detail": f"Row count {row_count:,} (required β‰₯ {min_rows:,})", - }) + violations.append( + { + "rule": "min_rows", + "passed": row_count >= min_rows, + "detail": f"Row count {row_count:,} (required β‰₯ {min_rows:,})", + } + ) except Exception as e: violations.append({"rule": "min_rows", "passed": False, "detail": str(e)}) # --- Custom SQL rules --- + # Only allow SELECT queries for custom rules (no DDL, DML, or dangerous statements) + allowed_start_patterns = ("select", "with") # Case-insensitive, lowercase for comparison for rule_spec in contract.get("rules", []): rule_name = rule_spec.get("name", "custom_rule") - sql = rule_spec.get("sql", "") + sql = rule_spec.get("sql", "").strip() assertion = rule_spec.get("assert", "result == 0") + + # Validate SQL starts with SELECT or WITH + sql_lower = sql.lower() + if not any(sql_lower.startswith(pattern) for pattern in allowed_start_patterns): + violations.append( + { + "rule": rule_name, + "passed": False, + "detail": f"Custom SQL rules must start with SELECT or WITH; got: {sql[:50]}", + } + ) + continue + try: rule_result = connector.execute_query(sql) result_value = 0 @@ -280,11 +353,13 @@ def validate_contracts( except Exception: passed = False - violations.append({ - "rule": rule_name, - "passed": passed, - "detail": f"SQL returned {result_value}; assertion '{assertion}' β†’ {passed}", - }) + violations.append( + { + "rule": rule_name, + "passed": passed, + "detail": f"SQL returned {result_value}; assertion '{assertion}' β†’ {passed}", + } + ) except Exception as e: violations.append({"rule": rule_name, "passed": False, "detail": str(e)}) @@ -317,18 +392,20 @@ def validate_contracts( f"Failed rules: {', '.join(failed_rules)}" ), db=db, - severity="fail" + severity="fail", ) - all_results.append({ - "contract_id": cid, - "version": contract.get("version"), - "table": table, - "passed": contract_passed, - "total_rules": total, - "passed_rules": passed_count, - "violations": [v for v in violations if not v["passed"]], - }) + all_results.append( + { + "contract_id": cid, + "version": contract.get("version"), + "table": table, + "passed": contract_passed, + "total_rules": total, + "passed_rules": passed_count, + "violations": [v for v in violations if not v["passed"]], + } + ) db.commit() return {"contracts_validated": len(all_results), "results": all_results} @@ -342,9 +419,8 @@ def get_contract_results( db: Session = Depends(get_db), ): """Return recent contract validation results.""" - query = ( - db.query(ContractValidationResult) - .order_by(ContractValidationResult.validated_at.desc()) + query = db.query(ContractValidationResult).order_by( + ContractValidationResult.validated_at.desc() ) if contract_id: query = query.filter(ContractValidationResult.contract_id == contract_id) diff --git a/backend/routers/distribution.py b/backend/routers/distribution.py index f5e8c10..41d3677 100644 --- a/backend/routers/distribution.py +++ b/backend/routers/distribution.py @@ -43,14 +43,15 @@ ["table", "column"], ) -TOP_N = 20 # How many top values to track for categoricals -BUCKET_COUNT = 10 # Histogram buckets for numerics +TOP_N = 20 # How many top values to track for categoricals +BUCKET_COUNT = 10 # Histogram buckets for numerics # --------------------------------------------------------------------------- # API Endpoints # --------------------------------------------------------------------------- + @router.post("/snapshot") def take_distribution_snapshot( table_name: Optional[str] = None, @@ -77,6 +78,7 @@ def take_distribution_snapshot( return {"message": "No tables configured for distribution monitoring"} from connectors.base import get_warehouse_connector + connector = get_warehouse_connector() results = [] @@ -152,16 +154,18 @@ def take_distribution_snapshot( f"Current: {drift_result.get('current_value')}" ), db=db, - severity="warn" + severity="warn", ) - results.append({ - "table": tname, - "column": col_name, - "type": col_type, - "snapshot_taken": True, - "drift": drift_result, - }) + results.append( + { + "table": tname, + "column": col_name, + "type": col_type, + "snapshot_taken": True, + "drift": drift_result, + } + ) except Exception as e: logger.error(f"Distribution snapshot failed for {tname}.{col_name}: {e}") @@ -238,6 +242,7 @@ def get_distribution_history( # Private helpers # --------------------------------------------------------------------------- + def _snapshot_categorical(connector, table: str, column: str, top_n: int) -> dict: """ Return top-N value distribution for a categorical column. @@ -354,12 +359,14 @@ def _snapshot_numeric(connector, table: str, column: str) -> dict: WHERE {column} >= {low} AND {column} {op} {high} """ ) - histogram.append({ - "bucket": i + 1, - "range_low": round(low, 4), - "range_high": round(high, 4), - "count": int(count_result[0]["cnt"]) if count_result else 0, - }) + histogram.append( + { + "bucket": i + 1, + "range_low": round(low, 4), + "range_high": round(high, 4), + "count": int(count_result[0]["cnt"]) if count_result else 0, + } + ) return { "total_rows": total_rows, diff --git a/backend/routers/finops.py b/backend/routers/finops.py index 68214be..6e65411 100644 --- a/backend/routers/finops.py +++ b/backend/routers/finops.py @@ -18,11 +18,10 @@ router = APIRouter() finops_cost_gauge = Gauge( - "observakit_finops_costs", - "Compute cost (credits/bytes) over the last N days", - ["warehouse"] + "observakit_finops_costs", "Compute cost (credits/bytes) over the last N days", ["warehouse"] ) + @router.post("/poll", dependencies=[Depends(verify_api_key)]) def poll_finops_costs(days: int = 7, db: Session = Depends(get_db)): """ @@ -43,11 +42,7 @@ def poll_finops_costs(days: int = 7, db: Session = Depends(get_db)): cost = connector.get_compute_costs(days=days) finops_cost_gauge.labels(warehouse=warehouse_type).set(cost) - return { - "warehouse": warehouse_type, - "cost_tracked": cost, - "period_days": days - } + return {"warehouse": warehouse_type, "cost_tracked": cost, "period_days": days} except Exception as e: logger.error(f"Failed to poll FinOps costs: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/routers/freshness.py b/backend/routers/freshness.py index 7a3b36a..687db74 100644 --- a/backend/routers/freshness.py +++ b/backend/routers/freshness.py @@ -173,7 +173,9 @@ def _parse_duration(duration_str: str) -> float: def _trigger_alert(table: str, lag_seconds: float, status: str, channel: str, db: Session): """Dispatch a freshness alert.""" - lag_str = f"{lag_seconds / 3600:.1f} hours" if lag_seconds is not None else "unknown (no data found)" + lag_str = ( + f"{lag_seconds / 3600:.1f} hours" if lag_seconds is not None else "unknown (no data found)" + ) message = ( f"Freshness Alert: {table}\n" f" Lag: {lag_str}\n" @@ -187,5 +189,5 @@ def _trigger_alert(table: str, lag_seconds: float, status: str, channel: str, db subject=f"{'πŸ”΄' if status == 'fail' else '🟑'} Freshness: {table} is {status}", message=message, db=db, - severity=status + severity=status, ) diff --git a/backend/routers/profiling.py b/backend/routers/profiling.py index 0d1c1b8..c950df6 100644 --- a/backend/routers/profiling.py +++ b/backend/routers/profiling.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from backend.models import ColumnProfile, get_db +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import get_warehouse_connector logger = logging.getLogger(__name__) @@ -22,6 +23,9 @@ def run_profiling(table_name: str, db: Session = Depends(get_db)): """ Run column-level profiling for a specific table. """ + if not is_safe_table_reference(table_name): + raise HTTPException(status_code=400, detail=f"Invalid table reference: {table_name}") + connector = get_warehouse_connector() schema = connector.get_schema(table_name) @@ -35,9 +39,15 @@ def run_profiling(table_name: str, db: Session = Depends(get_db)): profiles = [] for col in schema: col_name = col["name"] + # Validate column name before using it in SQL + if not is_safe_identifier(col_name): + logger.error(f"Invalid column name '{col_name}' in table {table_name} - skipping") + continue col_type = col["type"].lower() - is_numeric = any(t in col_type for t in ["int", "decimal", "numeric", "float", "real", "double"]) + is_numeric = any( + t in col_type for t in ["int", "decimal", "numeric", "float", "real", "double"] + ) # Use portable SQL instead of PostgreSQL-only FILTER clause. # CAST(... AS CHAR) works on MySQL; CAST(... AS VARCHAR) on others; @@ -58,24 +68,30 @@ def run_profiling(table_name: str, db: Session = Depends(get_db)): results = connector.execute_query(stats_query) if results: res = results[0] - null_count = int(res["null_count"] or 0) + null_count = int(res.get("null_count", 0) or 0) profile = ColumnProfile( table_name=table_name, column_name=col_name, null_count=null_count, null_pct=(null_count / row_count) if row_count > 0 else 0, - distinct_count=int(res["distinct_count"]), - min_value=res["min_val"], - max_value=res["max_val"], - mean_value=float(res["mean_val"]) if res["mean_val"] is not None else None, - profiled_at=datetime.now(timezone.utc) + distinct_count=int(res.get("distinct_count", 0)), + min_value=res.get("min_val"), + max_value=res.get("max_val"), + mean_value=float(res.get("mean_val")) + if res.get("mean_val") is not None + else None, + profiled_at=datetime.now(timezone.utc), ) db.add(profile) - profiles.append({ - "column": col_name, - "null_pct": round((null_count / row_count) * 100, 2) if row_count > 0 else 0, - "distinct_count": int(res["distinct_count"]) - }) + profiles.append( + { + "column": col_name, + "null_pct": round((null_count / row_count) * 100, 2) + if row_count > 0 + else 0, + "distinct_count": int(res.get("distinct_count", 0)), + } + ) except Exception as e: logger.error(f"Failed to profile column {col_name} in {table_name}: {e}") @@ -87,17 +103,21 @@ def run_profiling(table_name: str, db: Session = Depends(get_db)): def get_latest_profile(table_name: str, db: Session = Depends(get_db)): """Get the most recent profile for a table.""" # Find the latest profiling run timestamp - latest_run = db.query(ColumnProfile.profiled_at).filter( - ColumnProfile.table_name == table_name - ).order_by(ColumnProfile.profiled_at.desc()).first() + latest_run = ( + db.query(ColumnProfile.profiled_at) + .filter(ColumnProfile.table_name == table_name) + .order_by(ColumnProfile.profiled_at.desc()) + .first() + ) if not latest_run: raise HTTPException(status_code=404, detail="No profiles found for this table") - records = db.query(ColumnProfile).filter( - ColumnProfile.table_name == table_name, - ColumnProfile.profiled_at == latest_run[0] - ).all() + records = ( + db.query(ColumnProfile) + .filter(ColumnProfile.table_name == table_name, ColumnProfile.profiled_at == latest_run[0]) + .all() + ) return [ { diff --git a/backend/routers/schema_diff.py b/backend/routers/schema_diff.py index 7846317..a21f014 100644 --- a/backend/routers/schema_diff.py +++ b/backend/routers/schema_diff.py @@ -38,6 +38,7 @@ def take_snapshot(db: Session = Depends(get_db)): for table_name in tables: try: from connectors.base import get_warehouse_connector + connector = get_warehouse_connector() current_columns = connector.get_schema(table_name) @@ -61,12 +62,14 @@ def take_snapshot(db: Session = Depends(get_db)): if prev_snapshot: diffs = _compute_diff(table_name, prev_snapshot.columns_json, current_columns, db) - results.append({ - "table": table_name, - "column_count": len(current_columns), - "changes": len(diffs), - "diffs": diffs, - }) + results.append( + { + "table": table_name, + "column_count": len(current_columns), + "changes": len(diffs), + "diffs": diffs, + } + ) # Alert if changes detected if diffs and schema_config.get("alert"): @@ -154,11 +157,13 @@ def _compute_diff(table_name: str, old_columns: list, new_columns: list, db: Ses new_value=None, ) db.add(diff) - diffs.append({ - "change_type": "removed", - "column": name, - "old_type": old_map[name].get("type"), - }) + diffs.append( + { + "change_type": "removed", + "column": name, + "old_type": old_map[name].get("type"), + } + ) # Columns added for name in new_map: @@ -171,11 +176,13 @@ def _compute_diff(table_name: str, old_columns: list, new_columns: list, db: Ses new_value=new_map[name].get("type", ""), ) db.add(diff) - diffs.append({ - "change_type": "added", - "column": name, - "new_type": new_map[name].get("type"), - }) + diffs.append( + { + "change_type": "added", + "column": name, + "new_type": new_map[name].get("type"), + } + ) # Type changes for name in old_map: @@ -191,12 +198,14 @@ def _compute_diff(table_name: str, old_columns: list, new_columns: list, db: Ses new_value=new_type, ) db.add(diff) - diffs.append({ - "change_type": "type_changed", - "column": name, - "old_type": old_type, - "new_type": new_type, - }) + diffs.append( + { + "change_type": "type_changed", + "column": name, + "old_type": old_type, + "new_type": new_type, + } + ) return diffs @@ -225,5 +234,5 @@ def _trigger_schema_alert(table: str, diffs: list, channel: str, db: Session): subject=f"⚠️ Schema Drift: {table}", message=message, db=db, - severity="warn" + severity="warn", ) diff --git a/backend/routers/webhooks.py b/backend/routers/webhooks.py index 222d7a0..baacca7 100644 --- a/backend/routers/webhooks.py +++ b/backend/routers/webhooks.py @@ -30,7 +30,10 @@ def trigger_test_alert(): if success: return {"status": "success", "message": "Test alert dispatched."} else: - return {"status": "error", "message": "Failed to dispatch test alert. Check SLACK_WEBHOOK_URL."} + return { + "status": "error", + "message": "Failed to dispatch test alert. Check SLACK_WEBHOOK_URL.", + } @router.get("/airflow") @@ -178,4 +181,3 @@ def _parse_datetime(dt_str: Optional[str]) -> Optional[datetime]: return datetime.fromisoformat(dt_str.replace("Z", "+00:00")) except (ValueError, AttributeError): return None - diff --git a/backend/scheduler.py b/backend/scheduler.py index 0067066..097b56f 100644 --- a/backend/scheduler.py +++ b/backend/scheduler.py @@ -61,7 +61,9 @@ def _advisory_lock(db, job_id: str): {"key": lock_key}, ).scalar() if not result: - logger.debug("Advisory lock busy for job=%s β€” skipping (another replica is running it)", job_id) + logger.debug( + "Advisory lock busy for job=%s β€” skipping (another replica is running it)", job_id + ) yield False return yield True @@ -86,6 +88,7 @@ def _run_job(pillar: str, job_fn, job_id: str | None = None): _job_id = job_id or pillar from backend.models import SessionLocal + db = SessionLocal() try: with _advisory_lock(db, _job_id) as acquired: @@ -94,7 +97,8 @@ def _run_job(pillar: str, job_fn, job_id: str | None = None): logger.info( '{"event":"job_start","run_id":"%s","pillar":"%s"}', - run_id, pillar, + run_id, + pillar, ) t0 = time.monotonic() try: @@ -104,13 +108,18 @@ def _run_job(pillar: str, job_fn, job_id: str | None = None): status = "error" logger.error( '{"event":"job_error","run_id":"%s","pillar":"%s","error":"%s"}', - run_id, pillar, exc, + run_id, + pillar, + exc, ) finally: duration_ms = int((time.monotonic() - t0) * 1000) logger.info( '{"event":"job_end","run_id":"%s","pillar":"%s","duration_ms":%d,"status":"%s"}', - run_id, pillar, duration_ms, status, + run_id, + pillar, + duration_ms, + status, ) finally: db.close() @@ -120,8 +129,10 @@ def _run_job(pillar: str, job_fn, job_id: str | None = None): # Job implementations β€” each receives an open DB session # --------------------------------------------------------------------------- + def _run_freshness_checks(): """Trigger freshness checks for all configured tables.""" + def _job(db): from backend.routers.freshness import poll_freshness from connectors.base import get_warehouse_connector @@ -137,6 +148,7 @@ def _job(db): def _run_volume_checks(): """Trigger volume anomaly checks.""" + def _job(db): from backend.routers.checks import run_volume_checks from connectors.base import get_warehouse_connector @@ -152,8 +164,10 @@ def _job(db): def _run_schema_checks(): """Trigger schema drift detection.""" + def _job(db): from backend.routers.schema_diff import take_snapshot + take_snapshot(db=db) _run_job("schema", _job) @@ -161,8 +175,10 @@ def _job(db): def _run_quality_checks(): """Trigger quality checks.""" + def _job(db): from backend.routers.checks import run_quality_checks + run_quality_checks(db=db) _run_job("quality", _job) @@ -170,8 +186,10 @@ def _job(db): def _run_finops_checks(): """Trigger FinOps cost checks.""" + def _job(db): from backend.routers.finops import poll_finops_costs + # Defaulting to 7 days for the scheduled check poll_finops_costs(days=7, db=db) @@ -180,8 +198,10 @@ def _job(db): def _run_dbt_watcher(): """Poll dbt project's target/run_results.json and ingest if newer than last seen.""" + def _job(db): from dbt_integration.watcher import poll_dbt_artifacts + result = poll_dbt_artifacts(db=db) if result["status"] == "error": raise RuntimeError(result.get("error", "unknown dbt watcher error")) @@ -194,6 +214,7 @@ def _job(db): # Scheduler lifecycle # --------------------------------------------------------------------------- + def start_scheduler(): """Start the APScheduler background scheduler.""" global _scheduler @@ -294,10 +315,12 @@ def get_scheduler_jobs() -> list[dict]: jobs = [] for job in _scheduler.get_jobs(): next_run = job.next_run_time - jobs.append({ - "id": job.id, - "name": job.name, - "next_run": next_run.isoformat() if next_run else None, - "trigger": str(job.trigger), - }) + jobs.append( + { + "id": job.id, + "name": job.name, + "next_run": next_run.isoformat() if next_run else None, + "trigger": str(job.trigger), + } + ) return jobs diff --git a/backend/security.py b/backend/security.py new file mode 100644 index 0000000..17fe67d --- /dev/null +++ b/backend/security.py @@ -0,0 +1,418 @@ +""" +Security utilities for ObservaKit. + +Provides safe identifier validation and SQL injection prevention. +""" + +import re + +# Maximum length for table/column identifiers to prevent abuse +MAX_IDENTIFIER_LENGTH = 64 + +# List of reserved SQL keywords that cannot be used as identifiers +RESERVED_KEYWORDS = { + "SELECT", + "FROM", + "WHERE", + "INSERT", + "UPDATE", + "DELETE", + "DROP", + "CREATE", + "ALTER", + "TRUNCATE", + "UNION", + "JOIN", + "INNER", + "LEFT", + "RIGHT", + "FULL", + "ON", + "AND", + "OR", + "NOT", + "IN", + "LIKE", + "BETWEEN", + "EXISTS", + "DISTINCT", + "GROUP", + "ORDER", + "HAVING", + "LIMIT", + "OFFSET", + "VALUES", + "INTO", + "VALUES", + "SET", + "CASE", + "WHEN", + "ELSE", + "END", + "AS", + "IS", + "NULL", + "TRUE", + "FALSE", + "COUNT", + "SUM", + "AVG", + "MAX", + "MIN", + "CAST", + "CONVERT", + "COALESCE", + "NULLIF", + "DATE", + "TIME", + "TIMESTAMP", + "INTERVAL", + "YEAR", + "MONTH", + "DAY", + "HOUR", + "MINUTE", + "SECOND", + "WITH", + "RECURSIVE", + "VIEW", + "TABLE", + "COLUMN", + "INDEX", + "SEQUENCE", + "TRIGGER", + "FUNCTION", + "PROCEDURE", + "DATABASE", + "SCHEMA", + "ROLE", + "GRANT", + "REVOKE", + "COMMIT", + "ROLLBACK", + "SAVEPOINT", + "TRANSACTION", + "BEGIN", + "WORK", + "LOCK", + "TABLES", + "SYNONYM", + "TYPE", + "DOMAIN", + "CAST", + "ALL", + "ANY", + "SOME", + "EXCEPT", + "INTERSECT", + "MINUS", + "MAX", + "MIN", + "SUM", + "AVG", + "COUNT", + "TOP", + "FETCH", + "NEXT", + "FIRST", + "ONLY", + "SPLIT_PART", + "ARRAY", + "JSON", + "JSONB", + "XML", + "TEXT", + "VARCHAR", + "CHAR", + "INTEGER", + "BIGINT", + "SMALLINT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DOUBLE", + "BOOLEAN", + "DATE", + "DATETIME", + "TIMESTAMP", + "TIME", + "INTERVAL", + "YEAR", + "MONTH", + "DAY", + "HOUR", + "MINUTE", + "SECOND", + "AUTO_INCREMENT", + "DEFAULT", + "PRIMARY", + "KEY", + "FOREIGN", + "REFERENCE", + "CHECK", + "CONSTRAINT", + "UNIQUE", + "NOT", + "NULL", + "ASC", + "DESC", + "NULLS", + "FIRST", + "LAST", + "EXCLUDE", + "PARTITION", + "OVER", + "RANGE", + "ROWS", + "GROUPS", + "UNBOUNDED", + "PRECEDING", + "FOLLOWING", + "EXTRACT", + "DATE_PART", + "EXTRACT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "SESSION_USER", + "SYSTEM_USER", + "LOCALTIME", + "LOCALTIMESTAMP", + "VERSION", + "REPLACE", + "REPEAT", + "REPLACE", + "IF", + "THEN", + "ELSE", + "ENDIF", + "ENDIF", + "LOOP", + "ENDLOOP", + "WHILE", + "ENDWHILE", + "FOR", + "ENDFOR", + "RETURN", + "OUTER", + "CROSS", + "NATURAL", + "USING", + "NATURAL", + "LEFT", + "RIGHT", + "FULL", + "INNER", + "OUTER", + "CROSS", + "APPLY", + "PIVOT", + "UNPIVOT", + "LATERAL", + "TABLESAMPLE", + "BERNOULLI", + "SYSTEM", + "PERCENT", + "REPEATABLE", + "ONLY", + "DEFAULT", + "CASCADE", + "LOCAL", + "SESSION", + "TRANSACTION", + "CONCURRENTLY", + "DEFERRABLE", + "INITIALLY", + "IMMEDIATE", + "DEFERRED", + "SET", + "LOCAL", + "GLOBAL", + "TEMP", + "TEMPORARY", + "UNLOGGED", + "EXTERNAL", + "CATALOG", + "PRESERVE", + "RESTART", + "CONTINUE", + "NESTED", + "LEVEL", + "READ", + "WRITE", + "COMMIT", + "ROLLBACK", + "SAVEPOINT", + "CONNECT", + "TERMINATE", + "PURGE", + "RELEASE", + "PREPARE", + "EXECUTE", + "DEALLOCATE", + "DESCRIBE", + "EXPLAIN", + "ANALYZE", + "VERBOSE", + "COSTS", + "BUFFERS", + "TIMINGS", + "FORMAT", + "TEXT", + "XML", + "JSON", + "YAML", + "TREE", +} + + +def is_safe_identifier(name: str) -> bool: + """ + Validate that a table or column name is a safe SQL identifier. + + Checks: + - Length is within bounds + - Only contains alphanumeric characters and underscores + - Does not contain SQL keywords orReserved words + - Does not contain path traversal or special characters + + Args: + name: The identifier to validate + + Returns: + True if the identifier is safe, False otherwise + """ + if not name or not isinstance(name, str): + return False + + # Trim whitespace + name = name.strip() + + # Check length + if len(name) == 0 or len(name) > MAX_IDENTIFIER_LENGTH: + return False + + # Check for path traversal or special characters + if ".." in name or "/" in name or "\\" in name or ";" in name or "'" in name: + return False + + # Check for SQL injection patterns + if re.search(r"[\r\n]", name): + return False + + # Only allow alphanumeric and underscore + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name): + return False + + # Check against reserved keywords (case-insensitive) + if name.upper() in RESERVED_KEYWORDS: + return False + + return True + + +def is_safe_table_reference(table: str) -> bool: + """ + Validate a table reference (may include catalog.schema.table or schema.table). + + Args: + table: The table reference string + + Returns: + True if all identifier parts are safe, False otherwise + """ + if not table or not isinstance(table, str): + return False + + table = table.strip() + + # Split on dots to handle multi-part names + parts = table.split(".") + + for part in parts: + part = part.strip() + if not is_safe_identifier(part): + return False + + return True + + +def get_qualified_table_name( + table: str, default_catalog: str = None, default_schema: str = None +) -> str: + """ + Validate and normalize a table reference, adding catalog/schema if needed. + + Args: + table: The table reference (may be just table, or schema.table, or catalog.schema.table) + default_catalog: Default catalog if not specified + default_schema: Default schema if not specified + + Returns: + The fully qualified table name if valid + + Raises: + ValueError: If the table reference is invalid + """ + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") + + parts = table.split(".") + + if len(parts) == 1: + # Just table name - add default schema + if default_schema: + return f"{default_schema}.{parts[0]}" + return table + elif len(parts) == 2: + # schema.table or catalog.schema + if default_catalog: + # Assume first part is schema if we have a default catalog + return f"{default_catalog}.{parts[0]}.{parts[1]}" + return table + elif len(parts) == 3: + # catalog.schema.table - fully qualified + return table + else: + raise ValueError(f"Invalid table reference format: {table}") + + +def safe_quote_identifier(name: str) -> str: + """ + Safely quote a SQL identifier for use in a query. + + Note: This should only be used AFTER validation with is_safe_identifier(). + + For PostgreSQL: uses double quotes + For other databases, this may need adjustment + + Args: + name: The identifier to quote + + Returns: + The quoted identifier + """ + if not is_safe_identifier(name): + raise ValueError(f"Cannot quote unsafe identifier: {name}") + + # Escape double quotes within the identifier + escaped = name.replace('"', '""') + return f'"{escaped}"' + + +def safe_quote_table(table: str) -> str: + """ + Safely quote a table name for use in a query. + + Args: + table: The table reference + + Returns: + The fully qualified and quoted table name + """ + if not is_safe_table_reference(table): + raise ValueError(f"Cannot quote unsafe table reference: {table}") + + parts = table.split(".") + return ".".join(safe_quote_identifier(p) for p in parts) diff --git a/cli/main.py b/cli/main.py index 365740a..a5e5863 100644 --- a/cli/main.py +++ b/cli/main.py @@ -28,6 +28,7 @@ # Helpers # --------------------------------------------------------------------------- + def _api_url() -> str: return os.getenv("OBSERVAKIT_API_URL", "http://localhost:8000") @@ -76,6 +77,7 @@ def _status_icon(s: str) -> str: # Sub-commands # --------------------------------------------------------------------------- + def cmd_status(args) -> int: """Print a health summary of all monitored tables.""" output_json = getattr(args, "output", None) == "json" @@ -92,9 +94,7 @@ def cmd_status(args) -> int: print(json.dumps(data, indent=2)) tables = data.get("tables", []) any_fail = any( - t.get(p) == "fail" - for t in tables - for p in ("freshness", "volume", "quality", "schema") + t.get(p) == "fail" for t in tables for p in ("freshness", "volume", "quality", "schema") ) return 1 if any_fail else 0 @@ -102,7 +102,9 @@ def cmd_status(args) -> int: tables = data.get("tables", []) print(f"\nπŸ”­ ObservaKit Status (window: {data.get('window_hours', 24)}h)") - print(f" Healthy: {summary.get('healthy', 0)} Warn: {summary.get('warn', 0)} Fail: {summary.get('fail', 0)}\n") + print( + f" Healthy: {summary.get('healthy', 0)} Warn: {summary.get('warn', 0)} Fail: {summary.get('fail', 0)}\n" + ) if not tables: print(" No monitored tables found in the last 24 hours.") @@ -231,6 +233,7 @@ def cmd_validate_config(args) -> int: # 1. Load and parse kit.yml try: from config.loader import load_config + config = load_config(config_path) except FileNotFoundError: print(f"❌ Config file not found: {config_path}") @@ -266,7 +269,9 @@ def cmd_validate_config(args) -> int: default_channel = alerts_cfg.get("default_channel") supported_channels = {"slack", "email", "discord", "webhook", "teams", "pagerduty"} if default_channel and default_channel not in supported_channels: - errors.append(f"alerts.default_channel '{default_channel}' is not supported (supported: {', '.join(sorted(supported_channels))})") + errors.append( + f"alerts.default_channel '{default_channel}' is not supported (supported: {', '.join(sorted(supported_channels))})" + ) for i, rule in enumerate(alerts_cfg.get("routing", [])): if not rule.get("channel"): @@ -278,6 +283,7 @@ def cmd_validate_config(args) -> int: import os as _os import yaml + contract_files = _glob.glob(f"{contracts_dir}/*.yml") if _os.path.isdir(contracts_dir) else [] contract_errors = 0 for cf in contract_files: @@ -303,7 +309,9 @@ def cmd_validate_config(args) -> int: sections = list(config.keys()) print(f"βœ… Config valid β€” sections: {', '.join(sections)}") if contract_files: - print(f" Contracts validated: {len(contract_files) - contract_errors}/{len(contract_files)}") + print( + f" Contracts validated: {len(contract_files) - contract_errors}/{len(contract_files)}" + ) return 0 @@ -328,7 +336,9 @@ def cmd_diff(args) -> int: for d in diffs: old_v = d.get("old_value", "") or "" new_v = d.get("new_value", "") or "" - print(f" {d.get('table_name', ''):<35} {d.get('change_type', ''):<18} {d.get('column_name', ''):<25} {old_v:<20} {new_v}") + print( + f" {d.get('table_name', ''):<35} {d.get('change_type', ''):<18} {d.get('column_name', ''):<25} {old_v:<20} {new_v}" + ) print(f"\n {len(diffs)} change(s) detected.") return 1 # exit 1 when drift found β€” useful for CI gates @@ -351,7 +361,9 @@ def cmd_init(args) -> int: print("\nQ: Do you have a Slack Webhook URL for alerts? (Leave blank to skip)") slack_url = input("> ").strip() - print("\nQ: What is the main table you want to monitor for data freshness? (e.g. public.orders)") + print( + "\nQ: What is the main table you want to monitor for data freshness? (e.g. public.orders)" + ) table_name = input("> [public.orders]: ").strip() or "public.orders" yaml_content = f"""# ObservaKit Configuration @@ -398,34 +410,38 @@ def cmd_init(args) -> int: # Entry point # --------------------------------------------------------------------------- + def main(): parser = argparse.ArgumentParser( prog="observakit", description="ObservaKit CLI β€” data observability for small teams", ) parser.add_argument( - "--url", default=None, - help="ObservaKit API URL (default: $OBSERVAKIT_API_URL or http://localhost:8000)" - ) - parser.add_argument( - "--api-key", default=None, - help="API key (default: $OBSERVAKIT_API_KEY)" + "--url", + default=None, + help="ObservaKit API URL (default: $OBSERVAKIT_API_URL or http://localhost:8000)", ) + parser.add_argument("--api-key", default=None, help="API key (default: $OBSERVAKIT_API_KEY)") parser.add_argument( - "--config", default=None, - help="Path to kit.yml (default: $OBSERVAKIT_CONFIG or config/kit.yml)" + "--config", + default=None, + help="Path to kit.yml (default: $OBSERVAKIT_CONFIG or config/kit.yml)", ) sub = parser.add_subparsers(dest="command", required=True) # status p_status = sub.add_parser("status", help="Print health status of all monitored tables") - p_status.add_argument("--output", choices=["json"], default=None, help="Output format (json for scripting/CI)") + p_status.add_argument( + "--output", choices=["json"], default=None, help="Output format (json for scripting/CI)" + ) # check p_check = sub.add_parser("check", help="Run quality checks") p_check.add_argument("--dry-run", action="store_true", help="Preview without writing results") - p_check.add_argument("--output", choices=["json"], default=None, help="Output format (json for scripting/CI)") + p_check.add_argument( + "--output", choices=["json"], default=None, help="Output format (json for scripting/CI)" + ) # profile p_profile = sub.add_parser("profile", help="Run column profiling for a table") @@ -434,7 +450,9 @@ def main(): # suppress p_suppress = sub.add_parser("suppress", help="Suppress alerts for a table") p_suppress.add_argument("table", help="Table name") - p_suppress.add_argument("--minutes", type=int, default=60, help="Duration in minutes (default: 60)") + p_suppress.add_argument( + "--minutes", type=int, default=60, help="Duration in minutes (default: 60)" + ) p_suppress.add_argument("--reason", default=None, help="Reason for suppression") # init @@ -444,9 +462,15 @@ def main(): sub.add_parser("test-alert", help="Fire a test alert to configured channels") # validate-config - p_validate = sub.add_parser("validate-config", help="Dry-run parse kit.yml without connecting to warehouse") - p_validate.add_argument("--config", dest="config", default=None, - help="Path to kit.yml to validate (overrides --config at top level)") + p_validate = sub.add_parser( + "validate-config", help="Dry-run parse kit.yml without connecting to warehouse" + ) + p_validate.add_argument( + "--config", + dest="config", + default=None, + help="Path to kit.yml to validate (overrides --config at top level)", + ) # diff p_diff = sub.add_parser("diff", help="Show schema changes vs last saved snapshot") diff --git a/config/loader.py b/config/loader.py index ff2cfaf..670e5cf 100644 --- a/config/loader.py +++ b/config/loader.py @@ -16,10 +16,12 @@ def _expand_env_vars(value: Any) -> Any: """Recursively expand ${VAR:-default} patterns in strings.""" if isinstance(value, str): + def replacer(match: re.Match) -> str: var_name = match.group(1) default_val = match.group(2) if match.group(2) is not None else "" return os.getenv(var_name, default_val) + return _ENV_VAR_PATTERN.sub(replacer, value) elif isinstance(value, dict): return {k: _expand_env_vars(v) for k, v in value.items()} @@ -37,6 +39,11 @@ def load_config(path: str = "config/kit.yml") -> dict: slack_url = config["alerts"]["slack"]["webhook_url"] # β†’ reads SLACK_WEBHOOK_URL from environment """ - with open(path, "r") as f: - raw = yaml.safe_load(f) + try: + with open(path, "r") as f: + raw = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Config file not found: {path}") + except yaml.YAMLError as e: + raise ValueError(f"Invalid YAML in config file {path}: {e}") return _expand_env_vars(raw) diff --git a/connectors/base.py b/connectors/base.py index ed4fa86..1fe4234 100644 --- a/connectors/base.py +++ b/connectors/base.py @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) + def resilient_query(): """ Decorator for adding exponential backoff retries to warehouse queries. @@ -24,7 +25,7 @@ def resilient_query(): reraise=True, before_sleep=lambda retry_state: logger.warning( f"Transient warehouse error encountered. Retrying (attempt {retry_state.attempt_number})..." - ) + ), ) @@ -114,27 +115,35 @@ def get_warehouse_connector() -> WarehouseConnector: if warehouse_type == "postgres": from connectors.postgres import PostgresConnector + return PostgresConnector() elif warehouse_type == "bigquery": from connectors.bigquery import BigQueryConnector + return BigQueryConnector() elif warehouse_type == "snowflake": from connectors.snowflake import SnowflakeConnector + return SnowflakeConnector() elif warehouse_type in ("mysql", "mariadb"): from connectors.mysql import MySQLConnector + return MySQLConnector() elif warehouse_type == "redshift": from connectors.redshift import RedshiftConnector + return RedshiftConnector() elif warehouse_type == "duckdb": from connectors.duckdb import DuckDBConnector + return DuckDBConnector() elif warehouse_type == "databricks": from connectors.databricks import DatabricksConnector + return DatabricksConnector() elif warehouse_type == "trino": from connectors.trino import TrinoConnector + return TrinoConnector() else: raise ValueError( @@ -157,9 +166,11 @@ def get_orchestrator_connector() -> OrchestratorConnector: if orch_type == "airflow": from connectors.airflow import AirflowConnector + return AirflowConnector() elif orch_type == "prefect": from connectors.prefect import PrefectConnector + return PrefectConnector() else: raise ValueError(f"Unsupported orchestrator type: {orch_type}") diff --git a/connectors/bigquery.py b/connectors/bigquery.py index c368bd5..481299f 100644 --- a/connectors/bigquery.py +++ b/connectors/bigquery.py @@ -8,6 +8,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -50,8 +51,11 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: """Get the max value of a timestamp column.""" + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") client = self.connect() full_table = f"{self._project}.{self._dataset}.{table.split('.')[-1]}" + # Backtick-escape the column name (table is validated by is_safe_table_reference) query = f"SELECT MAX(`{column}`) as max_ts FROM `{full_table}`" try: @@ -65,6 +69,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") """Get the current row count of a table.""" client = self.connect() full_table = f"{self._project}.{self._dataset}.{table.split('.')[-1]}" @@ -82,8 +88,14 @@ def get_row_count(self, table: str) -> int: @resilient_query() def get_schema(self, table: str) -> list[dict]: """Get schema from INFORMATION_SCHEMA.COLUMNS.""" + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") client = self.connect() table_name = table.split(".")[-1] + # Validate table name (already done by is_safe_table_reference on full reference) + # But also validate the extracted table name + if not is_safe_identifier(table_name): + raise ValueError(f"Invalid table name: {table_name}") query = f""" SELECT column_name AS name, @@ -125,6 +137,8 @@ def execute_query(self, query: str, params: dict = None) -> list[dict]: @resilient_query() def get_compute_costs(self, days: int = 7) -> float: """Get total bytes billed over the last N days from INFORMATION_SCHEMA.JOBS.""" + # Ensure days is an integer to prevent SQL injection + days = int(days) client = self.connect() # Note: region must be specified; defaulting to region-us for demo purposes. query = f""" @@ -150,7 +164,7 @@ def get_soda_config(self) -> dict: "project_id": self._project, "dataset": self._dataset, "account_info_json_path": self._credentials_path, - } + }, } } diff --git a/connectors/databricks.py b/connectors/databricks.py index 9d61a15..9169d0c 100644 --- a/connectors/databricks.py +++ b/connectors/databricks.py @@ -20,6 +20,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -88,6 +89,8 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: with conn.cursor() as cur: @@ -105,6 +108,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: with conn.cursor() as cur: @@ -149,6 +154,9 @@ def get_schema(self, table: str) -> list[dict]: if not rows: # Fallback: DESCRIBE TABLE (works for older metastore tables) + # Validate the table reference before using it in DESCRIBE + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") with conn.cursor() as cur: cur.execute(f"DESCRIBE TABLE {table}") rows = cur.fetchall() diff --git a/connectors/duckdb.py b/connectors/duckdb.py index 39a75a0..41d720f 100644 --- a/connectors/duckdb.py +++ b/connectors/duckdb.py @@ -18,6 +18,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -62,6 +63,8 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: result = conn.execute(f"SELECT MAX({column}) FROM {table}").fetchone() @@ -78,6 +81,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: result = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() @@ -92,6 +97,8 @@ def get_schema(self, table: str) -> list[dict]: Return column metadata for a DuckDB table using PRAGMA table_info(). Falls back to information_schema when the table is not a base table. """ + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: rows = conn.execute(f"PRAGMA table_info('{table}')").fetchall() @@ -108,6 +115,9 @@ def get_schema(self, table: str) -> list[dict]: ] # Fallback: information_schema (works for views / virtual tables) schema, tbl = ("main", table) if "." not in table else table.split(".", 1) + # Validate schema and table names + if not is_safe_identifier(schema) or not is_safe_identifier(tbl): + raise ValueError(f"Invalid identifier in table reference: {table}") rows = conn.execute( """ SELECT column_name, data_type, is_nullable, ordinal_position diff --git a/connectors/mysql.py b/connectors/mysql.py index 5dcc1c7..83851f6 100644 --- a/connectors/mysql.py +++ b/connectors/mysql.py @@ -41,9 +41,7 @@ def connect(self): import pymysql import pymysql.cursors except ImportError: - raise RuntimeError( - "PyMySQL is not installed. Run: pip install 'observakit[mysql]'" - ) + raise RuntimeError("PyMySQL is not installed. Run: pip install 'observakit[mysql]'") if self._conn is None or not self._conn.open: self._conn = pymysql.connect(**self._config) diff --git a/connectors/postgres.py b/connectors/postgres.py index 5b8cbc3..f405887 100644 --- a/connectors/postgres.py +++ b/connectors/postgres.py @@ -11,6 +11,7 @@ import psycopg2 import psycopg2.extras +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -42,13 +43,12 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: - """Get the max value of a timestamp column.""" + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: with conn.cursor() as cur: - cur.execute( - f"SELECT MAX({column}) FROM {table}" - ) + cur.execute(f"SELECT MAX({column}) FROM {table}") result = cur.fetchone() return result[0] if result and result[0] else None except Exception as e: @@ -58,7 +58,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: - """Get the current row count of a table.""" + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: with conn.cursor() as cur: @@ -146,4 +147,3 @@ def get_gx_config(self) -> dict: }, }, } - diff --git a/connectors/redshift.py b/connectors/redshift.py index c90a573..ce3136d 100644 --- a/connectors/redshift.py +++ b/connectors/redshift.py @@ -16,6 +16,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -73,6 +74,8 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: with conn.cursor() as cur: @@ -86,6 +89,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: with conn.cursor() as cur: @@ -127,10 +132,7 @@ def get_schema(self, table: str) -> list[dict]: ) rows = cur.fetchall() columns = cur.description - return [ - dict(zip([col[0] for col in columns], row)) - for row in rows - ] + return [dict(zip([col[0] for col in columns], row)) for row in rows] except Exception: # Fallback to information_schema for clusters without SVV_COLUMNS access try: @@ -179,6 +181,8 @@ def get_query_bytes_scanned(self, hours: int = 24) -> list[dict]: FinOps helper: return total bytes scanned per user/query label in the last N hours. Uses STL_SCAN which is Redshift-specific. """ + # Ensure hours is an integer to prevent SQL injection + hours = int(hours) query = f""" SELECT userid, diff --git a/connectors/snowflake.py b/connectors/snowflake.py index d96a98b..fb1584e 100644 --- a/connectors/snowflake.py +++ b/connectors/snowflake.py @@ -8,6 +8,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -48,7 +49,8 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: - """Get the max value of a timestamp column.""" + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: cur = conn.cursor() @@ -63,7 +65,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: - """Get the current row count of a table.""" + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: cur = conn.cursor() @@ -99,10 +102,7 @@ def get_schema(self, table: str) -> list[dict]: (schema_name.upper(), table_name.upper()), ) columns = cur.description - return [ - dict(zip([col[0].lower() for col in columns], row)) - for row in cur.fetchall() - ] + return [dict(zip([col[0].lower() for col in columns], row)) for row in cur.fetchall()] except Exception as e: logger.error(f"Snowflake error getting schema for {table}: {e}") self.close() @@ -116,10 +116,7 @@ def execute_query(self, query: str, params: dict = None) -> list[dict]: cur = conn.cursor() cur.execute(query, params) columns = cur.description - return [ - dict(zip([col[0].lower() for col in columns], row)) - for row in cur.fetchall() - ] + return [dict(zip([col[0].lower() for col in columns], row)) for row in cur.fetchall()] except Exception as e: logger.error(f"Snowflake error executing query: {e}") self.close() @@ -128,6 +125,8 @@ def execute_query(self, query: str, params: dict = None) -> list[dict]: @resilient_query() def get_compute_costs(self, days: int = 7) -> float: """Get compute credits used over the last N days.""" + # Ensure days is an integer to prevent SQL injection + days = int(days) conn = self.connect() query = f""" SELECT SUM(CREDITS_USED) as total_credits @@ -155,7 +154,7 @@ def get_soda_config(self) -> dict: "account": self._config["account"], "database": self._config["database"], "warehouse": self._config["warehouse"], - } + }, } } diff --git a/connectors/trino.py b/connectors/trino.py index 0261b5d..020244e 100644 --- a/connectors/trino.py +++ b/connectors/trino.py @@ -22,6 +22,7 @@ from datetime import datetime from typing import Optional +from backend.security import is_safe_identifier, is_safe_table_reference from connectors.base import WarehouseConnector, resilient_query logger = logging.getLogger(__name__) @@ -92,6 +93,8 @@ def close(self): @resilient_query() def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: + if not is_safe_table_reference(table) or not is_safe_identifier(column): + raise ValueError(f"Invalid table/column reference: table={table}, column={column}") conn = self.connect() try: with conn.cursor() as cur: @@ -109,6 +112,8 @@ def get_max_timestamp(self, table: str, column: str) -> Optional[datetime]: @resilient_query() def get_row_count(self, table: str) -> int: + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: with conn.cursor() as cur: @@ -125,6 +130,8 @@ def get_schema(self, table: str) -> list[dict]: Return column metadata from information_schema.columns. Trino uses three-part naming: catalog.schema.table. """ + if not is_safe_table_reference(table): + raise ValueError(f"Invalid table reference: {table}") conn = self.connect() try: parts = table.split(".") @@ -138,6 +145,14 @@ def get_schema(self, table: str) -> list[dict]: schema = self._schema tbl = table + # Validate catalog, schema, and table names individually + if ( + not is_safe_identifier(catalog) + or not is_safe_identifier(schema) + or not is_safe_identifier(tbl) + ): + raise ValueError(f"Invalid identifier in table reference: {table}") + with conn.cursor() as cur: cur.execute( f""" diff --git a/dbt_integration/parse_artifacts.py b/dbt_integration/parse_artifacts.py index 19a809d..0237530 100644 --- a/dbt_integration/parse_artifacts.py +++ b/dbt_integration/parse_artifacts.py @@ -11,18 +11,21 @@ from backend.models import CheckResult, PipelineRun except ImportError: import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from backend.models import CheckResult, PipelineRun # Configure logging -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) # Build DATABASE_URL from the same env vars the backend uses, so this # script works both from the host and inside the Docker network. _db_user = os.getenv("METADATA_DB_USER", "observakit") _db_pass = os.getenv("METADATA_DB_PASSWORD", "changeme") -_db_host = os.getenv("METADATA_DB_HOST", "postgres") # Docker service name, not localhost +_db_host = os.getenv("METADATA_DB_HOST", "postgres") # Docker service name, not localhost _db_port = os.getenv("METADATA_DB_PORT", "5432") _db_name = os.getenv("METADATA_DB_NAME", "observakit") DATABASE_URL = os.getenv( @@ -33,6 +36,7 @@ engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + def parse_run_results(run_results_path: str, manifest_path: str): """Parses dbt run_results.json and manifest.json into the DB.""" @@ -46,10 +50,10 @@ def parse_run_results(run_results_path: str, manifest_path: str): db = SessionLocal() try: - with open(run_results_path, 'r') as f: + with open(run_results_path, "r") as f: run_results = json.load(f) - with open(manifest_path, 'r') as f: + with open(manifest_path, "r") as f: manifest = json.load(f) invocation_id = run_results.get("metadata", {}).get("invocation_id") @@ -70,7 +74,9 @@ def parse_run_results(run_results_path: str, manifest_path: str): node_info = manifest.get("nodes", {}).get(unique_id, {}) if not node_info: # Might be a test or macro, check broader dicts if needed - node_info = manifest.get("sources", {}).get(unique_id, manifest.get("metrics", {}).get(unique_id, {})) + node_info = manifest.get("sources", {}).get( + unique_id, manifest.get("metrics", {}).get(unique_id, {}) + ) resource_type = node_info.get("resource_type", "unknown") @@ -82,7 +88,7 @@ def parse_run_results(run_results_path: str, manifest_path: str): run_id=f"dbt_{invocation_id}_{unique_id}", state="success" if status in ["success", "pass"] else "failed", duration_seconds=execution_time, - end_time=timestamp + end_time=timestamp, ) db.add(run) @@ -96,8 +102,13 @@ def parse_run_results(run_results_path: str, manifest_path: str): check_name=str(unique_id), check_type="dbt_test", passed=True if status in ["pass", "success"] else False, - details=json.dumps({"error_message": str(result.get("message", "")), "execution_time": execution_time}), - executed_at=timestamp + details=json.dumps( + { + "error_message": str(result.get("message", "")), + "execution_time": execution_time, + } + ), + executed_at=timestamp, ) db.add(check) @@ -110,8 +121,10 @@ def parse_run_results(run_results_path: str, manifest_path: str): finally: db.close() + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser(description="Parse dbt artifacts into ObservaKit metadata DB.") parser.add_argument("--run-results", type=str, required=True, help="Path to run_results.json") parser.add_argument("--manifest", type=str, required=True, help="Path to manifest.json") diff --git a/scripts/generate_mock_data.py b/scripts/generate_mock_data.py index c1c48ba..096f455 100644 --- a/scripts/generate_mock_data.py +++ b/scripts/generate_mock_data.py @@ -56,7 +56,7 @@ f"{os.getenv('METADATA_DB_USER', 'observakit')}:" f"{os.getenv('METADATA_DB_PASSWORD', 'changeme')}@" f"{os.getenv('METADATA_DB_HOST', 'localhost')}:" - f"{os.getenv('METADATA_DB_PORT', '5433')}/" # 5433 = host-side port in docker-compose + f"{os.getenv('METADATA_DB_PORT', '5433')}/" # 5433 = host-side port in docker-compose f"{os.getenv('METADATA_DB_NAME', 'observakit')}" ) @@ -72,7 +72,7 @@ ] # Freshness SLA thresholds (seconds) -FRESHNESS_WARN_THRESHOLD = 3600 # 1 h +FRESHNESS_WARN_THRESHOLD = 3600 # 1 h FRESHNESS_FAIL_THRESHOLD = 14400 # 4 h @@ -105,7 +105,8 @@ def generate_warehouse_data(): with engine.begin() as conn: # orders β€” freshness + volume + quality anomalies conn.execute(text("DROP TABLE IF EXISTS public.orders CASCADE;")) - conn.execute(text(""" + conn.execute( + text(""" CREATE TABLE public.orders ( order_id VARCHAR(50), customer_id VARCHAR(50), @@ -113,7 +114,8 @@ def generate_warehouse_data(): status VARCHAR(20), updated_at TIMESTAMP ); - """)) + """) + ) now = datetime.now(timezone.utc) values = [] for i in range(1000): @@ -122,17 +124,24 @@ def generate_warehouse_data(): f"('ORD-{i}','CUST-{i % 100}',{random.uniform(10.0, 500.0):.2f},'completed','{updated.isoformat()}')" ) # Inject quality anomalies - values[-2] = f"(NULL,'CUST-X',999.99,'completed','{(now-timedelta(minutes=180)).isoformat()}')" - values[-1] = f"('ORD-999','CUST-Y',450.00,'processing','{(now-timedelta(minutes=180)).isoformat()}')" - conn.execute(text( - "INSERT INTO public.orders (order_id,customer_id,amount,status,updated_at) VALUES " - + ",".join(values) - )) + values[-2] = ( + f"(NULL,'CUST-X',999.99,'completed','{(now - timedelta(minutes=180)).isoformat()}')" + ) + values[-1] = ( + f"('ORD-999','CUST-Y',450.00,'processing','{(now - timedelta(minutes=180)).isoformat()}')" + ) + conn.execute( + text( + "INSERT INTO public.orders (order_id,customer_id,amount,status,updated_at) VALUES " + + ",".join(values) + ) + ) print(" βœ“ public.orders (1000 rows, freshness+volume+quality anomalies injected)") # customers β€” healthy table conn.execute(text("DROP TABLE IF EXISTS public.customers CASCADE;")) - conn.execute(text(""" + conn.execute( + text(""" CREATE TABLE public.customers ( customer_id VARCHAR(50) PRIMARY KEY, email VARCHAR(255), @@ -140,7 +149,8 @@ def generate_warehouse_data(): tier VARCHAR(20), created_at TIMESTAMP ); - """)) + """) + ) c_vals = [] for i in range(500): updated = now - timedelta(minutes=random.randint(10, 30)) @@ -149,10 +159,12 @@ def generate_warehouse_data(): c_vals.append( f"('CUST-{i}','user{i}@example.com','{country}','{tier}','{updated.isoformat()}')" ) - conn.execute(text( - "INSERT INTO public.customers (customer_id,email,country,tier,created_at) VALUES " - + ",".join(c_vals) - )) + conn.execute( + text( + "INSERT INTO public.customers (customer_id,email,country,tier,created_at) VALUES " + + ",".join(c_vals) + ) + ) print(" βœ“ public.customers (500 rows, healthy)") except Exception as exc: print(f" ⚠ Could not create warehouse tables ({exc}); skipping.") @@ -175,7 +187,9 @@ def generate_metadata(): VolumeRecord, ) - engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if DB_TYPE == "sqlite" else {}) + engine = create_engine( + DATABASE_URL, connect_args={"check_same_thread": False} if DB_TYPE == "sqlite" else {} + ) Base.metadata.create_all(engine) Session = sessionmaker(bind=engine) db = Session() @@ -185,8 +199,14 @@ def generate_metadata(): # ---- Truncate all demo tables ---------------------------------------- print(" Clearing existing demo data…") for model in [ - VolumeRecord, FreshnessRecord, CheckResult, PipelineRun, - SchemaDiff, SchemaSnapshot, ColumnProfile, CheckSuppression, + VolumeRecord, + FreshnessRecord, + CheckResult, + PipelineRun, + SchemaDiff, + SchemaSnapshot, + ColumnProfile, + CheckSuppression, ]: db.query(model).delete() db.commit() @@ -207,43 +227,47 @@ def generate_metadata(): else: lag = random.uniform(60, 1500) # healthy - db.add(FreshnessRecord( - table_name=table, - timestamp_column="updated_at", - last_updated_at=ts - timedelta(seconds=lag), - lag_seconds=lag, - status=_status_for_lag(lag), - checked_at=ts, - )) + db.add( + FreshnessRecord( + table_name=table, + timestamp_column="updated_at", + last_updated_at=ts - timedelta(seconds=lag), + lag_seconds=lag, + status=_status_for_lag(lag), + checked_at=ts, + ) + ) # Add the current (most-recent) freshness check for each table current_lags = { - "public.orders": random.uniform(16200, 18000), # fail - "public.customers": random.uniform(60, 600), # ok - "public.products": random.uniform(3700, 5400), # warn - "analytics.daily_revenue": random.uniform(3700, 7000), # warn - "analytics.user_events": random.uniform(60, 900), # ok + "public.orders": random.uniform(16200, 18000), # fail + "public.customers": random.uniform(60, 600), # ok + "public.products": random.uniform(3700, 5400), # warn + "analytics.daily_revenue": random.uniform(3700, 7000), # warn + "analytics.user_events": random.uniform(60, 900), # ok } for table, lag in current_lags.items(): - db.add(FreshnessRecord( - table_name=table, - timestamp_column="updated_at", - last_updated_at=now - timedelta(seconds=lag), - lag_seconds=lag, - status=_status_for_lag(lag), - checked_at=now - timedelta(minutes=2), - )) + db.add( + FreshnessRecord( + table_name=table, + timestamp_column="updated_at", + last_updated_at=now - timedelta(seconds=lag), + lag_seconds=lag, + status=_status_for_lag(lag), + checked_at=now - timedelta(minutes=2), + ) + ) # ========================================================================= # 2. VOLUME β€” 7 days of row counts + anomaly today on public.orders # ========================================================================= print(" πŸ“ˆ Generating volume records…") volume_baselines = { - "public.orders": (1700, 50), - "public.customers": (500, 10), - "public.products": (3200, 80), - "analytics.daily_revenue": (90, 5), - "analytics.user_events": (45000, 1200), + "public.orders": (1700, 50), + "public.customers": (500, 10), + "public.products": (3200, 80), + "analytics.daily_revenue": (90, 5), + "analytics.user_events": (45000, 1200), } for day in range(7, 0, -1): ts = now - timedelta(days=day) @@ -259,15 +283,17 @@ def generate_metadata(): rolling_avg = float(baseline) deviation_pct = abs(row_count - baseline) / baseline - db.add(VolumeRecord( - table_name=table, - dag_id=f"{table.replace('.', '_')}_etl", - row_count=row_count, - rolling_avg=rolling_avg, - deviation_pct=deviation_pct, - is_anomaly=is_anomaly, - recorded_at=ts, - )) + db.add( + VolumeRecord( + table_name=table, + dag_id=f"{table.replace('.', '_')}_etl", + row_count=row_count, + rolling_avg=rolling_avg, + deviation_pct=deviation_pct, + is_anomaly=is_anomaly, + recorded_at=ts, + ) + ) # ========================================================================= # 3. QUALITY CHECKS β€” 7 days, realistic pass rates, failures today @@ -275,30 +301,35 @@ def generate_metadata(): print(" βœ… Generating quality check records…") checks_def = { "public.orders": [ - ("Orders table must not be empty", "soda", True, "row_count > 0"), - ("order_id must not be null", "soda", False, "3 null values found in last run"), - ("order_id must be unique", "soda", False, "2 duplicate order_ids detected"), - ("amount must be non-negative", "great_expectations", True, None), - ("status must be in allowed values", "custom_sql", True, None), + ("Orders table must not be empty", "soda", True, "row_count > 0"), + ("order_id must not be null", "soda", False, "3 null values found in last run"), + ("order_id must be unique", "soda", False, "2 duplicate order_ids detected"), + ("amount must be non-negative", "great_expectations", True, None), + ("status must be in allowed values", "custom_sql", True, None), ], "public.customers": [ - ("customer_id must not be null", "soda", True, None), - ("email must not be null", "soda", True, None), - ("country must be 2-char ISO code", "great_expectations", True, None), + ("customer_id must not be null", "soda", True, None), + ("email must not be null", "soda", True, None), + ("country must be 2-char ISO code", "great_expectations", True, None), ], "public.products": [ - ("product_id must not be null", "soda", True, None), - ("price must be positive", "soda", True, None), - ("category must not be null", "great_expectations", False, "12 rows with null category"), + ("product_id must not be null", "soda", True, None), + ("price must be positive", "soda", True, None), + ( + "category must not be null", + "great_expectations", + False, + "12 rows with null category", + ), ], "analytics.daily_revenue": [ - ("revenue must be non-negative", "soda", True, None), - ("No duplicate date entries", "custom_sql", True, None), + ("revenue must be non-negative", "soda", True, None), + ("No duplicate date entries", "custom_sql", True, None), ], "analytics.user_events": [ - ("user_id must not be null", "soda", True, None), - ("event_type must not be null", "soda", True, None), - ("timestamp must not be null", "great_expectations", True, None), + ("user_id must not be null", "soda", True, None), + ("event_type must not be null", "soda", True, None), + ("timestamp must not be null", "great_expectations", True, None), ], } @@ -315,15 +346,17 @@ def generate_metadata(): passed = random.random() > 0.03 details = "Transient failure" if not passed else None - db.add(CheckResult( - check_name=check_name, - table_name=table, - check_type=check_type, - passed=passed, - metric_value=0.0 if passed else 1.0, - details=details, - executed_at=ts + timedelta(minutes=random.randint(0, 55)), - )) + db.add( + CheckResult( + check_name=check_name, + table_name=table, + check_type=check_type, + passed=passed, + metric_value=0.0 if passed else 1.0, + details=details, + executed_at=ts + timedelta(minutes=random.randint(0, 55)), + ) + ) # ========================================================================= # 4. SCHEMA DRIFT β€” 3 realistic changes across two tables @@ -332,49 +365,58 @@ def generate_metadata(): # Snapshot for public.orders orders_columns_v1 = [ - {"name": "order_id", "type": "varchar(50)", "nullable": False, "ordinal": 1}, - {"name": "customer_id", "type": "varchar(50)", "nullable": True, "ordinal": 2}, - {"name": "amount", "type": "decimal(10,2)", "nullable": True, "ordinal": 3}, - {"name": "status", "type": "varchar(20)", "nullable": True, "ordinal": 4}, - {"name": "updated_at", "type": "timestamp", "nullable": True, "ordinal": 5}, + {"name": "order_id", "type": "varchar(50)", "nullable": False, "ordinal": 1}, + {"name": "customer_id", "type": "varchar(50)", "nullable": True, "ordinal": 2}, + {"name": "amount", "type": "decimal(10,2)", "nullable": True, "ordinal": 3}, + {"name": "status", "type": "varchar(20)", "nullable": True, "ordinal": 4}, + {"name": "updated_at", "type": "timestamp", "nullable": True, "ordinal": 5}, ] - db.add(SchemaSnapshot( - table_name="public.orders", - columns_json=orders_columns_v1, - snapshot_at=now - timedelta(days=7), - )) - db.add(SchemaSnapshot( - table_name="public.orders", - columns_json=orders_columns_v1 + [ - {"name": "discount_code", "type": "varchar(50)", "nullable": True, "ordinal": 6} - ], - snapshot_at=now - timedelta(days=3), - )) - - db.add(SchemaDiff( - table_name="public.orders", - change_type="added", - column_name="discount_code", - old_value=None, - new_value="varchar(50)", - detected_at=now - timedelta(days=3), - )) - db.add(SchemaDiff( - table_name="public.orders", - change_type="type_changed", - column_name="amount", - old_value="decimal(10,2)", - new_value="decimal(14,4)", - detected_at=now - timedelta(days=1), - )) - db.add(SchemaDiff( - table_name="public.customers", - change_type="removed", - column_name="legacy_segment", - old_value="text", - new_value=None, - detected_at=now - timedelta(hours=6), - )) + db.add( + SchemaSnapshot( + table_name="public.orders", + columns_json=orders_columns_v1, + snapshot_at=now - timedelta(days=7), + ) + ) + db.add( + SchemaSnapshot( + table_name="public.orders", + columns_json=orders_columns_v1 + + [{"name": "discount_code", "type": "varchar(50)", "nullable": True, "ordinal": 6}], + snapshot_at=now - timedelta(days=3), + ) + ) + + db.add( + SchemaDiff( + table_name="public.orders", + change_type="added", + column_name="discount_code", + old_value=None, + new_value="varchar(50)", + detected_at=now - timedelta(days=3), + ) + ) + db.add( + SchemaDiff( + table_name="public.orders", + change_type="type_changed", + column_name="amount", + old_value="decimal(10,2)", + new_value="decimal(14,4)", + detected_at=now - timedelta(days=1), + ) + ) + db.add( + SchemaDiff( + table_name="public.customers", + change_type="removed", + column_name="legacy_segment", + old_value="text", + new_value=None, + detected_at=now - timedelta(hours=6), + ) + ) # ========================================================================= # 5. PIPELINE RUNS (Alerts tab) β€” 50 Airflow run events, realistic cadence @@ -396,16 +438,18 @@ def generate_metadata(): else: state = "failed" if random.random() < 0.08 else "success" duration = random.uniform(120, 1800) - db.add(PipelineRun( - orchestrator="airflow", - dag_id=dag, - run_id=f"run_{run_time.strftime('%Y%m%d%H%M')}_{i}", - state=state, - start_time=run_time, - end_time=run_time + timedelta(seconds=duration), - duration_seconds=duration, - recorded_at=run_time + timedelta(seconds=duration + 5), - )) + db.add( + PipelineRun( + orchestrator="airflow", + dag_id=dag, + run_id=f"run_{run_time.strftime('%Y%m%d%H%M')}_{i}", + state=state, + start_time=run_time, + end_time=run_time + timedelta(seconds=duration), + duration_seconds=duration, + recorded_at=run_time + timedelta(seconds=duration + 5), + ) + ) # ========================================================================= # 6. COLUMN PROFILING β€” public.orders + public.customers @@ -414,57 +458,63 @@ def generate_metadata(): profiled_at = now - timedelta(hours=1) orders_profiles = [ - ("order_id", 1000, 0.003, 999, "ORD-0", "ORD-999", None), - ("customer_id", 1000, 0.000, 100, "CUST-0", "CUST-99", None), - ("amount", 1000, 0.000, 892, "10.12", "499.87", 255.34), - ("status", 1000, 0.000, 2, "completed","processing",None), - ("updated_at", 1000, 0.000, 987, "2024-01-09", "2024-01-15", None), + ("order_id", 1000, 0.003, 999, "ORD-0", "ORD-999", None), + ("customer_id", 1000, 0.000, 100, "CUST-0", "CUST-99", None), + ("amount", 1000, 0.000, 892, "10.12", "499.87", 255.34), + ("status", 1000, 0.000, 2, "completed", "processing", None), + ("updated_at", 1000, 0.000, 987, "2024-01-09", "2024-01-15", None), ] for col, total, null_pct, distinct, min_v, max_v, mean_v in orders_profiles: null_count = int(total * null_pct) - db.add(ColumnProfile( - table_name="public.orders", - column_name=col, - null_count=null_count, - null_pct=null_pct, - distinct_count=distinct, - min_value=min_v, - max_value=max_v, - mean_value=mean_v, - profiled_at=profiled_at, - )) + db.add( + ColumnProfile( + table_name="public.orders", + column_name=col, + null_count=null_count, + null_pct=null_pct, + distinct_count=distinct, + min_value=min_v, + max_value=max_v, + mean_value=mean_v, + profiled_at=profiled_at, + ) + ) customers_profiles = [ - ("customer_id", 500, 0.000, 500, "CUST-0", "CUST-99", None), - ("email", 500, 0.000, 500, "user0@example.com", "user99@example.com", None), - ("country", 500, 0.000, 5, "CA", "US", None), - ("tier", 500, 0.000, 3, "enterprise", "pro", None), - ("created_at", 500, 0.000, 498, "2024-01-08", "2024-01-15", None), + ("customer_id", 500, 0.000, 500, "CUST-0", "CUST-99", None), + ("email", 500, 0.000, 500, "user0@example.com", "user99@example.com", None), + ("country", 500, 0.000, 5, "CA", "US", None), + ("tier", 500, 0.000, 3, "enterprise", "pro", None), + ("created_at", 500, 0.000, 498, "2024-01-08", "2024-01-15", None), ] for col, total, null_pct, distinct, min_v, max_v, mean_v in customers_profiles: null_count = int(total * null_pct) - db.add(ColumnProfile( - table_name="public.customers", - column_name=col, - null_count=null_count, - null_pct=null_pct, - distinct_count=distinct, - min_value=min_v, - max_value=max_v, - mean_value=mean_v, - profiled_at=profiled_at, - )) + db.add( + ColumnProfile( + table_name="public.customers", + column_name=col, + null_count=null_count, + null_pct=null_pct, + distinct_count=distinct, + min_value=min_v, + max_value=max_v, + mean_value=mean_v, + profiled_at=profiled_at, + ) + ) # ========================================================================= # 7. SUPPRESSION β€” one active maintenance window # ========================================================================= print(" πŸ”• Adding active suppression for analytics.daily_revenue…") - db.add(CheckSuppression( - table_name="analytics.daily_revenue", - suppressed_until=now + timedelta(hours=4), - reason="Planned warehouse maintenance window β€” batch backfill in progress", - created_at=now - timedelta(hours=1), - )) + db.add( + CheckSuppression( + table_name="analytics.daily_revenue", + suppressed_until=now + timedelta(hours=4), + reason="Planned warehouse maintenance window β€” batch backfill in progress", + created_at=now - timedelta(hours=1), + ) + ) db.commit() db.close() diff --git a/tests/test_alerts.py b/tests/test_alerts.py index b53cf8a..05384e0 100644 --- a/tests/test_alerts.py +++ b/tests/test_alerts.py @@ -12,17 +12,16 @@ import pytest from alerts.base import dispatch_alert, get_alert_dispatcher, is_alert_suppressed -from backend.models import AlertLog, CheckSuppression - +from backend.models import CheckSuppression # --------------------------------------------------------------------------- # Dispatcher unit tests β€” Slack # --------------------------------------------------------------------------- + class TestSlackDispatcher: def test_slack_sends_payload_on_success(self, monkeypatch): """SlackDispatcher.send() should POST to the webhook URL.""" - import os monkeypatch.setenv("SLACK_WEBHOOK_URL", "https://hooks.slack.com/test") from alerts.slack import SlackDispatcher @@ -75,6 +74,7 @@ def test_slack_returns_false_when_not_configured(self): # Dispatcher factory # --------------------------------------------------------------------------- + class TestDispatcherFactory: def test_factory_returns_slack_dispatcher(self): from alerts.slack import SlackDispatcher @@ -103,6 +103,7 @@ def test_factory_returns_pagerduty_dispatcher(self): # dispatch_alert routing β€” integration-style with mocked config # --------------------------------------------------------------------------- + class TestDispatchAlertRouting: @patch("config.loader.load_config") def test_routes_to_default_channel_when_no_rules(self, mock_load_config, monkeypatch): @@ -151,6 +152,7 @@ def test_routing_rule_matches_alert_type(self, mock_load_config, monkeypatch): # Suppression integration with dispatch # --------------------------------------------------------------------------- + class TestSuppressionWithAlerts: def test_alert_not_fired_when_table_suppressed(self, db_session): """ @@ -159,11 +161,13 @@ def test_alert_not_fired_when_table_suppressed(self, db_session): This test verifies the suppression check returns True. """ future = datetime.now(timezone.utc) + timedelta(hours=1) - db_session.add(CheckSuppression( - table_name="public.orders", - suppressed_until=future, - reason="planned maintenance", - )) + db_session.add( + CheckSuppression( + table_name="public.orders", + suppressed_until=future, + reason="planned maintenance", + ) + ) db_session.commit() suppressed = is_alert_suppressed(db_session, "public.orders") @@ -178,6 +182,7 @@ def test_alert_fired_when_no_suppression(self, db_session): # Teams dispatcher # --------------------------------------------------------------------------- + class TestTeamsDispatcher: def test_teams_sends_adaptive_card_on_success(self, monkeypatch): monkeypatch.setenv("TEAMS_WEBHOOK_URL", "https://outlook.office.com/webhook/test") @@ -223,6 +228,7 @@ def test_teams_returns_false_on_http_error(self, monkeypatch): # PagerDuty dispatcher # --------------------------------------------------------------------------- + class TestPagerDutyDispatcher: def test_pagerduty_triggers_event_on_success(self, monkeypatch): monkeypatch.setenv("PAGERDUTY_ROUTING_KEY", "abc123key") @@ -295,6 +301,7 @@ def test_pagerduty_warn_maps_to_warning_severity(self, monkeypatch): # Slack Block Kit payload structure # --------------------------------------------------------------------------- + class TestSlackBlockKit: def test_slack_payload_uses_attachments_with_color(self, monkeypatch): monkeypatch.setenv("SLACK_WEBHOOK_URL", "https://hooks.slack.com/test") diff --git a/tests/test_api_auth.py b/tests/test_api_auth.py index bf1323a..73555bc 100644 --- a/tests/test_api_auth.py +++ b/tests/test_api_auth.py @@ -12,7 +12,7 @@ """ import os -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest from fastapi.testclient import TestClient @@ -24,20 +24,17 @@ os.environ["METADATA_DB_TYPE"] = "sqlite" os.environ["DATABASE_URL"] = "sqlite:///:memory:" +from sqlalchemy.pool import StaticPool + from backend.main import app # noqa: E402 from backend.models import Base, get_db # noqa: E402 - -from sqlalchemy.pool import StaticPool - # --------------------------------------------------------------------------- # Override the DB dependency to use an in-memory SQLite DB # --------------------------------------------------------------------------- _test_engine = create_engine( - "sqlite:///:memory:", - connect_args={"check_same_thread": False}, - poolclass=StaticPool + "sqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool ) Base.metadata.create_all(bind=_test_engine) _TestSession = sessionmaker(bind=_test_engine) @@ -57,10 +54,12 @@ def override_get_db(): @pytest.fixture(scope="module") def client(): # Patch out the scheduler and warehouse connector so no real connections happen - with patch("backend.main.start_scheduler"), \ - patch("backend.main.shutdown_scheduler"), \ - patch("backend.main.command.upgrade"), \ - patch("connectors.base.get_warehouse_connector"): + with ( + patch("backend.main.start_scheduler"), + patch("backend.main.shutdown_scheduler"), + patch("backend.main.command.upgrade"), + patch("connectors.base.get_warehouse_connector"), + ): with TestClient(app, raise_server_exceptions=False) as c: yield c @@ -69,6 +68,7 @@ def client(): # Auth tests # --------------------------------------------------------------------------- + class TestAPIKeyAuth: def test_missing_api_key_returns_403(self, client): response = client.get("/freshness/") diff --git a/tests/test_connectors.py b/tests/test_connectors.py index 4f51bde..65259fe 100644 --- a/tests/test_connectors.py +++ b/tests/test_connectors.py @@ -16,18 +16,21 @@ class TestWarehouseConnectorFactory: def test_postgres_connector(self): connector = get_warehouse_connector() from connectors.postgres import PostgresConnector + assert isinstance(connector, PostgresConnector) @patch.dict("os.environ", {"WAREHOUSE_TYPE": "bigquery"}) def test_bigquery_connector(self): connector = get_warehouse_connector() from connectors.bigquery import BigQueryConnector + assert isinstance(connector, BigQueryConnector) @patch.dict("os.environ", {"WAREHOUSE_TYPE": "snowflake"}) def test_snowflake_connector(self): connector = get_warehouse_connector() from connectors.snowflake import SnowflakeConnector + assert isinstance(connector, SnowflakeConnector) @patch.dict("os.environ", {"WAREHOUSE_TYPE": "unsupported"}) diff --git a/tests/test_contracts.py b/tests/test_contracts.py index 0e8acc8..1e863dc 100644 --- a/tests/test_contracts.py +++ b/tests/test_contracts.py @@ -13,11 +13,11 @@ from backend.routers.checks import _safe_eval_assertion - # --------------------------------------------------------------------------- # _safe_eval_assertion β€” core assertion evaluator # --------------------------------------------------------------------------- + class TestSafeEvalAssertion: def test_result_equals_zero_passes(self): assert _safe_eval_assertion("result == 0", 0) is True @@ -57,6 +57,7 @@ def test_empty_assertion_raises_value_error(self): # Contract YAML loading β€” test that the parser handles well-formed files # --------------------------------------------------------------------------- + class TestContractYAMLLoading: def _write_contract(self, tmpdir: str, content: dict) -> str: path = os.path.join(tmpdir, "test_contract.yml") @@ -73,8 +74,12 @@ def test_well_formed_contract_loads(self): "owner": "data-eng@example.com", "columns": [ {"name": "id", "type": "integer", "nullable": False, "unique": True}, - {"name": "status", "type": "varchar", "nullable": False, - "allowed_values": ["pending", "confirmed", "shipped"]}, + { + "name": "status", + "type": "varchar", + "nullable": False, + "allowed_values": ["pending", "confirmed", "shipped"], + }, {"name": "amount", "type": "numeric", "nullable": False, "min": 0}, ], "rules": [ @@ -131,6 +136,7 @@ def test_malformed_yaml_raises_on_load(self): # allowed_values check logic (pure Python simulation) # --------------------------------------------------------------------------- + class TestAllowedValuesLogicSimulation: """ Simulate the allowed-values rule logic without hitting a real DB. diff --git a/tests/test_dbt_parser.py b/tests/test_dbt_parser.py index e200035..0a12721 100644 --- a/tests/test_dbt_parser.py +++ b/tests/test_dbt_parser.py @@ -8,23 +8,16 @@ def test_dbt_parser_success(tmp_path): # Create dummy dbt artifacts run_results_data = { - "metadata": { - "invocation_id": "test_inv_123", - "generated_at": "2023-10-27T10:00:00Z" - }, + "metadata": {"invocation_id": "test_inv_123", "generated_at": "2023-10-27T10:00:00Z"}, "results": [ - { - "unique_id": "model.my_project.my_model", - "status": "success", - "execution_time": 1.5 - }, + {"unique_id": "model.my_project.my_model", "status": "success", "execution_time": 1.5}, { "unique_id": "test.my_project.not_null_my_model_id", "status": "fail", "execution_time": 0.5, - "message": "Got 5 results, expected 0." - } - ] + "message": "Got 5 results, expected 0.", + }, + ], } manifest_data = { @@ -32,13 +25,13 @@ def test_dbt_parser_success(tmp_path): "model.my_project.my_model": { "resource_type": "model", "name": "my_model", - "alias": "my_model_table" + "alias": "my_model_table", }, "test.my_project.not_null_my_model_id": { "resource_type": "test", "column_name": "id", - "attached_node": "model.my_project.my_model" - } + "attached_node": "model.my_project.my_model", + }, } } diff --git a/tests/test_distribution_drift.py b/tests/test_distribution_drift.py index 075d706..dba9ed4 100644 --- a/tests/test_distribution_drift.py +++ b/tests/test_distribution_drift.py @@ -10,11 +10,11 @@ from backend.routers.distribution import _detect_drift - # --------------------------------------------------------------------------- # Null-% drift detection # --------------------------------------------------------------------------- + class TestNullPctDrift: def test_null_pct_below_threshold_no_drift(self): prev = {"null_pct": 0.02, "top_values": []} @@ -49,6 +49,7 @@ def test_null_pct_decrease_also_detected(self): # Categorical value-share shift # --------------------------------------------------------------------------- + class TestCategoricalDrift: def _make_cat(self, values: dict, null_pct=0.0) -> dict: """Helper: build a categorical snapshot dict.""" @@ -87,6 +88,7 @@ def test_value_disappearing_triggers_drift(self): # Numeric mean-shift detection # --------------------------------------------------------------------------- + class TestNumericDrift: def _make_num(self, mean, max_val=1000.0, null_pct=0.0) -> dict: return { diff --git a/tests/test_finops.py b/tests/test_finops.py index de30345..dd8c50a 100644 --- a/tests/test_finops.py +++ b/tests/test_finops.py @@ -10,6 +10,7 @@ def client(): return TestClient(app) + @patch("connectors.base.get_warehouse_connector") def test_finops_poll(mock_get_connector, client): # Mock the warehouse connector and its get_compute_costs method @@ -17,7 +18,9 @@ def test_finops_poll(mock_get_connector, client): mock_connector.get_compute_costs.return_value = 150.75 # Mock the environment variable for warehouse type to something other than postgres - with patch.dict("os.environ", {"WAREHOUSE_TYPE": "snowflake", "OBSERVAKIT_API_KEY": "observakit123"}): + with patch.dict( + "os.environ", {"WAREHOUSE_TYPE": "snowflake", "OBSERVAKIT_API_KEY": "observakit123"} + ): response = client.post("/finops/poll?days=7", headers={"X-API-Key": "observakit123"}) assert response.status_code == 200 @@ -29,10 +32,13 @@ def test_finops_poll(mock_get_connector, client): # Verify the connector method was called with correct days mock_connector.get_compute_costs.assert_called_once_with(days=7) + @patch("connectors.base.get_warehouse_connector") def test_finops_poll_postgres_skip(mock_get_connector, client): # If the warehouse is postgres, it should skip tracking - with patch.dict("os.environ", {"WAREHOUSE_TYPE": "postgres", "OBSERVAKIT_API_KEY": "observakit123"}): + with patch.dict( + "os.environ", {"WAREHOUSE_TYPE": "postgres", "OBSERVAKIT_API_KEY": "observakit123"} + ): response = client.post("/finops/poll?days=7", headers={"X-API-Key": "observakit123"}) assert response.status_code == 200 diff --git a/tests/test_schema_diff.py b/tests/test_schema_diff.py index cd414a4..e79a526 100644 --- a/tests/test_schema_diff.py +++ b/tests/test_schema_diff.py @@ -2,7 +2,6 @@ ObservaKit β€” Schema Drift Detection Tests """ - from backend.models import SchemaDiff, SchemaSnapshot from backend.routers.schema_diff import _compute_diff diff --git a/tests/test_stage3_production.py b/tests/test_stage3_production.py index 4a59efa..d26965f 100644 --- a/tests/test_stage3_production.py +++ b/tests/test_stage3_production.py @@ -4,33 +4,37 @@ 2. Resilient Query Retries (Tenacity). """ -import pytest +from datetime import datetime, timedelta, timezone from unittest.mock import MagicMock, patch -from datetime import datetime, timezone, timedelta + from alerts.base import dispatch_alert -from connectors.base import resilient_query from backend.models import AlertLog, CheckSuppression +from connectors.base import resilient_query + class TestStage3Resiliency: - def test_dispatch_alert_logs_to_db(self, db_session, monkeypatch): """Verify that dispatch_alert creates an AlertLog record when db is provided.""" # Mock dispatcher to return True (success) mock_dispatcher = MagicMock() mock_dispatcher.send.return_value = True - - monkeypatch.setattr("alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher) + + monkeypatch.setattr( + "alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher + ) # Mock config to avoid file lookups - monkeypatch.setattr("config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}}) - + monkeypatch.setattr( + "config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}} + ) + dispatch_alert( alert_type="test_alert", message="Hello World", table_name="public.users", db=db_session, - severity="fail" + severity="fail", ) - + # Verify log entry exists log = db_session.query(AlertLog).filter(AlertLog.table_name == "public.users").first() assert log is not None @@ -42,18 +46,22 @@ def test_dispatch_alert_deduplication(self, db_session, monkeypatch): """Verify that dispatch_alert honors the 1-hour deduplication window.""" mock_dispatcher = MagicMock() mock_dispatcher.send.return_value = True - monkeypatch.setattr("alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher) - monkeypatch.setattr("config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}}) - + monkeypatch.setattr( + "alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher + ) + monkeypatch.setattr( + "config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}} + ) + # 1. Dispatch first alert dispatch_alert("quality", "First fail", "public.orders", db=db_session) assert mock_dispatcher.send.call_count == 1 - + # 2. Dispatch second alert immediately (same type/table) dispatch_alert("quality", "Second fail", "public.orders", db=db_session) # Should NOT have sent again assert mock_dispatcher.send.call_count == 1 - + # 3. Dispatch alert for different table dispatch_alert("quality", "Other fail", "public.users", db=db_session) assert mock_dispatcher.send.call_count == 2 @@ -62,26 +70,32 @@ def test_dispatch_alert_suppression(self, db_session, monkeypatch): """Verify that dispatch_alert honors active CheckSuppression windows.""" mock_dispatcher = MagicMock() mock_dispatcher.send.return_value = True - monkeypatch.setattr("alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher) - monkeypatch.setattr("config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}}) - + monkeypatch.setattr( + "alerts.base.get_alert_dispatcher", lambda *args, **kwargs: mock_dispatcher + ) + monkeypatch.setattr( + "config.loader.load_config", lambda *args: {"alerts": {"default_channel": "slack"}} + ) + # Add suppression - db_session.add(CheckSuppression( - table_name="public.orders", - suppressed_until=datetime.now(timezone.utc) + timedelta(hours=1), - reason="testing" - )) + db_session.add( + CheckSuppression( + table_name="public.orders", + suppressed_until=datetime.now(timezone.utc) + timedelta(hours=1), + reason="testing", + ) + ) db_session.commit() - + dispatch_alert("quality", "Fail", "public.orders", db=db_session) # Should NOT have sent assert mock_dispatcher.send.call_count == 0 def test_resilient_query_retries(self): """Verify that the resilient_query decorator retries on failure.""" - + call_count = 0 - + @resilient_query() def unstable_function(): nonlocal call_count @@ -89,10 +103,10 @@ def unstable_function(): if call_count < 3: raise RuntimeError("Transient Error") return "Success" - + # We need to monkeypatch the wait to make the test fast with patch("tenacity.nap.time.sleep", return_value=None): result = unstable_function() - + assert result == "Success" assert call_count == 3 diff --git a/tests/test_suppressions.py b/tests/test_suppressions.py index 5e91cbd..8985fa9 100644 --- a/tests/test_suppressions.py +++ b/tests/test_suppressions.py @@ -7,8 +7,6 @@ from datetime import datetime, timedelta, timezone -import pytest - from alerts.base import is_alert_deduped, is_alert_suppressed from backend.models import AlertLog, CheckSuppression diff --git a/tests/test_volume.py b/tests/test_volume.py index 3c232e1..91622ec 100644 --- a/tests/test_volume.py +++ b/tests/test_volume.py @@ -67,11 +67,7 @@ def test_volume_history_ordering(self, db_session): db_session.add(record) db_session.commit() - records = ( - db_session.query(VolumeRecord) - .order_by(VolumeRecord.recorded_at.desc()) - .all() - ) + records = db_session.query(VolumeRecord).order_by(VolumeRecord.recorded_at.desc()).all() assert len(records) == 5 # Most recent should have the lowest row count offset assert records[0].row_count == 1000