diff --git a/backend/api/jobs/procrastinate_app.py b/backend/api/jobs/procrastinate_app.py index 932f9a44..68ad26e0 100644 --- a/backend/api/jobs/procrastinate_app.py +++ b/backend/api/jobs/procrastinate_app.py @@ -1,57 +1,79 @@ from __future__ import annotations -import os from typing import Optional +import psycopg +import psycopg_pool import procrastinate from procrastinate import exceptions as procrastinate_exceptions from procrastinate.psycopg_connector import PsycopgConnector -import psycopg from ..core.config import settings +from ..services.postgres_auth import postgres_server -def build_postgres_dsn() -> str: - """Build a DSN for Procrastinate. +# --------------------------------------------------------------------------- +# Azure token-refreshing pool +# --------------------------------------------------------------------------- - Notes: - - For now we only support password auth (docker/local). If POSTGRES_MODE=azure, - we raise so we don't silently run jobs without a working queue. - - This can be extended later to use Azure token auth by injecting a password - token similarly to postgres_auth.PostgresServer. +def _make_azure_pool() -> psycopg_pool.AsyncConnectionPool: + """Return an AsyncConnectionPool whose connections always carry a fresh + Entra ID token as their password. + + We subclass psycopg.AsyncConnection and override connect() so that every + time the pool opens a new physical connection it calls + postgres_server._refresh_azure_token() first. That method is cached and + only hits Azure AD when the token is within 60 s of expiry. """ + class AzureTokenConnection(psycopg.AsyncConnection): + @classmethod + async def connect(cls, conninfo: str = "", **kwargs): # type: ignore[override] + kwargs["password"] = postgres_server._refresh_azure_token() + return await super().connect(conninfo, **kwargs) + + return psycopg_pool.AsyncConnectionPool( + conninfo=postgres_server.build_conninfo(include_password=False), + connection_class=AzureTokenConnection, + open=False, # caller must await pool.open() before use + ) - mode = (settings.POSTGRES_MODE or "").lower().strip() - prof = settings.postgres_profile(mode) - if prof.get("mode") == "azure": - raise RuntimeError( - "Procrastinate asyncpg DSN for POSTGRES_MODE=azure is not implemented yet. " - "Run workers with docker/local Postgres first, or extend build_postgres_dsn() to use Entra tokens." - ) +# --------------------------------------------------------------------------- +# Connector factory +# --------------------------------------------------------------------------- - host = prof.get("host") - db = prof.get("database") - user = prof.get("user") - password = prof.get("password") - port = int(prof.get("port") or 5432) - if not (host and db and user and password): - raise RuntimeError("Missing Postgres config for Procrastinate") +class _AzurePsycopgConnector(PsycopgConnector): + """PsycopgConnector that creates a token-refreshing pool for Azure. + + PsycopgConnector.open_async() accepts an already-constructed pool via its + `pool` kwarg, so we build the pool here and hand it off — the rest of the + connector (execute_query_*, etc.) is unchanged. + """ - # psycopg DSN - return f"postgresql://{user}:{password}@{host}:{port}/{db}" + async def open_async(self, pool=None, **kwargs) -> None: # type: ignore[override] + if pool is None: + pool = _make_azure_pool() + await pool.open() + await super().open_async(pool=pool, **kwargs) -PROCRASTINATE_APP = procrastinate.App( - # Procrastinate 3.2.x no longer ships an asyncpg connector; use psycopg (v3). - # PsycopgConnector forwards kwargs to psycopg_pool.AsyncConnectionPool which expects - # `conninfo` (not `dsn`). - connector=PsycopgConnector(conninfo=build_postgres_dsn()), - # Namespace for internal procrastinate tables - # You can set this later to avoid collisions. -) +def _build_connector() -> PsycopgConnector: + if settings.POSTGRES_MODE == "azure": + return _AzurePsycopgConnector() + return PsycopgConnector(conninfo=postgres_server.build_conninfo()) +# --------------------------------------------------------------------------- +# Module-level app singleton +# --------------------------------------------------------------------------- + +PROCRASTINATE_APP = procrastinate.App(connector=_build_connector()) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + def jobs_enabled() -> bool: return bool(getattr(settings, "ENABLE_PROCRASTINATE", False)) @@ -90,8 +112,6 @@ async def _schema_installed() -> bool: except Exception: return False - # Ensure the app is open. If it was already open, do not close it here; - # the API process keeps it open for the whole lifespan. opened_here = False try: _ = PROCRASTINATE_APP.connector.pool @@ -106,8 +126,6 @@ async def _schema_installed() -> bool: try: await PROCRASTINATE_APP.schema_manager.apply_schema_async() except procrastinate_exceptions.ConnectorException as e: - # If schema is already present (possibly created by a previous run), - # treat duplicate-object errors as success. cause: BaseException | None = e.__cause__ if isinstance(cause, (psycopg.errors.DuplicateObject, psycopg.errors.DuplicateTable, psycopg.errors.DuplicateFunction)): if await _schema_installed(): @@ -136,11 +154,6 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int: opened_here = True try: - # NOTE: status values are stored in the procrastinate_job_status enum. - # We only delete jobs that haven't started yet. - # Procrastinate's PsycopgConnector.execute_query_async does not accept a positional - # params dict; it only accepts keyword arguments. - # It also returns None, so to report a deleted count we use RETURNING. rows = await PROCRASTINATE_APP.connector.execute_query_all_async( """ DELETE FROM procrastinate_jobs @@ -159,10 +172,6 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int: async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list[str]] = None) -> int: """Best-effort: delete enqueued (todo) Procrastinate jobs for a given run-all job_id. - This makes Cancel feel more responsive by removing jobs that haven't started yet. - Jobs already in `doing` cannot be safely removed here and will stop cooperatively - once they next check `run_all_repo.is_canceled(job_id)`. - Returns number of deleted procrastinate_jobs rows. """ @@ -193,10 +202,7 @@ async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list async def run_worker_once(*, queues: Optional[list[str]] = None) -> None: - """Run a worker loop. Intended to be launched as a background task. - - If you run this inside the API process, set concurrency low. - """ + """Run a worker loop. Intended to be launched as a background task.""" qs = queues or ["default"] await PROCRASTINATE_APP.run_worker_async( diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py index e053339c..4e51db0b 100644 --- a/backend/api/services/postgres_auth.py +++ b/backend/api/services/postgres_auth.py @@ -176,6 +176,18 @@ def _candidate_kwargs(self, mode: str, psycopg3: bool = False) -> Dict[str, Any] return kwargs + def build_conninfo(self, *, include_password: bool = True) -> str: + """Build a psycopg key=value conninfo string for the current mode. + + Pass include_password=False to omit the password — useful as a pool + template where the token is injected per-connection (azure mode). + """ + kw = self._candidate_kwargs(self._mode(), psycopg3=True) + kw.pop("connect_timeout", None) + if not include_password: + kw.pop("password", None) + return " ".join(f"{k}={v}" for k, v in kw.items() if v is not None) + def _connect_with_mode(self, mode: str): kwargs = self._candidate_kwargs(mode) safe_kwargs = {k: ("***" if k == "password" else v) for k, v in kwargs.items()} diff --git a/backend/main.py b/backend/main.py index 606dc55f..d4b74536 100644 --- a/backend/main.py +++ b/backend/main.py @@ -45,6 +45,10 @@ async def startup_event(): except Exception as e: print(f"⚠️ Failed to ensure SR table exists: {e}", flush=True) + if user_db_service: + await user_db_service.ensure_table_exists() + print("✓ Users table initialized", flush=True) + # Procrastinate schema + run-all job tables try: from api.jobs.procrastinate_app import (