From 452008dfab07d74ce93cb4f944aee4b992569e1e Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Mon, 13 Apr 2026 15:14:24 +0200 Subject: [PATCH 1/7] PostgreSQL connection pooling --- .github/copilot-instructions.md | 1 + .github/dependabot.yml | 8 + src/readers/reader_postgres.py | 50 ++++-- src/writers/writer_postgres.py | 54 +++++-- tests/integration/test_connection_reuse.py | 99 ++++++++++++ tests/unit/readers/test_reader_postgres.py | 174 ++++++++++++++------- tests/unit/writers/test_writer_postgres.py | 146 +++++++++++++++++ 7 files changed, 446 insertions(+), 86 deletions(-) create mode 100644 tests/integration/test_connection_reuse.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 884b8f5..96ffbe5 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -31,6 +31,7 @@ Python style Patterns - `__init__` methods must not raise exceptions; defer validation and connection to first use (lazy init) - Writers: inherit from `Writer(ABC)`, implement `write(topic, message) -> (bool, str|None)` and `check_health() -> (bool, str)` +- PostgreSQL: `WriterPostgres` and `ReaderPostgres` cache a single connection per instance - Route dispatch via `ROUTE_MAP` dict mapping routes to handler functions in `event_gate_lambda.py` and `event_stats_lambda.py` - Separate business logic from environment access (env vars, file I/O, network calls) - No duplicate validation; centralize parsing in one layer where practical diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 68e0c70..e32ab77 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,6 +14,10 @@ updates: commit-message: prefix: "chore" include: "scope" + groups: + github-actions: + patterns: + - "*" - package-ecosystem: "pip" directory: "/" @@ -31,3 +35,7 @@ updates: include: "scope" allow: - dependency-type: "direct" + groups: + python-dependencies: + patterns: + - "*" diff --git a/src/readers/reader_postgres.py b/src/readers/reader_postgres.py index 7c5de0c..dd7129f 100644 --- a/src/readers/reader_postgres.py +++ b/src/readers/reader_postgres.py @@ -34,6 +34,7 @@ try: import psycopg2 from psycopg2 import Error as PsycopgError + from psycopg2 import OperationalError from psycopg2 import sql as psycopg2_sql except ImportError: psycopg2 = None # type: ignore @@ -42,6 +43,9 @@ class PsycopgError(Exception): # type: ignore """Shim psycopg2 error base when psycopg2 is not installed.""" + class OperationalError(PsycopgError): # type: ignore + """Shim psycopg2 OperationalError when psycopg2 is not installed.""" + logger = logging.getLogger(__name__) @@ -70,6 +74,7 @@ def __init__(self) -> None: self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") self._db_config: dict[str, Any] | None = None + self._connection: Any | None = None logger.debug("Initialized PostgreSQL reader.") def _load_db_config(self) -> dict[str, Any]: @@ -81,6 +86,22 @@ def _load_db_config(self) -> dict[str, Any]: raise RuntimeError("Failed to load database configuration.") return config + def _get_connection(self) -> Any: + """Return a cached database connection, creating one if needed.""" + if self._connection is not None and not self._connection.closed: + return self._connection + db_config = self._load_db_config() + self._connection = psycopg2.connect( # type: ignore[attr-defined] + database=db_config["database"], + host=db_config["host"], + user=db_config["user"], + password=db_config["password"], + port=db_config["port"], + options="-c statement_timeout=30000 -c default_transaction_read_only=on", + ) + logger.debug("New PostgreSQL reader connection established.") + return self._connection + def read_stats( self, timestamp_start: int | None = None, @@ -124,20 +145,23 @@ def read_stats( params.append(limit + 1) try: - with psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], - options="-c statement_timeout=30000 -c default_transaction_read_only=on", - ) as connection: - with connection.cursor() as db_cursor: - db_cursor.execute(query, params) - col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] - raw_rows = db_cursor.fetchall() + for attempt in range(2): + try: + connection = self._get_connection() + with connection.cursor() as db_cursor: + db_cursor.execute(query, params) + col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] + raw_rows = db_cursor.fetchall() + connection.rollback() + break + except OperationalError as exc: + self._connection = None + if attempt > 0: + raise RuntimeError(f"Database connection failed after retry: {exc}") from exc + logger.warning("PostgreSQL connection lost, reconnecting.") except PsycopgError as exc: - raise RuntimeError(f"Database query failed: {exc}") from exc + self._connection = None + raise RuntimeError(f"Database query error: {exc}") from exc rows = [dict(zip(col_names, row, strict=True)) for row in raw_rows] diff --git a/src/writers/writer_postgres.py b/src/writers/writer_postgres.py index c0533b1..45f79b6 100644 --- a/src/writers/writer_postgres.py +++ b/src/writers/writer_postgres.py @@ -31,6 +31,7 @@ try: import psycopg2 from psycopg2 import Error as PsycopgError + from psycopg2 import OperationalError from psycopg2 import sql as psycopg2_sql except ImportError: psycopg2 = None # type: ignore @@ -39,6 +40,9 @@ class PsycopgError(Exception): # type: ignore """Shim psycopg2 error base when psycopg2 is not installed.""" + class OperationalError(PsycopgError): # type: ignore + """Shim psycopg2 OperationalError when psycopg2 is not installed.""" + logger = logging.getLogger(__name__) @@ -53,6 +57,7 @@ def __init__(self, config: dict[str, Any]) -> None: self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") self._db_config: dict[str, Any | None] | None = None + self._connection: Any | None = None logger.debug("Initialized PostgreSQL writer.") def _load_db_config(self) -> dict[str, Any]: @@ -61,6 +66,21 @@ def _load_db_config(self) -> dict[str, Any]: self._db_config = load_postgres_config(self._secret_name, self._secret_region) return self._db_config # type: ignore[return-value] + def _get_connection(self) -> Any: + """Return a cached database connection, creating one if needed.""" + if self._connection is not None and not self._connection.closed: + return self._connection + db_config = self._load_db_config() + self._connection = psycopg2.connect( # type: ignore[attr-defined] + database=db_config["database"], + host=db_config["host"], + user=db_config["user"], + password=db_config["password"], + port=db_config["port"], + ) + logger.debug("New PostgreSQL writer connection established.") + return self._connection + def _postgres_edla_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: """Insert a dlchange style event row. Args: @@ -278,23 +298,25 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N table_info = TOPIC_TABLE_MAP[topic_name] - with psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], - ) as connection: - with connection.cursor() as cursor: - if topic_name == "public.cps.za.dlchange": - self._postgres_edla_write(cursor, table_info["main"], message) - elif topic_name == "public.cps.za.runs": - self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) - elif topic_name == "public.cps.za.test": - self._postgres_test_write(cursor, table_info["main"], message) - - connection.commit() + for attempt in range(2): + try: + connection = self._get_connection() + with connection.cursor() as cursor: + if topic_name == "public.cps.za.dlchange": + self._postgres_edla_write(cursor, table_info["main"], message) + elif topic_name == "public.cps.za.runs": + self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) + elif topic_name == "public.cps.za.test": + self._postgres_test_write(cursor, table_info["main"], message) + connection.commit() + break + except OperationalError: + self._connection = None + if attempt > 0: + raise + logger.warning("PostgreSQL connection lost, reconnecting.") except (RuntimeError, PsycopgError, BotoCoreError, ClientError, ValueError, KeyError) as e: + self._connection = None err_msg = f"The Postgres writer failed with unknown error: {str(e)}" logger.exception(err_msg) return False, err_msg diff --git a/tests/integration/test_connection_reuse.py b/tests/integration/test_connection_reuse.py new file mode 100644 index 0000000..86da228 --- /dev/null +++ b/tests/integration/test_connection_reuse.py @@ -0,0 +1,99 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import time +import uuid + +import pytest + +from tests.integration.conftest import EventGateTestClient, EventStatsTestClient + + +def _make_test_event() -> dict: + """Build a minimal runs event payload.""" + now_ms = int(time.time() * 1000) + return { + "event_id": str(uuid.uuid4()), + "job_ref": f"conn-reuse-{uuid.uuid4().hex[:8]}", + "tenant_id": "CONN_REUSE_TEST", + "source_app": "integration-conn-reuse", + "source_app_version": "1.0.0", + "environment": "test", + "timestamp_start": now_ms - 60000, + "timestamp_end": now_ms, + "jobs": [ + { + "catalog_id": "db.schema.conn_reuse_table", + "status": "succeeded", + "timestamp_start": now_ms - 60000, + "timestamp_end": now_ms, + } + ], + } + + +class TestWriterConnectionReuse: + """Verify that WriterPostgres reuses connections across invocations.""" + + @pytest.fixture(scope="class", autouse=True) + def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: + """Post events so the writer connection is established.""" + for _ in range(2): + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + + def test_writer_reuses_connection_across_writes( + self, seed_events: None, eventgate_client: EventGateTestClient, valid_token: str + ) -> None: + """Test that subsequent writes reuse the same cached connection.""" + from src.event_gate_lambda import writers + + writer = writers["postgres"] + conn_before = writer._connection + assert conn_before is not None + assert 0 == conn_before.closed + + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + assert conn_before is writer._connection + + +class TestReaderConnectionReuse: + """Verify that ReaderPostgres reuses connections across invocations.""" + + @pytest.fixture(scope="class", autouse=True) + def seed_events(self, eventgate_client: EventGateTestClient, valid_token: str) -> None: + """Seed events so stats queries return data.""" + for _ in range(2): + event = _make_test_event() + response = eventgate_client.post_event("public.cps.za.runs", event, token=valid_token) + assert 202 == response["statusCode"] + + def test_reader_reuses_connection_across_reads(self, seed_events: None, stats_client: EventStatsTestClient) -> None: + """Test that successive queries reuse the same cached connection.""" + from src.event_stats_lambda import reader_postgres + + stats_client.post_stats("public.cps.za.runs", {}) + + conn_after_first = reader_postgres._connection + assert conn_after_first is not None + assert 0 == conn_after_first.closed + + stats_client.post_stats("public.cps.za.runs", {}) + assert conn_after_first is reader_postgres._connection diff --git a/tests/unit/readers/test_reader_postgres.py b/tests/unit/readers/test_reader_postgres.py index 124145e..bab7c28 100644 --- a/tests/unit/readers/test_reader_postgres.py +++ b/tests/unit/readers/test_reader_postgres.py @@ -22,6 +22,26 @@ import pytest from src.readers.reader_postgres import ReaderPostgres +import src.readers.reader_postgres as rp + +_STATS_DESCRIPTION = [ + ("event_id",), + ("job_ref",), + ("tenant_id",), + ("source_app",), + ("source_app_version",), + ("environment",), + ("run_timestamp_start",), + ("run_timestamp_end",), + ("internal_id",), + ("country",), + ("catalog_id",), + ("status",), + ("timestamp_start",), + ("timestamp_end",), + ("message",), + ("additional_info",), +] @pytest.fixture @@ -94,6 +114,7 @@ def _make_mock_connection(description: list[tuple[str, ...]], rows: list[tuple[A mock_cursor.fetchall.return_value = rows mock_conn = MagicMock() + mock_conn.closed = 0 mock_conn.__enter__ = MagicMock(return_value=mock_conn) mock_conn.__exit__ = MagicMock(return_value=False) mock_conn.cursor.return_value.__enter__ = MagicMock(return_value=mock_cursor) @@ -106,24 +127,6 @@ class TestReadStats: def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: """Test that read_stats returns rows and pagination info.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] rows = [ ( "ev1", @@ -162,7 +165,7 @@ def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: di None, ), ] - mock_conn = _make_mock_connection(description, rows) + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, rows) with ( patch("boto3.Session") as mock_session, @@ -184,30 +187,12 @@ def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: di def test_has_more_when_extra_row_returned(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: """Test that has_more is True when more rows than limit exist.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] rows = [ ("ev1", "r", "T", "a", "1", "t", 0, 0, 3, "ZA", "c", "s", 0, 0, None, None), ("ev2", "r", "T", "a", "1", "t", 0, 0, 2, "ZA", "c", "s", 0, 0, None, None), ("ev3", "r", "T", "a", "1", "t", 0, 0, 1, "ZA", "c", "s", 0, 0, None, None), ] - mock_conn = _make_mock_connection(description, rows) + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, rows) with ( patch("boto3.Session") as mock_session, @@ -250,25 +235,7 @@ def test_missing_connection_field_raises_runtime_error( def test_cursor_filters_by_internal_id(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: """Test that passing cursor appends internal_id condition.""" - description = [ - ("event_id",), - ("job_ref",), - ("tenant_id",), - ("source_app",), - ("source_app_version",), - ("environment",), - ("run_timestamp_start",), - ("run_timestamp_end",), - ("internal_id",), - ("country",), - ("catalog_id",), - ("status",), - ("timestamp_start",), - ("timestamp_end",), - ("message",), - ("additional_info",), - ] - mock_conn = _make_mock_connection(description, []) + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, []) with ( patch("boto3.Session") as mock_session, @@ -389,3 +356,96 @@ def test_unhealthy_when_load_raises_runtime_error(self, reader: ReaderPostgres) assert False is healthy assert "Failed to load." == message + + +class TestConnectionReuse: + """Tests for PostgreSQL connection reuse error handling.""" + + def test_reconnects_on_closed_connection(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that a closed connection triggers reconnection.""" + mock_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + + with ( + patch("boto3.Session") as mock_session, + patch("src.readers.reader_postgres.psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = mock_conn + + reader.read_stats(limit=10) + assert 1 == mock_pg.connect.call_count + + mock_conn.closed = 2 + + reader.read_stats(limit=10) + assert 2 == mock_pg.connect.call_count + + def test_retries_on_operational_error(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that OperationalError triggers retry with fresh connection.""" + fail_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + fail_cursor = fail_conn.cursor.return_value.__enter__.return_value + fail_cursor.execute.side_effect = rp.OperationalError("connection reset") + + ok_conn = _make_mock_connection(_STATS_DESCRIPTION, []) + + with ( + patch("boto3.Session") as mock_session, + patch("src.readers.reader_postgres.psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.side_effect = [fail_conn, ok_conn] + + rows, pagination = reader.read_stats(limit=10) + + assert 2 == mock_pg.connect.call_count + assert [] == rows + + def test_raises_after_retry_exhausted(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: + """Test that OperationalError on both attempts raises RuntimeError.""" + fail_conn = MagicMock() + fail_conn.closed = 0 + fail_cursor = MagicMock() + fail_cursor.execute.side_effect = rp.OperationalError("connection reset") + fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) + fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch("boto3.Session") as mock_session, + patch("src.readers.reader_postgres.psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = fail_conn + + with pytest.raises(RuntimeError, match="Database connection failed after retry"): + reader.read_stats(limit=10) + + def test_discards_connection_on_non_operational_error( + self, reader: ReaderPostgres, pg_secret: dict[str, Any] + ) -> None: + """Test that a non-OperationalError PsycopgError discards the connection.""" + fail_conn = MagicMock() + fail_conn.closed = 0 + fail_cursor = MagicMock() + fail_cursor.execute.side_effect = rp.PsycopgError("integrity error") + fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) + fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) + + with ( + patch("boto3.Session") as mock_session, + patch("src.readers.reader_postgres.psycopg2") as mock_pg, + ): + mock_client = MagicMock() + mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} + mock_session.return_value.client.return_value = mock_client + mock_pg.connect.return_value = fail_conn + + with pytest.raises(RuntimeError, match="Database query error"): + reader.read_stats(limit=10) + + assert reader._connection is None diff --git a/tests/unit/writers/test_writer_postgres.py b/tests/unit/writers/test_writer_postgres.py index 2aa8a95..27a1a03 100644 --- a/tests/unit/writers/test_writer_postgres.py +++ b/tests/unit/writers/test_writer_postgres.py @@ -195,6 +195,7 @@ class DummyConnection: def __init__(self, store): self.commit_called = False self.store = store + self.closed = 0 def cursor(self): return DummyCursor(self.store) @@ -212,8 +213,10 @@ def __exit__(self, exc_type, exc, tb): class DummyPsycopg: def __init__(self, store): self.store = store + self.connect_count = 0 def connect(self, **kwargs): + self.connect_count += 1 return DummyConnection(self.store) @@ -395,3 +398,146 @@ def test_check_health_load_config_exception(mocker): healthy, msg = writer.check_health() assert not healthy assert "secret fetch failed" == msg + + +# --- connection reuse --- + + +def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): + store = [] + psycopg = DummyPsycopg(store) + monkeypatch.setattr(wp, "psycopg2", psycopg) + writer = WriterPostgres({}) + writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + writer.write("public.cps.za.test", message) + assert 1 == psycopg.connect_count + + writer._connection.closed = 2 + + writer.write("public.cps.za.test", message) + assert 2 == psycopg.connect_count + + +def test_write_retries_on_operational_error(reset_env, monkeypatch): + store = [] + fail_flag = [True] + + class RetryCursor: + def __init__(self): + pass + + def execute(self, sql, params): + if fail_flag[0]: + fail_flag[0] = False + raise wp.OperationalError("connection reset") + store.append((sql, params)) + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class RetryConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return RetryCursor() + + def commit(self): + pass + + class RetryPsycopg: + def __init__(self): + self.connect_count = 0 + + def connect(self, **kwargs): + self.connect_count += 1 + return RetryConnection() + + psycopg = RetryPsycopg() + monkeypatch.setattr(wp, "psycopg2", psycopg) + writer = WriterPostgres({}) + writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, err = writer.write("public.cps.za.test", message) + + assert ok and err is None + assert 2 == psycopg.connect_count + assert 1 == len(store) + + +def test_write_fails_after_retry_exhausted(reset_env, monkeypatch): + class AlwaysFailCursor: + def execute(self, sql, params): + raise wp.OperationalError("connection reset") + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class AlwaysFailConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return AlwaysFailCursor() + + def commit(self): + pass + + class AlwaysFailPsycopg: + def connect(self, **kwargs): + return AlwaysFailConnection() + + monkeypatch.setattr(wp, "psycopg2", AlwaysFailPsycopg()) + writer = WriterPostgres({}) + writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, err = writer.write("public.cps.za.test", message) + + assert not ok + assert "failed with unknown error" in err + + +def test_write_discards_connection_on_non_operational_error(reset_env, monkeypatch): + class FailCursor: + def execute(self, sql, params): + raise wp.PsycopgError("integrity error") + + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + class FailConnection: + def __init__(self): + self.closed = 0 + + def cursor(self): + return FailCursor() + + def commit(self): + pass + + class FailPsycopg: + def connect(self, **kwargs): + return FailConnection() + + monkeypatch.setattr(wp, "psycopg2", FailPsycopg()) + writer = WriterPostgres({}) + writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} + + ok, _ = writer.write("public.cps.za.test", message) + + assert not ok + assert writer._connection is None From 731d2592581242e7a389d238d52e27e976dbd07e Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 09:32:21 +0200 Subject: [PATCH 2/7] Updating constants.py --- src/handlers/handler_topic.py | 7 ++++--- src/utils/config_loader.py | 8 +++++--- src/utils/constants.py | 30 +++++++++++++++++++++++------- 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/handlers/handler_topic.py b/src/handlers/handler_topic.py index a0e35f0..7eae543 100644 --- a/src/handlers/handler_topic.py +++ b/src/handlers/handler_topic.py @@ -29,6 +29,7 @@ from src.handlers.handler_token import HandlerToken from src.utils.conf_path import CONF_DIR from src.utils.config_loader import load_access_config +from src.utils.constants import TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TEST from src.utils.utils import build_error_response from src.writers.writer import Writer @@ -69,11 +70,11 @@ def with_load_topic_schemas(self) -> "HandlerTopic": logger.debug("Loading topic schemas from %s.", topic_schemas_dir) with open(os.path.join(topic_schemas_dir, "runs.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.runs"] = json.load(file) + self.topics[TOPIC_RUNS] = json.load(file) with open(os.path.join(topic_schemas_dir, "dlchange.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.dlchange"] = json.load(file) + self.topics[TOPIC_DLCHANGE] = json.load(file) with open(os.path.join(topic_schemas_dir, "test.json"), "r", encoding="utf-8") as file: - self.topics["public.cps.za.test"] = json.load(file) + self.topics[TOPIC_TEST] = json.load(file) logger.debug("Loaded topic schemas successfully.") return self diff --git a/src/utils/config_loader.py b/src/utils/config_loader.py index a233b4a..71cdca7 100644 --- a/src/utils/config_loader.py +++ b/src/utils/config_loader.py @@ -23,6 +23,8 @@ from boto3.resources.base import ServiceResource +from src.utils.constants import TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TEST + logger = logging.getLogger(__name__) @@ -76,9 +78,9 @@ def load_topic_names(conf_dir: str) -> list[str]: List of topic name strings. """ filename_to_topic = { - "runs.json": "public.cps.za.runs", - "dlchange.json": "public.cps.za.dlchange", - "test.json": "public.cps.za.test", + "runs.json": TOPIC_RUNS, + "dlchange.json": TOPIC_DLCHANGE, + "test.json": TOPIC_TEST, } schemas_dir = os.path.join(conf_dir, "topic_schemas") topics: list[str] = [] diff --git a/src/utils/constants.py b/src/utils/constants.py index fa29fce..5d94c4b 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -16,7 +16,7 @@ """Constants and enums used across the project.""" -from typing import Any +from typing import TypedDict # Configuration keys TOKEN_PROVIDER_URL_KEY = "token_provider_url" @@ -24,16 +24,32 @@ TOKEN_PUBLIC_KEYS_URL_KEY = "token_public_keys_url" SSL_CA_BUNDLE_KEY = "ssl_ca_bundle" +# Postgres connection +POSTGRES_STATEMENT_TIMEOUT_MS = 30000 +POSTGRES_MAX_RETRIES = 2 + # Postgres stats defaults POSTGRES_DEFAULT_LIMIT = 50 POSTGRES_MAX_LIMIT = 1000 POSTGRES_DEFAULT_WINDOW_MS = 7 * 24 * 60 * 60 * 1000 # 7 days in milliseconds -SUPPORTED_TOPICS: list[str] = ["public.cps.za.runs"] +# Topic name constants +TOPIC_RUNS = "public.cps.za.runs" +TOPIC_DLCHANGE = "public.cps.za.dlchange" +TOPIC_TEST = "public.cps.za.test" + +SUPPORTED_TOPICS: list[str] = [TOPIC_RUNS] + + +class TopicTableConfig(TypedDict, total=False): + """Structure describing a topic's PostgreSQL table mapping.""" + main: str + jobs: str + columns: dict[str, list[str]] + -# Maps topic names to their PostgreSQL table(s) -TOPIC_TABLE_MAP: dict[str, dict[str, Any]] = { - "public.cps.za.runs": { +TOPIC_TABLE_MAP: dict[str, TopicTableConfig] = { + TOPIC_RUNS: { "main": "public_cps_za_runs", "jobs": "public_cps_za_runs_jobs", "columns": { @@ -60,7 +76,7 @@ ], }, }, - "public.cps.za.dlchange": { + TOPIC_DLCHANGE: { "main": "public_cps_za_dlchange", "columns": { "main": [ @@ -80,7 +96,7 @@ ], }, }, - "public.cps.za.test": { + TOPIC_TEST: { "main": "public_cps_za_test", "columns": { "main": [ From f9e9f296b585335c0fb7e66c84b58a4e39a17b31 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 10:35:30 +0200 Subject: [PATCH 3/7] SQL queries. --- src/readers/sql/stats.sql | 28 ++++++++++++++++++++++++++++ src/writers/sql/inserts.sql | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/readers/sql/stats.sql create mode 100644 src/writers/sql/inserts.sql diff --git a/src/readers/sql/stats.sql b/src/readers/sql/stats.sql new file mode 100644 index 0000000..2f40afd --- /dev/null +++ b/src/readers/sql/stats.sql @@ -0,0 +1,28 @@ +-- name: get_stats(ts_start, ts_end, lim) +-- Get run/job statistics with keyset pagination. +SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app, + r.source_app_version, r.environment, + r.timestamp_start AS run_timestamp_start, + r.timestamp_end AS run_timestamp_end, + j.internal_id, j.country, j.catalog_id, j.status, + j.timestamp_start, j.timestamp_end, j.message, j.additional_info + FROM public_cps_za_runs_jobs j + INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id + WHERE r.timestamp_start >= :ts_start AND r.timestamp_start <= :ts_end + ORDER BY j.internal_id DESC + LIMIT :lim; + +-- name: get_stats_with_cursor(ts_start, ts_end, cursor_id, lim) +-- Get run/job statistics with cursor-based keyset pagination. +SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app, + r.source_app_version, r.environment, + r.timestamp_start AS run_timestamp_start, + r.timestamp_end AS run_timestamp_end, + j.internal_id, j.country, j.catalog_id, j.status, + j.timestamp_start, j.timestamp_end, j.message, j.additional_info + FROM public_cps_za_runs_jobs j + INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id + WHERE r.timestamp_start >= :ts_start AND r.timestamp_start <= :ts_end + AND j.internal_id < :cursor_id + ORDER BY j.internal_id DESC + LIMIT :lim; diff --git a/src/writers/sql/inserts.sql b/src/writers/sql/inserts.sql new file mode 100644 index 0000000..f6db68d --- /dev/null +++ b/src/writers/sql/inserts.sql @@ -0,0 +1,37 @@ +-- name: insert_dlchange(event_id, tenant_id, source_app, source_app_version, environment, timestamp_event, country, catalog_id, operation, location, format, format_options, additional_info)! +-- Insert a dlchange-style event row. +INSERT INTO public_cps_za_dlchange + (event_id, tenant_id, source_app, source_app_version, environment, + timestamp_event, country, catalog_id, operation, + "location", "format", format_options, additional_info) +VALUES + (:event_id, :tenant_id, :source_app, :source_app_version, :environment, + :timestamp_event, :country, :catalog_id, :operation, + :location, :format, :format_options, :additional_info); + +-- name: insert_run(event_id, job_ref, tenant_id, source_app, source_app_version, environment, timestamp_start, timestamp_end)! +-- Insert a run event row. +INSERT INTO public_cps_za_runs + (event_id, job_ref, tenant_id, source_app, source_app_version, + environment, timestamp_start, timestamp_end) +VALUES + (:event_id, :job_ref, :tenant_id, :source_app, :source_app_version, + :environment, :timestamp_start, :timestamp_end); + +-- name: insert_run_job(event_id, country, catalog_id, status, timestamp_start, timestamp_end, message, additional_info)! +-- Insert a run job row. +INSERT INTO public_cps_za_runs_jobs + (event_id, country, catalog_id, status, + timestamp_start, timestamp_end, message, additional_info) +VALUES + (:event_id, :country, :catalog_id, :status, + :timestamp_start, :timestamp_end, :message, :additional_info); + +-- name: insert_test(event_id, tenant_id, source_app, environment, timestamp_event, additional_info)! +-- Insert a test event row. +INSERT INTO public_cps_za_test + (event_id, tenant_id, source_app, environment, + timestamp_event, additional_info) +VALUES + (:event_id, :tenant_id, :source_app, :environment, + :timestamp_event, :additional_info); From 9d4cba7050388558c8ca9a33b4f458a22d6a605d Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 12:01:35 +0200 Subject: [PATCH 4/7] postgres_base.py --- .github/copilot-instructions.md | 4 + src/utils/constants.py | 2 + src/utils/postgres_base.py | 142 +++++++++++++++++++++++++ src/utils/utils.py | 4 +- tests/unit/utils/test_trace_logging.py | 8 +- 5 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 src/utils/postgres_base.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 96ffbe5..5fabdf0 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -51,3 +51,7 @@ Testing Quality gates (run after changes, fix only if below threshold) - Run all quality gates at once: `make qa` - Once a quality gate passes, do not re-run it in different scenarios + +Git workflow +- Do NOT create git commits; committing is the developer's responsibility + diff --git a/src/utils/constants.py b/src/utils/constants.py index 5d94c4b..8ba044b 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -27,6 +27,7 @@ # Postgres connection POSTGRES_STATEMENT_TIMEOUT_MS = 30000 POSTGRES_MAX_RETRIES = 2 +REQUIRED_CONNECTION_FIELDS = ("host", "user", "password", "port") # Postgres stats defaults POSTGRES_DEFAULT_LIMIT = 50 @@ -43,6 +44,7 @@ class TopicTableConfig(TypedDict, total=False): """Structure describing a topic's PostgreSQL table mapping.""" + main: str jobs: str columns: dict[str, list[str]] diff --git a/src/utils/postgres_base.py b/src/utils/postgres_base.py new file mode 100644 index 0000000..4e46c4d --- /dev/null +++ b/src/utils/postgres_base.py @@ -0,0 +1,142 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Shared base class for PostgreSQL reader and writer.""" + +import logging +import os +from collections.abc import Callable +from functools import cached_property +from typing import Any, TypedDict + +from src.utils.constants import POSTGRES_MAX_RETRIES +from src.utils.utils import load_postgres_config + +try: + import psycopg2 + from psycopg2 import Error as PsycopgError + from psycopg2 import OperationalError +except ImportError: + psycopg2 = None # type: ignore + + class PsycopgError(Exception): # type: ignore + """Shim psycopg2 error base when psycopg2 is not installed.""" + + class OperationalError(PsycopgError): # type: ignore + """Shim psycopg2 OperationalError when psycopg2 is not installed.""" + + +logger = logging.getLogger(__name__) + + +class PostgresConfig(TypedDict): + """PostgreSQL connection configuration.""" + + database: str + host: str + user: str + password: str + port: int + + +def _build_postgres_config(aws_secret: dict[str, Any]) -> PostgresConfig: + """Validate and build a `PostgresConfig` from an AWS Secrets Manager dict. + Args: + aws_secret: Dictionary loaded from AWS Secrets Manager. + Returns: + A validated `PostgresConfig`. + """ + return PostgresConfig( + database=str(aws_secret.get("database", "")), + host=str(aws_secret.get("host", "")), + user=str(aws_secret.get("user", "")), + password=str(aws_secret.get("password", "")), + port=int(aws_secret.get("port", 0)), + ) + + +class PostgresBase: + """Shared base for PostgreSQL reader and writer.""" + + def __init__(self) -> None: + self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") + self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") + # Any because psycopg2.extensions.connection is unavailable when psycopg2 is not installed. + self._connection: Any | None = None + + @cached_property + def _pg_config(self) -> PostgresConfig: + """Load database config from AWS Secrets Manager on first access.""" + aws_secret = load_postgres_config(self._secret_name, self._secret_region) + return _build_postgres_config(aws_secret) + + def _connect_options(self) -> str | None: + """Return psycopg2 connection `options` string. + Reader overrides this to inject connection-level settings. + """ + return None + + def _get_connection(self) -> Any: + """Return a cached database connection, creating one if needed.""" + if self._connection is not None and not self._connection.closed: + return self._connection + if psycopg2 is None: + raise RuntimeError("psycopg2 is not installed.") + pg_config = self._pg_config + connect_kwargs: dict[str, str | int] = { + "database": pg_config["database"], + "host": pg_config["host"], + "user": pg_config["user"], + "password": pg_config["password"], + "port": pg_config["port"], + } + options = self._connect_options() + if options: + connect_kwargs["options"] = options + self._connection = psycopg2.connect(**connect_kwargs) + logger.debug("New PostgreSQL connection established.") + return self._connection + + def _close_connection(self) -> None: + """Close and discard the cached connection.""" + conn_to_close = self._connection + self._connection = None + if conn_to_close is not None: + try: + conn_to_close.close() + except (PsycopgError, OSError): + logger.debug("Failed to close PostgreSQL connection.") + + def _execute_with_retry[T](self, operation: Callable[..., T]) -> T: + """Run `operation(connection)` with one retry on `OperationalError`. + Args: + operation: Callable receiving a psycopg2 connection. + Returns: + Whatever `operation` returns on success. + Raises: + RuntimeError: If the retry is also exhausted. + """ + last_exc: OperationalError | None = None + for attempt in range(POSTGRES_MAX_RETRIES): + try: + connection = self._get_connection() + return operation(connection) + except OperationalError as exc: + last_exc = exc + self._close_connection() + if attempt < POSTGRES_MAX_RETRIES - 1: + logger.warning("PostgreSQL connection lost, reconnecting.") + raise RuntimeError(f"Database connection failed after {POSTGRES_MAX_RETRIES} attempts: {last_exc}") diff --git a/src/utils/utils.py b/src/utils/utils.py index 8d80cfe..e19d494 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -95,6 +95,6 @@ def load_postgres_config(secret_name: str, secret_region: str) -> dict[str, Any] aws_secrets = boto3.Session().client(service_name="secretsmanager", region_name=secret_region) postgres_secret = aws_secrets.get_secret_value(SecretId=secret_name)["SecretString"] - config: dict[str, Any] = json.loads(postgres_secret) + aws_pg_secret: dict[str, Any] = json.loads(postgres_secret) logger.debug("Loaded PostgreSQL config from Secrets Manager.") - return config + return aws_pg_secret diff --git a/tests/unit/utils/test_trace_logging.py b/tests/unit/utils/test_trace_logging.py index ecaefbf..cda7ad6 100644 --- a/tests/unit/utils/test_trace_logging.py +++ b/tests/unit/utils/test_trace_logging.py @@ -18,6 +18,7 @@ from src.utils.logging_levels import TRACE_LEVEL from src.utils.trace_logging import log_payload_at_trace +import src.utils.postgres_base as postgres_base import src.writers.writer_eventbridge as writer_eventbridge import src.writers.writer_kafka as writer_kafka import src.writers.writer_postgres as writer_postgres @@ -93,6 +94,9 @@ def cursor(self): def commit(self): pass + def close(self): + pass + def __enter__(self): return self @@ -103,14 +107,14 @@ class DummyPsycopg2: def connect(self, **kwargs): return DummyConnection() - monkeypatch.setattr(writer_postgres, "psycopg2", DummyPsycopg2()) + monkeypatch.setattr(postgres_base, "psycopg2", DummyPsycopg2()) # Set trace level on the module's logger writer_postgres.logger.setLevel(TRACE_LEVEL) caplog.set_level(TRACE_LEVEL) writer = writer_postgres.WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + writer._pg_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} message = {"event_id": "e", "tenant_id": "t", "source_app": "a", "environment": "dev", "timestamp": 1} ok, err = writer.write("public.cps.za.test", message) From 3cdabd05c1f3991986999699f18edae392a82b82 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 15:19:11 +0200 Subject: [PATCH 5/7] writer_postgres.py and reader_postgres.py update based on comments. --- .github/copilot-instructions.md | 1 - requirements.txt | 2 + src/readers/reader_postgres.py | 164 +++++----- src/writers/writer_postgres.py | 340 +++++++-------------- tests/unit/readers/test_reader_postgres.py | 53 ++-- tests/unit/utils/test_postgres_base.py | 220 +++++++++++++ tests/unit/writers/test_writer_postgres.py | 315 ++++++++++--------- 7 files changed, 589 insertions(+), 506 deletions(-) create mode 100644 tests/unit/utils/test_postgres_base.py diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 5fabdf0..6c1e5bf 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -54,4 +54,3 @@ Quality gates (run after changes, fix only if below threshold) Git workflow - Do NOT create git commits; committing is the developer's responsibility - diff --git a/requirements.txt b/requirements.txt index 60fe41e..dfd6aab 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,9 @@ cryptography==46.0.7 jsonschema==4.25.1 PyJWT==2.12.1 requests==2.32.5 +aiosql==15.0 boto3==1.42.88 +botocore==1.40.76 confluent-kafka==2.14.0 moto[s3,secretsmanager,events]==5.1.22 testcontainers==4.14.1 diff --git a/src/readers/reader_postgres.py b/src/readers/reader_postgres.py index dd7129f..0f546eb 100644 --- a/src/readers/reader_postgres.py +++ b/src/readers/reader_postgres.py @@ -17,90 +17,57 @@ """Postgres reader for run/job statistics.""" import logging -import os import time +from dataclasses import dataclass from datetime import datetime, timezone +from functools import cached_property +from pathlib import Path from typing import Any +import aiosql from botocore.exceptions import BotoCoreError, ClientError from src.utils.constants import ( POSTGRES_DEFAULT_LIMIT, POSTGRES_DEFAULT_WINDOW_MS, POSTGRES_MAX_LIMIT, + POSTGRES_STATEMENT_TIMEOUT_MS, + REQUIRED_CONNECTION_FIELDS, ) -from src.utils.utils import load_postgres_config - -try: - import psycopg2 - from psycopg2 import Error as PsycopgError - from psycopg2 import OperationalError - from psycopg2 import sql as psycopg2_sql -except ImportError: - psycopg2 = None # type: ignore - psycopg2_sql = None # type: ignore - - class PsycopgError(Exception): # type: ignore - """Shim psycopg2 error base when psycopg2 is not installed.""" - - class OperationalError(PsycopgError): # type: ignore - """Shim psycopg2 OperationalError when psycopg2 is not installed.""" - +from src.utils.postgres_base import PsycopgError, PostgresBase logger = logging.getLogger(__name__) -_RUNS_SQL_BASE = ( - "SELECT r.event_id, r.job_ref, r.tenant_id, r.source_app," - " r.source_app_version, r.environment," - " r.timestamp_start AS run_timestamp_start," - " r.timestamp_end AS run_timestamp_end," - " j.internal_id, j.country, j.catalog_id, j.status," - " j.timestamp_start, j.timestamp_end, j.message, j.additional_info" - " FROM public_cps_za_runs_jobs j" - " INNER JOIN public_cps_za_runs r ON j.event_id = r.event_id" - " WHERE r.timestamp_start >= %s AND r.timestamp_start <= %s" -) +_SQL_DIR = Path(__file__).parent / "sql" + -_RUNS_SQL_TAIL = " ORDER BY j.internal_id DESC LIMIT %s" +@dataclass(frozen=True) +class ReaderQueries: + """Typed holder for reader SQL query strings loaded via aiosql.""" -_RUNS_SQL = _RUNS_SQL_BASE + _RUNS_SQL_TAIL -_RUNS_SQL_WITH_CURSOR = _RUNS_SQL_BASE + " AND j.internal_id < %s" + _RUNS_SQL_TAIL + get_stats: str + get_stats_with_cursor: str -class ReaderPostgres: +class ReaderPostgres(PostgresBase): """Read-only Postgres accessor for run/job statistics.""" def __init__(self) -> None: - self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") - self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") - self._db_config: dict[str, Any] | None = None - self._connection: Any | None = None + super().__init__() logger.debug("Initialized PostgreSQL reader.") - def _load_db_config(self) -> dict[str, Any]: - """Load database config from AWS Secrets Manager if not already loaded.""" - if self._db_config is None: - self._db_config = load_postgres_config(self._secret_name, self._secret_region) - config = self._db_config - if config is None: - raise RuntimeError("Failed to load database configuration.") - return config - - def _get_connection(self) -> Any: - """Return a cached database connection, creating one if needed.""" - if self._connection is not None and not self._connection.closed: - return self._connection - db_config = self._load_db_config() - self._connection = psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], - options="-c statement_timeout=30000 -c default_transaction_read_only=on", + def _connect_options(self) -> str | None: + """Set statement timeout and read-only mode for reader connections.""" + return f"-c statement_timeout={POSTGRES_STATEMENT_TIMEOUT_MS}" " -c default_transaction_read_only=on" + + @cached_property + def _queries(self) -> ReaderQueries: + """Load SQL queries from the `sql/` directory via aiosql.""" + queries = aiosql.from_path(_SQL_DIR, "psycopg2") + return ReaderQueries( + get_stats=queries.get_stats.sql, + get_stats_with_cursor=queries.get_stats_with_cursor.sql, ) - logger.debug("New PostgreSQL reader connection established.") - return self._connection def read_stats( self, @@ -123,44 +90,24 @@ def read_stats( Raises: RuntimeError: On database connectivity or query errors. """ - db_config = self._load_db_config() - required_keys = ("database", "host", "user", "password", "port") - missing_keys = [key for key in required_keys if not db_config.get(key)] - if missing_keys: - raise RuntimeError(f"PostgreSQL config missing: {', '.join(missing_keys)}.") - if psycopg2 is None: - raise RuntimeError("psycopg2 is not available.") + config = self._pg_config + if not config.get("database"): + raise RuntimeError("PostgreSQL config missing: database.") + if not all(config.get(field) for field in REQUIRED_CONNECTION_FIELDS): + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not config.get(field)] + raise RuntimeError(f"PostgreSQL config missing: {', '.join(missing)}.") limit = max(1, min(limit, POSTGRES_MAX_LIMIT)) now_ms = int(time.time() * 1000) ts_start = timestamp_start if timestamp_start is not None else (now_ms - POSTGRES_DEFAULT_WINDOW_MS) ts_end = timestamp_end if timestamp_end is not None else now_ms - params: list[Any] = [ts_start, ts_end] - if cursor is not None: - params.append(cursor) - query = psycopg2_sql.SQL(_RUNS_SQL_WITH_CURSOR) - else: - query = psycopg2_sql.SQL(_RUNS_SQL) - params.append(limit + 1) - try: - for attempt in range(2): - try: - connection = self._get_connection() - with connection.cursor() as db_cursor: - db_cursor.execute(query, params) - col_names = [desc[0] for desc in db_cursor.description] # type: ignore[union-attr] - raw_rows = db_cursor.fetchall() - connection.rollback() - break - except OperationalError as exc: - self._connection = None - if attempt > 0: - raise RuntimeError(f"Database connection failed after retry: {exc}") from exc - logger.warning("PostgreSQL connection lost, reconnecting.") + col_names, raw_rows = self._execute_with_retry( + lambda conn: self._run_stats_query(conn, ts_start, ts_end, cursor, limit) + ) except PsycopgError as exc: - self._connection = None + self._close_connection() raise RuntimeError(f"Database query error: {exc}") from exc rows = [dict(zip(col_names, row, strict=True)) for row in raw_rows] @@ -184,6 +131,37 @@ def read_stats( logger.debug("Stats query returned %d rows.", len(rows)) return rows, pagination + def _run_stats_query( + self, + connection: Any, + ts_start: int, + ts_end: int, + cursor: int | None, + limit: int, + ) -> tuple[list[str], list[tuple[Any, ...]]]: + """Execute the stats SQL query and return column names and raw rows.""" + try: + with connection.cursor() as db_cursor: + if cursor is not None: + db_cursor.execute( + self._queries.get_stats_with_cursor, + {"ts_start": ts_start, "ts_end": ts_end, "cursor_id": cursor, "lim": limit + 1}, + ) + else: + db_cursor.execute( + self._queries.get_stats, + {"ts_start": ts_start, "ts_end": ts_end, "lim": limit + 1}, + ) + if db_cursor.description is None: + raise RuntimeError("Stats query returned no result description.") + col_names = [desc[0] for desc in db_cursor.description] + raw_rows = db_cursor.fetchall() + finally: + # Rollback closes the implicit transaction opened by the SELECT, + # leaving the cached connection in a clean idle state for reuse. + connection.rollback() + return col_names, raw_rows + @staticmethod def _format_row(row: dict[str, Any]) -> dict[str, Any]: """Add computed columns to a result row. @@ -249,14 +227,14 @@ def check_health(self) -> tuple[bool, str]: return False, "postgres secret not configured" try: - db_config = self._load_db_config() + pg_config = self._pg_config except (BotoCoreError, ClientError, RuntimeError, ValueError, KeyError) as err: return False, str(err) - if not db_config.get("database"): + if not pg_config.get("database"): return False, "database not configured" - missing = [f for f in ("host", "user", "password", "port") if not db_config.get(f)] + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing: return False, f"{missing[0]} not configured" diff --git a/src/writers/writer_postgres.py b/src/writers/writer_postgres.py index 45f79b6..85949cf 100644 --- a/src/writers/writer_postgres.py +++ b/src/writers/writer_postgres.py @@ -18,250 +18,136 @@ import json import logging -import os +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path from typing import Any +import aiosql from botocore.exceptions import BotoCoreError, ClientError -from src.utils.constants import TOPIC_TABLE_MAP +from src.utils.constants import REQUIRED_CONNECTION_FIELDS, TOPIC_DLCHANGE, TOPIC_RUNS, TOPIC_TABLE_MAP, TOPIC_TEST +from src.utils.postgres_base import PsycopgError, PostgresBase +import src.utils.postgres_base as _pb from src.utils.trace_logging import log_payload_at_trace -from src.utils.utils import load_postgres_config from src.writers.writer import Writer -try: - import psycopg2 - from psycopg2 import Error as PsycopgError - from psycopg2 import OperationalError - from psycopg2 import sql as psycopg2_sql -except ImportError: - psycopg2 = None # type: ignore - psycopg2_sql = None # type: ignore +logger = logging.getLogger(__name__) - class PsycopgError(Exception): # type: ignore - """Shim psycopg2 error base when psycopg2 is not installed.""" +_SQL_DIR = Path(__file__).parent / "sql" - class OperationalError(PsycopgError): # type: ignore - """Shim psycopg2 OperationalError when psycopg2 is not installed.""" +@dataclass(frozen=True) +class WriterQueries: + """Typed holder for writer SQL query strings loaded via aiosql.""" -logger = logging.getLogger(__name__) + insert_dlchange: str + insert_run: str + insert_run_job: str + insert_test: str -class WriterPostgres(Writer): - """Postgres writer for storing events in PostgreSQL database. - Database credentials are loaded lazily from AWS Secrets Manager on first use. - """ +class WriterPostgres(Writer, PostgresBase): + """Postgres writer for storing events in PostgreSQL database.""" def __init__(self, config: dict[str, Any]) -> None: - super().__init__(config) - self._secret_name = os.environ.get("POSTGRES_SECRET_NAME", "") - self._secret_region = os.environ.get("POSTGRES_SECRET_REGION", "") - self._db_config: dict[str, Any | None] | None = None - self._connection: Any | None = None + Writer.__init__(self, config) + PostgresBase.__init__(self) logger.debug("Initialized PostgreSQL writer.") - def _load_db_config(self) -> dict[str, Any]: - """Load database config from AWS Secrets Manager if not already loaded.""" - if self._db_config is None: - self._db_config = load_postgres_config(self._secret_name, self._secret_region) - return self._db_config # type: ignore[return-value] - - def _get_connection(self) -> Any: - """Return a cached database connection, creating one if needed.""" - if self._connection is not None and not self._connection.closed: - return self._connection - db_config = self._load_db_config() - self._connection = psycopg2.connect( # type: ignore[attr-defined] - database=db_config["database"], - host=db_config["host"], - user=db_config["user"], - password=db_config["password"], - port=db_config["port"], + @cached_property + def _queries(self) -> WriterQueries: + """Load SQL queries from the `sql/` directory via aiosql.""" + queries = aiosql.from_path(_SQL_DIR, "psycopg2") + return WriterQueries( + insert_dlchange=queries.insert_dlchange.sql, + insert_run=queries.insert_run.sql, + insert_run_job=queries.insert_run_job.sql, + insert_test=queries.insert_test.sql, ) - logger.debug("New PostgreSQL writer connection established.") - return self._connection - def _postgres_edla_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: + def _insert_dlchange(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a dlchange style event row. Args: cursor: Database cursor. - table: Target table name. message: Event payload. """ - logger.debug("Sending to Postgres - %s.", table) - query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - tenant_id, - source_app, - source_app_version, - environment, - timestamp_event, - country, - catalog_id, - operation, - "location", - "format", - format_options, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table)) + logger.debug("Sending to Postgres - dlchange.") cursor.execute( - query, - ( - message["event_id"], - message["tenant_id"], - message["source_app"], - message["source_app_version"], - message["environment"], - message["timestamp_event"], - message.get("country", ""), - message["catalog_id"], - message["operation"], - message.get("location"), - message["format"], - (json.dumps(message.get("format_options")) if "format_options" in message else None), - (json.dumps(message.get("additional_info")) if "additional_info" in message else None), - ), + self._queries.insert_dlchange, + { + "event_id": message["event_id"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "source_app_version": message["source_app_version"], + "environment": message["environment"], + "timestamp_event": message["timestamp_event"], + "country": message.get("country", ""), + "catalog_id": message["catalog_id"], + "operation": message["operation"], + "location": message.get("location"), + "format": message["format"], + "format_options": (json.dumps(message.get("format_options")) if "format_options" in message else None), + "additional_info": ( + json.dumps(message.get("additional_info")) if "additional_info" in message else None + ), + }, ) - def _postgres_run_write(self, cursor: Any, table_runs: str, table_jobs: str, message: dict[str, Any]) -> None: + def _insert_run(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a run event row plus related job rows. Args: cursor: Database cursor. - table_runs: Runs table name. - table_jobs: Jobs table name. message: Event payload (includes jobs array). """ - logger.debug("Sending to Postgres - %s and %s.", table_runs, table_jobs) - runs_query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - job_ref, - tenant_id, - source_app, - source_app_version, - environment, - timestamp_start, - timestamp_end - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table_runs)) + logger.debug("Sending to Postgres - runs.") cursor.execute( - runs_query, - ( - message["event_id"], - message["job_ref"], - message["tenant_id"], - message["source_app"], - message["source_app_version"], - message["environment"], - message["timestamp_start"], - message["timestamp_end"], - ), + self._queries.insert_run, + { + "event_id": message["event_id"], + "job_ref": message["job_ref"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "source_app_version": message["source_app_version"], + "environment": message["environment"], + "timestamp_start": message["timestamp_start"], + "timestamp_end": message["timestamp_end"], + }, ) - - jobs_query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - country, - catalog_id, - status, - timestamp_start, - timestamp_end, - message, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table_jobs)) for job in message["jobs"]: cursor.execute( - jobs_query, - ( - message["event_id"], - job.get("country", ""), - job["catalog_id"], - job["status"], - job["timestamp_start"], - job["timestamp_end"], - job.get("message"), - (json.dumps(job.get("additional_info")) if "additional_info" in job else None), - ), + self._queries.insert_run_job, + { + "event_id": message["event_id"], + "country": job.get("country", ""), + "catalog_id": job["catalog_id"], + "status": job["status"], + "timestamp_start": job["timestamp_start"], + "timestamp_end": job["timestamp_end"], + "message": job.get("message"), + "additional_info": (json.dumps(job.get("additional_info")) if "additional_info" in job else None), + }, ) - def _postgres_test_write(self, cursor: Any, table: str, message: dict[str, Any]) -> None: + def _insert_test(self, cursor: Any, message: dict[str, Any]) -> None: """Insert a test topic row. Args: cursor: Database cursor. - table: Target table name. message: Event payload. """ - logger.debug("Sending to Postgres - %s.", table) - query = psycopg2_sql.SQL(""" - INSERT INTO {} - ( - event_id, - tenant_id, - source_app, - environment, - timestamp_event, - additional_info - ) - VALUES - ( - %s, - %s, - %s, - %s, - %s, - %s - )""").format(psycopg2_sql.Identifier(table)) + logger.debug("Sending to Postgres - test.") cursor.execute( - query, - ( - message["event_id"], - message["tenant_id"], - message["source_app"], - message["environment"], - message["timestamp"], - (json.dumps(message.get("additional_info")) if "additional_info" in message else None), - ), + self._queries.insert_test, + { + "event_id": message["event_id"], + "tenant_id": message["tenant_id"], + "source_app": message["source_app"], + "environment": message["environment"], + "timestamp_event": message["timestamp"], + "additional_info": ( + json.dumps(message.get("additional_info")) if "additional_info" in message else None + ), + }, ) def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | None]: @@ -273,19 +159,19 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N Tuple of (success: bool, error_message: str | None). """ try: - db_config = self._load_db_config() + pg_config = self._pg_config - if not db_config.get("database"): + if not pg_config.get("database"): logger.debug("No Postgres - skipping Postgres writer.") return True, None - missing = [f for f in ("host", "user", "password", "port") if not db_config.get(f)] + missing = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing: msg = f"PostgreSQL connection field '{missing[0]}' not configured." logger.error(msg) return False, msg - if psycopg2 is None: + if not self._is_psycopg2_available(): logger.debug("psycopg2 not available - skipping actual Postgres write.") return True, None @@ -296,53 +182,49 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N logger.error(msg) return False, msg - table_info = TOPIC_TABLE_MAP[topic_name] - - for attempt in range(2): - try: - connection = self._get_connection() - with connection.cursor() as cursor: - if topic_name == "public.cps.za.dlchange": - self._postgres_edla_write(cursor, table_info["main"], message) - elif topic_name == "public.cps.za.runs": - self._postgres_run_write(cursor, table_info["main"], table_info["jobs"], message) - elif topic_name == "public.cps.za.test": - self._postgres_test_write(cursor, table_info["main"], message) - connection.commit() - break - except OperationalError: - self._connection = None - if attempt > 0: - raise - logger.warning("PostgreSQL connection lost, reconnecting.") + self._execute_with_retry(lambda conn: self._write_topic(conn, topic_name, message)) except (RuntimeError, PsycopgError, BotoCoreError, ClientError, ValueError, KeyError) as e: - self._connection = None + self._close_connection() err_msg = f"The Postgres writer failed with unknown error: {str(e)}" logger.exception(err_msg) return False, err_msg return True, None + def _write_topic(self, connection: Any, topic_name: str, message: dict[str, Any]) -> None: + """Execute the insert for the given topic inside a transaction.""" + with connection.cursor() as cursor: + if topic_name == TOPIC_DLCHANGE: + self._insert_dlchange(cursor, message) + elif topic_name == TOPIC_RUNS: + self._insert_run(cursor, message) + elif topic_name == TOPIC_TEST: + self._insert_test(cursor, message) + connection.commit() + + @staticmethod + def _is_psycopg2_available() -> bool: + """Check whether psycopg2 is importable.""" + return _pb.psycopg2 is not None + def check_health(self) -> tuple[bool, str]: """Check PostgreSQL writer health. Returns: Tuple of (is_healthy: bool, message: str). """ - # Checking if Postgres intentionally disabled if not self._secret_name or not self._secret_region: return True, "not configured" try: - db_config = self._load_db_config() + pg_config = self._pg_config logger.debug("PostgreSQL config loaded during health check.") except (BotoCoreError, ClientError, ValueError, KeyError) as err: return False, str(err) - # Validate database configuration fields - if not db_config.get("database"): + if not pg_config.get("database"): return True, "database not configured" - missing_fields = [field for field in ("host", "user", "password", "port") if not db_config.get(field)] + missing_fields = [field for field in REQUIRED_CONNECTION_FIELDS if not pg_config.get(field)] if missing_fields: return False, f"{missing_fields[0]} not configured" diff --git a/tests/unit/readers/test_reader_postgres.py b/tests/unit/readers/test_reader_postgres.py index bab7c28..5d3f378 100644 --- a/tests/unit/readers/test_reader_postgres.py +++ b/tests/unit/readers/test_reader_postgres.py @@ -22,7 +22,7 @@ import pytest from src.readers.reader_postgres import ReaderPostgres -import src.readers.reader_postgres as rp +import src.utils.postgres_base as pb _STATS_DESCRIPTION = [ ("event_id",), @@ -69,17 +69,17 @@ def reader(mock_env: None) -> ReaderPostgres: return ReaderPostgres() -class TestLoadDbConfig: - """Tests for lazy database configuration loading.""" +class TestDbConfig: + """Tests for lazy database configuration loading via cached_property.""" def test_loads_config_from_secrets_manager(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: - """Test that _load_db_config loads from Secrets Manager.""" + """Test that _pg_config loads from Secrets Manager.""" mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} with patch("boto3.Session") as mock_session: mock_session.return_value.client.return_value = mock_client - result = reader._load_db_config() + result = reader._pg_config assert "eventgate" == result["database"] assert "localhost" == result["host"] @@ -91,8 +91,8 @@ def test_caches_config_after_first_load(self, reader: ReaderPostgres, pg_secret: with patch("boto3.Session") as mock_session: mock_session.return_value.client.return_value = mock_client - reader._load_db_config() - reader._load_db_config() + _ = reader._pg_config + _ = reader._pg_config assert 1 == mock_client.get_secret_value.call_count @@ -102,7 +102,7 @@ def test_returns_empty_db_when_no_env_vars(self, monkeypatch: pytest.MonkeyPatch monkeypatch.setenv("POSTGRES_SECRET_REGION", "") reader = ReaderPostgres() - result = reader._load_db_config() + result = reader._pg_config assert "" == result["database"] @@ -169,7 +169,7 @@ def test_returns_rows_and_pagination(self, reader: ReaderPostgres, pg_secret: di with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -196,7 +196,7 @@ def test_has_more_when_extra_row_returned(self, reader: ReaderPostgres, pg_secre with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -234,12 +234,12 @@ def test_missing_connection_field_raises_runtime_error( reader.read_stats() def test_cursor_filters_by_internal_id(self, reader: ReaderPostgres, pg_secret: dict[str, Any]) -> None: - """Test that passing cursor appends internal_id condition.""" + """Test that passing cursor uses the cursor query variant.""" mock_conn = _make_mock_connection(_STATS_DESCRIPTION, []) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -251,9 +251,8 @@ def test_cursor_filters_by_internal_id(self, reader: ReaderPostgres, pg_secret: executed_sql = mock_conn.cursor.return_value.__enter__.return_value.execute.call_args[0][0] executed_params = mock_conn.cursor.return_value.__enter__.return_value.execute.call_args[0][1] - sql_str = executed_sql.as_string(None) if hasattr(executed_sql, "as_string") else executed_sql - assert "j.internal_id < %s" in sql_str - assert 100 in executed_params + assert "j.internal_id <" in executed_sql + assert 100 == executed_params["cursor_id"] class TestFormatRow: @@ -350,8 +349,12 @@ def test_healthy_when_config_valid(self, reader: ReaderPostgres, pg_secret: dict assert "ok" == message def test_unhealthy_when_load_raises_runtime_error(self, reader: ReaderPostgres) -> None: - """Test returns unhealthy when _load_db_config raises RuntimeError.""" - with patch.object(reader, "_load_db_config", side_effect=RuntimeError("Failed to load.")): + """Test returns unhealthy when _pg_config raises RuntimeError.""" + with patch.object( + type(reader), + "_pg_config", + new_callable=lambda: property(lambda self: (_ for _ in ()).throw(RuntimeError("Failed to load."))), + ): healthy, message = reader.check_health() assert False is healthy @@ -367,7 +370,7 @@ def test_reconnects_on_closed_connection(self, reader: ReaderPostgres, pg_secret with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -386,13 +389,13 @@ def test_retries_on_operational_error(self, reader: ReaderPostgres, pg_secret: d """Test that OperationalError triggers retry with fresh connection.""" fail_conn = _make_mock_connection(_STATS_DESCRIPTION, []) fail_cursor = fail_conn.cursor.return_value.__enter__.return_value - fail_cursor.execute.side_effect = rp.OperationalError("connection reset") + fail_cursor.execute.side_effect = pb.OperationalError("connection reset") ok_conn = _make_mock_connection(_STATS_DESCRIPTION, []) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} @@ -409,20 +412,20 @@ def test_raises_after_retry_exhausted(self, reader: ReaderPostgres, pg_secret: d fail_conn = MagicMock() fail_conn.closed = 0 fail_cursor = MagicMock() - fail_cursor.execute.side_effect = rp.OperationalError("connection reset") + fail_cursor.execute.side_effect = pb.OperationalError("connection reset") fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} mock_session.return_value.client.return_value = mock_client mock_pg.connect.return_value = fail_conn - with pytest.raises(RuntimeError, match="Database connection failed after retry"): + with pytest.raises(RuntimeError, match="Database connection failed after"): reader.read_stats(limit=10) def test_discards_connection_on_non_operational_error( @@ -432,13 +435,13 @@ def test_discards_connection_on_non_operational_error( fail_conn = MagicMock() fail_conn.closed = 0 fail_cursor = MagicMock() - fail_cursor.execute.side_effect = rp.PsycopgError("integrity error") + fail_cursor.execute.side_effect = pb.PsycopgError("integrity error") fail_conn.cursor.return_value.__enter__ = MagicMock(return_value=fail_cursor) fail_conn.cursor.return_value.__exit__ = MagicMock(return_value=False) with ( patch("boto3.Session") as mock_session, - patch("src.readers.reader_postgres.psycopg2") as mock_pg, + patch.object(pb, "psycopg2") as mock_pg, ): mock_client = MagicMock() mock_client.get_secret_value.return_value = {"SecretString": json.dumps(pg_secret)} diff --git a/tests/unit/utils/test_postgres_base.py b/tests/unit/utils/test_postgres_base.py new file mode 100644 index 0000000..714c7e2 --- /dev/null +++ b/tests/unit/utils/test_postgres_base.py @@ -0,0 +1,220 @@ +# +# Copyright 2026 ABSA Group Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from unittest.mock import MagicMock, patch + +import pytest + +import src.utils.postgres_base as pb +from src.utils.postgres_base import PostgresBase, _build_postgres_config + + +class _ConcreteBase(PostgresBase): + """Minimal concrete subclass for testing PostgresBase directly.""" + + +# _build_postgres_config + + +def test_build_postgres_config_full(): + raw = {"database": "mydb", "host": "localhost", "user": "admin", "password": "secret", "port": "5432"} + result = _build_postgres_config(raw) + assert "mydb" == result["database"] + assert "localhost" == result["host"] + assert "admin" == result["user"] + assert "secret" == result["password"] + assert 5432 == result["port"] + + +def test_build_postgres_config_defaults_for_missing_keys(): + result = _build_postgres_config({}) + assert "" == result["database"] + assert "" == result["host"] + assert "" == result["user"] + assert "" == result["password"] + assert 0 == result["port"] + + +# PostgresBase.__init__ + + +def test_init_reads_env_vars(monkeypatch): + monkeypatch.setenv("POSTGRES_SECRET_NAME", "my-secret") + monkeypatch.setenv("POSTGRES_SECRET_REGION", "eu-west-1") + instance = _ConcreteBase() + assert "my-secret" == instance._secret_name + assert "eu-west-1" == instance._secret_region + assert instance._connection is None + + +def test_init_defaults_when_env_vars_absent(monkeypatch): + monkeypatch.delenv("POSTGRES_SECRET_NAME", raising=False) + monkeypatch.delenv("POSTGRES_SECRET_REGION", raising=False) + instance = _ConcreteBase() + assert "" == instance._secret_name + assert "" == instance._secret_region + + +# _pg_config + + +def test_pg_config_is_cached(): + base = _ConcreteBase() + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret) as mock_load: + _ = base._pg_config + _ = base._pg_config + assert 1 == mock_load.call_count + + +def test_pg_config_builds_correct_values(): + base = _ConcreteBase() + secret = {"database": "testdb", "host": "dbhost", "user": "dbuser", "password": "dbpass", "port": 5433} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + config = base._pg_config + assert "testdb" == config["database"] + assert 5433 == config["port"] + + +# _connect_options + + +def test_connect_options_returns_none(): + assert _ConcreteBase()._connect_options() is None + + +# _get_connection + + +def test_get_connection_returns_cached_open_connection(): + base = _ConcreteBase() + mock_conn = MagicMock(closed=0) + base._connection = mock_conn + assert mock_conn is base._get_connection() + + +def test_get_connection_raises_when_psycopg2_not_installed(monkeypatch): + monkeypatch.setattr(pb, "psycopg2", None) + with pytest.raises(RuntimeError, match="psycopg2 is not installed"): + _ConcreteBase()._get_connection() + + +def test_get_connection_creates_new_when_closed(monkeypatch): + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = MagicMock(closed=1) + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + result = base._get_connection() + assert mock_conn is result + + +def test_get_connection_passes_options_from_subclass(monkeypatch): + class _BaseWithOptions(PostgresBase): + def _connect_options(self) -> str | None: + return "-c statement_timeout=5000" + + captured = {} + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.side_effect = lambda **kw: captured.update(kw) or MagicMock(closed=0) + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + _BaseWithOptions()._get_connection() + assert "-c statement_timeout=5000" == captured["options"] + + +# _close_connection + + +def test_close_connection_clears_reference_and_calls_close(): + base = _ConcreteBase() + mock_conn = MagicMock() + base._connection = mock_conn + base._close_connection() + mock_conn.close.assert_called_once() + assert base._connection is None + + +def test_close_connection_is_noop_when_no_connection(): + base = _ConcreteBase() + base._close_connection() + assert base._connection is None + + +def test_close_connection_silences_psycopg_error(): + base = _ConcreteBase() + mock_conn = MagicMock() + mock_conn.close.side_effect = pb.PsycopgError("gone") + base._connection = mock_conn + base._close_connection() + assert base._connection is None + + +def test_close_connection_silences_os_error(): + base = _ConcreteBase() + mock_conn = MagicMock() + mock_conn.close.side_effect = OSError("socket closed") + base._connection = mock_conn + base._close_connection() + assert base._connection is None + + +# _execute_with_retry + + +def test_execute_with_retry_returns_operation_result(): + base = _ConcreteBase() + base._connection = MagicMock(closed=0) + assert "ok" == base._execute_with_retry(lambda conn: "ok") + + +def test_execute_with_retry_reconnects_on_first_failure(monkeypatch): + calls = [] + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + + def op(conn): + calls.append(len(calls)) + if len(calls) == 1: + raise pb.OperationalError("timeout") + return "recovered" + + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + result = base._execute_with_retry(op) + + assert "recovered" == result + assert 2 == len(calls) + + +def test_execute_with_retry_raises_after_all_attempts_fail(monkeypatch): + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + with pytest.raises(RuntimeError, match="Database connection failed after"): + base._execute_with_retry(lambda conn: (_ for _ in ()).throw(pb.OperationalError("down"))) diff --git a/tests/unit/writers/test_writer_postgres.py b/tests/unit/writers/test_writer_postgres.py index 27a1a03..db4536c 100644 --- a/tests/unit/writers/test_writer_postgres.py +++ b/tests/unit/writers/test_writer_postgres.py @@ -21,35 +21,58 @@ import pytest from src.writers.writer_postgres import WriterPostgres -import src.writers.writer_postgres as wp +import src.utils.postgres_base as pb import src.utils.utils as secrets_mod class MockCursor: - def __init__(self): - self.executions = [] - - @staticmethod - def _sql_to_str(sql) -> str: - """Render a psycopg2 SQL Composable without a real connection.""" - if hasattr(sql, "_seq"): - return "".join(MockCursor._sql_to_str(part) for part in sql._seq) - if hasattr(sql, "_wrapped"): - wrapped = sql._wrapped - if isinstance(wrapped, str): - return wrapped - return ".".join(f'"{s}"' for s in wrapped) - return str(sql) + def __init__(self, store=None): + self.executions = store if store is not None else [] def execute(self, sql, params): - sql_str = self._sql_to_str(sql) if hasattr(sql, "_seq") or hasattr(sql, "_wrapped") else str(sql) - self.executions.append((sql_str.strip(), params)) + self.executions.append((sql, params)) + def __enter__(self): + return self + + def __exit__(self, *args): + return False + + +class DummyConnection: + def __init__(self, store): + self.commit_called = False + self.store = store + self.closed = 0 + + def cursor(self): + return MockCursor(self.store) + + def commit(self): + self.commit_called = True + + def close(self): + pass + + +class DummyPsycopg: + def __init__(self, store): + self.store = store + self.connect_count = 0 + + def connect(self, **kwargs): + self.connect_count += 1 + return DummyConnection(self.store) -# --- Insert helpers --- + +@pytest.fixture +def reset_env(): + yield + os.environ.pop("POSTGRES_SECRET_NAME", None) + os.environ.pop("POSTGRES_SECRET_REGION", None) -def test_postgres_edla_write_with_optional_fields(): +def test_insert_dlchange_with_optional_fields(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -67,19 +90,18 @@ def test_postgres_edla_write_with_optional_fields(): "format_options": {"compression": "snappy"}, "additional_info": {"foo": "bar"}, } - writer._postgres_edla_write(cur, "table_a", message) - assert len(cur.executions) == 1 + writer._insert_dlchange(cur, message) + assert 1 == len(cur.executions) _sql, params = cur.executions[0] - assert len(params) == 13 - assert params[0] == "e1" - assert params[6] == "za" - assert params[9] == "s3://bucket/path" - assert params[10] == "parquet" - assert json.loads(params[11]) == {"compression": "snappy"} - assert json.loads(params[12]) == {"foo": "bar"} + assert "e1" == params["event_id"] + assert "za" == params["country"] + assert "s3://bucket/path" == params["location"] + assert "parquet" == params["format"] + assert {"compression": "snappy"} == json.loads(params["format_options"]) + assert {"foo": "bar"} == json.loads(params["additional_info"]) -def test_postgres_edla_write_missing_optional(): +def test_insert_dlchange_missing_optional(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -93,16 +115,16 @@ def test_postgres_edla_write_missing_optional(): "operation": "overwrite", "format": "delta", } - writer._postgres_edla_write(cur, "table_a", message) + writer._insert_dlchange(cur, message) _sql, params = cur.executions[0] - assert params[6] == "" - assert params[9] is None - assert params[10] == "delta" - assert params[11] is None - assert params[12] is None + assert "" == params["country"] + assert params["location"] is None + assert "delta" == params["format"] + assert params["format_options"] is None + assert params["additional_info"] is None -def test_postgres_run_write(): +def test_insert_run(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -127,28 +149,24 @@ def test_postgres_run_write(): }, ], } - writer._postgres_run_write(cur, "runs_table", "jobs_table", message) - assert len(cur.executions) == 3 + writer._insert_run(cur, message) + assert 3 == len(cur.executions) - # Check run insert - run_sql, run_params = cur.executions[0] - assert "source_app_version" in run_sql - assert run_params[3] == "runapp" + _run_sql, run_params = cur.executions[0] + assert "runapp" == run_params["source_app"] - # Check first job _job1_sql, job1_params = cur.executions[1] - assert job1_params[1] == "" - assert job1_params[2] == "c1" + assert "" == job1_params["country"] + assert "c1" == job1_params["catalog_id"] - # Check second job _job2_sql, job2_params = cur.executions[2] - assert job2_params[1] == "bw" - assert job2_params[2] == "c2" - assert job2_params[6] == "err" - assert json.loads(job2_params[7]) == {"k": "v"} + assert "bw" == job2_params["country"] + assert "c2" == job2_params["catalog_id"] + assert "err" == job2_params["message"] + assert {"k": "v"} == json.loads(job2_params["additional_info"]) -def test_postgres_test_write(): +def test_insert_test(): writer = WriterPostgres({}) cur = MockCursor() message = { @@ -159,110 +177,67 @@ def test_postgres_test_write(): "timestamp": 999, "additional_info": {"a": 1}, } - writer._postgres_test_write(cur, "table_test", message) - assert len(cur.executions) == 1 + writer._insert_test(cur, message) + assert 1 == len(cur.executions) _sql, params = cur.executions[0] - assert params[0] == "t1" - assert params[1] == "tenant-x" - assert json.loads(params[5]) == {"a": 1} - - -# --- write() behavioral paths --- - - -@pytest.fixture -def reset_env(): - yield - os.environ.pop("POSTGRES_SECRET_NAME", None) - os.environ.pop("POSTGRES_SECRET_REGION", None) - - -class DummyCursor: - def __init__(self, store): - self.store = store - - def execute(self, sql, params): - self.store.append((sql, params)) - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -class DummyConnection: - def __init__(self, store): - self.commit_called = False - self.store = store - self.closed = 0 - - def cursor(self): - return DummyCursor(self.store) - - def commit(self): - self.commit_called = True - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - -class DummyPsycopg: - def __init__(self, store): - self.store = store - self.connect_count = 0 - - def connect(self, **kwargs): - self.connect_count += 1 - return DummyConnection(self.store) - - -# --- write() --- + assert "t1" == params["event_id"] + assert "tenant-x" == params["tenant_id"] + assert {"a": 1} == json.loads(params["additional_info"]) def test_write_skips_when_no_database(reset_env): writer = WriterPostgres({}) - writer._db_config = {"database": ""} + type(writer)._pg_config = property(lambda self: {"database": ""}) ok, err = writer.write("public.cps.za.test", {}) + del type(writer)._pg_config assert ok and err is None def test_write_fails_when_connection_field_missing(reset_env): writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432} + ) ok, err = writer.write("public.cps.za.test", {}) + del type(writer)._pg_config assert not ok assert "host" in err and "not configured" in err def test_write_skips_when_psycopg2_missing(reset_env, monkeypatch): writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} - monkeypatch.setattr(wp, "psycopg2", None) + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) + monkeypatch.setattr(pb, "psycopg2", None) ok, err = writer.write("public.cps.za.test", {}) + del type(writer)._pg_config assert ok and err is None def test_write_unknown_topic_returns_false(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) ok, err = writer.write("public.cps.za.unknown", {}) + del type(writer)._pg_config assert not ok and "Unknown topic" in err def test_write_success_known_topic(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) - assert ok and err is None and len(store) == 1 + del type(writer)._pg_config + assert ok and err is None and 1 == len(store) def test_write_exception_returns_false(reset_env, monkeypatch): @@ -270,10 +245,13 @@ class FailingPsycopg: def connect(self, **kwargs): raise RuntimeError("boom") - monkeypatch.setattr(wp, "psycopg2", FailingPsycopg()) + monkeypatch.setattr(pb, "psycopg2", FailingPsycopg()) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) ok, err = writer.write("public.cps.za.test", {}) + del type(writer)._pg_config assert not ok and "failed with unknown error" in err @@ -289,17 +267,18 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", lambda: MockSession()) writer = WriterPostgres({}) - assert writer._db_config is None - # Trigger lazy load via check_health + assert "_pg_config" not in writer.__dict__ writer.check_health() - assert writer._db_config == secret_dict + assert "db" == writer._pg_config["database"] def test_write_dlchange_success(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = { "event_id": "e1", "tenant_id": "t", @@ -312,14 +291,17 @@ def test_write_dlchange_success(reset_env, monkeypatch): "format": "parquet", } ok, err = writer.write("public.cps.za.dlchange", message) - assert ok and err is None and len(store) == 1 + del type(writer)._pg_config + assert ok and err is None and 1 == len(store) def test_write_runs_success(reset_env, monkeypatch): store = [] - monkeypatch.setattr(wp, "psycopg2", DummyPsycopg(store)) + monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = { "event_id": "r1", "job_ref": "job", @@ -332,17 +314,14 @@ def test_write_runs_success(reset_env, monkeypatch): "jobs": [{"catalog_id": "c", "status": "ok", "timestamp_start": 1, "timestamp_end": 2}], } ok, err = writer.write("public.cps.za.runs", message) - assert ok and err is None and len(store) == 2 # run + job insert - - -# --- check_health() --- + del type(writer)._pg_config + assert ok and err is None and 2 == len(store) # run + job insert def test_check_health_not_configured(): - # No secret env vars set, so it's "not configured" writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert healthy and msg == "not configured" + assert healthy and "not configured" == msg def test_check_health_success(reset_env, monkeypatch): @@ -358,7 +337,7 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", MockSession) writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert healthy and msg == "ok" + assert healthy and "ok" == msg def test_check_health_missing_host(reset_env, monkeypatch): @@ -382,33 +361,37 @@ def test_check_health_database_not_configured(): writer = WriterPostgres({}) writer._secret_name = "mysecret" writer._secret_region = "eu-west-1" - writer._db_config = {"database": ""} + type(writer)._pg_config = property(lambda self: {"database": ""}) healthy, msg = writer.check_health() + del type(writer)._pg_config assert healthy assert "database not configured" == msg def test_check_health_load_config_exception(mocker): - """check_health returns (False, error) when _load_db_config raises.""" + """check_health returns (False, error) when _pg_config raises.""" writer = WriterPostgres({}) writer._secret_name = "mysecret" writer._secret_region = "eu-west-1" - mocker.patch.object(writer, "_load_db_config", side_effect=ValueError("secret fetch failed")) + mocker.patch.object( + type(writer), + "_pg_config", + new_callable=lambda: property(lambda self: (_ for _ in ()).throw(ValueError("secret fetch failed"))), + ) healthy, msg = writer.check_health() assert not healthy assert "secret fetch failed" == msg -# --- connection reuse --- - - def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): store = [] psycopg = DummyPsycopg(store) - monkeypatch.setattr(wp, "psycopg2", psycopg) + monkeypatch.setattr(pb, "psycopg2", psycopg) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} writer.write("public.cps.za.test", message) @@ -417,6 +400,7 @@ def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): writer._connection.closed = 2 writer.write("public.cps.za.test", message) + del type(writer)._pg_config assert 2 == psycopg.connect_count @@ -425,13 +409,10 @@ def test_write_retries_on_operational_error(reset_env, monkeypatch): fail_flag = [True] class RetryCursor: - def __init__(self): - pass - def execute(self, sql, params): if fail_flag[0]: fail_flag[0] = False - raise wp.OperationalError("connection reset") + raise pb.OperationalError("connection reset") store.append((sql, params)) def __enter__(self): @@ -450,6 +431,9 @@ def cursor(self): def commit(self): pass + def close(self): + pass + class RetryPsycopg: def __init__(self): self.connect_count = 0 @@ -459,12 +443,15 @@ def connect(self, **kwargs): return RetryConnection() psycopg = RetryPsycopg() - monkeypatch.setattr(wp, "psycopg2", psycopg) + monkeypatch.setattr(pb, "psycopg2", psycopg) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) + del type(writer)._pg_config assert ok and err is None assert 2 == psycopg.connect_count @@ -474,7 +461,7 @@ def connect(self, **kwargs): def test_write_fails_after_retry_exhausted(reset_env, monkeypatch): class AlwaysFailCursor: def execute(self, sql, params): - raise wp.OperationalError("connection reset") + raise pb.OperationalError("connection reset") def __enter__(self): return self @@ -492,16 +479,22 @@ def cursor(self): def commit(self): pass + def close(self): + pass + class AlwaysFailPsycopg: def connect(self, **kwargs): return AlwaysFailConnection() - monkeypatch.setattr(wp, "psycopg2", AlwaysFailPsycopg()) + monkeypatch.setattr(pb, "psycopg2", AlwaysFailPsycopg()) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) + del type(writer)._pg_config assert not ok assert "failed with unknown error" in err @@ -510,7 +503,7 @@ def connect(self, **kwargs): def test_write_discards_connection_on_non_operational_error(reset_env, monkeypatch): class FailCursor: def execute(self, sql, params): - raise wp.PsycopgError("integrity error") + raise pb.PsycopgError("integrity error") def __enter__(self): return self @@ -528,16 +521,22 @@ def cursor(self): def commit(self): pass + def close(self): + pass + class FailPsycopg: def connect(self, **kwargs): return FailConnection() - monkeypatch.setattr(wp, "psycopg2", FailPsycopg()) + monkeypatch.setattr(pb, "psycopg2", FailPsycopg()) writer = WriterPostgres({}) - writer._db_config = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + type(writer)._pg_config = property( + lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, _ = writer.write("public.cps.za.test", message) + del type(writer)._pg_config assert not ok assert writer._connection is None From 73d8b4dfd6cae498d3b402dfad54b2204e9ef910 Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 15:32:07 +0200 Subject: [PATCH 6/7] requirements.txt update --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index cf41dee..23cdd46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ PyJWT==2.12.1 requests==2.32.5 aiosql==15.0 boto3==1.42.91 -botocore==1.40.76 +botocore==1.42.91 confluent-kafka==2.14.0 moto[s3,secretsmanager,events]==5.1.22 testcontainers==4.14.2 From 2955caf7e974650b0a8f36c2b5fea2b7b83ac36e Mon Sep 17 00:00:00 2001 From: "Tobias.Mikula" Date: Tue, 21 Apr 2026 15:59:11 +0200 Subject: [PATCH 7/7] CodeRabbit review comments implementation. --- src/readers/reader_postgres.py | 6 +- src/utils/constants.py | 1 + src/utils/postgres_base.py | 36 ++++++-- src/writers/writer_postgres.py | 6 +- tests/unit/readers/test_reader_postgres.py | 4 +- tests/unit/utils/test_postgres_base.py | 33 +++++++- tests/unit/writers/test_writer_postgres.py | 99 ++++++++++++---------- 7 files changed, 126 insertions(+), 59 deletions(-) diff --git a/src/readers/reader_postgres.py b/src/readers/reader_postgres.py index 0f546eb..c1a488c 100644 --- a/src/readers/reader_postgres.py +++ b/src/readers/reader_postgres.py @@ -159,7 +159,11 @@ def _run_stats_query( finally: # Rollback closes the implicit transaction opened by the SELECT, # leaving the cached connection in a clean idle state for reuse. - connection.rollback() + try: + connection.rollback() + except PsycopgError: + logger.debug("Failed to close the implicit transaction. Closing cached connection.", exc_info=True) + self._close_connection() return col_names, raw_rows @staticmethod diff --git a/src/utils/constants.py b/src/utils/constants.py index 8ba044b..2260105 100644 --- a/src/utils/constants.py +++ b/src/utils/constants.py @@ -25,6 +25,7 @@ SSL_CA_BUNDLE_KEY = "ssl_ca_bundle" # Postgres connection +POSTGRES_CONNECT_TIMEOUT_SECONDS = 5 POSTGRES_STATEMENT_TIMEOUT_MS = 30000 POSTGRES_MAX_RETRIES = 2 REQUIRED_CONNECTION_FIELDS = ("host", "user", "password", "port") diff --git a/src/utils/postgres_base.py b/src/utils/postgres_base.py index 4e46c4d..b37340e 100644 --- a/src/utils/postgres_base.py +++ b/src/utils/postgres_base.py @@ -22,7 +22,7 @@ from functools import cached_property from typing import Any, TypedDict -from src.utils.constants import POSTGRES_MAX_RETRIES +from src.utils.constants import POSTGRES_CONNECT_TIMEOUT_SECONDS, POSTGRES_MAX_RETRIES from src.utils.utils import load_postgres_config try: @@ -58,13 +58,31 @@ def _build_postgres_config(aws_secret: dict[str, Any]) -> PostgresConfig: aws_secret: Dictionary loaded from AWS Secrets Manager. Returns: A validated `PostgresConfig`. + Raises: + ValueError: If `database` is present but other required fields are + missing or invalid. """ + database = str(aws_secret.get("database", "")) + + if database: + required_fields = ("host", "user", "password", "port") + missing_fields = [field for field in required_fields if not aws_secret.get(field)] + if missing_fields: + raise ValueError(f"Missing PostgreSQL secret fields: {', '.join(missing_fields)}") + + try: + port = int(aws_secret["port"]) + except (TypeError, ValueError) as exc: + raise ValueError(f"Invalid PostgreSQL port: {aws_secret.get('port')}") from exc + else: + port = int(aws_secret.get("port", 0)) + return PostgresConfig( - database=str(aws_secret.get("database", "")), + database=database, host=str(aws_secret.get("host", "")), user=str(aws_secret.get("user", "")), password=str(aws_secret.get("password", "")), - port=int(aws_secret.get("port", 0)), + port=port, ) @@ -102,6 +120,7 @@ def _get_connection(self) -> Any: "user": pg_config["user"], "password": pg_config["password"], "port": pg_config["port"], + "connect_timeout": POSTGRES_CONNECT_TIMEOUT_SECONDS, } options = self._connect_options() if options: @@ -120,23 +139,26 @@ def _close_connection(self) -> None: except (PsycopgError, OSError): logger.debug("Failed to close PostgreSQL connection.") - def _execute_with_retry[T](self, operation: Callable[..., T]) -> T: + def _execute_with_retry[T](self, operation: Callable[..., T], *, retry: bool = True) -> T: """Run `operation(connection)` with one retry on `OperationalError`. Args: operation: Callable receiving a psycopg2 connection. + retry: Whether to retry on `OperationalError`. Disable for + non-idempotent operations (e.g. writes) to avoid duplicates. Returns: Whatever `operation` returns on success. Raises: RuntimeError: If the retry is also exhausted. """ + max_attempts = POSTGRES_MAX_RETRIES if retry else 1 last_exc: OperationalError | None = None - for attempt in range(POSTGRES_MAX_RETRIES): + for attempt in range(max_attempts): try: connection = self._get_connection() return operation(connection) except OperationalError as exc: last_exc = exc self._close_connection() - if attempt < POSTGRES_MAX_RETRIES - 1: + if attempt < max_attempts - 1: logger.warning("PostgreSQL connection lost, reconnecting.") - raise RuntimeError(f"Database connection failed after {POSTGRES_MAX_RETRIES} attempts: {last_exc}") + raise RuntimeError(f"Database connection failed after {max_attempts} attempts: {last_exc}") from last_exc diff --git a/src/writers/writer_postgres.py b/src/writers/writer_postgres.py index 85949cf..e062361 100644 --- a/src/writers/writer_postgres.py +++ b/src/writers/writer_postgres.py @@ -182,11 +182,11 @@ def write(self, topic_name: str, message: dict[str, Any]) -> tuple[bool, str | N logger.error(msg) return False, msg - self._execute_with_retry(lambda conn: self._write_topic(conn, topic_name, message)) + self._execute_with_retry(lambda conn: self._write_topic(conn, topic_name, message), retry=False) except (RuntimeError, PsycopgError, BotoCoreError, ClientError, ValueError, KeyError) as e: self._close_connection() - err_msg = f"The Postgres writer failed with unknown error: {str(e)}" - logger.exception(err_msg) + err_msg = f"The Postgres writer failed with unknown error: {e!s}" + logger.exception("The Postgres writer failed with unknown error: %s.", e) return False, err_msg return True, None diff --git a/tests/unit/readers/test_reader_postgres.py b/tests/unit/readers/test_reader_postgres.py index 5d3f378..c1aa958 100644 --- a/tests/unit/readers/test_reader_postgres.py +++ b/tests/unit/readers/test_reader_postgres.py @@ -228,7 +228,7 @@ def test_missing_connection_field_raises_runtime_error( with ( patch("boto3.Session") as mock_session, - pytest.raises(RuntimeError, match="host, user, password"), + pytest.raises(ValueError, match="Missing PostgreSQL secret fields"), ): mock_session.return_value.client.return_value = mock_client reader.read_stats() @@ -402,7 +402,7 @@ def test_retries_on_operational_error(self, reader: ReaderPostgres, pg_secret: d mock_session.return_value.client.return_value = mock_client mock_pg.connect.side_effect = [fail_conn, ok_conn] - rows, pagination = reader.read_stats(limit=10) + rows, _pagination = reader.read_stats(limit=10) assert 2 == mock_pg.connect.call_count assert [] == rows diff --git a/tests/unit/utils/test_postgres_base.py b/tests/unit/utils/test_postgres_base.py index 714c7e2..4c07176 100644 --- a/tests/unit/utils/test_postgres_base.py +++ b/tests/unit/utils/test_postgres_base.py @@ -38,7 +38,7 @@ def test_build_postgres_config_full(): assert 5432 == result["port"] -def test_build_postgres_config_defaults_for_missing_keys(): +def test_build_postgres_config_defaults_for_empty_database(): result = _build_postgres_config({}) assert "" == result["database"] assert "" == result["host"] @@ -47,6 +47,16 @@ def test_build_postgres_config_defaults_for_missing_keys(): assert 0 == result["port"] +def test_build_postgres_config_rejects_missing_fields_when_database_set(): + with pytest.raises(ValueError, match="Missing PostgreSQL secret fields"): + _build_postgres_config({"database": "mydb"}) + + +def test_build_postgres_config_rejects_invalid_port(): + with pytest.raises(ValueError, match="Invalid PostgreSQL port"): + _build_postgres_config({"database": "mydb", "host": "h", "user": "u", "password": "p", "port": "abc"}) + + # PostgresBase.__init__ @@ -218,3 +228,24 @@ def test_execute_with_retry_raises_after_all_attempts_fail(monkeypatch): with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): with pytest.raises(RuntimeError, match="Database connection failed after"): base._execute_with_retry(lambda conn: (_ for _ in ()).throw(pb.OperationalError("down"))) + + +def test_execute_with_retry_no_retry_fails_on_first_attempt(monkeypatch): + calls = [] + mock_conn = MagicMock(closed=0) + mock_psycopg2 = MagicMock() + mock_psycopg2.connect.return_value = mock_conn + monkeypatch.setattr(pb, "psycopg2", mock_psycopg2) + base = _ConcreteBase() + base._connection = mock_conn + secret = {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + + def op(conn): + calls.append(1) + raise pb.OperationalError("timeout") + + with patch("src.utils.postgres_base.load_postgres_config", return_value=secret): + with pytest.raises(RuntimeError, match="Database connection failed after 1 attempts"): + base._execute_with_retry(op, retry=False) + + assert 1 == len(calls) diff --git a/tests/unit/writers/test_writer_postgres.py b/tests/unit/writers/test_writer_postgres.py index db4536c..f5910f8 100644 --- a/tests/unit/writers/test_writer_postgres.py +++ b/tests/unit/writers/test_writer_postgres.py @@ -185,33 +185,34 @@ def test_insert_test(): assert {"a": 1} == json.loads(params["additional_info"]) -def test_write_skips_when_no_database(reset_env): +def test_write_skips_when_no_database(reset_env, monkeypatch): writer = WriterPostgres({}) - type(writer)._pg_config = property(lambda self: {"database": ""}) + monkeypatch.setattr(type(writer), "_pg_config", property(lambda self: {"database": ""})) ok, err = writer.write("public.cps.za.test", {}) - del type(writer)._pg_config assert ok and err is None -def test_write_fails_when_connection_field_missing(reset_env): +def test_write_fails_when_connection_field_missing(reset_env, monkeypatch): writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "", "user": "u", "password": "p", "port": 5432}), ) ok, err = writer.write("public.cps.za.test", {}) - del type(writer)._pg_config assert not ok assert "host" in err and "not configured" in err def test_write_skips_when_psycopg2_missing(reset_env, monkeypatch): writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) monkeypatch.setattr(pb, "psycopg2", None) ok, err = writer.write("public.cps.za.test", {}) - del type(writer)._pg_config assert ok and err is None @@ -219,11 +220,12 @@ def test_write_unknown_topic_returns_false(reset_env, monkeypatch): store = [] monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) ok, err = writer.write("public.cps.za.unknown", {}) - del type(writer)._pg_config assert not ok and "Unknown topic" in err @@ -231,12 +233,13 @@ def test_write_success_known_topic(reset_env, monkeypatch): store = [] monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) - del type(writer)._pg_config assert ok and err is None and 1 == len(store) @@ -247,11 +250,12 @@ def connect(self, **kwargs): monkeypatch.setattr(pb, "psycopg2", FailingPsycopg()) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) ok, err = writer.write("public.cps.za.test", {}) - del type(writer)._pg_config assert not ok and "failed with unknown error" in err @@ -276,8 +280,10 @@ def test_write_dlchange_success(reset_env, monkeypatch): store = [] monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = { "event_id": "e1", @@ -291,7 +297,6 @@ def test_write_dlchange_success(reset_env, monkeypatch): "format": "parquet", } ok, err = writer.write("public.cps.za.dlchange", message) - del type(writer)._pg_config assert ok and err is None and 1 == len(store) @@ -299,8 +304,10 @@ def test_write_runs_success(reset_env, monkeypatch): store = [] monkeypatch.setattr(pb, "psycopg2", DummyPsycopg(store)) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = { "event_id": "r1", @@ -314,7 +321,6 @@ def test_write_runs_success(reset_env, monkeypatch): "jobs": [{"catalog_id": "c", "status": "ok", "timestamp_start": 1, "timestamp_end": 2}], } ok, err = writer.write("public.cps.za.runs", message) - del type(writer)._pg_config assert ok and err is None and 2 == len(store) # run + job insert @@ -353,17 +359,16 @@ def client(self, service_name, region_name): monkeypatch.setattr(secrets_mod.boto3, "Session", MockSession) writer = WriterPostgres({}) healthy, msg = writer.check_health() - assert not healthy and "host not configured" in msg + assert not healthy and "Missing PostgreSQL secret fields" in msg -def test_check_health_database_not_configured(): +def test_check_health_database_not_configured(monkeypatch): """check_health returns (True, 'database not configured') when database field is empty.""" writer = WriterPostgres({}) writer._secret_name = "mysecret" writer._secret_region = "eu-west-1" - type(writer)._pg_config = property(lambda self: {"database": ""}) + monkeypatch.setattr(type(writer), "_pg_config", property(lambda self: {"database": ""})) healthy, msg = writer.check_health() - del type(writer)._pg_config assert healthy assert "database not configured" == msg @@ -389,8 +394,10 @@ def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): psycopg = DummyPsycopg(store) monkeypatch.setattr(pb, "psycopg2", psycopg) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} @@ -400,11 +407,10 @@ def test_write_reconnects_on_closed_connection(reset_env, monkeypatch): writer._connection.closed = 2 writer.write("public.cps.za.test", message) - del type(writer)._pg_config assert 2 == psycopg.connect_count -def test_write_retries_on_operational_error(reset_env, monkeypatch): +def test_write_does_not_retry_on_operational_error(reset_env, monkeypatch): store = [] fail_flag = [True] @@ -445,17 +451,18 @@ def connect(self, **kwargs): psycopg = RetryPsycopg() monkeypatch.setattr(pb, "psycopg2", psycopg) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) - del type(writer)._pg_config - assert ok and err is None - assert 2 == psycopg.connect_count - assert 1 == len(store) + assert not ok + assert "failed with unknown error" in err + assert 1 == psycopg.connect_count def test_write_fails_after_retry_exhausted(reset_env, monkeypatch): @@ -488,13 +495,14 @@ def connect(self, **kwargs): monkeypatch.setattr(pb, "psycopg2", AlwaysFailPsycopg()) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, err = writer.write("public.cps.za.test", message) - del type(writer)._pg_config assert not ok assert "failed with unknown error" in err @@ -530,13 +538,14 @@ def connect(self, **kwargs): monkeypatch.setattr(pb, "psycopg2", FailPsycopg()) writer = WriterPostgres({}) - type(writer)._pg_config = property( - lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432} + monkeypatch.setattr( + type(writer), + "_pg_config", + property(lambda self: {"database": "db", "host": "h", "user": "u", "password": "p", "port": 5432}), ) message = {"event_id": "id", "tenant_id": "ten", "source_app": "app", "environment": "dev", "timestamp": 123} ok, _ = writer.write("public.cps.za.test", message) - del type(writer)._pg_config assert not ok assert writer._connection is None