Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 56 additions & 50 deletions backend/api/jobs/procrastinate_app.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,79 @@
from __future__ import annotations

import os
from typing import Optional

import psycopg
import psycopg_pool
import procrastinate
from procrastinate import exceptions as procrastinate_exceptions
from procrastinate.psycopg_connector import PsycopgConnector
import psycopg

from ..core.config import settings
from ..services.postgres_auth import postgres_server


def build_postgres_dsn() -> str:
"""Build a DSN for Procrastinate.
# ---------------------------------------------------------------------------
# Azure token-refreshing pool
# ---------------------------------------------------------------------------

Notes:
- For now we only support password auth (docker/local). If POSTGRES_MODE=azure,
we raise so we don't silently run jobs without a working queue.
- This can be extended later to use Azure token auth by injecting a password
token similarly to postgres_auth.PostgresServer.
def _make_azure_pool() -> psycopg_pool.AsyncConnectionPool:
"""Return an AsyncConnectionPool whose connections always carry a fresh
Entra ID token as their password.

We subclass psycopg.AsyncConnection and override connect() so that every
time the pool opens a new physical connection it calls
postgres_server._refresh_azure_token() first. That method is cached and
only hits Azure AD when the token is within 60 s of expiry.
"""
class AzureTokenConnection(psycopg.AsyncConnection):
@classmethod
async def connect(cls, conninfo: str = "", **kwargs): # type: ignore[override]
kwargs["password"] = postgres_server._refresh_azure_token()
return await super().connect(conninfo, **kwargs)

return psycopg_pool.AsyncConnectionPool(
conninfo=postgres_server.build_conninfo(include_password=False),
connection_class=AzureTokenConnection,
open=False, # caller must await pool.open() before use
)

mode = (settings.POSTGRES_MODE or "").lower().strip()
prof = settings.postgres_profile(mode)

if prof.get("mode") == "azure":
raise RuntimeError(
"Procrastinate asyncpg DSN for POSTGRES_MODE=azure is not implemented yet. "
"Run workers with docker/local Postgres first, or extend build_postgres_dsn() to use Entra tokens."
)
# ---------------------------------------------------------------------------
# Connector factory
# ---------------------------------------------------------------------------

host = prof.get("host")
db = prof.get("database")
user = prof.get("user")
password = prof.get("password")
port = int(prof.get("port") or 5432)
if not (host and db and user and password):
raise RuntimeError("Missing Postgres config for Procrastinate")
class _AzurePsycopgConnector(PsycopgConnector):
"""PsycopgConnector that creates a token-refreshing pool for Azure.

PsycopgConnector.open_async() accepts an already-constructed pool via its
`pool` kwarg, so we build the pool here and hand it off — the rest of the
connector (execute_query_*, etc.) is unchanged.
"""

# psycopg DSN
return f"postgresql://{user}:{password}@{host}:{port}/{db}"
async def open_async(self, pool=None, **kwargs) -> None: # type: ignore[override]
if pool is None:
pool = _make_azure_pool()
await pool.open()
await super().open_async(pool=pool, **kwargs)


PROCRASTINATE_APP = procrastinate.App(
# Procrastinate 3.2.x no longer ships an asyncpg connector; use psycopg (v3).
# PsycopgConnector forwards kwargs to psycopg_pool.AsyncConnectionPool which expects
# `conninfo` (not `dsn`).
connector=PsycopgConnector(conninfo=build_postgres_dsn()),
# Namespace for internal procrastinate tables
# You can set this later to avoid collisions.
)
def _build_connector() -> PsycopgConnector:
if settings.POSTGRES_MODE == "azure":
return _AzurePsycopgConnector()
return PsycopgConnector(conninfo=postgres_server.build_conninfo())


# ---------------------------------------------------------------------------
# Module-level app singleton
# ---------------------------------------------------------------------------

PROCRASTINATE_APP = procrastinate.App(connector=_build_connector())


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def jobs_enabled() -> bool:
return bool(getattr(settings, "ENABLE_PROCRASTINATE", False))

Expand Down Expand Up @@ -90,8 +112,6 @@ async def _schema_installed() -> bool:
except Exception:
return False

# Ensure the app is open. If it was already open, do not close it here;
# the API process keeps it open for the whole lifespan.
opened_here = False
try:
_ = PROCRASTINATE_APP.connector.pool
Expand All @@ -106,8 +126,6 @@ async def _schema_installed() -> bool:
try:
await PROCRASTINATE_APP.schema_manager.apply_schema_async()
except procrastinate_exceptions.ConnectorException as e:
# If schema is already present (possibly created by a previous run),
# treat duplicate-object errors as success.
cause: BaseException | None = e.__cause__
if isinstance(cause, (psycopg.errors.DuplicateObject, psycopg.errors.DuplicateTable, psycopg.errors.DuplicateFunction)):
if await _schema_installed():
Expand Down Expand Up @@ -136,11 +154,6 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int:
opened_here = True

try:
# NOTE: status values are stored in the procrastinate_job_status enum.
# We only delete jobs that haven't started yet.
# Procrastinate's PsycopgConnector.execute_query_async does not accept a positional
# params dict; it only accepts keyword arguments.
# It also returns None, so to report a deleted count we use RETURNING.
rows = await PROCRASTINATE_APP.connector.execute_query_all_async(
"""
DELETE FROM procrastinate_jobs
Expand All @@ -159,10 +172,6 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int:
async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list[str]] = None) -> int:
"""Best-effort: delete enqueued (todo) Procrastinate jobs for a given run-all job_id.

This makes Cancel feel more responsive by removing jobs that haven't started yet.
Jobs already in `doing` cannot be safely removed here and will stop cooperatively
once they next check `run_all_repo.is_canceled(job_id)`.

Returns number of deleted procrastinate_jobs rows.
"""

Expand Down Expand Up @@ -193,10 +202,7 @@ async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list


async def run_worker_once(*, queues: Optional[list[str]] = None) -> None:
"""Run a worker loop. Intended to be launched as a background task.

If you run this inside the API process, set concurrency low.
"""
"""Run a worker loop. Intended to be launched as a background task."""

qs = queues or ["default"]
await PROCRASTINATE_APP.run_worker_async(
Expand Down
12 changes: 12 additions & 0 deletions backend/api/services/postgres_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,18 @@ def _candidate_kwargs(self, mode: str, psycopg3: bool = False) -> Dict[str, Any]

return kwargs

def build_conninfo(self, *, include_password: bool = True) -> str:
"""Build a psycopg key=value conninfo string for the current mode.

Pass include_password=False to omit the password — useful as a pool
template where the token is injected per-connection (azure mode).
"""
kw = self._candidate_kwargs(self._mode(), psycopg3=True)
kw.pop("connect_timeout", None)
if not include_password:
kw.pop("password", None)
return " ".join(f"{k}={v}" for k, v in kw.items() if v is not None)

def _connect_with_mode(self, mode: str):
kwargs = self._candidate_kwargs(mode)
safe_kwargs = {k: ("***" if k == "password" else v) for k, v in kwargs.items()}
Expand Down
4 changes: 4 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ async def startup_event():
except Exception as e:
print(f"⚠️ Failed to ensure SR table exists: {e}", flush=True)

if user_db_service:
await user_db_service.ensure_table_exists()
print("✓ Users table initialized", flush=True)

# Procrastinate schema + run-all job tables
try:
from api.jobs.procrastinate_app import (
Expand Down
Loading