diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md new file mode 100644 index 00000000..6b04a76c --- /dev/null +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -0,0 +1,456 @@ +# LongRunningAgentServer + +Durable, crash-resumable agent execution for MLflow `ResponsesAgent` handlers. + +This document describes: +1. What `LongRunningAgentServer` does and the guarantees it gives callers ([§1](#1-what-this-module-does)). +2. The four customer journeys it covers, with sequence diagrams ([§2](#2-customer-journeys)). +3. The architecture: storage layout, claim mechanism, recovery, and stream resume ([§3](#3-architecture)). +4. Author-side requirements: what changes (and doesn't) when a handler opts into durable mode ([§4](#4-author-side-requirements)). +5. The interface today and how it's expected to evolve when [TaskFlow](https://github.com/databricks-eng/universe/tree/master/experimental/taskflow) lands ([§5](#5-future-direction-taskflow)). + +## 1. What this module does + +`LongRunningAgentServer` extends MLflow's `AgentServer` for `ResponsesAgent` handlers with three capabilities: + +1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres. +2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events. +3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run and finishes the work via **prose recovery**: the new attempt receives a single user message containing `json.dumps(events)` of the crashed attempt's stream-event log plus a directive asking the LLM to figure out what's done vs interrupted and continue. The handler runs on a freshly-rotated SDK session. + +Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server. + +### Guarantees + +- **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`. +- **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`. +- **SDK-agnostic recovery.** The resumed attempt receives a flat prose narrative — no provider-specific tool-pair structure, no synthetic tool events, no per-SDK adapter code. +- **Per-template UI-echo dedup.** The bridge does NOT trim echoed history. When the chat client echoes the full prior conversation in `request.input`, the agent handler is responsible for deduping its input against the SDK's session/checkpointer state — typically by forwarding only the latest user message when the session already has prior turns. See the templates in `app-templates/agent-{openai,langgraph}-advanced/` for the canonical 1-2 line shape. +- **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility. +- **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. All durability lives below the handler boundary. + +### Non-goals + +- Cross-region failover. Pods are assumed to share one Lakebase. +- Tool-level checkpointing / exactly-once tool execution. +- A workflow DSL. Handlers are ordinary async generators / coroutines. + +## 2. Customer journeys + +### CUJ 1: Author writes a long-running agent + +The author subclasses `LongRunningAgentServer` and registers `@invoke()` / `@stream()` handlers like a regular MLflow agent server. **No durability code in `agent.py`.** + +```python +from databricks_ai_bridge.long_running import LongRunningAgentServer +from mlflow.genai.agent_server import invoke, stream + +agent_server = LongRunningAgentServer( + "ResponsesAgent", + db_instance_name="my-lakebase-instance", +) + +@stream() +async def stream_handler(request): + # ordinary agent code: build messages, call SDK, yield events + ... + +@invoke() +async def invoke_handler(request): + ... + +app = agent_server.app +``` + +The agent author writes their handler exactly the same way they would for the non-durable `AgentServer`. `LongRunningAgentServer` adds the durable wiring transparently. + +### CUJ 2: Pod crashes mid-tool, client polls + +A client posts a long-running request, the owning pod dies mid-tool, another pod takes over via prose recovery, and the client gets the final output without restarting. + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant A as Pod A (owner) + participant DB as Lakebase
(agent_server.*) + participant B as Pod B + participant SDK as SDK store
(rotated session) + + C->>A: POST /responses
{input, background:true, stream:true} + A->>DB: INSERT response_id, attempt=1,
heartbeat_at=now(), original_request (full input) + A-->>C: 200 {id: resp_xxx, status: in_progress} + activate A + Note over A: heartbeat loop (every 3s)
CAS-checks attempt_number = 1 + A->>DB: UPDATE heartbeat_at WHERE attempt_number=1 + + par Streaming + A->>SDK: write to session T (original conv_id) + A->>DB: append events seq=0..N (attempt=1) + and Polling + C->>A: GET /responses/{id}?stream=true&starting_after=N + A-->>C: SSE events + end + + Note over A: 💥 pod crashes mid-tool
(after tool_use, before tool_result) + deactivate A + Note over DB: heartbeat_at goes stale (>10s old) + + C->>B: GET /responses/{id}?stream=true&starting_after=K + B->>DB: SELECT heartbeat_at, attempt + Note over B: heartbeat stale → claim + B->>DB: CAS UPDATE attempt=2, heartbeat_at=now()
WHERE attempt=1 AND heartbeat stale (atomic) + B->>DB: SELECT prior events WHERE attempt=1 + Note over B: build resume input:
· original_request.input (full prior turns)
· + prose recovery message: json.dumps(events) + directive
· rotate conv_id → ::attempt-2 + B->>DB: append response.resumed sentinel
{conversation_id: rotated value} + B->>B: re-invoke @stream() handler
(fresh rotated SDK session) + activate B + B->>SDK: write to T::attempt-2 (clean) + B->>DB: append events seq=K+1..M (attempt=2) + B-->>C: SSE events (response.resumed, then attempt-2 events) + deactivate B +``` + +**What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the attempt boundary and carries the rotated `conversation_id` so the chatbot can use the rotated session for subsequent turns. + +**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus a single user message whose body is `[RECOVERY] ... Events: `. The model reads it as "the prior attempt crashed, here's the raw event log, figure out what's done and continue." + +### CUJ 3: Subsequent turn after a crashed turn + +After a successful crash + resume, the next turn from the client lands on a fresh `POST /responses`. The chatbot uses the **rotated** `conversation_id` (captured from the `response.resumed` sentinel) so the handler resolves to the rotated SDK session — which was populated cleanly during attempt 2's prose-recovery run. + +```mermaid +sequenceDiagram + autonumber + participant C as Chatbot + participant S as Server (any pod) + participant SDK as SDK store + + Note over SDK: state at original conv_id T:
incomplete (orphan tool_use)
state at T::attempt-2:
complete (prose + resumed turn output) + + C->>C: lookup alias map:
chat_id → T::attempt-2 + C->>S: POST /responses
{input: [echo of full UI history, new user msg],
context.conversation_id: T::attempt-2} + S->>S: handler dedupes UI echo
(forwards only latest user when session has history) + S->>S: handler resolves session_id = T::attempt-2 + S->>SDK: load history at T::attempt-2 + SDK-->>S: clean state (prose + attempt-2 emissions) + S->>S: model receives [clean history] + [new user msg] + Note over S: turn succeeds normally + + Note over SDK: original session T is now orphaned forever
(never read again — chatbot uses rotated alias) +``` + +Three things make this work without per-SDK repair code: + +1. **Server emits the rotated conv_id in `response.resumed`.** The chatbot reads it and updates its `Map` alias. +2. **Per-template UI-echo dedup in the handler.** When the SDK's session/checkpointer already has prior-turn state, the agent forwards only the latest user message (not the full echoed history). This prevents `Runner.run` from sending duplicates of session items + input items to the LLM. +3. **Always-rotate.** The rotated session was the one populated during attempt 2's run. Subsequent turns land on it. The original poisoned session is never read. + +### CUJ 4: Multi-pod stale-claim contention + +Two pods each see the same response in `in_progress` with a stale heartbeat. Only one wins the CAS. + +```mermaid +sequenceDiagram + autonumber + participant B as Pod B + participant DB as Lakebase + participant C as Pod C + + Note over B,C: both pods see response in_progress, heartbeat stale + par + B->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale + and + C->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale + end + Note over DB: only one row matches; the other UPDATE returns 0 rows + DB-->>B: RETURNING attempt_number=N+1 + DB-->>C: RETURNING (no row) + Note over B: B wins, builds resume input + spawns handler + Note over C: C aborts cleanly, returns to its retrieve loop +``` + +The `claim_stale_response` function (`repository.py`) executes a single `UPDATE … WHERE attempt_number = :current AND ((heartbeat_at IS NULL) OR (heartbeat_at < now() - interval))` with `RETURNING`. Postgres serializes the writes; only the pod whose `current` value was unmodified at commit time gets the `RETURNING` row. + +## 3. Architecture + +### 3.1 Storage layout + +Two tables in the `agent_server` schema: + +```mermaid +erDiagram + responses { + text response_id PK + text status "in_progress / completed / failed" + timestamptz created_at + timestamptz heartbeat_at + int attempt_number "CAS guard for claim atomicity" + text original_request "JSON of initial POST (full input)" + text trace_id + } + messages { + text response_id FK + int sequence_number + int attempt_number + text item "JSON of output item" + text stream_event "JSON of SSE frame" + } + responses ||--o{ messages : "has" +``` + +- `responses.attempt_number` is the CAS guard for claim atomicity. **There is no `owner_pod_id` column** — ownership is implicit. The pod that last successfully heartbeats at the current `attempt_number` is the de facto owner. A heartbeat write at attempt N stops working the moment another pod has CAS-bumped the row to N+1, so the prior owner detects it has lost the claim on its next heartbeat (rowcount=0) and shuts down its heartbeat task. +- `responses.original_request` stores the **full untrimmed input** so the resume path can recover the entire prior-turn history when the rotated SDK session starts empty. +- `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body). +- Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime. + +### 3.2 The four key flows + +```mermaid +flowchart TD + POST["POST /responses
background=true"] --> CR[create_response
store FULL original_request] + CR --> SPAWN[spawn @stream handler
handler dedupes UI echo against SDK session] + SPAWN --> HB[heartbeat loop
every 3s] + SPAWN --> EMIT[append SSE events
to messages table] + EMIT --> DONE{handler
exits?} + DONE -- yes --> COMPLETE[update status=completed] + DONE -- crash --> STALE[heartbeat stops
row goes stale] + + GET["GET /responses/{id}
?stream=true&starting_after=N"] --> CHECK[check heartbeat age] + CHECK -->|fresh or terminal| READ[read events from messages
where seq > N] + CHECK -->|stale| CLAIM[CAS claim
attempt += 1] + CLAIM -->|won| BUILD[build prose recovery message
rotate conv_id
append rotated id to response.resumed] + CLAIM -->|lost| READ + BUILD --> SPAWN + READ --> STREAM[SSE stream to client] + + KILL["POST /_debug/kill_task/{id}
(test-only)"] --> CANCEL[cancel asyncio task
without status update] + CANCEL --> STALE +``` + +### 3.3 Resume input construction + +When a stale-claim CAS succeeds, the new owner builds the resume input by serializing the prior attempt's stream events as JSON in a single user message: + +```mermaid +flowchart LR + PRIOR[prior attempt's events
from messages table] --> FILTER[filter events by
attempt_number = prior_attempt] + FILTER --> JSON[json.dumps the events list] + JSON --> COMPOSE[compose single user message:
'[RECOVERY] previous attempt crashed.
Below is the raw stream-event log...
Inspect the events, figure out what is
already done versus in-progress, and continue.

Events: <json.dumps>'] + + ROT[_rotate_conversation_id
::attempt-N suffix] --> SUBMIT + COMPOSE --> SUBMIT[append to original_request.input
spawn handler with rotated request
emit response.resumed sentinel
with rotated conversation_id] +``` + +Why JSON-dumped events: the LLM reads them as the authoritative record of what attempt 1 did and decides what to do — re-run an interrupted tool, skip completed ones, summarize from there. No structural carry-forward, no synthetic tool events, no per-SDK pairing rules. The handler doesn't have to know any of this is durable resume — it just sees a recovery user message in `request.input`. + +Why rotation: the original SDK session may carry mid-turn state from the crashed attempt (orphan `tool_use`, partial checkpoint) that's hard to repair from outside the SDK. Rotating to `{base}::attempt-N` opens a fresh, empty session for the resumed attempt; the recovery message is the single source of truth for what already happened. + +Why the sentinel carries the rotated conv_id: cooperating chat clients capture it (via SSE) and use the rotated session for subsequent turns, so the original orphan-poisoned session is never read again. + +### 3.4 Per-template UI-echo dedup (NOT in the bridge) + +The bridge does **not** trim UI echo from `request.input`. Echo dedup is the agent handler's responsibility — it owns its SDK session/checkpointer and is the right layer to know what's already persisted vs what's a new turn. + +The canonical shape, per template: + +**OpenAI Agents SDK** (`agent-openai-advanced/agent_server/utils.py`): +```python +session_items = await session.get_items() +if session_items and len(messages) > 1: + return [messages[-1]] +return messages +``` + +**LangGraph** (`agent-langgraph-advanced/agent_server/agent.py`): +```python +state = await agent.aget_state(config) +if state and state.values.get("messages") and input_state["messages"]: + last_user = next( + (m for m in reversed(input_state["messages"]) if m.get("role") == "user"), + None, + ) + input_state["messages"] = [last_user] if last_user else [] +``` + +Both: when the SDK store already has prior turns, forward only the latest user message and let the SDK prepend its own history. Without dedup, `Runner.run` (OpenAI) or `add_messages` (LangGraph) end up combining session+input → duplicate items → malformed assistant.tool_calls block → 400. + +Bridge's role here is just to pass `request.input` through untouched. + +### 3.5 Proactive stale-scan loop + +In addition to the lazy-on-GET claim path (a stale heartbeat is detected when a client GETs `/responses/{id}`), each pod runs a background **scanner** that periodically queries for in-progress responses with stale heartbeats and tries to claim+resume them. This means crashed responses get recovered even when no client is actively polling. + +```mermaid +flowchart LR + A[every ~30s ± 50% jitter] --> B[SELECT response_id FROM responses
WHERE status='in_progress'
AND heartbeat_at < now-threshold
LIMIT 50] + B --> C{any rows?} + C -->|no| A + C -->|yes| D[for each id: get_response
+ _try_claim_and_resume] + D --> A +``` + +Each pod jitters its scan interval (`stale_scan_jitter_fraction = 0.5` by default) so multiple pods don't synchronize their queries. CAS-claim semantics ensure only one pod succeeds in claiming any given stale response. + +The scanner is a background task spawned in the FastAPI lifespan (alongside `init_db`) and cancelled on shutdown. + +### 3.6 Heartbeat and stale threshold + +Defaults are tuned for a single Lakebase deployment with low-latency writes. + +| Setting | Default | Rationale | +|---|---|---| +| `heartbeat_interval_seconds` | 3.0 | Frequent enough that short pauses (GC, tokio task waits) don't trip stale detection | +| `heartbeat_stale_threshold_seconds` | 10.0 | Three missed heartbeats = unambiguously dead. Validated stale > interval at startup. | +| `task_timeout_seconds` | 3600 | Hard ceiling. After this, a stuck `in_progress` row is force-failed regardless of heartbeat. | +| `poll_interval_seconds` | 1.0 | Stream-retrieve polls the messages table at this rate while waiting for new events. | + +The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler. + +### 3.7 Claim atomicity + +```mermaid +sequenceDiagram + autonumber + participant Pod as Pod B (claimer) + participant DB as Lakebase + + Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R + DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress + Note over Pod: stale → attempt CAS + Pod->>DB: UPDATE responses
SET attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number + alt match + DB-->>Pod: attempt_number=2 + Note over Pod: claim succeeded — Pod B is the new de facto owner
(prior owner's heartbeat at attempt=1 will fail next tick) + else no match (another pod beat us) + DB-->>Pod: (empty) + Note over Pod: claim lost — back off + end +``` + +Postgres row locking ensures only one of N concurrent UPDATEs matches the `attempt_number = N` predicate, so at most one pod ends up owning a given resume. The bumped `attempt_number` simultaneously revokes the prior owner's heartbeat: their next heartbeat write `WHERE attempt_number = N` returns rowcount=0, telling them they've lost the claim. + +## 4. Author-side requirements + +### 4.1 What's invisible to authors + +| Concern | Where it lives | Author-visible? | +|---|---|---| +| Heartbeat + claim | `LongRunningAgentServer` | No | +| Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No | +| Prose recovery message construction | `LongRunningAgentServer._build_prose_recovery_message` | No | +| UI-echo dedup | per-template handler (see §3.4) | Yes — 1-2 lines in `agent.py` / `utils.py` | +| Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No | +| Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) | + +The author's `agent.py` is unchanged from a non-durable agent. They construct an `AsyncCheckpointSaver` (LangGraph) or `AsyncDatabricksSession` (OpenAI) and use it normally. Durability fires entirely above the SDK boundary — the SDK adapters themselves contain zero durability code. + +### 4.2 Author-visible client cooperation (chat-UI side) + +For the always-rotate flow to work cross-turn, a cooperating chat UI needs to: + +1. **Capture the rotated `conversation_id` from the SSE `response.resumed` event** when one is emitted during a streaming retrieve. +2. **Use the rotated value as `context.conversation_id` on subsequent requests** for the same chat. + +The Express proxy in `e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts` does this with an in-memory `Map`. A multi-pod chatbot deployment would persist this on the chat row. + +Without client cooperation: the next turn lands on the original (orphan-poisoned) session and the LLM call fails on the provider's `tool_use ↔ tool_result` pairing rule. The bridge does not silently repair this — the cross-turn property requires the alias. + +### 4.3 Settings worth exposing to authors + +- `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config. +- `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load. +- `task_timeout_seconds` — per-attempt ceiling. +- `stale_scan_interval_seconds` / `stale_scan_jitter_fraction` — controls how often (and with how much randomness) each pod scans the DB for stale responses to claim. Defaults to 30s with ±50% jitter. + +Everything else is internal. + +### 4.4 Test-only debug endpoint: `/_debug/kill_task/{response_id}` + +Cancels the in-flight asyncio task that owns the given response on this pod **without** running the `_task_scope` cleanup. The DB row stays `in_progress` with a heartbeat that's about to go stale — exactly the shape a real pod crash leaves. Used in integration tests to simulate a pod crash without restarting the container. + +Opt-in via env var: only registered when `LONG_RUNNING_ENABLE_DEBUG_KILL=1`. Never exposed in production. + +Returns 404 if no in-flight task for that response exists on this specific pod (the task may already have finished, or it may be running on a different pod). + +```bash +curl -sS -X POST -H "Authorization: Bearer $TOKEN" "$APP_URL/_debug/kill_task/$RID" +``` + +## 5. Future direction: TaskFlow + +[TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery. + +When TaskFlow is production-ready, `LongRunningAgentServer` is expected to keep its **HTTP surface and author-visible API unchanged**, swapping only the engine internals. + +### Mapping today → TaskFlow + +```mermaid +flowchart LR + subgraph TODAY[LongRunningAgentServer today] + T1[create_response + asyncio.create_task] + T2[_heartbeat async CM] + T3[_try_claim_and_resume CAS] + T4[_build_prose_recovery_message] + T5[/responses/{id}?stream=true] + T6[/_debug/kill_task] + end + + subgraph TF[TaskFlow] + F1[Taskflow.start name input user_id] + F2[built-in executor heartbeat] + F3[built-in recovery worker + claim_for_recovery] + F4[TaskHandler.recover ctx previous_events] + F5[Taskflow.subscribe key last_seq] + F6[Taskflow.simulate_crash key] + end + + T1 --> F1 + T2 --> F2 + T3 --> F3 + T4 --> F4 + T5 --> F5 + T6 --> F6 +``` + +### What stays in `LongRunningAgentServer` after the swap + +- `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas). +- The MLflow `@invoke()` / `@stream()` handler convention. +- `_build_prose_recovery_message` — recovery-message construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. +- `_rotate_conversation_id` — same reason. +- Author-visible settings (db config, heartbeat tuning, task timeout, stale-scan tuning). + +### What gets deleted + +- The hand-rolled heartbeat task and CAS claim CTE — replaced by TaskFlow's executor heartbeat + `claim_for_recovery`. +- The `_try_claim_and_resume` lazy-claim path — TaskFlow's recovery worker handles this proactively and across pods, fixing the documented "claim only fires on GET" limitation. +- The `agent_server.responses` + `agent_server.messages` schema — TaskFlow owns its storage layer. +- The stream-cursor logic in `_stream_retrieve` — `Taskflow.subscribe(key, last_seq)` is the cursor-based stream resume. + +### What requires a small TaskFlow API addition + +TaskFlow derives idempotency keys as `SHA256(name + canonical_input + user_id)`. The HTTP surface uses a server-generated `resp_{uuid}`. We've requested an `idempotency_key: Option` parameter on `Taskflow.submit()` so we can keep the existing HTTP contract while submitting to TaskFlow. See [`engine.rs:317`](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/experimental/taskflow/engine/src/engine.rs?L317) for the current `generate_key` call site; the override would slot in there. + +### Migration sequencing + +1. Add `LongRunningAgentServer(backend="taskflow"|"builtin")` knob, default `"builtin"`. HTTP surface unchanged on either backend. +2. Port `agent-non-conversational` (the simplest template) to `backend="taskflow"`. Run the full crash-resume matrix. +3. Port the advanced templates. **Zero changes to `agent.py`** — the swap is the constructor argument. +4. Flip default to `"taskflow"`. Deprecate `"builtin"`. +5. Delete the heartbeat / claim / repository code. Big delete PR. + +The point of the `LongRunningAgentServer` abstraction is exactly this kind of swap: callers should never have to care which engine is underneath. + +--- + +## Quick reference + +- **Code:** `src/databricks_ai_bridge/long_running/` +- **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py` +- **Settings:** `LongRunningSettings` in `settings.py` +- **Models:** `Response` and `Message` in `models.py` +- **HTTP routes:** registered in `LongRunningAgentServer._setup_routes` +- **Prose recovery:** `_build_prose_recovery_message` in `server.py` +- **Conversation rotation:** `_rotate_conversation_id` in `server.py` +- **Stale scanner:** `LongRunningAgentServer._stale_response_scanner_loop` in `server.py` +- **UI-echo dedup:** in agent code, see `app-templates/agent-{openai,langgraph}-advanced/` diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 903d466f..69ca8a79 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -79,7 +79,50 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}")) await conn.run_sync(Base.metadata.create_all) + # Idempotent migration for tables created by earlier versions: add any + # columns introduced for durable-resume support. Each statement runs in + # its own transaction so an InsufficientPrivilege on one ALTER (another + # pod's SP owns the table but the schema is already migrated) doesn't + # poison the rest. A single mega-transaction would abort entirely on the + # first owner-check failure even with IF NOT EXISTS. + migration_stmts = ( + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"CREATE INDEX IF NOT EXISTS idx_responses_stale " + f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " + "WHERE status = 'in_progress'", + ) + skipped_migrations: list[str] = [] + for stmt in migration_stmts: + try: + async with _engine.begin() as conn: + await conn.execute(text(stmt)) + except Exception as exc: + msg = str(exc).lower() + if "insufficientprivilege" in msg or "must be owner" in msg: + skipped_migrations.append(stmt.split("\n")[0]) + continue + raise + _initialized = True + if skipped_migrations: + # WARN-level summary: if the DB was previously migrated by another SP + # this is fine, but if it's genuinely a new table and our SP lacks + # ALTER, claim/heartbeat queries will fail later with a confusing + # "column does not exist" — surface it clearly at startup. + logger.warning( + "[DB] Skipped %d durability migration(s) due to insufficient " + "privilege — assuming table was already migrated by another " + "service principal. Crash-resume will fail with 'column does " + "not exist' if this assumption is wrong. Skipped: %s", + len(skipped_migrations), + ", ".join(skipped_migrations), + ) logger.info("[DB] Engine and schema ready") diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 1d876dc7..dfdcca70 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -14,7 +14,14 @@ class Base(DeclarativeBase): class Response(Base): - """Response status tracking for background agent tasks.""" + """Response status tracking for background agent tasks. + + Durability columns (``heartbeat_at``, ``attempt_number``, + ``original_request``) support crash-resume: another pod atomically + claims a stale in-progress row by CAS-ing on ``attempt_number`` and + replays the agent loop. The owning pod is implicit — it's whatever + pod last successfully heartbeat at the current attempt_number. + """ __tablename__ = "responses" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -25,12 +32,22 @@ class Response(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) + heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) + original_request: Mapped[str | None] = mapped_column(Text, nullable=True) messages = relationship("Message", back_populates="response", cascade="all, delete-orphan") class Message(Base): - """Stream events and output items for a response.""" + """Stream events and output items for a response. + + ``attempt_number`` tags events by which run attempt emitted them so that + resumed runs append to the same event log without overwriting earlier + (abandoned) attempts, and retrieve can filter to the latest attempt only. + """ __tablename__ = "messages" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -44,6 +61,9 @@ class Message(Base): sequence_number: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False, default=0 ) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) item: Mapped[str | None] = mapped_column(Text, nullable=True) stream_event: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index fcd86b29..2c6a1ff0 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -5,30 +5,64 @@ from typing import Any, NamedTuple from sqlalchemy import select, update +from sqlalchemy.sql import bindparam, text from databricks_ai_bridge.long_running.db import session_scope -from databricks_ai_bridge.long_running.models import Message, Response +from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response -async def create_response(response_id: str, status: str) -> None: - """Insert a new response.""" +async def create_response( + response_id: str, + status: str, + *, + durable: bool = False, + original_request: dict[str, Any] | None = None, +) -> None: + """Insert a new response row. + + When ``durable=True``, ``heartbeat_at`` is initialized to ``now()`` so + the row doesn't immediately look stale. Non-durable callers (tests, + legacy flows) skip the heartbeat init. + """ async with session_scope() as session: - session.add(Response(response_id=response_id, status=status)) + session.add( + Response( + response_id=response_id, + status=status, + heartbeat_at=datetime.now().astimezone() if durable else None, + original_request=( + json.dumps(original_request) if original_request is not None else None + ), + ) + ) await session.commit() async def update_response_status( - response_id: str, status: str, *, expected_current_status: str | None = None + response_id: str, + status: str, + *, + expected_current_status: str | None = None, + expected_attempt_number: int | None = None, ) -> bool: """Update response status. Returns True if a row was updated. If *expected_current_status* is given the update only takes effect when the row's current status matches, avoiding concurrent-update races. + + If *expected_attempt_number* is given the update only takes effect when the + row's current ``attempt_number`` matches, ensuring only the pod that owns + the current attempt can transition the row to a terminal state. This + prevents a stale background task (e.g. a deferred-fail timer that fired + after another pod claimed the row for resume) from clobbering the new + owner's in-progress state. """ async with session_scope() as session: stmt = update(Response).where(Response.response_id == response_id) if expected_current_status is not None: stmt = stmt.where(Response.status == expected_current_status) + if expected_attempt_number is not None: + stmt = stmt.where(Response.attempt_number == expected_attempt_number) stmt = stmt.values(status=status) result = await session.execute(stmt) await session.commit() @@ -43,18 +77,120 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None: await session.commit() +async def heartbeat_response(response_id: str, expected_attempt_number: int) -> bool: + """Update heartbeat_at for a response IFF the attempt is still ours. + + Returns True on success. A False result means the claim has been lost — + another pod CAS-bumped ``attempt_number``, so this pod is no longer the + owner and the heartbeat task should stop. Implicit-ownership model: + whichever pod last successfully heartbeats at the current + ``attempt_number`` is the de facto owner. + """ + async with session_scope() as session: + stmt = ( + update(Response) + .where( + Response.response_id == response_id, + Response.attempt_number == expected_attempt_number, + ) + .values(heartbeat_at=datetime.now().astimezone()) + ) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + + +async def claim_stale_response( + response_id: str, + stale_threshold_seconds: float, +) -> int | None: + """Atomically claim an in-progress response whose heartbeat has gone stale. + + Uses a single conditional UPDATE so exactly one caller wins on contention: + claim only succeeds if status is ``in_progress`` AND + (``heartbeat_at IS NULL`` OR ``heartbeat_at`` is older than the threshold). + The new attempt_number is the previous + 1; the prior attempt's heartbeat + task will detect this on its next heartbeat (rowcount=0) and stop. + + Returns the new ``attempt_number`` on success, or ``None`` if the row did + not satisfy the claim conditions (already completed, heartbeat still fresh, + or nonexistent). + """ + # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for + # the incremented column as ergonomically. Using a single statement keeps the + # claim atomic without an explicit transaction-level lock. + stmt = text( + f""" + UPDATE {AGENT_DB_SCHEMA}.responses + SET heartbeat_at = now(), + attempt_number = attempt_number + 1 + WHERE response_id = :rid + AND status = 'in_progress' + AND (heartbeat_at IS NULL + OR heartbeat_at < now() - make_interval(secs => :threshold)) + RETURNING attempt_number + """ + ).bindparams( + bindparam("rid", type_=None), + bindparam("threshold", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"rid": response_id, "threshold": stale_threshold_seconds}, + ) + row = result.first() + await session.commit() + return int(row[0]) if row else None + + +async def find_stale_response_ids( + stale_threshold_seconds: float, + limit: int = 50, +) -> list[str]: + """Return ids of in_progress responses whose heartbeat is older than the + threshold. Used by the proactive scanner to find candidates for resume + without waiting for a client GET. + + Limited to ``limit`` rows per scan to bound DB load. Ordered by + ``heartbeat_at`` ascending so the oldest staleness is handled first. + """ + stmt = text( + f""" + SELECT response_id FROM {AGENT_DB_SCHEMA}.responses + WHERE status = 'in_progress' + AND heartbeat_at IS NOT NULL + AND heartbeat_at < now() - make_interval(secs => :threshold) + ORDER BY heartbeat_at ASC + LIMIT :limit + """ + ).bindparams( + bindparam("threshold", type_=None), + bindparam("limit", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"threshold": stale_threshold_seconds, "limit": limit}, + ) + return [row[0] for row in result.all()] + + async def append_message( response_id: str, sequence_number: int, item: str | None = None, stream_event: dict[str, Any] | None = None, + *, + attempt_number: int = 1, ) -> None: - """Append a message (stream event) for a response.""" + """Append a message (stream event) for a response, tagged with attempt_number.""" async with session_scope() as session: session.add( Message( response_id=response_id, sequence_number=sequence_number, + attempt_number=attempt_number, item=item, stream_event=json.dumps(stream_event) if stream_event is not None else None, ) @@ -65,22 +201,26 @@ async def append_message( async def get_messages( response_id: str, after_sequence: int | None = None, -) -> list[tuple[int, str | None, dict[str, Any] | None]]: - """Fetch messages for a response, optionally after a sequence number. + *, + attempt_number: int | None = None, +) -> list[tuple[int, str | None, dict[str, Any] | None, int]]: + """Fetch messages for a response, optionally filtering by sequence / attempt. - Returns list of (sequence_number, item, stream_event_dict). + Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``. """ async with session_scope() as session: stmt = select(Message).where(Message.response_id == response_id) if after_sequence is not None: stmt = stmt.where(Message.sequence_number > after_sequence) + if attempt_number is not None: + stmt = stmt.where(Message.attempt_number == attempt_number) stmt = stmt.order_by(Message.sequence_number) result = await session.execute(stmt) rows = result.scalars().all() out = [] for r in rows: evt = json.loads(r.stream_event) if r.stream_event else None - out.append((r.sequence_number, r.item, evt)) + out.append((r.sequence_number, r.item, evt, r.attempt_number)) return out @@ -89,6 +229,9 @@ class ResponseInfo(NamedTuple): status: str created_at: datetime trace_id: str | None + heartbeat_at: datetime | None + attempt_number: int + original_request: dict[str, Any] | None async def get_response(response_id: str) -> ResponseInfo | None: @@ -97,5 +240,13 @@ async def get_response(response_id: str) -> ResponseInfo | None: result = await session.execute(select(Response).where(Response.response_id == response_id)) row = result.scalar_one_or_none() if row: - return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id) + return ResponseInfo( + row.response_id, + row.status, + row.created_at, + row.trace_id, + row.heartbeat_at, + row.attempt_number, + json.loads(row.original_request) if row.original_request else None, + ) return None diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index b3374d67..40007ef4 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -6,9 +6,13 @@ raise RuntimeError("The long_running module requires Python 3.11 or later.") import asyncio +import copy import inspect import json import logging +import os +import random +import socket import time import uuid from collections.abc import AsyncGenerator @@ -34,9 +38,12 @@ from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured from databricks_ai_bridge.long_running.repository import ( append_message, + claim_stale_response, create_response, + find_stale_response_ids, get_messages, get_response, + heartbeat_response, update_response_status, update_response_trace_id, ) @@ -47,15 +54,29 @@ BACKGROUND_KEY = "background" +# Process-local identifier for log lines. Not stored in the DB — heartbeat +# ownership is implicit via attempt_number CAS. +_POD_LOG_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" + async def _deferred_mark_failed( - response_id: str, delay: float = 2.0, reason: str = "Task timed out" + response_id: str, + delay: float = 2.0, + reason: str = "Task timed out", + *, + owning_attempt_number: int | None = None, ) -> None: """Mark a response as failed after a short delay. Runs as an independent asyncio task so the caller (``_task_scope``) can - return immediately. The delay lets the connection pool stabilise after - a cancellation before we attempt new DB writes. + return immediately. The delay lets the connection pool stabilise after a + cancellation before we attempt new DB writes. + + ``owning_attempt_number`` should be the attempt this pod was running when + the failure was scheduled. The terminal status update is CAS-checked + against it: if another pod has already claimed the row for a higher + attempt by the time this fires, we skip the failed-status write so we + don't clobber the new owner's state. """ try: await asyncio.sleep(delay) @@ -65,7 +86,17 @@ async def _deferred_mark_failed( # or SELECT FOR UPDATE on the response row to serialise writers. async with asyncio.timeout(delay): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + current_attempt = await _current_attempt(response_id) + if owning_attempt_number is not None and current_attempt != owning_attempt_number: + logger.info( + "Skipping deferred fail for %s: ownership changed " + "(was attempt=%d, now attempt=%d)", + response_id, + owning_attempt_number, + current_attempt, + ) + return error_event = { "type": "error", @@ -75,8 +106,18 @@ async def _deferred_mark_failed( "code": "task_timeout", }, } - await append_message(response_id, next_seq, item=None, stream_event=error_event) - await update_response_status(response_id, "failed") + await append_message( + response_id, + next_seq, + item=None, + stream_event=error_event, + attempt_number=current_attempt, + ) + await update_response_status( + response_id, + "failed", + expected_attempt_number=owning_attempt_number, + ) logger.info("Marked %s as failed (reason: %s)", response_id, reason) except TimeoutError: @@ -91,10 +132,21 @@ async def _deferred_mark_failed( ) +async def _current_attempt(response_id: str) -> int: + """Fetch the current attempt_number for a response, defaulting to 1.""" + resp = await get_response(response_id) + return resp.attempt_number if resp else 1 + + def _sse_event(event_type: str, data: dict[str, Any] | str) -> str: - """Format an SSE event per Open Responses spec.""" + """Emit ``data:``-only SSE frames. Match the non-durable stream format + so downstream SSE parsers dispatch on the payload's ``type`` field + rather than a leading ``event:`` name line. Claude's multi-response + stream (one response.created/completed pair per tool iteration) plus + the event-name prefix confuses the AI SDK's Databricks provider into + a retry loop.""" payload = data if isinstance(data, str) else json.dumps(data) - return f"event: {event_type}\ndata: {payload}\n\n" + return f"data: {payload}\n\n" def _age_seconds(created_at: datetime) -> float: @@ -105,9 +157,99 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +def _build_prose_recovery_message( + messages: list[tuple], prior_attempt_number: int +) -> dict[str, Any]: + """Build a single user message containing the prior attempt's raw event + log + a directive that asks the LLM to figure out what already completed + and continue. + + The body is `json.dumps(events)` of the prior attempt's stream events + wrapped in a recovery prompt. SDK-agnostic — no provider-specific pairing + rules, no structured carry-forward, no synthetic events. The model reads + the JSON, decides which tool calls succeeded, which were interrupted, and + continues. + """ + prior_events = [ + evt + for _seq, _item_json, evt, attempt_tag in messages + if attempt_tag == prior_attempt_number and isinstance(evt, dict) + ] + body = ( + "[RECOVERY] The previous attempt of this agent task crashed " + "mid-execution. Below is the raw stream-event log from that attempt " + "as JSON. Some tool calls may have completed and some may have been " + "interrupted before returning a result. Inspect the events, figure " + "out what is already done versus in-progress / not completed, and " + "continue the task from where it left off. If a tool call was " + "interrupted, you may re-invoke it if its result is still needed.\n\n" + f"Events:\n{json.dumps(prior_events)}" + ) + return { + "type": "message", + "role": "user", + "content": body, + } + + +def _rotate_conversation_id( + request_dict: dict[str, Any], + new_attempt_number: int, + response_id: str, +) -> dict[str, Any]: + """Rotate the conversation anchor to a per-attempt value. + + After a crash, attempt N+1 should see a FRESH checkpointer / session so it + doesn't inherit mid-turn state that the SDK can't repair cleanly (most + notably the LangGraph stream-event attempt-boundary orphan artifact). + The handler's priority chain is: + + 1. custom_inputs.thread_id / session_id (explicit, wins) + 2. context.conversation_id (fallback) + 3. auto-generated (last resort) + + We drop (1), pick the current base anchor, and write ``{base}::attempt-N`` + into (2). The handler then resolves to a fresh key for this attempt while + still being deterministic across retries of the same attempt. + + The LLM sees full turn history via ``original_request.input``, which was + captured at the initial POST — before any attempt ran, so it's clean by + construction. + """ + custom_inputs = request_dict.get("custom_inputs") + if not isinstance(custom_inputs, dict): + custom_inputs = {} + + base_anchor = ( + custom_inputs.get("thread_id") + or custom_inputs.get("session_id") + or (request_dict.get("context") or {}).get("conversation_id") + or response_id + ) + + custom_inputs.pop("thread_id", None) + custom_inputs.pop("session_id", None) + request_dict["custom_inputs"] = custom_inputs + + ctx = request_dict.get("context") or {} + ctx = dict(ctx) + rotated = f"{base_anchor}::attempt-{new_attempt_number}" + ctx["conversation_id"] = rotated + request_dict["context"] = ctx + logger.info( + "[durable] rotated conversation_id for resume response_id=%s attempt=%d base=%s rotated=%s", + response_id, + new_attempt_number, + base_anchor, + rotated, + ) + return request_dict + + @experimental class LongRunningAgentServer(AgentServer): - """AgentServer subclass adding background mode and retrieve endpoints. + """AgentServer subclass adding background mode, retrieve endpoints, and + durable resume. Only compatible with ``ResponsesAgent`` mode. @@ -125,6 +267,16 @@ class LongRunningAgentServer(AgentServer): ``LAKEBASE_INSTANCE_NAME``, ``LAKEBASE_AUTOSCALING_ENDPOINT``, or both ``LAKEBASE_AUTOSCALING_PROJECT`` and ``LAKEBASE_AUTOSCALING_BRANCH``. + Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run + whose owning pod has stopped heartbeating for more than + ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims + the run and re-invokes the registered handler with a rotated + ``conversation_id`` (so the agent SDK resolves to a fresh thread/session), + the original request's ``input`` enriched with the prior attempt's already + emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic + output paired with any tool call that didn't finish. Completed work is + preserved; only the interrupted step re-runs. + Args: enable_chat_proxy: Whether to enable the chat proxy endpoint. db_instance_name: Lakebase provisioned instance name. Overrides @@ -143,6 +295,12 @@ class LongRunningAgentServer(AgentServer): Defaults to 5000 (5 seconds). cleanup_timeout_seconds: Timeout for DB cleanup after task failure. Defaults to 7.0. + heartbeat_interval_seconds: How often the owning pod writes + ``heartbeat_at`` while a run is in flight. Defaults to 3.0. + heartbeat_stale_threshold_seconds: Age at which a heartbeat is + considered stale and another pod may claim the run. Also used + as the grace window for a freshly-created run that hasn't + written its first heartbeat yet. Defaults to 10.0. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -162,6 +320,8 @@ def __init__( poll_interval_seconds: float = 1.0, db_statement_timeout_ms: int = 5000, cleanup_timeout_seconds: float = 7.0, + heartbeat_interval_seconds: float = 3.0, + heartbeat_stale_threshold_seconds: float = 10.0, ): if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( @@ -173,11 +333,18 @@ def __init__( poll_interval_seconds=poll_interval_seconds, db_statement_timeout_ms=db_statement_timeout_ms, cleanup_timeout_seconds=cleanup_timeout_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + heartbeat_stale_threshold_seconds=heartbeat_stale_threshold_seconds, ) self._db_instance_name = db_instance_name self._db_autoscaling_endpoint = db_autoscaling_endpoint self._db_project = db_project self._db_branch = db_branch + # Track in-flight background tasks per response_id so the debug-kill + # endpoint can simulate a pod crash without tearing the whole pod + # down. Not load-bearing for correctness — durability still relies on + # DB state, this is just a test affordance. + self._running_tasks: dict[str, asyncio.Task] = {} super().__init__(agent_type, enable_chat_proxy=enable_chat_proxy) def _setup_routes(self) -> None: @@ -195,6 +362,46 @@ async def cancel_endpoint(response_id: str): detail="Cancellation is not yet implemented.", ) + # Debug endpoint for testing durable resume: cancels the in-flight + # asyncio task that owns the given response_id WITHOUT running the + # _task_scope cleanup, so the DB row stays in_progress with a + # going-stale heartbeat — exactly the shape a real pod crash leaves. + # Opt-in via env var so it's never exposed in production. Env var + # is checked at request time (not registration time) because some + # platforms inject env vars after the FastAPI app object is built. + @self.app.post("/_debug/kill_task/{response_id}") + async def _debug_kill_task(response_id: str): + if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") != "1": + raise HTTPException( + status_code=404, + detail="Debug kill endpoint is disabled.", + ) + task = self._running_tasks.get(response_id) + if task is None: + logger.info( + "[durable] kill endpoint: no task response_id=%s on pod=%s", + response_id, + _POD_LOG_ID, + ) + raise HTTPException( + status_code=404, + detail=( + "No in-flight task for that response_id on this pod " + "(may already have finished or be running on another pod)." + ), + ) + logger.info( + "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + response_id, + _POD_LOG_ID, + ) + task.cancel() + return { + "response_id": response_id, + "pod_id": _POD_LOG_ID, + "status": "task_cancelled", + } + db_configured = is_db_configured() @self.app.get("/responses/{response_id}") @@ -234,8 +441,19 @@ async def _db_lifespan(app): branch=self._db_branch, db_statement_timeout_ms=self._settings.db_statement_timeout_ms, ) - yield - await dispose_db() + scanner_task = asyncio.create_task( + self._stale_response_scanner_loop(), + name="durable-stale-scanner", + ) + try: + yield + finally: + scanner_task.cancel() + try: + await scanner_task + except (asyncio.CancelledError, Exception): + pass + await dispose_db() self.app.router.lifespan_context = _db_lifespan @@ -290,11 +508,32 @@ async def _handle_background_request( ) -> dict[str, Any] | StreamingResponse: """Start a new conversation and return response_id immediately.""" response_id = f"resp_{uuid.uuid4().hex[:24]}" - await create_response(response_id, "in_progress") + # Anchor the conversation to response_id so any future replay from a + # different pod resolves to the same agent-SDK thread/session. We + # round-trip through dict + validator so the handler still receives a + # pydantic ResponsesAgentRequest (its declared arg type). The + # declared param type is ``dict`` but the runtime object is a pydantic + # model from ``validate_and_convert_request``; fall back to ``dict()`` + # when tests pass a plain dict directly. + dump = getattr(request_data, "model_dump", None) + request_dict = dump() if callable(dump) else dict(request_data) + # Store the FULL request (untrimmed) as `original_request` so resume can + # recover the entire prior-turn history. Per-template handlers are + # responsible for deduping their own UI-echoed input against the SDK's + # session/checkpointer state — the bridge no longer trims input. + await create_response( + response_id, + "in_progress", + durable=True, + original_request=request_dict, + ) + durable_request = self.validator.validate_and_convert_request(request_dict) - logger.debug( - "Background response created", - extra={"response_id": response_id, "stream": is_streaming}, + logger.info( + "Background response created response_id=%s stream=%s pod=%s", + response_id, + is_streaming, + _POD_LOG_ID, ) response_obj: dict[str, Any] = { @@ -309,26 +548,187 @@ async def _handle_background_request( } # Fire-and-forget is intentional — task status is persisted to the database. + # We still track the task handle so the debug-kill endpoint can simulate + # a crash (and so we know whether a claim target lives on this pod). if is_streaming: - asyncio.create_task( - self._run_background_stream(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_stream( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return await self._handle_retrieve_request( response_id, stream=True, starting_after=0, ) else: - asyncio.create_task( - self._run_background_invoke(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_invoke( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return response_obj + def _track_task(self, response_id: str, task: asyncio.Task) -> None: + """Record a background task so the debug-kill endpoint can find it.""" + self._running_tasks[response_id] = task + task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None)) + + async def _stale_response_scanner_loop(self) -> None: + """Periodically scan for in_progress responses with stale heartbeats and + try to claim+resume them. The proactive counterpart to the lazy claim + path on ``GET /responses/{id}``. + + Each iteration sleeps for a jittered interval so multiple pods don't + synchronize their reads. Runs until cancelled (in the lifespan + teardown). + """ + base = self._settings.stale_scan_interval_seconds + jitter = self._settings.stale_scan_jitter_fraction + threshold = self._settings.heartbeat_stale_threshold_seconds + logger.info( + "[durable] stale-scan loop start interval=%.1fs jitter=±%.0f%% threshold=%.1fs pod=%s", + base, + jitter * 100, + threshold, + _POD_LOG_ID, + ) + try: + while True: + # Jittered sleep — random scaling of base interval centered on 1.0. + delay = base * (1.0 + random.uniform(-jitter, jitter)) + await asyncio.sleep(delay) + try: + stale_ids = await find_stale_response_ids(threshold) + if not stale_ids: + continue + logger.info( + "[durable] stale-scan found %d candidate(s): %s", + len(stale_ids), + stale_ids, + ) + for response_id in stale_ids: + try: + resp = await get_response(response_id) + if resp: + await self._try_claim_and_resume(response_id, resp) + except Exception: + logger.exception( + "[durable] stale-scan resume failed response_id=%s", + response_id, + ) + except Exception: + # Don't let an iteration failure kill the loop. + logger.exception("[durable] stale-scan iteration failed") + except asyncio.CancelledError: + logger.info("[durable] stale-scan loop stopped pod=%s", _POD_LOG_ID) + raise + + @asynccontextmanager + async def _heartbeat(self, response_id: str, attempt_number: int) -> AsyncGenerator[None, None]: + """Keep the response row's heartbeat_at fresh while the body runs. + + A background task writes ``heartbeat_at = now()`` every + ``heartbeat_interval_seconds``, scoped to ``attempt_number``. The + update only matches if ``attempt_number`` still equals the value the + heartbeat was started with — if another pod has CAS-claimed the + row (bumping attempt_number), this heartbeat returns 0 rows and the + task knows it has lost ownership and stops. + + Implicit-ownership model: there is no ``owner_pod_id`` column. The + last pod to successfully heartbeat at the current attempt is the + de facto owner. + """ + interval = self._settings.heartbeat_interval_seconds + stop = asyncio.Event() + + async def _beat(): + beats = 0 + logger.info( + "[durable] heartbeat start response_id=%s attempt=%d pod=%s interval=%.1fs", + response_id, + attempt_number, + _POD_LOG_ID, + interval, + ) + try: + while not stop.is_set(): + try: + ok = await heartbeat_response(response_id, attempt_number) + if not ok: + # CAS failed → attempt_number has moved past us, + # another pod owns this response now. Stop the + # heartbeat task; the handler is still running but + # its emissions to the message log will be tagged + # with this attempt and ignored on the resumed + # path's filter (which keys on the new attempt). + logger.info( + "[durable] heartbeat lost ownership response_id=%s " + "attempt=%d (another pod claimed); stopping", + response_id, + attempt_number, + ) + stop.set() + break + beats += 1 + # Sampled heartbeat log so the lifecycle is visible + # without spamming every interval. Every 5th (~15s + # at 3s interval) is a good compromise. + if beats % 5 == 1: + logger.info( + "[durable] heartbeat beat#%d response_id=%s attempt=%d pod=%s", + beats, + response_id, + attempt_number, + _POD_LOG_ID, + ) + except Exception: + logger.warning( + "[durable] heartbeat write failed response_id=%s; will retry", + response_id, + exc_info=True, + ) + try: + await asyncio.wait_for(stop.wait(), timeout=interval) + except TimeoutError: + pass + except asyncio.CancelledError: + pass + logger.info( + "[durable] heartbeat stop response_id=%s attempt=%d pod=%s total_beats=%d", + response_id, + attempt_number, + _POD_LOG_ID, + beats, + ) + + hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}") + try: + yield + finally: + stop.set() + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + @asynccontextmanager async def _task_scope( - self, response_id: str, state: dict[str, Any] + self, + response_id: str, + state: dict[str, Any], + *, + attempt_number: int = 1, ) -> AsyncGenerator[None, None]: - """Timeout + error handling wrapper for background tasks.""" + """Timeout + error handling wrapper for background tasks. + + ``attempt_number`` is CAS-checked on terminal-status writes so a + deferred fail / cleanup that fires after another pod has claimed the + row for resume doesn't clobber the new owner's in-progress state. + """ try: async with asyncio.timeout(self._settings.task_timeout_seconds): yield @@ -339,7 +739,11 @@ async def _task_scope( self._settings.task_timeout_seconds, ) asyncio.create_task( - _deferred_mark_failed(response_id, delay=self._settings.cleanup_timeout_seconds), + _deferred_mark_failed( + response_id, + delay=self._settings.cleanup_timeout_seconds, + owning_attempt_number=attempt_number, + ), name=f"deferred-fail-{response_id}", ) except Exception as exc: @@ -348,7 +752,7 @@ async def _task_scope( # TODO: sequence number computation is racy (see _deferred_mark_failed). async with asyncio.timeout(self._settings.cleanup_timeout_seconds): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 await append_message( response_id, next_seq, @@ -361,8 +765,13 @@ async def _task_scope( "code": "task_failed", }, }, + attempt_number=attempt_number, + ) + await update_response_status( + response_id, + "failed", + expected_attempt_number=attempt_number, ) - await update_response_status(response_id, "failed") except Exception: logger.exception( "[error-cleanup] Immediate update failed for %s, deferring", @@ -373,6 +782,7 @@ async def _task_scope( response_id, delay=self._settings.cleanup_timeout_seconds, reason=str(exc), + owning_attempt_number=attempt_number, ), name=f"deferred-fail-{response_id}", ) @@ -382,11 +792,22 @@ async def _run_background_stream( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_stream(response_id, request_data, return_trace_id, state) + async with ( + self._task_scope(response_id, state, attempt_number=attempt_number), + self._heartbeat(response_id, attempt_number), + ): + await self._do_background_stream( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) def transform_stream_event(self, event: dict, response_id: str) -> dict: """Override to transform events before persistence (e.g. replace placeholder IDs).""" @@ -398,16 +819,35 @@ async def _do_background_stream( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via stream_fn, persist each stream event as a message row.""" stream_fn = get_stream_function() if stream_fn is None: - await update_response_status(response_id, "failed") + await update_response_status( + response_id, "failed", expected_attempt_number=attempt_number + ) raise RuntimeError("No stream function registered; cannot run background stream") func_name = stream_fn.__name__ + logger.info( + "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s", + response_id, + attempt_number, + _POD_LOG_ID, + func_name, + ) all_chunks: list[dict[str, Any]] = [] - seq = 0 + # Continue sequence numbering across attempts so the client's cursor + # never rewinds on resume. First attempt starts at 0 and skips the DB + # lookup — keeps the fast path identical to pre-resume behavior and + # avoids an extra query per background request. + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + seq = 0 with mlflow.start_span(name=func_name) as span: span.set_inputs(request_data) @@ -420,16 +860,27 @@ async def _do_background_stream( evt_type = evt.get("type", "message") logger.debug( "SSE event (background)", - extra={"response_id": response_id, "seq": seq, "type": evt_type}, + extra={ + "response_id": response_id, + "seq": seq, + "type": evt_type, + "attempt": attempt_number, + }, ) await append_message( response_id, seq, item=json.dumps(item) if item is not None else None, stream_event=evt, + attempt_number=attempt_number, ) seq += 1 state["seq"] = seq + # Explicit yield so task.cancel() propagates promptly on + # tight event streams. The OpenAI Agents Runner's + # stream_events() awaits a queue that empties fast enough + # that cancellation can sit for tens of seconds without this. + await asyncio.sleep(0) span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "openai") span.set_outputs(ResponsesAgent.responses_agent_output_reducer(all_chunks)) @@ -439,12 +890,28 @@ async def _do_background_stream( response_id, seq, stream_event={"trace_id": span.trace_id}, + attempt_number=attempt_number, ) - await update_response_status(response_id, "completed") - logger.debug( - "Background stream completed", - extra={"response_id": response_id, "total_events": seq}, + updated = await update_response_status( + response_id, "completed", expected_attempt_number=attempt_number + ) + if not updated: + logger.info( + "[durable] skipped completed-status write response_id=%s attempt=%d " + "(another pod claimed the row mid-handler); pod=%s", + response_id, + attempt_number, + _POD_LOG_ID, + ) + return + logger.info( + "[durable] background stream completed response_id=%s attempt=%d " + "total_events=%d pod=%s", + response_id, + attempt_number, + seq, + _POD_LOG_ID, ) async def _run_background_invoke( @@ -452,11 +919,22 @@ async def _run_background_invoke( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_invoke(response_id, request_data, return_trace_id, state) + async with ( + self._task_scope(response_id, state, attempt_number=attempt_number), + self._heartbeat(response_id, attempt_number), + ): + await self._do_background_invoke( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) async def _do_background_invoke( self, @@ -464,11 +942,15 @@ async def _do_background_invoke( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via invoke_fn, persist each output item as a message row.""" invoke_fn = get_invoke_function() if invoke_fn is None: - await update_response_status(response_id, "failed") + await update_response_status( + response_id, "failed", expected_attempt_number=attempt_number + ) raise RuntimeError("No invoke function registered; cannot run background invoke") func_name = invoke_fn.__name__ @@ -485,27 +967,191 @@ async def _do_background_invoke( span.set_outputs(result) output = result.get("output", []) + # Continue sequence numbering across attempts (see _do_background_stream). + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + base_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + base_seq = 0 for i, item in enumerate(output): item_dict = ( item if isinstance(item, dict) else (item.model_dump() if hasattr(item, "model_dump") else {"content": str(item)}) ) + seq = base_seq + i await append_message( response_id, - i, + seq, item=json.dumps(item_dict), stream_event={"type": "response.output_item.done", "item": item_dict}, + attempt_number=attempt_number, ) - state["seq"] = i + 1 + state["seq"] = seq + 1 if return_trace_id: await update_response_trace_id(response_id, span.trace_id) - await update_response_status(response_id, "completed") + updated = await update_response_status( + response_id, "completed", expected_attempt_number=attempt_number + ) + if not updated: + logger.info( + "[durable] skipped completed-status write response_id=%s attempt=%d " + "(another pod claimed the row mid-handler); pod=%s", + response_id, + attempt_number, + _POD_LOG_ID, + ) + return logger.debug( "Background invoke completed", extra={"response_id": response_id, "output_items": len(output)}, ) + async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: + """If ``resp`` is a stale in-progress run, attempt an atomic claim. + + On success, kick off a new background task that re-invokes the handler + on a rotated conversation anchor with the replayed input enriched by + the prior attempt's emitted items, and returns the new + ``attempt_number``. On failure (another pod won, or the run is no + longer stale), returns ``None``. + + This is the lazy resume path: triggered by a client retrieve. Pods + don't poll for stale work proactively in v1 — if no client ever calls + ``GET /responses/{id}``, the task_timeout sweep eventually marks it + failed. + """ + if resp.status != "in_progress": + return None + # The run may be freshly started but too young to have a heartbeat yet; + # respect the creation age as a grace period equal to the stale + # threshold. Otherwise a quick follow-up retrieve could hijack a + # running pod before it ever writes its first heartbeat. + if resp.heartbeat_at is None: + age = _age_seconds(resp.created_at) + if age < self._settings.heartbeat_stale_threshold_seconds: + logger.debug( + "[durable] claim skipped response_id=%s reason=grace_period " + "age=%.1fs threshold=%.1fs", + response_id, + age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + else: + hb_age = _age_seconds(resp.heartbeat_at) + if hb_age < self._settings.heartbeat_stale_threshold_seconds: + # Heartbeat is fresh — owner is alive. Common case, keep + # quiet at debug so we don't spam every poll iteration. + logger.debug( + "[durable] claim skipped response_id=%s reason=heartbeat_fresh " + "age=%.1fs threshold=%.1fs", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + logger.info( + "[durable] stale heartbeat detected response_id=%s " + "heartbeat_age=%.1fs threshold=%.1fs", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + ) + if resp.original_request is None: + # Nothing to replay from — the run predates durability metadata. + logger.warning( + "[durable] cannot resume response_id=%s reason=no_original_request", + response_id, + ) + return None + + logger.info( + "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s", + response_id, + resp.attempt_number, + _POD_LOG_ID, + ) + new_attempt = await claim_stale_response( + response_id, + stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds, + ) + if new_attempt is None: + # Someone else owns it, or the row was updated between the read and + # the claim. Expected under contention. + logger.info( + "[durable] claim lost response_id=%s (another pod won or row changed)", + response_id, + ) + return None + + # Build a "resume" request by REPLAYING the original POST's input on a + # ROTATED conversation anchor, plus a single prose user message that + # narrates the prior attempt's completed tool calls / outputs / narrative. + # + # Always-rotate + prose recovery design: + # 1. Rotation makes the handler's SDK helpers resolve to a FRESH + # thread_id / session_id, so the rotated session starts empty and + # cannot inherit orphan-poisoned mid-turn state from the crashed + # attempt. Subsequent turns from the client should also use the + # rotated anchor (templates return it via custom_outputs); the + # original session becomes orphaned permanently and is never read. + # 2. The prose user message is the single source of truth for what + # already ran. The LLM reads it as a recovery instruction and + # continues. No structural carry-forward, no synthetic outputs, + # no per-SDK adapter wrappers needed. + existing = await get_messages(response_id, after_sequence=None) + next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + prose_msg = _build_prose_recovery_message(existing, prior_attempt_number=new_attempt - 1) + + resume_dict = copy.deepcopy(resp.original_request) + resume_input = list(resume_dict.get("input") or []) + resume_input.append(prose_msg) + resume_dict["input"] = resume_input + logger.info( + "[durable] resume built prose recovery message for attempt %d response_id=%s", + new_attempt - 1, + response_id, + ) + resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id) + resume_request = self.validator.validate_and_convert_request(resume_dict) + # Surface the rotated conversation_id in the sentinel so clients that + # cache `chat_id → conversation_id` can pick up the rotation and use + # the rotated session on subsequent turns. Without this the next turn + # lands on the original (orphan-poisoned) session. + rotated_conv_id = (resume_dict.get("context") or {}).get("conversation_id") + await append_message( + response_id, + next_seq, + stream_event={ + "type": "response.resumed", + "attempt": new_attempt, + "from_seq": next_seq, + "conversation_id": rotated_conv_id, + }, + attempt_number=new_attempt, + ) + + logger.info( + "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d", + response_id, + new_attempt, + _POD_LOG_ID, + next_seq, + ) + + task = asyncio.create_task( + self._run_background_stream( + response_id, + resume_request, + return_trace_id=False, + attempt_number=new_attempt, + ), + name=f"resume-{response_id}-{new_attempt}", + ) + self._track_task(response_id, task) + return new_attempt + async def _handle_retrieve_request( self, response_id: str, @@ -523,7 +1169,20 @@ async def _handle_retrieve_request( if resp is None: raise HTTPException(status_code=404, detail="Response not found") - _, status, created_at, trace_id = resp + # Try a lazy resume before falling back to the absolute-timeout sweep. + # This gives us crash-recovery semantics: an idle client reconnecting + # after a pod died will reclaim the run and resume it here instead of + # just marking it failed. + await self._try_claim_and_resume(response_id, resp) + + # Refresh after the potential resume: status / attempt_number may have changed. + resp = await get_response(response_id) + if resp is None: + raise HTTPException(status_code=404, detail="Response not found") + + status = resp.status + created_at = resp.created_at + trace_id = resp.trace_id if ( status == "in_progress" @@ -542,10 +1201,9 @@ async def _handle_retrieve_request( }, ) # TODO: sequence number computation here is racy under concurrent writers. - # Acceptable at current scale; for high-QPS use a DB-assigned sequence or - # SELECT FOR UPDATE on the response row to serialise writers. existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -558,6 +1216,7 @@ async def _handle_retrieve_request( "code": "task_timeout", }, }, + attempt_number=attempt, ) status = "failed" @@ -579,25 +1238,45 @@ async def _handle_retrieve_request( messages = await get_messages(response_id, after_sequence=None) if not messages and status == "in_progress": - return {"id": response_id, "status": "in_progress"} + return { + "id": response_id, + "status": "in_progress", + "attempt_number": resp.attempt_number, + } if status == "completed" and messages: + # Only consider items from the final (successful) attempt so that + # abandoned in-progress items from crashed attempts don't leak + # into the authoritative response body. Completed output_item.done + # events across attempts together make up the conversation — the + # agent SDK's checkpointer guarantees done-items are not re-emitted + # by later attempts, so this is a union with no duplicates. output = [] - for _, _, evt in messages: - if evt and "item" in evt: - output.append(evt["item"]) + for _, _, evt, _attempt in messages: + if evt and evt.get("type") == "response.output_item.done": + output.append(evt.get("item")) result: dict[str, Any] = { "id": response_id, "status": "completed", - "output": output, + "output": [o for o in output if o is not None], + "attempt_number": resp.attempt_number, } if trace_id: result["metadata"] = {"trace_id": trace_id} return result if status == "failed" and messages: - for _, _, evt in messages: + for _, _, evt, _attempt in messages: if evt and evt.get("type") == "error": - return {"id": response_id, "status": "failed", "error": evt.get("error")} - return {"id": response_id, "status": status} + return { + "id": response_id, + "status": "failed", + "error": evt.get("error"), + "attempt_number": resp.attempt_number, + } + return { + "id": response_id, + "status": status, + "attempt_number": resp.attempt_number, + } async def _stream_retrieve( self, @@ -638,15 +1317,26 @@ async def _stream_retrieve( ) break - _, status, _, _ = resp + status = resp.status + # Self-heal: if this response is still in_progress but its owning + # pod has gone silent past heartbeat_stale_threshold, try to claim + # + resume on this pod. A no-op if heartbeat is fresh or another + # pod already won. Without this, a stream opened before the crash + # would idle forever polling a dead run — since _try_claim_and_resume + # is only triggered by the outer retrieve handler on fresh GETs. + if status == "in_progress": + await self._try_claim_and_resume(response_id, resp) + # starting_after=0 fetches all messages (sequence numbers start at 0). # We use after_sequence=-1 for the DB query so that seq 0 is included. after_seq = last_seq - 1 if last_seq == 0 else last_seq messages = await get_messages(response_id, after_sequence=after_seq) - for seq, _, evt in messages: + for seq, _, evt, _attempt in messages: if evt is not None: - evt = {**evt, "sequence_number": seq} + # Tag every SSE frame with the response_id so proxies / + # clients can discover it without parsing nested fields. + evt = {**evt, "sequence_number": seq, "response_id": response_id} event_type = evt.get("type", "message") logger.debug( "SSE event", diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 7b646116..8224b4d8 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -15,6 +15,15 @@ class LongRunningSettings: poll_interval_seconds: float = 1.0 db_statement_timeout_ms: int = 5000 cleanup_timeout_seconds: float = 7.0 + heartbeat_interval_seconds: float = 3.0 + heartbeat_stale_threshold_seconds: float = 10.0 + # Proactive stale-scan loop: how often (on average) each pod queries the + # responses table for stale-heartbeat rows and tries to claim+resume them. + # Each pod jitters this interval so multiple pods don't all hit the DB at + # once. The loop is the proactive counterpart to the lazy-on-GET claim + # path; it ensures crashed responses get recovered even if no client polls. + stale_scan_interval_seconds: float = 30.0 + stale_scan_jitter_fraction: float = 0.5 def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: @@ -25,6 +34,16 @@ def __post_init__(self) -> None: raise ValueError("db_statement_timeout_ms must be positive") if self.cleanup_timeout_seconds <= 0: raise ValueError("cleanup_timeout_seconds must be positive") + if self.heartbeat_interval_seconds <= 0: + raise ValueError("heartbeat_interval_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= 0: + raise ValueError("heartbeat_stale_threshold_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= self.heartbeat_interval_seconds: + raise ValueError( + f"heartbeat_stale_threshold_seconds ({self.heartbeat_stale_threshold_seconds}) " + f"must be strictly greater than heartbeat_interval_seconds " + f"({self.heartbeat_interval_seconds}) to avoid false stale-run detection." + ) db_timeout_s = self.db_statement_timeout_ms / 1000.0 if self.cleanup_timeout_seconds <= db_timeout_s: raise ValueError( @@ -32,3 +51,7 @@ def __post_init__(self) -> None: f"strictly greater than db_statement_timeout_ms converted to seconds " f"({db_timeout_s})" ) + if self.stale_scan_interval_seconds <= 0: + raise ValueError("stale_scan_interval_seconds must be positive") + if not 0 <= self.stale_scan_jitter_fraction < 1: + raise ValueError("stale_scan_jitter_fraction must be in [0, 1)") diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index a1290ba1..2565a1e0 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -160,10 +160,12 @@ async def test_get_messages(mock_session): result_mock.scalars.return_value.all.return_value = [msg1, msg2] mock_session.execute.return_value = result_mock + msg1.attempt_number = 1 + msg2.attempt_number = 1 messages = await get_messages("resp_abc123", after_sequence=None) assert len(messages) == 2 - assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}) - assert messages[1] == (1, None, None) + assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}, 1) + assert messages[1] == (1, None, None, 1) @pytest.mark.asyncio @@ -175,6 +177,9 @@ async def test_get_response(mock_session): row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) row.trace_id = "trace_xyz" + row.heartbeat_at = None + row.attempt_number = 1 + row.original_request = None result_mock = MagicMock() result_mock.scalar_one_or_none.return_value = row mock_session.execute.return_value = result_mock @@ -185,6 +190,9 @@ async def test_get_response(mock_session): "completed", datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc), "trace_xyz", + None, # heartbeat_at + 1, # attempt_number + None, # original_request ) @@ -283,9 +291,15 @@ async def test_creates_schema_and_tables(self, reset_db_globals): with patch(f"{DB_MODULE}.AsyncLakebaseSQLAlchemy", mock_cls), patch(f"{DB_MODULE}.event"): await init_db(autoscaling_endpoint="ep") - mock_conn.execute.assert_awaited_once() - sql_arg = str(mock_conn.execute.call_args[0][0]) - assert "CREATE SCHEMA IF NOT EXISTS" in sql_arg + # init_db runs: CREATE SCHEMA + run_sync(create_all) + a series of + # ADD COLUMN IF NOT EXISTS / CREATE INDEX IF NOT EXISTS to migrate + # the durability columns onto pre-existing tables. + all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list) + assert "CREATE SCHEMA IF NOT EXISTS" in all_sql + assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql + assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql + assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql + assert "idx_responses_stale" in all_sql mock_conn.run_sync.assert_awaited_once() @pytest.mark.asyncio @@ -346,3 +360,104 @@ async def fake_factory(): monkeypatch.setattr(db_mod, "_session_factory", fake_factory) async with session_scope() as session: assert session is mock_session + + +# --------------------------------------------------------------------------- +# Durability metadata: heartbeat (CAS on attempt), claim, attempt_number +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_response_durable_stamps_heartbeat_and_original_request(mock_session): + """Durable callers stamp heartbeat_at + serialized request on creation — + without these, a resumed pod can't re-invoke the handler.""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response( + "resp_abc", + "in_progress", + durable=True, + original_request={"input": [{"role": "user", "content": "hi"}]}, + ) + added = mock_session.add.call_args[0][0] + assert added.heartbeat_at is not None + # original_request is JSON-encoded for Text storage. + assert '"role": "user"' in added.original_request + + +@pytest.mark.asyncio +async def test_create_response_without_durability_metadata(mock_session): + """Non-durable callers (tests, legacy flows) write no heartbeat so the + stale sweep can't accidentally claim them.""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response("resp_x", "in_progress") + added = mock_session.add.call_args[0][0] + assert added.heartbeat_at is None + assert added.original_request is None + + +@pytest.mark.asyncio +async def test_heartbeat_response_updates_when_attempt_matches(mock_session): + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 1 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", expected_attempt_number=1) + assert ok is True + mock_session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_heartbeat_response_fails_when_attempt_changed(mock_session): + """If the CAS misses (attempt_number bumped by another pod's claim), + heartbeat reports failure so the caller can stop looping.""" + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 0 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", expected_attempt_number=1) + assert ok is False + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_attempt_number(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + row = MagicMock() + row.__iter__ = lambda self: iter([2]) + row.__getitem__ = lambda self, i: 2 + result_mock = MagicMock() + result_mock.first.return_value = row + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0) + assert attempt == 2 + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_none_when_not_eligible(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + result_mock = MagicMock() + result_mock.first.return_value = None + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0) + assert attempt is None + + +@pytest.mark.asyncio +async def test_append_message_with_attempt_number(mock_session): + """Resumed events must be tagged with the resume attempt so retrieve can + filter or the client can render the response.resumed boundary cleanly.""" + from databricks_ai_bridge.long_running.repository import append_message + + await append_message("resp_abc", 5, stream_event={"x": 1}, attempt_number=3) + added = mock_session.add.call_args[0][0] + assert added.attempt_number == 3 + assert added.sequence_number == 5 diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 27b6eaf5..ae33b05c 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,9 +14,12 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") +from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( LongRunningAgentServer, + _build_prose_recovery_message, _deferred_mark_failed, + _rotate_conversation_id, _sse_event, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings @@ -34,6 +37,38 @@ def _make_server(**kwargs): return LongRunningAgentServer("ResponsesAgent", **kwargs) +def _resp_info( + response_id: str = "resp_123", + status: str = "in_progress", + created_at=None, + trace_id: str | None = None, + heartbeat_at=None, + attempt_number: int = 1, + original_request: dict | None = None, +) -> ResponseInfo: + """Build a ResponseInfo with sensible defaults for tests. + + Mirrors the server's repository model so test setups stay terse even as + durability columns grow over time. + """ + if created_at is None: + created_at = datetime.now(timezone.utc) + return ResponseInfo( + response_id=response_id, + status=status, + created_at=created_at, + trace_id=trace_id, + heartbeat_at=heartbeat_at, + attempt_number=attempt_number, + original_request=original_request, + ) + + +def _msg(seq: int, item=None, evt=None, attempt: int = 1): + """Build a (seq, item, stream_event, attempt_number) tuple for get_messages mocks.""" + return (seq, item, evt, attempt) + + def _mock_span(): """Return a mock MLflow span with the attributes the server uses.""" span = MagicMock() @@ -55,8 +90,8 @@ def _mock_validator(server): class TestSSEEvent: def test_dict_data(self): result = _sse_event("response.created", {"id": "resp_123", "status": "in_progress"}) - assert result.startswith("event: response.created\n") - assert "data: " in result + assert result.startswith("data: ") + assert "event:" not in result assert result.endswith("\n\n") data_line = result.split("data: ")[1].strip() parsed = json.loads(data_line) @@ -64,8 +99,8 @@ def test_dict_data(self): def test_string_data(self): result = _sse_event("error", "something went wrong") - assert "event: error\n" in result - assert "data: something went wrong\n\n" in result + assert "event:" not in result + assert result == "data: something went wrong\n\n" class TestLongRunningSettings: @@ -189,7 +224,7 @@ def test_starting_after_zero_without_stream_is_allowed(self): patch( f"{MODULE}.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( f"{MODULE}.get_messages", @@ -209,8 +244,13 @@ async def test_marks_response_failed(self): patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, - return_value=[(0, None, {"type": "response.created"})], + return_value=[_msg(0, None, {"type": "response.created"})], ) as mock_get, + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch( "databricks_ai_bridge.long_running.server.append_message", new_callable=AsyncMock, @@ -230,7 +270,7 @@ async def test_marks_response_failed(self): stream_event = args[1]["stream_event"] assert stream_event["type"] == "error" assert stream_event["error"]["code"] == "task_timeout" - mock_update.assert_awaited_once_with("resp_123", "failed") + mock_update.assert_awaited_once_with("resp_123", "failed", expected_attempt_number=None) @pytest.mark.asyncio async def test_handles_db_error_gracefully(self): @@ -242,6 +282,36 @@ async def test_handles_db_error_gracefully(self): # Should not raise await _deferred_mark_failed("resp_123", delay=0.01) + @pytest.mark.asyncio + async def test_skips_status_write_when_attempt_changed(self): + # The pod that scheduled this fail was running attempt=1; by the + # time this fires, another pod has bumped to attempt=2. We must NOT + # write terminal status. + with ( + patch( + "databricks_ai_bridge.long_running.server.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {"type": "response.created"})], + ), + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(attempt_number=2), + ), + patch( + "databricks_ai_bridge.long_running.server.append_message", + new_callable=AsyncMock, + ) as mock_append, + patch( + "databricks_ai_bridge.long_running.server.update_response_status", + new_callable=AsyncMock, + ) as mock_update, + ): + await _deferred_mark_failed("resp_123", delay=0.01, owning_attempt_number=1) + # Neither append nor status-write fires when we've lost ownership. + mock_append.assert_not_awaited() + mock_update.assert_not_awaited() + class TestRetrieveRequest: @pytest.mark.asyncio @@ -271,13 +341,13 @@ async def test_completed_returns_output(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), "trace_abc"), + return_value=_resp_info("resp_123", "completed", trace_id="trace_abc"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - ( + _msg( 0, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -305,7 +375,7 @@ async def test_stale_run_detection(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_stale", "in_progress", old_time, None), + return_value=_resp_info("resp_stale", "in_progress", created_at=old_time), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -336,7 +406,7 @@ async def test_in_progress_returns_status(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -347,7 +417,11 @@ async def test_in_progress_returns_status(self): result = await server._handle_retrieve_request( "resp_123", stream=False, starting_after=0 ) - assert result == {"id": "resp_123", "status": "in_progress"} + assert result == { + "id": "resp_123", + "status": "in_progress", + "attempt_number": 1, + } class TestStreamRetrieve: @@ -360,14 +434,14 @@ async def test_completed_stream(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "completed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created", "id": "resp_123"}), - ( + _msg(0, None, {"type": "response.created", "id": "resp_123"}), + _msg( 1, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -394,13 +468,13 @@ async def test_failed_stream_stops(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "failed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "failed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "error", "error": {"message": "boom"}}), + _msg(0, None, {"type": "error", "error": {"message": "boom"}}), ], ), ): @@ -433,7 +507,9 @@ async def fake_stream(request_data): patch(f"{MODULE}.get_stream_function", return_value=fake_stream), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.ResponsesAgent") as mock_ra, ): mock_mlflow.start_span.return_value = span @@ -448,7 +524,7 @@ async def fake_stream(request_data): assert seqs == [0, 1, 2] # Verify state tracks final seq assert state["seq"] == 3 - mock_update.assert_awaited_once_with("resp_1", "completed") + mock_update.assert_awaited_once_with("resp_1", "completed", expected_attempt_number=1) @pytest.mark.asyncio async def test_calls_transform_stream_event(self): @@ -500,12 +576,14 @@ async def test_no_stream_fn_marks_failed(self): with ( patch(f"{MODULE}.get_stream_function", return_value=None), - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 0} with pytest.raises(RuntimeError, match="No stream function registered"): await server._do_background_stream("resp_x", {}, False, state) - mock_update.assert_awaited_once_with("resp_x", "failed") + mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_persists_trace_id_when_requested(self): @@ -554,7 +632,9 @@ async def fake_invoke(request_data): patch(f"{MODULE}.get_invoke_function", return_value=fake_invoke), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock), ): mock_mlflow.start_span.return_value = span @@ -572,7 +652,7 @@ async def fake_invoke(request_data): assert evt["type"] == "response.output_item.done" assert "item" in evt assert state["seq"] == 2 - mock_update.assert_awaited_once_with("resp_inv", "completed") + mock_update.assert_awaited_once_with("resp_inv", "completed", expected_attempt_number=1) @pytest.mark.asyncio async def test_trace_id_persisted_when_requested(self): @@ -603,12 +683,14 @@ async def test_no_invoke_fn_marks_failed(self): with ( patch(f"{MODULE}.get_invoke_function", return_value=None), - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 0} with pytest.raises(RuntimeError, match="No invoke function registered"): await server._do_background_invoke("resp_x", {}, False, state) - mock_update.assert_awaited_once_with("resp_x", "failed") + mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_sync_invoke_fn_supported(self): @@ -623,7 +705,9 @@ def sync_invoke(request_data): patch(f"{MODULE}.get_invoke_function", return_value=sync_invoke), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock), ): mock_mlflow.start_span.return_value = span @@ -632,7 +716,9 @@ def sync_invoke(request_data): await server._do_background_invoke("resp_sync", {"input": "hi"}, False, state) assert mock_append.await_count == 1 - mock_update.assert_awaited_once_with("resp_sync", "completed") + mock_update.assert_awaited_once_with( + "resp_sync", "completed", expected_attempt_number=1 + ) # --------------------------------------------------------------------------- @@ -668,12 +754,19 @@ async def test_exception_writes_error_event_inline(self): f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created"}), - (1, None, {"type": "response.output_text.delta"}), + _msg(0, None, {"type": "response.created"}), + _msg(1, None, {"type": "response.output_text.delta"}), ], ), + patch( + f"{MODULE}.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 2} async with server._task_scope("resp_err", state): @@ -686,7 +779,7 @@ async def test_exception_writes_error_event_inline(self): assert evt["error"]["message"] == "something broke" assert evt["error"]["code"] == "task_failed" assert mock_append.call_args.args[1] == 2 # next_seq - mock_update.assert_awaited_once_with("resp_err", "failed") + mock_update.assert_awaited_once_with("resp_err", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_exception_falls_back_to_deferred_on_db_failure(self): @@ -875,3 +968,544 @@ async def test_lifespan_not_set_when_db_not_configured(self): routes = [r.path for r in server.app.routes if hasattr(r, "path")] assert "/responses/{response_id}" in routes mock_init.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Durable resume: claim/heartbeat/attempt_number/sentinel +# --------------------------------------------------------------------------- + + +class TestBuildProseRecoveryMessage: + """Prose recovery serializer: produce a single Responses-API user-message + item containing the prior attempt's stream events as JSON, plus a + directive that asks the LLM to figure out what's done vs interrupted.""" + + def _done(self, seq, attempt, item): + return (seq, None, {"type": "response.output_item.done", "item": item}, attempt) + + def test_returns_user_message_shape(self): + out = _build_prose_recovery_message([], prior_attempt_number=1) + assert out["type"] == "message" + assert out["role"] == "user" + assert isinstance(out["content"], str) + assert "[RECOVERY]" in out["content"] + + def test_includes_events_json(self): + messages = [ + self._done( + 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"} + ), + self._done(1, 1, {"type": "function_call_output", "call_id": "c1", "output": "ok"}), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + body = out["content"] + # Body should contain the raw events JSON-serialized. + assert '"call_id": "c1"' in body + assert '"output": "ok"' in body + assert '"name": "f"' in body + + def test_filters_other_attempts(self): + messages = [ + self._done( + 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"} + ), + self._done( + 1, 2, {"type": "function_call", "call_id": "c2", "name": "g", "arguments": "{}"} + ), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + body = out["content"] + assert '"call_id": "c1"' in body + # attempt 2 events excluded + assert '"call_id": "c2"' not in body + + def test_empty_attempt_emits_empty_events_array(self): + out = _build_prose_recovery_message([], prior_attempt_number=1) + # Body still contains the recovery directive and an empty events array. + assert "[RECOVERY]" in out["content"] + assert "Events:\n[]" in out["content"] + + +class TestRotateConversationId: + def test_rotate_drops_thread_id_and_sets_rotated_context(self): + r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "thread_id" not in out["custom_inputs"] + assert out["custom_inputs"]["user_id"] == "u" + assert out["context"]["conversation_id"] == "t1::attempt-2" + + def test_rotate_drops_session_id(self): + r = {"custom_inputs": {"session_id": "s1"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "session_id" not in out["custom_inputs"] + assert out["context"]["conversation_id"] == "s1::attempt-2" + + def test_rotate_falls_back_to_context_conversation_id(self): + r = {"custom_inputs": {}, "context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-3" + + def test_rotate_falls_back_to_response_id_as_last_resort(self): + r = {"custom_inputs": {}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "resp_x::attempt-2" + + def test_rotate_handles_missing_custom_inputs_key(self): + r = {"context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-2" + assert out["custom_inputs"] == {} + + +class TestHandleBackgroundRequestPersistsDurabilityState: + """Background request entry point should stamp the response row with the + full original_request body so resume can recover full prior-turn history.""" + + @pytest.mark.asyncio + async def test_persists_durable_flag_and_original_request(self): + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + _mock_validator(server) + + captured: dict = {} + + async def fake_create_response( + response_id, status, *, durable=False, original_request=None + ): + captured["response_id"] = response_id + captured["status"] = status + captured["durable"] = durable + captured["original_request"] = original_request + + with ( + patch(f"{MODULE}.create_response", side_effect=fake_create_response), + patch("asyncio.create_task") as mock_create_task, + ): + result = await server._handle_background_request( + {"input": [{"role": "user", "content": "hi"}]}, + is_streaming=False, + return_trace_id=False, + ) + + assert captured["status"] == "in_progress" + assert captured["durable"] is True + # original_request preserves the input the client sent (no + # conversation_id injection — the client owns that decision). + orig = captured["original_request"] + assert orig["input"] == [{"role": "user", "content": "hi"}] + # Return shape: immediate response_obj, not a stream. + assert result["id"] == captured["response_id"] + assert result["status"] == "in_progress" + mock_create_task.assert_called_once() + + +class TestTryClaimAndResume: + @pytest.mark.asyncio + async def test_no_op_when_completed(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + resp = _resp_info(status="completed") + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_grace_period_for_fresh_run(self): + """Just-started runs get a grace window before they're claim-eligible.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_stale_threshold_seconds=15.0 + ) + # created 2s ago, no heartbeat yet → should NOT be claimed. + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=2), + heartbeat_at=None, + original_request={"input": []}, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_op_without_original_request(self): + """Legacy rows created before durability metadata can't be resumed.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=None, + original_request=None, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_claim_fails_returns_none(self): + """Another pod won the race — we quietly step aside.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=300), + original_request={"input": [{"role": "user"}]}, + ) + with ( + patch( + f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=None + ) as mock_claim, + patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, + ): + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_awaited_once() + mock_append.assert_not_awaited() + + @pytest.mark.asyncio + async def test_successful_claim_spawns_resume_and_emits_sentinel(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"user_id": "u"}, + "context": {"conversation_id": "resp_x"}, + }, + ) + captured: dict = {} + + async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt_number=1): + captured["seq"] = seq + captured["event"] = stream_event + captured["attempt_tag"] = attempt_number + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch( + f"{MODULE}.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {}), _msg(1, None, {})], + ), + patch(f"{MODULE}.append_message", side_effect=fake_append), + patch("asyncio.create_task") as mock_create_task, + ): + attempt = await server._try_claim_and_resume("resp_x", resp) + + assert attempt == 2 + # Sentinel is written at next_seq (existing seqs were 0 and 1). + assert captured["seq"] == 2 + assert captured["event"]["type"] == "response.resumed" + assert captured["event"]["attempt"] == 2 + assert captured["attempt_tag"] == 2 + # A resume task is spawned; it was not awaited synchronously. + mock_create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_resume_replays_input_and_rotates_conversation_id(self): + """Resume must replay original_request.input (not blank it) and rotate + the conversation anchor so the handler resolves to a fresh thread / + session for the new attempt. Prevents the LangGraph stream-event + attempt-boundary orphan artifact (rotation-findings.md).""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"thread_id": "t1", "user_id": "u"}, + "context": {}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Input is REPLAYED (not blanked) and a prose-recovery user message is + # appended so attempt N+1's LLM sees the original request plus a + # narrative of what happened. The MLflow validator normalizes the shape. + assert len(dumped["input"]) == 2 + assert dumped["input"][0]["role"] == "user" + assert dumped["input"][0]["content"] == "hi" + assert dumped["input"][1]["role"] == "user" + assert "[RECOVERY]" in dumped["input"][1]["content"] + # thread_id was dropped so the handler's priority-2 fallback wins. + assert "thread_id" not in (dumped["custom_inputs"] or {}) + # Other custom_inputs keys are preserved. + assert dumped["custom_inputs"]["user_id"] == "u" + # conversation_id is rotated to a per-attempt value anchored on t1. + assert dumped["context"]["conversation_id"] == "t1::attempt-2" + assert kwargs.get("attempt_number") == 2 + + @pytest.mark.asyncio + async def test_resume_rotation_anchors_on_context_conversation_id(self): + """When the client didn't pin a thread_id/session_id, rotation uses + the injected context.conversation_id as the base anchor.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {}, + "context": {"conversation_id": "resp_x"}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=3), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Rotation anchors on the stored context.conversation_id (priority 2). + # Note: re-rotating in a subsequent attempt would re-anchor on the + # ORIGINAL stored value, not the previous rotation — no stacking. + assert dumped["context"]["conversation_id"] == "resp_x::attempt-3" + + +class TestRetrieveTriggersLazyClaim: + @pytest.mark.asyncio + async def test_retrieve_calls_try_claim(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + + resp = _resp_info("resp_x", "in_progress") + with ( + patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch.object( + server, "_try_claim_and_resume", new_callable=AsyncMock, return_value=None + ) as mock_claim, + ): + await server._handle_retrieve_request("resp_x", stream=False, starting_after=0) + + mock_claim.assert_awaited_once() + + +class TestHeartbeatContextManager: + @pytest.mark.asyncio + async def test_writes_heartbeat_periodically(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x", attempt_number=1): + await asyncio.sleep(0.2) # enough time for 2+ heartbeats + + # Heartbeat interval is 0.05s so we should see at least 2 writes. + assert mock_hb.await_count >= 2 + for call in mock_hb.await_args_list: + assert call.args[0] == "resp_x" + + @pytest.mark.asyncio + async def test_stops_cleanly_on_exit(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x", attempt_number=1): + pass # immediate exit + + # Give the heartbeat loop a chance to observe the stop signal. + await asyncio.sleep(0.1) + writes_after_exit = mock_hb.await_count + + await asyncio.sleep(0.15) + # No new writes after the scope closed. + assert mock_hb.await_count == writes_after_exit + + @pytest.mark.asyncio + async def test_db_error_does_not_interrupt_body(self): + """Heartbeat failures are logged, not raised — the stale check catches + real death, so a transient write miss must not kill a live run.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + body_ran = False + with patch( + f"{MODULE}.heartbeat_response", + new_callable=AsyncMock, + side_effect=RuntimeError("db down"), + ): + async with server._heartbeat("resp_x", attempt_number=1): + await asyncio.sleep(0.1) + body_ran = True + assert body_ran + + +class TestSettingsHeartbeatValidation: + def test_stale_must_exceed_interval(self): + with pytest.raises(ValueError, match="heartbeat_stale_threshold_seconds"): + LongRunningSettings( + heartbeat_interval_seconds=5.0, + heartbeat_stale_threshold_seconds=5.0, + ) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError, match="heartbeat_interval_seconds must be positive"): + LongRunningSettings(heartbeat_interval_seconds=0) + + def test_defaults_match_chat_ux(self): + # 3s interval + 15s stale gives ~5 heartbeats before a pod is considered + # dead — snug enough to recover conversations within a user's + # "reconnecting..." patience window. + s = LongRunningSettings() + assert s.heartbeat_interval_seconds == 3.0 + assert s.heartbeat_stale_threshold_seconds == 10.0 + + +class TestDebugKillTask: + """The opt-in debug-kill endpoint lets integration tests simulate a crash + against a deployed pod without restarting the whole app. Off by default + because exposing task cancellation bypasses the normal cleanup path.""" + + def test_endpoint_absent_by_default(self): + from starlette.testclient import TestClient + + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + resp = client.post("/_debug/kill_task/resp_x") + assert resp.status_code == 404 # route not registered + + def test_endpoint_registered_when_env_set(self, monkeypatch): + from starlette.testclient import TestClient + + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + # No in-flight task for this response_id on this pod → 404, not 405. + resp = client.post("/_debug/kill_task/resp_missing") + assert resp.status_code == 404 + assert "No in-flight task" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_cancels_tracked_task(self, monkeypatch): + """Direct-call variant: skip the TestClient (which is sync and blocks + the loop) and call the handler logic through _running_tasks directly. + Covers the important behavior: cancelling a tracked task propagates + CancelledError and the tracking dict is cleared by the done-callback. + """ + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + + cancel_event = asyncio.Event() + + async def long_running(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_event.set() + raise + + task = asyncio.create_task(long_running()) + server._track_task("resp_tracked", task) + + # Yield once so the new task can start waiting on sleep(60). + await asyncio.sleep(0) + assert "resp_tracked" in server._running_tasks + + task.cancel() + # Expect CancelledError from awaiting the task itself, and the cancel + # event set inside the except handler before the re-raise. + with pytest.raises(asyncio.CancelledError): + await task + assert cancel_event.is_set() + # done-callback (scheduled on loop) clears the registration after the + # task completes — give it one more tick. + await asyncio.sleep(0) + assert "resp_tracked" not in server._running_tasks