diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 884b8f5..6c1e5bf 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -31,6 +31,7 @@ Python style Patterns - `__init__` methods must not raise exceptions; defer validation and connection to first use (lazy init) - Writers: inherit from `Writer(ABC)`, implement `write(topic, message) -> (bool, str|None)` and `check_health() -> (bool, str)` +- PostgreSQL: `WriterPostgres` and `ReaderPostgres` cache a single connection per instance - Route dispatch via `ROUTE_MAP` dict mapping routes to handler functions in `event_gate_lambda.py` and `event_stats_lambda.py` - Separate business logic from environment access (env vars, file I/O, network calls) - No duplicate validation; centralize parsing in one layer where practical @@ -50,3 +51,6 @@ Testing Quality gates (run after changes, fix only if below threshold) - Run all quality gates at once: `make qa` - Once a quality gate passes, do not re-run it in different scenarios + +Git workflow +- Do NOT create git commits; committing is the developer's responsibility diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 68e0c70..e32ab77 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,6 +14,10 @@ updates: commit-message: prefix: "chore" include: "scope" + groups: + github-actions: + patterns: + - "*" - package-ecosystem: "pip" directory: "/" @@ -31,3 +35,7 @@ updates: include: "scope" allow: - dependency-type: "direct" + groups: + python-dependencies: + patterns: + - "*" diff --git a/requirements.txt b/requirements.txt index 0e92380..23cdd46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,9 @@ cryptography==46.0.7 jsonschema==4.25.1 PyJWT==2.12.1 requests==2.32.5 +aiosql==15.0 boto3==1.42.91 +botocore==1.42.91 confluent-kafka==2.14.0 moto[s3,secretsmanager,events]==5.1.22 testcontainers==4.14.2 diff --git a/src/handlers/handler_topic.py b/src/handlers/handler_topic.py index a0e35f0..7eae543 100644 --- a/src/handlers/handler_topic.py +++ b/src/handlers/handler_topic.py @@ -29,6 +29,7 @@ from src.handlers.handler_token import HandlerToken from src.utils.conf_path import CONF_DIR from src.utils.config_loader import load_access_config +from src.utils.constants import TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TEST from src.utils.utils import build_error_response from src.writers.writer import Writer @@ -69,11 +70,11 @@ def with_load_topic_schemas(self) -> "HandlerTopic": logger.debug("Loading topic schemas from %s.", topic_schemas_dir) with open(os.path.join(topic_schemas_dir, "runs.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.runs"] = json.load(file) + self.topics[TOPIC_RUNS] = json.load(file) with open(os.path.join(topic_schemas_dir, "dlchange.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.dlchange"] = json.load(file) + self.topics[TOPIC_DLCHANGE] = json.load(file) with open(os.path.join(topic_schemas_dir, "test.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.test"] = json.load(file) + self.topics[TOPIC_TEST] = json.load(file) logger.debug("Loaded topic schemas successfully.") return self diff --git a/src/readers/reader_postgres.py b/src/readers/reader_postgres.py index 7c5de0c..c1a488c 100644 --- a/src/readers/reader_postgres.py +++ b/src/readers/reader_postgres.py @@ -17,69 +17,57 @@ """Postgres reader for run/job statistics.""" import logging -import os import time +from dataclasses import dataclass from datetime import datetime, timezone +from functools import cached_property +from pathlib import Path from typing import Any +import aiosql from botocore.exceptions import BotoCoreError, ClientError from src.utils.constants import ( POSTGRES_DEFAULT_LIMIT, POSTGRES_DEFAULT_WINDOW_MS, POSTGRES_MAX_LIMIT, + POSTGRES_STATEMENT_TIMEOUT_MS, + REQUIRED_CONNECTION_FIELDS, ) -from src.utils.utils import load_postgres_config - -try: - import psycopg2 - from psycopg2 import Error as PsycopgError - from psycopg2 import sql as psycopg2_sql -except ImportError: - psycopg2 = None # type: ignore - psycopg2_sql = None # type: ignore - - class PsycopgError(Exception): # type: ignore - """Shim psycopg2 error base when psycopg2 is not installed.""" - +from src.utils.postgres_base import PsycopgError, PostgresBase logger = logging.getLogger(__name__) -_RUNS_SQL_BASE = ( - "SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app," - " r.source_app_version, r.environment," - " r.timestamp_start AS run_timestamp_start," - " r.timestamp_end AS run_timestamp_end," - " j.internal_id, j.country, j.catalog_id, j.status," - " j.timestamp_start, j.timestamp_end, j.message, j.additional_info" - " FROM public_cps_za_runs_jobs j" - " INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id" - " WHERE r.timestamp_start >= %s AND r.timestamp_start <= %s" -) +_SQL_DIR = Path(__file__).parent / "sql" + -_RUNS_SQL_TAIL = " ORDER BY j.internal_id DESC LIMIT %s" +@dataclass(frozen=True) +class ReaderQueries: + """Typed holder for reader SQL query strings loaded via aiosql.""" -_RUNS_SQL = _RUNS_SQL_BASE + _RUNS_SQL_TAIL -_RUNS_SQL_WITH_CURSOR = _RUNS_SQL_BASE + " AND j.internal_id < %s" + _RUNS_SQL_TAIL + get_stats: str + get_stats_with_cursor: str -class ReaderPostgres: +class ReaderPostgres(PostgresBase): """Read-only Postgres accessor for run/job statistics.""" def __init__(self) -> None: - self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") - self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") - self._db_config: dict[str, Any] | None = None + super().__init__() logger.debug("Initialized PostgreSQL reader.") - def _load_db_config(self) -> dict[str, Any]: - """Load database config from AWS Secrets Manager if not already loaded.""" - if self._db_config is None: - self._db_config = load_postgres_config(self._secret_name, self._secret_region) - config = self._db_config - if config is None: - raise RuntimeError("Failed to load database configuration.") - return config + def _connect_options(self) -> str | None: + """Set statement timeout and read-only mode for reader connections.""" + return f"-c statement_timeout={POSTGRES_STATEMENT_TIMEOUT_MS}" " -c default_transaction_read_only=on" + + @cached_property + def _queries(self) -> ReaderQueries: + """Load SQL queries from the `sql/` directory via aiosql.""" + queries = aiosql.from_path(_SQL_DIR, "psycopg2") + return ReaderQueries( + get_stats=queries.get_stats.sql, + get_stats_with_cursor=queries.get_stats_with_cursor.sql, + ) def read_stats( self, @@ -102,42 +90,25 @@ def read_stats( Raises: RuntimeError: On database connectivity or query errors. """ - db_config = self._load_db_config() - required_keys = ("database", "host", "user", "password", "port") - missing_keys = [key for key in required_keys if not db_config.get(key)] - if missing_keys: - raise RuntimeError(f"PostgreSQL config missing: {', '.join(missing_keys)}.") - if psycopg2 is None: - raise RuntimeError("psycopg2 is not available.") + config = self._pg_config + if not config.get("database"): + raise RuntimeError("PostgreSQL config missing: database.") + if not all(config.get(field) for field in REQUIRED_CONNECTION_FIELDS): + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not config.get(field)] + raise RuntimeError(f"PostgreSQL config missing: {', '.join(missing)}.") limit = max(1, min(limit, POSTGRES_MAX_LIMIT)) now_ms = int(time.time() * 1000) ts_start = timestamp_start if timestamp_start is not None else (now_ms - POSTGRES_DEFAULT_WINDOW_MS) ts_end = timestamp_end if timestamp_end is not None else now_ms - params: list[Any] = [ts_start, ts_end] - if cursor is not None: - params.append(cursor) - query = psycopg2_sql.SQL(_RUNS_SQL_WITH_CURSOR) - else: - query = psycopg2_sql.SQL(_RUNS_SQL) - params.append(limit + 1) - try: - with psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], - options="-c statement_timeout=30000 -c default_transaction_read_only=on", - ) as connection: - with connection.cursor() as db_cursor: - db_cursor.execute(query, params) - col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] - raw_rows = db_cursor.fetchall() + col_names, raw_rows = self._execute_with_retry( + lambda conn: self._run_stats_query(conn, ts_start, ts_end, cursor, limit) + ) except PsycopgError as exc: - raise RuntimeError(f"Database query failed: {exc}") from exc + self._close_connection() + raise RuntimeError(f"Database query error: {exc}") from exc rows = [dict(zip(col_names, row, strict=True)) for row in raw_rows] @@ -160,6 +131,41 @@ def read_stats( logger.debug("Stats query returned %d rows.", len(rows)) return rows, pagination + def _run_stats_query( + self, + connection: Any, + ts_start: int, + ts_end: int, + cursor: int | None, + limit: int, + ) -> tuple[list[str], list[tuple[Any, ...]]]: + """Execute the stats SQL query and return column names and raw rows.""" + try: + with connection.cursor() as db_cursor: + if cursor is not None: + db_cursor.execute( + self._queries.get_stats_with_cursor, + {"ts_start": ts_start, "ts_end": ts_end, "cursor_id": cursor, "lim": limit + 1}, + ) + else: + db_cursor.execute( + self._queries.get_stats, + {"ts_start": ts_start, "ts_end": ts_end, "lim": limit + 1}, + ) + if db_cursor.description is None: + raise RuntimeError("Stats query returned no result description.") + col_names = [desc[0] for desc in db_cursor.description] + raw_rows = db_cursor.fetchall() + finally: + # Rollback closes the implicit transaction opened by the SELECT, + # leaving the cached connection in a clean idle state for reuse. + try: + connection.rollback() + except PsycopgError: + logger.debug("Failed to close the implicit transaction. Closing cached connection.", exc_info=True) + self._close_connection() + return col_names, raw_rows + @staticmethod def _format_row(row: dict[str, Any]) -> dict[str, Any]: """Add computed columns to a result row. @@ -225,14 +231,14 @@ def check_health(self) -> tuple[bool, str]: return False, "postgres secret not configured" try: - db_config = self._load_db_config() + pg_config = self._pg_config except (BotoCoreError, ClientError, RuntimeError, ValueError, KeyError) as err: return False, str(err) - if not db_config.get("database"): + if not pg_config.get("database"): return False, "database not configured" - missing = [f for f in ("host", "user", "password", "port") if not db_config.get(f)] + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing: return False, f"{missing[0]} not configured" diff --git a/src/readers/sql/stats.sql b/src/readers/sql/stats.sql new file mode 100644 index 0000000..2f40afd --- /dev/null +++ b/src/readers/sql/stats.sql @@ -0,0 +1,28 @@ +-- name: get_stats(ts_start, ts_end, lim) +-- Get run/job statistics with keyset pagination. +SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app, + r.source_app_version, r.environment, + r.timestamp_start AS run_timestamp_start, + r.timestamp_end AS run_timestamp_end, + j.internal_id, j.country, j.catalog_id, j.status, + j.timestamp_start, j.timestamp_end, j.message, j.additional_info + FROM public_cps_za_runs_jobs j + INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id + WHERE r.timestamp_start >= :ts_start AND r.timestamp_start <= :ts_end + ORDER BY j.internal_id DESC + LIMIT :lim; + +-- name: get_stats_with_cursor(ts_start, ts_end, cursor_id, lim) +-- Get run/job statistics with cursor-based keyset pagination. +SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app, + r.source_app_version, r.environment, + r.timestamp_start AS run_timestamp_start, + r.timestamp_end AS run_timestamp_end, + j.internal_id, j.country, j.catalog_id, j.status, + j.timestamp_start, j.timestamp_end, j.message, j.additional_info + FROM public_cps_za_runs_jobs j + INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id + WHERE r.timestamp_start >= :ts_start AND r.timestamp_start <= :ts_end + AND j.internal_id < :cursor_id + ORDER BY j.internal_id DESC + LIMIT :lim; diff --git a/src/utils/config_loader.py b/src/utils/config_loader.py index a233b4a..71cdca7 100644 --- a/src/utils/config_loader.py +++ b/src/utils/config_loader.py @@ -23,6 +23,8 @@ from boto3.resources.base import ServiceResource +from src.utils.constants import TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TEST + logger = logging.getLogger(__name__) @@ -76,9 +78,9 @@ def load_topic_names(conf_dir: str) -> list[str]: List of topic name strings. """ filename_to_topic = { - "runs.json": "public.cps.za.runs", - "dlchange.json": "public.cps.za.dlchange", - "test.json": "public.cps.za.test", + "runs.json": TOPIC_RUNS, + "dlchange.json": TOPIC_DLCHANGE, + "test.json": TOPIC_TEST, } schemas_dir = os.path.join(conf_dir, "topic_schemas") topics: list[str] = [] diff --git a/src/utils/constants.py b/src/utils/constants.py index fa29fce..2260105 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -16,7 +16,7 @@ """Constants and enums used across the project.""" -from typing import Any +from typing import TypedDict # Configuration keys TOKEN_PROVIDER_URL_KEY = "token_provider_url" @@ -24,16 +24,35 @@ TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url" SSL_CA_BUNDLE_KEY = "ssl_ca_bundle" +# Postgres connection +POSTGRES_CONNECT_TIMEOUT_SECONDS = 5 +POSTGRES_STATEMENT_TIMEOUT_MS = 30000 +POSTGRES_MAX_RETRIES = 2 +REQUIRED_CONNECTION_FIELDS = ("host", "user", "password", "port") + # Postgres stats defaults POSTGRES_DEFAULT_LIMIT = 50 POSTGRES_MAX_LIMIT = 1000 POSTGRES_DEFAULT_WINDOW_MS = 7 * 24 * 60 * 60 * 1000 # 7 days in milliseconds -SUPPORTED_TOPICS: list[str] = ["public.cps.za.runs"] +# Topic name constants +TOPIC_RUNS = "public.cps.za.runs" +TOPIC_DLCHANGE = "public.cps.za.dlchange" +TOPIC_TEST = "public.cps.za.test" + +SUPPORTED_TOPICS: list[str] = [TOPIC_RUNS] + + +class TopicTableConfig(TypedDict, total=False): + """Structure describing a topic's PostgreSQL table mapping.""" + + main: str + jobs: str + columns: dict[str, list[str]] + -# Maps topic names to their PostgreSQL table(s) -TOPIC_TABLE_MAP: dict[str, dict[str, Any]] = { - "public.cps.za.runs": { +TOPIC_TABLE_MAP: dict[str, TopicTableConfig] = { + TOPIC_RUNS: { "main": "public_cps_za_runs", "jobs": "public_cps_za_runs_jobs", "columns": { @@ -60,7 +79,7 @@ ], }, }, - "public.cps.za.dlchange": { + TOPIC_DLCHANGE: { "main": "public_cps_za_dlchange", "columns": { "main": [ @@ -80,7 +99,7 @@ ], }, }, - "public.cps.za.test": { + TOPIC_TEST: { "main": "public_cps_za_test", "columns": { "main": [ diff --git a/src/utils/postgres_base.py b/src/utils/postgres_base.py new file mode 100644 index 0000000..b37340e --- /dev/null +++ b/src/utils/postgres_base.py @@ -0,0 +1,164 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Shared base class for PostgreSQL reader and writer.""" + +import logging +import os +from collections.abc import Callable +from functools import cached_property +from typing import Any, TypedDict + +from src.utils.constants import POSTGRES_CONNECT_TIMEOUT_SECONDS, POSTGRES_MAX_RETRIES +from src.utils.utils import load_postgres_config + +try: + import psycopg2 + from psycopg2 import Error as PsycopgError + from psycopg2 import OperationalError +except ImportError: + psycopg2 = None # type: ignore + + class PsycopgError(Exception): # type: ignore + """Shim psycopg2 error base when psycopg2 is not installed.""" + + class OperationalError(PsycopgError): # type: ignore + """Shim psycopg2 OperationalError when psycopg2 is not installed.""" + + +logger = logging.getLogger(__name__) + + +class PostgresConfig(TypedDict): + """PostgreSQL connection configuration.""" + + database: str + host: str + user: str + password: str + port: int + + +def _build_postgres_config(aws_secret: dict[str, Any]) -> PostgresConfig: + """Validate and build a `PostgresConfig` from an AWS Secrets Manager dict. + Args: + aws_secret: Dictionary loaded from AWS Secrets Manager. + Returns: + A validated `PostgresConfig`. + Raises: + ValueError: If `database` is present but other required fields are + missing or invalid. + """ + database = str(aws_secret.get("database", "")) + + if database: + required_fields = ("host", "user", "password", "port") + missing_fields = [field for field in required_fields if not aws_secret.get(field)] + if missing_fields: + raise ValueError(f"Missing PostgreSQL secret fields: {', '.join(missing_fields)}") + + try: + port = int(aws_secret["port"]) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid PostgreSQL port: {aws_secret.get('port')}") from exc + else: + port = int(aws_secret.get("port", 0)) + + return PostgresConfig( + database=database, + host=str(aws_secret.get("host", "")), + user=str(aws_secret.get("user", "")), + password=str(aws_secret.get("password", "")), + port=port, + ) + + +class PostgresBase: + """Shared base for PostgreSQL reader and writer.""" + + def __init__(self) -> None: + self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") + self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") + # Any because psycopg2.extensions.connection is unavailable when psycopg2 is not installed. + self._connection: Any | None = None + + @cached_property + def _pg_config(self) -> PostgresConfig: + """Load database config from AWS Secrets Manager on first access.""" + aws_secret = load_postgres_config(self._secret_name, self._secret_region) + return _build_postgres_config(aws_secret) + + def _connect_options(self) -> str | None: + """Return psycopg2 connection `options` string. + Reader overrides this to inject connection-level settings. + """ + return None + + def _get_connection(self) -> Any: + """Return a cached database connection, creating one if needed.""" + if self._connection is not None and not self._connection.closed: + return self._connection + if psycopg2 is None: + raise RuntimeError("psycopg2 is not installed.") + pg_config = self._pg_config + connect_kwargs: dict[str, str | int] = { + "database": pg_config["database"], + "host": pg_config["host"], + "user": pg_config["user"], + "password": pg_config["password"], + "port": pg_config["port"], + "connect_timeout": POSTGRES_CONNECT_TIMEOUT_SECONDS, + } + options = self._connect_options() + if options: + connect_kwargs["options"] = options + self._connection = psycopg2.connect(**connect_kwargs) + logger.debug("New PostgreSQL connection established.") + return self._connection + + def _close_connection(self) -> None: + """Close and discard the cached connection.""" + conn_to_close = self._connection + self._connection = None + if conn_to_close is not None: + try: + conn_to_close.close() + except (PsycopgError, OSError): + logger.debug("Failed to close PostgreSQL connection.") + + def _execute_with_retry[T](self, operation: Callable[..., T], *, retry: bool = True) -> T: + """Run `operation(connection)` with one retry on `OperationalError`. + Args: + operation: Callable receiving a psycopg2 connection. + retry: Whether to retry on `OperationalError`. Disable for + non-idempotent operations (e.g. writes) to avoid duplicates. + Returns: + Whatever `operation` returns on success. + Raises: + RuntimeError: If the retry is also exhausted. + """ + max_attempts = POSTGRES_MAX_RETRIES if retry else 1 + last_exc: OperationalError | None = None + for attempt in range(max_attempts): + try: + connection = self._get_connection() + return operation(connection) + except OperationalError as exc: + last_exc = exc + self._close_connection() + if attempt < max_attempts - 1: + logger.warning("PostgreSQL connection lost, reconnecting.") + raise RuntimeError(f"Database connection failed after {max_attempts} attempts: {last_exc}") from last_exc diff --git a/src/utils/utils.py b/src/utils/utils.py index 8d80cfe..e19d494 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -95,6 +95,6 @@ def load_postgres_config(secret_name: str, secret_region: str) -> dict[str, Any] aws_secrets = boto3.Session().client(service_name="secretsmanager", region_name=secret_region) postgres_secret = aws_secrets.get_secret_value(SecretId=secret_name)["SecretString"] - config: dict[str, Any] = json.loads(postgres_secret) + aws_pg_secret: dict[str, Any] = json.loads(postgres_secret) logger.debug("Loaded PostgreSQL config from Secrets Manager.") - return config + return aws_pg_secret diff --git a/src/writers/sql/inserts.sql b/src/writers/sql/inserts.sql new file mode 100644 index 0000000..f6db68d --- /dev/null +++ b/src/writers/sql/inserts.sql @@ -0,0 +1,37 @@ +-- name: insert_dlchange(event_id, tenant_id, source_app, source_app_version, environment, timestamp_event, country, catalog_id, operation, location, format, format_options, additional_info)! +-- Insert a dlchange-style event row. +INSERT INTO public_cps_za_dlchange + (event_id, tenant_id, source_app, source_app_version, environment, + timestamp_event, country, catalog_id, operation, + "location", "format", format_options, additional_info) +VALUES + (:event_id, :tenant_id, :source_app, :source_app_version, :environment, + :timestamp_event, :country, :catalog_id, :operation, + :location, :format, :format_options, :additional_info); + +-- name: insert_run(event_id, job_ref, tenant_id, source_app, source_app_version, environment, timestamp_start, timestamp_end)! +-- Insert a run event row. +INSERT INTO public_cps_za_runs + (event_id, job_ref, tenant_id, source_app, source_app_version, + environment, timestamp_start, timestamp_end) +VALUES + (:event_id, :job_ref, :tenant_id, :source_app, :source_app_version, + :environment, :timestamp_start, :timestamp_end); + +-- name: insert_run_job(event_id, country, catalog_id, status, timestamp_start, timestamp_end, message, additional_info)! +-- Insert a run job row. +INSERT INTO public_cps_za_runs_jobs + (event_id, country, catalog_id, status, + timestamp_start, timestamp_end, message, additional_info) +VALUES + (:event_id, :country, :catalog_id, :status, + :timestamp_start, :timestamp_end, :message, :additional_info); + +-- name: insert_test(event_id, tenant_id, source_app, environment, timestamp_event, additional_info)! +-- Insert a test event row. +INSERT INTO public_cps_za_test + (event_id, tenant_id, source_app, environment, + timestamp_event, additional_info) +VALUES + (:event_id, :tenant_id, :source_app, :environment, + :timestamp_event, :additional_info); diff --git a/src/writers/writer_postgres.py b/src/writers/writer_postgres.py index c0533b1..e062361 100644 --- a/src/writers/writer_postgres.py +++ b/src/writers/writer_postgres.py @@ -18,230 +18,136 @@ import json import logging -import os +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path from typing import Any +import aiosql from botocore.exceptions import BotoCoreError, ClientError -from src.utils.constants import TOPIC_TABLE_MAP +from src.utils.constants import REQUIRED_CONNECTION_FIELDS, TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TABLE_MAP, TOPIC_TEST +from src.utils.postgres_base import PsycopgError, PostgresBase +import src.utils.postgres_base as _pb from src.utils.trace_logging import log_payload_at_trace -from src.utils.utils import load_postgres_config from src.writers.writer import Writer -try: - import psycopg2 - from psycopg2 import Error as PsycopgError - from psycopg2 import sql as psycopg2_sql -except ImportError: - psycopg2 = None # type: ignore - psycopg2_sql = None # type: ignore +logger = logging.getLogger(__name__) - class PsycopgError(Exception): # type: ignore - """Shim psycopg2 error base when psycopg2 is not installed.""" +_SQL_DIR = Path(__file__).parent / "sql" -logger = logging.getLogger(__name__) +@dataclass(frozen=True) +class WriterQueries: + """Typed holder for writer SQL query strings loaded via aiosql.""" + insert_dlchange: str + insert_run: str + insert_run_job: str + insert_test: str -class WriterPostgres(Writer): - """Postgres writer for storing events in PostgreSQL database. - Database credentials are loaded lazily from AWS Secrets Manager on first use. - """ + +class WriterPostgres(Writer, PostgresBase): + """Postgres writer for storing events in PostgreSQL database.""" def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) - self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") - self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") - self._db_config: dict[str, Any | None] | None = None + Writer.__init__(self, config) + PostgresBase.__init__(self) logger.debug("Initialized PostgreSQL writer.") - def _load_db_config(self) -> dict[str, Any]: - """Load database config from AWS Secrets Manager if not already loaded.""" - if self._db_config is None: - self._db_config = load_postgres_config(self._secret_name, self._secret_region) - return self._db_config # type: ignore[return-value] + @cached_property + def _queries(self) -> WriterQueries: + """Load SQL queries from the `sql/` directory via aiosql.""" + queries = aiosql.from_path(_SQL_DIR, "psycopg2") + return WriterQueries( + insert_dlchange=queries.insert_dlchange.sql, + insert_run=queries.insert_run.sql, + insert_run_job=queries.insert_run_job.sql, + insert_test=queries.insert_test.sql, + ) - def _postgres_edla_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: + def _insert_dlchange(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a dlchange style event row. Args: cursor: Database cursor. - table: Target table name. message: Event payload. """ - logger.debug("Sending to Postgres - %s.", table) - query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - tenant_id, - source_app, - source_app_version, - environment, - timestamp_event, - country, - catalog_id, - operation, - "location", - "format", - format_options, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table)) + logger.debug("Sending to Postgres - dlchange.") cursor.execute( - query, - ( - message["event_id"], - message["tenant_id"], - message["source_app"], - message["source_app_version"], - message["environment"], - message["timestamp_event"], - message.get("country", ""), - message["catalog_id"], - message["operation"], - message.get("location"), - message["format"], - (json.dumps(message.get("format_options")) if "format_options" in message else None), - (json.dumps(message.get("additional_info")) if "additional_info" in message else None), - ), + self._queries.insert_dlchange, + { + "event_id": message["event_id"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "source_app_version": message["source_app_version"], + "environment": message["environment"], + "timestamp_event": message["timestamp_event"], + "country": message.get("country", ""), + "catalog_id": message["catalog_id"], + "operation": message["operation"], + "location": message.get("location"), + "format": message["format"], + "format_options": (json.dumps(message.get("format_options")) if "format_options" in message else None), + "additional_info": ( + json.dumps(message.get("additional_info")) if "additional_info" in message else None + ), + }, ) - def _postgres_run_write(self, cursor: Any, table_runs: str, table_jobs: str, message: dict[str, Any]) -> None: + def _insert_run(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a run event row plus related job rows. Args: cursor: Database cursor. - table_runs: Runs table name. - table_jobs: Jobs table name. message: Event payload (includes jobs array). """ - logger.debug("Sending to Postgres - %s and %s.", table_runs, table_jobs) - runs_query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - job_ref, - tenant_id, - source_app, - source_app_version, - environment, - timestamp_start, - timestamp_end - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table_runs)) + logger.debug("Sending to Postgres - runs.") cursor.execute( - runs_query, - ( - message["event_id"], - message["job_ref"], - message["tenant_id"], - message["source_app"], - message["source_app_version"], - message["environment"], - message["timestamp_start"], - message["timestamp_end"], - ), + self._queries.insert_run, + { + "event_id": message["event_id"], + "job_ref": message["job_ref"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "source_app_version": message["source_app_version"], + "environment": message["environment"], + "timestamp_start": message["timestamp_start"], + "timestamp_end": message["timestamp_end"], + }, ) - - jobs_query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - country, - catalog_id, - status, - timestamp_start, - timestamp_end, - message, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table_jobs)) for job in message["jobs"]: cursor.execute( - jobs_query, - ( - message["event_id"], - job.get("country", ""), - job["catalog_id"], - job["status"], - job["timestamp_start"], - job["timestamp_end"], - job.get("message"), - (json.dumps(job.get("additional_info")) if "additional_info" in job else None), - ), + self._queries.insert_run_job, + { + "event_id": message["event_id"], + "country": job.get("country", ""), + "catalog_id": job["catalog_id"], + "status": job["status"], + "timestamp_start": job["timestamp_start"], + "timestamp_end": job["timestamp_end"], + "message": job.get("message"), + "additional_info": (json.dumps(job.get("additional_info")) if "additional_info" in job else None), + }, ) - def _postgres_test_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: + def _insert_test(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a test topic row. Args: cursor: Database cursor. - table: Target table name. message: Event payload. """ - logger.debug("Sending to Postgres - %s.", table) - query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - tenant_id, - source_app, - environment, - timestamp_event, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table)) + logger.debug("Sending to Postgres - test.") cursor.execute( - query, - ( - message["event_id"], - message["tenant_id"], - message["source_app"], - message["environment"], - message["timestamp"], - (json.dumps(message.get("additional_info")) if "additional_info" in message else None), - ), + self._queries.insert_test, + { + "event_id": message["event_id"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "environment": message["environment"], + "timestamp_event": message["timestamp"], + "additional_info": ( + json.dumps(message.get("additional_info")) if "additional_info" in message else None + ), + }, ) def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | None]: @@ -253,19 +159,19 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N Tuple of (success: bool, error_message: str | None). """ try: - db_config = self._load_db_config() + pg_config = self._pg_config - if not db_config.get("database"): + if not pg_config.get("database"): logger.debug("No Postgres - skipping Postgres writer.") return True, None - missing = [f for f in ("host", "user", "password", "port") if not db_config.get(f)] + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing: msg = f"PostgreSQL connection field '{missing[0]}' not configured." logger.error(msg) return False, msg - if psycopg2 is None: + if not self._is_psycopg2_available(): logger.debug("psycopg2 not available - skipping actual Postgres write.") return True, None @@ -276,51 +182,49 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N logger.error(msg) return False, msg - table_info = TOPIC_TABLE_MAP[topic_name] - - with psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], - ) as connection: - with connection.cursor() as cursor: - if topic_name == "public.cps.za.dlchange": - self._postgres_edla_write(cursor, table_info["main"], message) - elif topic_name == "public.cps.za.runs": - self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) - elif topic_name == "public.cps.za.test": - self._postgres_test_write(cursor, table_info["main"], message) - - connection.commit() + self._execute_with_retry(lambda conn: self._write_topic(conn, topic_name, message), retry=False) except (RuntimeError, PsycopgError, BotoCoreError, ClientError, ValueError, KeyError) as e: - err_msg = f"The Postgres writer failed with unknown error: {str(e)}" - logger.exception(err_msg) + self._close_connection() + err_msg = f"The Postgres writer failed with unknown error: {e!s}" + logger.exception("The Postgres writer failed with unknown error: %s.", e) return False, err_msg return True, None + def _write_topic(self, connection: Any, topic_name: str, message: dict[str, Any]) -> None: + """Execute the insert for the given topic inside a transaction.""" + with connection.cursor() as cursor: + if topic_name == TOPIC_DLCHANGE: + self._insert_dlchange(cursor, message) + elif topic_name == TOPIC_RUNS: + self._insert_run(cursor, message) + elif topic_name == TOPIC_TEST: + self._insert_test(cursor, message) + connection.commit() + + @staticmethod + def _is_psycopg2_available() -> bool: + """Check whether psycopg2 is importable.""" + return _pb.psycopg2 is not None + def check_health(self) -> tuple[bool, str]: """Check PostgreSQL writer health. Returns: Tuple of (is_healthy: bool, message: str). """ - # Checking if Postgres intentionally disabled if not self._secret_name or not self._secret_region: return True, "not configured" try: - db_config = self._load_db_config() + pg_config = self._pg_config logger.debug("PostgreSQL config loaded during health check.") except (BotoCoreError, ClientError, ValueError, KeyError) as err: return False, str(err) - # Validate database configuration fields - if not db_config.get("database"): + if not pg_config.get("database"): return True, "database not configured" - missing_fields = [field for field in ("host", "user", "password", "port") if not db_config.get(field)] + missing_fields = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing_fields: return False, f"{missing_fields[0]} not configured" diff --git a/tests/integration/test_connection_reuse.py b/tests/integration/test_connection_reuse.py new file mode 100644 index 0000000..86da228 --- /dev/null +++ b/tests/integration/test_connection_reuse.py @@ -0,0 +1,99 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import time +import uuid + +import pytest + +from tests.integration.conftest import EventGateTestClient, EventStatsTestClient + + +def _make_test_event() -> dict: + """Build a minimal runs event payload.""" + now_ms = int(time.time() * 1000) + return { + "event_id": str(uuid.uuid4()), + "job_ref": f"conn-reuse-{uuid.uuid4().hex[:8]}", + "tenant_id": "CONN_REUSE_TEST", + "source_app": "integration-conn-reuse", + "source_app_version": "1.0.0", + "environment": "test", + "timestamp_start": now_ms - 60000, + "timestamp_end": now_ms, + "jobs": [ + { + "catalog_id": "db.schema.conn_reuse_table", + "status": "succeeded", + "timestamp_start": now_ms - 60000, + "timestamp_end": now_ms, + } + ], + } + + +class TestWriterConnectionReuse: + """Verify that WriterPostgres reuses connections across invocations.""" + + @pytest.fixture(scope="class", autouse=True) + def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: + """Post events so the writer connection is established.""" + for _ in range(2): + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + + def test_writer_reuses_connection_across_writes( + self, seed_events: None, eventgate_client: EventGateTestClient, valid_token: str + ) -> None: + """Test that subsequent writes reuse the same cached connection.""" + from src.event_gate_lambda import writers + + writer = writers["postgres"] + conn_before = writer._connection + assert conn_before is not None + assert 0 == conn_before.closed + + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + assert conn_before is writer._connection + + +class TestReaderConnectionReuse: + """Verify that ReaderPostgres reuses connections across invocations.""" + + @pytest.fixture(scope="class", autouse=True) + def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: + """Seed events so stats queries return data.""" + for _ in range(2): + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + + def test_reader_reuses_connection_across_reads(self, seed_events: None, stats_client: EventStatsTestClient) -> None: + """Test that successive queries reuse the same cached connection.""" + from src.event_stats_lambda import reader_postgres + + stats_client.post_stats("public.cps.za.runs", {}) + + conn_after_first = reader_postgres._connection + assert conn_after_first is not None + assert 0 == conn_after_first.closed + + stats_client.post_stats("public.cps.za.runs", {}) + assert conn_after_first is reader_postgres._connection diff --git a/tests/unit/readers/test_reader_postgres.py b/tests/unit/readers/test_reader_postgres.py index 124145e..c1aa958 100644 --- a/tests/unit/readers/test_reader_postgres.py +++ b/tests/unit/readers/test_reader_postgres.py @@ -22,6 +22,26 @@ import pytest from src.readers.reader_postgres import ReaderPostgres +import src.utils.postgres_base as pb + +_STATS_DESCRIPTION = [ + ("event_id",), + ("job_ref",), + ("tenant_id",), + ("source_app",), + ("source_app_version",), + ("environment",), + ("run_timestamp_start",), + ("run_timestamp_end",), + ("internal_id",), + ("country",), + ("catalog_id",), + ("status",), + ("timestamp_start",), + ("timestamp_end",), + ("message",), + ("additional_info",), +] @pytest.fixture @@ -49,17 +69,17 @@ def reader(mock_env: None) -> ReaderPostgres: return ReaderPostgres() -class TestLoadDbConfig: - """Tests for lazy database configuration loading.""" +class TestDbConfig: + """Tests for lazy database configuration loading via cached_property.""" def test_loads_config_from_secrets_manager(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: - """Test that _load_db_config loads from Secrets Manager.""" + """Test that _pg_config loads from Secrets Manager.""" mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} with patch("boto3.Session") as mock_session: mock_session.return_value.client.return_value = mock_client - result = reader._load_db_config() + result = reader._pg_config assert "eventgate" == result["database"] assert "localhost" == result["host"] @@ -71,8 +91,8 @@ def test_caches_config_after_first_load(self, reader: ReaderPostgres, pg_secret: with patch("boto3.Session") as mock_session: mock_session.return_value.client.return_value = mock_client - reader._load_db_config() - reader._load_db_config() + _ = reader._pg_config + _ = reader._pg_config assert 1 == mock_client.get_secret_value.call_count @@ -82,7 +102,7 @@ def test_returns_empty_db_when_no_env_vars(self, monkeypatch: pytest.MonkeyPatch monkeypatch.setenv("POSTGRES_SECRET_REGION", "") reader = ReaderPostgres() - result = reader._load_db_config() + result = reader._pg_config assert "" == result["database"] @@ -94,6 +114,7 @@ def _make_mock_connection(description: list[tuple[str, ...]], rows: list[tuple[A mock_cursor.fetchall.return_value = rows mock_conn = MagicMock() + mock_conn.closed = 0 mock_conn.__enter__ = MagicMock(return_value=mock_conn) mock_conn.__exit__ = MagicMock(return_value=False) mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) @@ -106,24 +127,6 @@ class TestReadStats: def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: """Test that read_stats returns rows and pagination info.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] rows = [ ( "ev1", @@ -162,11 +165,11 @@ def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: di None, ), ] - mock_conn = _make_mock_connection(description, rows) + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, rows) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -184,34 +187,16 @@ def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: di def test_has_more_when_extra_row_returned(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: """Test that has_more is True when more rows than limit exist.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] rows = [ ("ev1", "r", "T", "a", "1", "t", 0, 0, 3, "ZA", "c", "s", 0, 0, None, None), ("ev2", "r", "T", "a", "1", "t", 0, 0, 2, "ZA", "c", "s", 0, 0, None, None), ("ev3", "r", "T", "a", "1", "t", 0, 0, 1, "ZA", "c", "s", 0, 0, None, None), ] - mock_conn = _make_mock_connection(description, rows) + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, rows) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -243,36 +228,18 @@ def test_missing_connection_field_raises_runtime_error( with ( patch("boto3.Session") as mock_session, - pytest.raises(RuntimeError, match="host, user, password"), + pytest.raises(ValueError, match="Missing PostgreSQL secret fields"), ): mock_session.return_value.client.return_value = mock_client reader.read_stats() def test_cursor_filters_by_internal_id(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: - """Test that passing cursor appends internal_id condition.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] - mock_conn = _make_mock_connection(description, []) + """Test that passing cursor uses the cursor query variant.""" + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, []) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -284,9 +251,8 @@ def test_cursor_filters_by_internal_id(self, reader: ReaderPostgres, pg_secret: executed_sql = mock_conn.cursor.return_value.__enter__.return_value.execute.call_args[0][0] executed_params = mock_conn.cursor.return_value.__enter__.return_value.execute.call_args[0][1] - sql_str = executed_sql.as_string(None) if hasattr(executed_sql, "as_string") else executed_sql - assert "j.internal_id < %s" in sql_str - assert 100 in executed_params + assert "j.internal_id <" in executed_sql + assert 100 == executed_params["cursor_id"] class TestFormatRow: @@ -383,9 +349,106 @@ def test_healthy_when_config_valid(self, reader: ReaderPostgres, pg_secret: dict assert "ok" == message def test_unhealthy_when_load_raises_runtime_error(self, reader: ReaderPostgres) -> None: - """Test returns unhealthy when _load_db_config raises RuntimeError.""" - with patch.object(reader, "_load_db_config", side_effect=RuntimeError("Failed to load.")): + """Test returns unhealthy when _pg_config raises RuntimeError.""" + with patch.object( + type(reader), + "_pg_config", + new_callable=lambda: property(lambda self: (_ for _ in ()).throw(RuntimeError("Failed to load."))), + ): healthy, message = reader.check_health() assert False is healthy assert "Failed to load." == message + + +class TestConnectionReuse: + """Tests for PostgreSQL connection reuse error handling.""" + + def test_reconnects_on_closed_connection(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that a closed connection triggers reconnection.""" + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + + with ( + patch("boto3.Session") as mock_session, + patch.object(pb, "psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = mock_conn + + reader.read_stats(limit=10) + assert 1 == mock_pg.connect.call_count + + mock_conn.closed = 2 + + reader.read_stats(limit=10) + assert 2 == mock_pg.connect.call_count + + def test_retries_on_operational_error(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that OperationalError triggers retry with fresh connection.""" + fail_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + fail_cursor = fail_conn.cursor.return_value.__enter__.return_value + fail_cursor.execute.side_effect = pb.OperationalError("connection reset") + + ok_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + + with ( + patch("boto3.Session") as mock_session, + patch.object(pb, "psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.side_effect = [fail_conn, ok_conn] + + rows, _pagination = reader.read_stats(limit=10) + + assert 2 == mock_pg.connect.call_count + assert [] == rows + + def test_raises_after_retry_exhausted(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that OperationalError on both attempts raises RuntimeError.""" + fail_conn = MagicMock() + fail_conn.closed = 0 + fail_cursor = MagicMock() + fail_cursor.execute.side_effect = pb.OperationalError("connection reset") + fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) + fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch("boto3.Session") as mock_session, + patch.object(pb, "psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = fail_conn + + with pytest.raises(RuntimeError, match="Database connection failed after"): + reader.read_stats(limit=10) + + def test_discards_connection_on_non_operational_error( + self, reader: ReaderPostgres, pg_secret: dict[str, Any] + ) -> None: + """Test that a non-OperationalError PsycopgError discards the connection.""" + fail_conn = MagicMock() + fail_conn.closed = 0 + fail_cursor = MagicMock() + fail_cursor.execute.side_effect = pb.PsycopgError("integrity error") + fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) + fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch("boto3.Session") as mock_session, + patch.object(pb, "psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = fail_conn + + with pytest.raises(RuntimeError, match="Database query error"): + reader.read_stats(limit=10) + + assert reader._connection is None diff --git a/tests/unit/utils/test_postgres_base.py b/tests/unit/utils/test_postgres_base.py new file mode 100644 index 0000000..4c07176 --- /dev/null +++ b/tests/unit/utils/test_postgres_base.py @@ -0,0 +1,251 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from unittest.mock import MagicMock, patch + +import pytest + +import src.utils.postgres_base as pb +from src.utils.postgres_base import PostgresBase, _build_postgres_config + + +class _ConcreteBase(PostgresBase): + """Minimal concrete subclass for testing PostgresBase directly.""" + + +# _build_postgres_config + + +def test_build_postgres_config_full(): + raw = {"database": "mydb", "host": "localhost", "user": "admin", "password": "secret", "port": "5432"} + result = _build_postgres_config(raw) + assert "mydb" == result["database"] + assert "localhost" == result["host"] + assert "admin" == result["user"] + assert "secret" == result["password"] + assert 5432 == result["port"] + + +def test_build_postgres_config_defaults_for_empty_database(): + result = _build_postgres_config({}) + assert "" == result["database"] + assert "" == result["host"] + assert "" == result["user"] + assert "" == result["password"] + assert 0 == result["port"] + + +def test_build_postgres_config_rejects_missing_fields_when_database_set(): + with pytest.raises(ValueError, match="Missing PostgreSQL secret fields"): + _build_postgres_config({"database": "mydb"}) + + +def test_build_postgres_config_rejects_invalid_port(): + with pytest.raises(ValueError, match="Invalid PostgreSQL port"): + _build_postgres_config({"database": "mydb", "host": "h", "user": "u", "password": "p", "port": "abc"}) + + +# PostgresBase.__init__ + + +def test_init_reads_env_vars(monkeypatch): + monkeypatch.setenv("POSTGRES_SECRET_NAME", "my-secret") + monkeypatch.setenv("POSTGRES_SECRET_REGION", "eu-west-1") + instance = _ConcreteBase() + assert "my-secret" == instance._secret_name + assert "eu-west-1" == instance._secret_region + assert instance._connection is None + + +def test_init_defaults_when_env_vars_absent(monkeypatch): + monkeypatch.delenv("POSTGRES_SECRET_NAME", raising=False) + monkeypatch.delenv("POSTGRES_SECRET_REGION", raising=False) + instance = _ConcreteBase() + assert "" == instance._secret_name + assert "" == instance._secret_region + + +# _pg_config + + +def test_pg_config_is_cached(): + base = _ConcreteBase() + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret) as mock_load: + _ = base._pg_config + _ = base._pg_config + assert 1 == mock_load.call_count + + +def test_pg_config_builds_correct_values(): + base = _ConcreteBase() + secret = {"database": "testdb", "host": "dbhost", "user": "dbuser", "password": "dbpass", "port": 5433} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + config = base._pg_config + assert "testdb" == config["database"] + assert 5433 == config["port"] + + +# _connect_options + + +def test_connect_options_returns_none(): + assert _ConcreteBase()._connect_options() is None + + +# _get_connection + + +def test_get_connection_returns_cached_open_connection(): + base = _ConcreteBase() + mock_conn = MagicMock(closed=0) + base._connection = mock_conn + assert mock_conn is base._get_connection() + + +def test_get_connection_raises_when_psycopg2_not_installed(monkeypatch): + monkeypatch.setattr(pb, "psycopg2", None) + with pytest.raises(RuntimeError, match="psycopg2 is not installed"): + _ConcreteBase()._get_connection() + + +def test_get_connection_creates_new_when_closed(monkeypatch): + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = MagicMock(closed=1) + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + result = base._get_connection() + assert mock_conn is result + + +def test_get_connection_passes_options_from_subclass(monkeypatch): + class _BaseWithOptions(PostgresBase): + def _connect_options(self) -> str | None: + return "-c statement_timeout=5000" + + captured = {} + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.side_effect = lambda **kw: captured.update(kw) or MagicMock(closed=0) + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + _BaseWithOptions()._get_connection() + assert "-c statement_timeout=5000" == captured["options"] + + +# _close_connection + + +def test_close_connection_clears_reference_and_calls_close(): + base = _ConcreteBase() + mock_conn = MagicMock() + base._connection = mock_conn + base._close_connection() + mock_conn.close.assert_called_once() + assert base._connection is None + + +def test_close_connection_is_noop_when_no_connection(): + base = _ConcreteBase() + base._close_connection() + assert base._connection is None + + +def test_close_connection_silences_psycopg_error(): + base = _ConcreteBase() + mock_conn = MagicMock() + mock_conn.close.side_effect = pb.PsycopgError("gone") + base._connection = mock_conn + base._close_connection() + assert base._connection is None + + +def test_close_connection_silences_os_error(): + base = _ConcreteBase() + mock_conn = MagicMock() + mock_conn.close.side_effect = OSError("socket closed") + base._connection = mock_conn + base._close_connection() + assert base._connection is None + + +# _execute_with_retry + + +def test_execute_with_retry_returns_operation_result(): + base = _ConcreteBase() + base._connection = MagicMock(closed=0) + assert "ok" == base._execute_with_retry(lambda conn: "ok") + + +def test_execute_with_retry_reconnects_on_first_failure(monkeypatch): + calls = [] + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + + def op(conn): + calls.append(len(calls)) + if len(calls) == 1: + raise pb.OperationalError("timeout") + return "recovered" + + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + result = base._execute_with_retry(op) + + assert "recovered" == result + assert 2 == len(calls) + + +def test_execute_with_retry_raises_after_all_attempts_fail(monkeypatch): + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + with pytest.raises(RuntimeError, match="Database connection failed after"): + base._execute_with_retry(lambda conn: (_ for _ in ()).throw(pb.OperationalError("down"))) + + +def test_execute_with_retry_no_retry_fails_on_first_attempt(monkeypatch): + calls = [] + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + + def op(conn): + calls.append(1) + raise pb.OperationalError("timeout") + + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + with pytest.raises(RuntimeError, match="Database connection failed after 1 attempts"): + base._execute_with_retry(op, retry=False) + + assert 1 == len(calls) diff --git a/tests/unit/utils/test_trace_logging.py b/tests/unit/utils/test_trace_logging.py index ecaefbf..cda7ad6 100644 --- a/tests/unit/utils/test_trace_logging.py +++ b/tests/unit/utils/test_trace_logging.py @@ -18,6 +18,7 @@ from src.utils.logging_levels import TRACE_LEVEL from src.utils.trace_logging import log_payload_at_trace +import src.utils.postgres_base as postgres_base import src.writers.writer_eventbridge as writer_eventbridge import src.writers.writer_kafka as writer_kafka import src.writers.writer_postgres as writer_postgres @@ -93,6 +94,9 @@ def cursor(self): def commit(self): pass + def close(self): + pass + def __enter__(self): return self @@ -103,14 +107,14 @@ class DummyPsycopg2: def connect(self, **kwargs): return DummyConnection() - monkeypatch.setattr(writer_postgres, "psycopg2", DummyPsycopg2()) + monkeypatch.setattr(postgres_base, "psycopg2", DummyPsycopg2()) # Set trace level on the module's logger writer_postgres.logger.setLevel(TRACE_LEVEL) caplog.set_level(TRACE_LEVEL) writer = writer_postgres.WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + writer._pg_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} message = {"event_id": "e", "tenant_id": "t", "source_app": "a", "environment": "dev", "timestamp": 1} ok, err = writer.write("public.cps.za.test", message) diff --git a/tests/unit/writers/test_writer_postgres.py b/tests/unit/writers/test_writer_postgres.py index 2aa8a95..f5910f8 100644 --- a/tests/unit/writers/test_writer_postgres.py +++ b/tests/unit/writers/test_writer_postgres.py @@ -21,35 +21,58 @@ import pytest from src.writers.writer_postgres import WriterPostgres -import src.writers.writer_postgres as wp +import src.utils.postgres_base as pb import src.utils.utils as secrets_mod class MockCursor: - def __init__(self): - self.executions = [] - - @staticmethod - def _sql_to_str(sql) -> str: - """Render a psycopg2 SQL Composable without a real connection.""" - if hasattr(sql, "_seq"): - return "".join(MockCursor._sql_to_str(part) for part in sql._seq) - if hasattr(sql, "_wrapped"): - wrapped = sql._wrapped - if isinstance(wrapped, str): - return wrapped - return ".".join(f'"{s}"' for s in wrapped) - return str(sql) + def __init__(self, store=None): + self.executions = store if store is not None else [] def execute(self, sql, params): - sql_str = self._sql_to_str(sql) if hasattr(sql, "_seq") or hasattr(sql, "_wrapped") else str(sql) - self.executions.append((sql_str.strip(), params)) + self.executions.append((sql, params)) + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + +class DummyConnection: + def __init__(self, store): + self.commit_called = False + self.store = store + self.closed = 0 + + def cursor(self): + return MockCursor(self.store) -# --- Insert helpers --- + def commit(self): + self.commit_called = True + def close(self): + pass -def test_postgres_edla_write_with_optional_fields(): + +class DummyPsycopg: + def __init__(self, store): + self.store = store + self.connect_count = 0 + + def connect(self, **kwargs): + self.connect_count += 1 + return DummyConnection(self.store) + + +@pytest.fixture +def reset_env(): + yield + os.environ.pop("POSTGRES_SECRET_NAME", None) + os.environ.pop("POSTGRES_SECRET_REGION", None) + + +def test_insert_dlchange_with_optional_fields(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -67,19 +90,18 @@ def test_postgres_edla_write_with_optional_fields(): "format_options": {"compression": "snappy"}, "additional_info": {"foo": "bar"}, } - writer._postgres_edla_write(cur, "table_a", message) - assert len(cur.executions) == 1 + writer._insert_dlchange(cur, message) + assert 1 == len(cur.executions) _sql, params = cur.executions[0] - assert len(params) == 13 - assert params[0] == "e1" - assert params[6] == "za" - assert params[9] == "s3://bucket/path" - assert params[10] == "parquet" - assert json.loads(params[11]) == {"compression": "snappy"} - assert json.loads(params[12]) == {"foo": "bar"} + assert "e1" == params["event_id"] + assert "za" == params["country"] + assert "s3://bucket/path" == params["location"] + assert "parquet" == params["format"] + assert {"compression": "snappy"} == json.loads(params["format_options"]) + assert {"foo": "bar"} == json.loads(params["additional_info"]) -def test_postgres_edla_write_missing_optional(): +def test_insert_dlchange_missing_optional(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -93,16 +115,16 @@ def test_postgres_edla_write_missing_optional(): "operation": "overwrite", "format": "delta", } - writer._postgres_edla_write(cur, "table_a", message) + writer._insert_dlchange(cur, message) _sql, params = cur.executions[0] - assert params[6] == "" - assert params[9] is None - assert params[10] == "delta" - assert params[11] is None - assert params[12] is None + assert "" == params["country"] + assert params["location"] is None + assert "delta" == params["format"] + assert params["format_options"] is None + assert params["additional_info"] is None -def test_postgres_run_write(): +def test_insert_run(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -127,28 +149,24 @@ def test_postgres_run_write(): }, ], } - writer._postgres_run_write(cur, "runs_table", "jobs_table", message) - assert len(cur.executions) == 3 + writer._insert_run(cur, message) + assert 3 == len(cur.executions) - # Check run insert - run_sql, run_params = cur.executions[0] - assert "source_app_version" in run_sql - assert run_params[3] == "runapp" + _run_sql, run_params = cur.executions[0] + assert "runapp" == run_params["source_app"] - # Check first job _job1_sql, job1_params = cur.executions[1] - assert job1_params[1] == "" - assert job1_params[2] == "c1" + assert "" == job1_params["country"] + assert "c1" == job1_params["catalog_id"] - # Check second job _job2_sql, job2_params = cur.executions[2] - assert job2_params[1] == "bw" - assert job2_params[2] == "c2" - assert job2_params[6] == "err" - assert json.loads(job2_params[7]) == {"k": "v"} + assert "bw" == job2_params["country"] + assert "c2" == job2_params["catalog_id"] + assert "err" == job2_params["message"] + assert {"k": "v"} == json.loads(job2_params["additional_info"]) -def test_postgres_test_write(): +def test_insert_test(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -159,77 +177,28 @@ def test_postgres_test_write(): "timestamp": 999, "additional_info": {"a": 1}, } - writer._postgres_test_write(cur, "table_test", message) - assert len(cur.executions) == 1 + writer._insert_test(cur, message) + assert 1 == len(cur.executions) _sql, params = cur.executions[0] - assert params[0] == "t1" - assert params[1] == "tenant-x" - assert json.loads(params[5]) == {"a": 1} - - -# --- write() behavioral paths --- - - -@pytest.fixture -def reset_env(): - yield - os.environ.pop("POSTGRES_SECRET_NAME", None) - os.environ.pop("POSTGRES_SECRET_REGION", None) - - -class DummyCursor: - def __init__(self, store): - self.store = store - - def execute(self, sql, params): - self.store.append((sql, params)) + assert "t1" == params["event_id"] + assert "tenant-x" == params["tenant_id"] + assert {"a": 1} == json.loads(params["additional_info"]) - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -class DummyConnection: - def __init__(self, store): - self.commit_called = False - self.store = store - - def cursor(self): - return DummyCursor(self.store) - - def commit(self): - self.commit_called = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -class DummyPsycopg: - def __init__(self, store): - self.store = store - def connect(self, **kwargs): - return DummyConnection(self.store) - - -# --- write() --- - - -def test_write_skips_when_no_database(reset_env): +def test_write_skips_when_no_database(reset_env, monkeypatch): writer = WriterPostgres({}) - writer._db_config = {"database": ""} + monkeypatch.setattr(type(writer), "_pg_config", property(lambda self: {"database": ""})) ok, err = writer.write("public.cps.za.test", {}) assert ok and err is None -def test_write_fails_when_connection_field_missing(reset_env): +def test_write_fails_when_connection_field_missing(reset_env, monkeypatch): writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432}), + ) ok, err = writer.write("public.cps.za.test", {}) assert not ok assert "host" in err and "not configured" in err @@ -237,29 +206,41 @@ def test_write_fails_when_connection_field_missing(reset_env): def test_write_skips_when_psycopg2_missing(reset_env, monkeypatch): writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} - monkeypatch.setattr(wp, "psycopg2", None) + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) + monkeypatch.setattr(pb, "psycopg2", None) ok, err = writer.write("public.cps.za.test", {}) assert ok and err is None def test_write_unknown_topic_returns_false(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) ok, err = writer.write("public.cps.za.unknown", {}) assert not ok and "Unknown topic" in err def test_write_success_known_topic(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) - assert ok and err is None and len(store) == 1 + assert ok and err is None and 1 == len(store) def test_write_exception_returns_false(reset_env, monkeypatch): @@ -267,9 +248,13 @@ class FailingPsycopg: def connect(self, **kwargs): raise RuntimeError("boom") - monkeypatch.setattr(wp, "psycopg2", FailingPsycopg()) + monkeypatch.setattr(pb, "psycopg2", FailingPsycopg()) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) ok, err = writer.write("public.cps.za.test", {}) assert not ok and "failed with unknown error" in err @@ -286,17 +271,20 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", lambda: MockSession()) writer = WriterPostgres({}) - assert writer._db_config is None - # Trigger lazy load via check_health + assert "_pg_config" not in writer.__dict__ writer.check_health() - assert writer._db_config == secret_dict + assert "db" == writer._pg_config["database"] def test_write_dlchange_success(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) message = { "event_id": "e1", "tenant_id": "t", @@ -309,14 +297,18 @@ def test_write_dlchange_success(reset_env, monkeypatch): "format": "parquet", } ok, err = writer.write("public.cps.za.dlchange", message) - assert ok and err is None and len(store) == 1 + assert ok and err is None and 1 == len(store) def test_write_runs_success(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) message = { "event_id": "r1", "job_ref": "job", @@ -329,17 +321,13 @@ def test_write_runs_success(reset_env, monkeypatch): "jobs": [{"catalog_id": "c", "status": "ok", "timestamp_start": 1, "timestamp_end": 2}], } ok, err = writer.write("public.cps.za.runs", message) - assert ok and err is None and len(store) == 2 # run + job insert - - -# --- check_health() --- + assert ok and err is None and 2 == len(store) # run + job insert def test_check_health_not_configured(): - # No secret env vars set, so it's "not configured" writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert healthy and msg == "not configured" + assert healthy and "not configured" == msg def test_check_health_success(reset_env, monkeypatch): @@ -355,7 +343,7 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", MockSession) writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert healthy and msg == "ok" + assert healthy and "ok" == msg def test_check_health_missing_host(reset_env, monkeypatch): @@ -371,27 +359,193 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", MockSession) writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert not healthy and "host not configured" in msg + assert not healthy and "Missing PostgreSQL secret fields" in msg -def test_check_health_database_not_configured(): +def test_check_health_database_not_configured(monkeypatch): """check_health returns (True, 'database not configured') when database field is empty.""" writer = WriterPostgres({}) writer._secret_name = "mysecret" writer._secret_region = "eu-west-1" - writer._db_config = {"database": ""} + monkeypatch.setattr(type(writer), "_pg_config", property(lambda self: {"database": ""})) healthy, msg = writer.check_health() assert healthy assert "database not configured" == msg def test_check_health_load_config_exception(mocker): - """check_health returns (False, error) when _load_db_config raises.""" + """check_health returns (False, error) when _pg_config raises.""" writer = WriterPostgres({}) writer._secret_name = "mysecret" writer._secret_region = "eu-west-1" - mocker.patch.object(writer, "_load_db_config", side_effect=ValueError("secret fetch failed")) + mocker.patch.object( + type(writer), + "_pg_config", + new_callable=lambda: property(lambda self: (_ for _ in ()).throw(ValueError("secret fetch failed"))), + ) healthy, msg = writer.check_health() assert not healthy assert "secret fetch failed" == msg + + +def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): + store = [] + psycopg = DummyPsycopg(store) + monkeypatch.setattr(pb, "psycopg2", psycopg) + writer = WriterPostgres({}) + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + writer.write("public.cps.za.test", message) + assert 1 == psycopg.connect_count + + writer._connection.closed = 2 + + writer.write("public.cps.za.test", message) + assert 2 == psycopg.connect_count + + +def test_write_does_not_retry_on_operational_error(reset_env, monkeypatch): + store = [] + fail_flag = [True] + + class RetryCursor: + def execute(self, sql, params): + if fail_flag[0]: + fail_flag[0] = False + raise pb.OperationalError("connection reset") + store.append((sql, params)) + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class RetryConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return RetryCursor() + + def commit(self): + pass + + def close(self): + pass + + class RetryPsycopg: + def __init__(self): + self.connect_count = 0 + + def connect(self, **kwargs): + self.connect_count += 1 + return RetryConnection() + + psycopg = RetryPsycopg() + monkeypatch.setattr(pb, "psycopg2", psycopg) + writer = WriterPostgres({}) + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, err = writer.write("public.cps.za.test", message) + + assert not ok + assert "failed with unknown error" in err + assert 1 == psycopg.connect_count + + +def test_write_fails_after_retry_exhausted(reset_env, monkeypatch): + class AlwaysFailCursor: + def execute(self, sql, params): + raise pb.OperationalError("connection reset") + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class AlwaysFailConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return AlwaysFailCursor() + + def commit(self): + pass + + def close(self): + pass + + class AlwaysFailPsycopg: + def connect(self, **kwargs): + return AlwaysFailConnection() + + monkeypatch.setattr(pb, "psycopg2", AlwaysFailPsycopg()) + writer = WriterPostgres({}) + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, err = writer.write("public.cps.za.test", message) + + assert not ok + assert "failed with unknown error" in err + + +def test_write_discards_connection_on_non_operational_error(reset_env, monkeypatch): + class FailCursor: + def execute(self, sql, params): + raise pb.PsycopgError("integrity error") + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class FailConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return FailCursor() + + def commit(self): + pass + + def close(self): + pass + + class FailPsycopg: + def connect(self, **kwargs): + return FailConnection() + + monkeypatch.setattr(pb, "psycopg2", FailPsycopg()) + writer = WriterPostgres({}) + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), + ) + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, _ = writer.write("public.cps.za.test", message) + + assert not ok + assert writer._connection is None