diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md
new file mode 100644
index 00000000..6b04a76c
--- /dev/null
+++ b/src/databricks_ai_bridge/long_running/AGENTS.md
@@ -0,0 +1,456 @@
+# LongRunningAgentServer
+
+Durable, crash-resumable agent execution for MLflow `ResponsesAgent` handlers.
+
+This document describes:
+1. What `LongRunningAgentServer` does and the guarantees it gives callers ([§1](#1-what-this-module-does)).
+2. The four customer journeys it covers, with sequence diagrams ([§2](#2-customer-journeys)).
+3. The architecture: storage layout, claim mechanism, recovery, and stream resume ([§3](#3-architecture)).
+4. Author-side requirements: what changes (and doesn't) when a handler opts into durable mode ([§4](#4-author-side-requirements)).
+5. The interface today and how it's expected to evolve when [TaskFlow](https://github.com/databricks-eng/universe/tree/master/experimental/taskflow) lands ([§5](#5-future-direction-taskflow)).
+
+## 1. What this module does
+
+`LongRunningAgentServer` extends MLflow's `AgentServer` for `ResponsesAgent` handlers with three capabilities:
+
+1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres.
+2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events.
+3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run and finishes the work via **prose recovery**: the new attempt receives a single user message containing `json.dumps(events)` of the crashed attempt's stream-event log plus a directive asking the LLM to figure out what's done vs interrupted and continue. The handler runs on a freshly-rotated SDK session.
+
+Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server.
+
+### Guarantees
+
+- **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`.
+- **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`.
+- **SDK-agnostic recovery.** The resumed attempt receives a flat prose narrative — no provider-specific tool-pair structure, no synthetic tool events, no per-SDK adapter code.
+- **Per-template UI-echo dedup.** The bridge does NOT trim echoed history. When the chat client echoes the full prior conversation in `request.input`, the agent handler is responsible for deduping its input against the SDK's session/checkpointer state — typically by forwarding only the latest user message when the session already has prior turns. See the templates in `app-templates/agent-{openai,langgraph}-advanced/` for the canonical 1-2 line shape.
+- **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility.
+- **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. All durability lives below the handler boundary.
+
+### Non-goals
+
+- Cross-region failover. Pods are assumed to share one Lakebase.
+- Tool-level checkpointing / exactly-once tool execution.
+- A workflow DSL. Handlers are ordinary async generators / coroutines.
+
+## 2. Customer journeys
+
+### CUJ 1: Author writes a long-running agent
+
+The author subclasses `LongRunningAgentServer` and registers `@invoke()` / `@stream()` handlers like a regular MLflow agent server. **No durability code in `agent.py`.**
+
+```python
+from databricks_ai_bridge.long_running import LongRunningAgentServer
+from mlflow.genai.agent_server import invoke, stream
+
+agent_server = LongRunningAgentServer(
+ "ResponsesAgent",
+ db_instance_name="my-lakebase-instance",
+)
+
+@stream()
+async def stream_handler(request):
+ # ordinary agent code: build messages, call SDK, yield events
+ ...
+
+@invoke()
+async def invoke_handler(request):
+ ...
+
+app = agent_server.app
+```
+
+The agent author writes their handler exactly the same way they would for the non-durable `AgentServer`. `LongRunningAgentServer` adds the durable wiring transparently.
+
+### CUJ 2: Pod crashes mid-tool, client polls
+
+A client posts a long-running request, the owning pod dies mid-tool, another pod takes over via prose recovery, and the client gets the final output without restarting.
+
+```mermaid
+sequenceDiagram
+ autonumber
+ participant C as Client
+ participant A as Pod A (owner)
+ participant DB as Lakebase
(agent_server.*)
+ participant B as Pod B
+ participant SDK as SDK store
(rotated session)
+
+ C->>A: POST /responses
{input, background:true, stream:true}
+ A->>DB: INSERT response_id, attempt=1,
heartbeat_at=now(), original_request (full input)
+ A-->>C: 200 {id: resp_xxx, status: in_progress}
+ activate A
+ Note over A: heartbeat loop (every 3s)
CAS-checks attempt_number = 1
+ A->>DB: UPDATE heartbeat_at WHERE attempt_number=1
+
+ par Streaming
+ A->>SDK: write to session T (original conv_id)
+ A->>DB: append events seq=0..N (attempt=1)
+ and Polling
+ C->>A: GET /responses/{id}?stream=true&starting_after=N
+ A-->>C: SSE events
+ end
+
+ Note over A: 💥 pod crashes mid-tool
(after tool_use, before tool_result)
+ deactivate A
+ Note over DB: heartbeat_at goes stale (>10s old)
+
+ C->>B: GET /responses/{id}?stream=true&starting_after=K
+ B->>DB: SELECT heartbeat_at, attempt
+ Note over B: heartbeat stale → claim
+ B->>DB: CAS UPDATE attempt=2, heartbeat_at=now()
WHERE attempt=1 AND heartbeat stale (atomic)
+ B->>DB: SELECT prior events WHERE attempt=1
+ Note over B: build resume input:
· original_request.input (full prior turns)
· + prose recovery message: json.dumps(events) + directive
· rotate conv_id → ::attempt-2
+ B->>DB: append response.resumed sentinel
{conversation_id: rotated value}
+ B->>B: re-invoke @stream() handler
(fresh rotated SDK session)
+ activate B
+ B->>SDK: write to T::attempt-2 (clean)
+ B->>DB: append events seq=K+1..M (attempt=2)
+ B-->>C: SSE events (response.resumed, then attempt-2 events)
+ deactivate B
+```
+
+**What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the attempt boundary and carries the rotated `conversation_id` so the chatbot can use the rotated session for subsequent turns.
+
+**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus a single user message whose body is `[RECOVERY] ... Events: `. The model reads it as "the prior attempt crashed, here's the raw event log, figure out what's done and continue."
+
+### CUJ 3: Subsequent turn after a crashed turn
+
+After a successful crash + resume, the next turn from the client lands on a fresh `POST /responses`. The chatbot uses the **rotated** `conversation_id` (captured from the `response.resumed` sentinel) so the handler resolves to the rotated SDK session — which was populated cleanly during attempt 2's prose-recovery run.
+
+```mermaid
+sequenceDiagram
+ autonumber
+ participant C as Chatbot
+ participant S as Server (any pod)
+ participant SDK as SDK store
+
+ Note over SDK: state at original conv_id T:
incomplete (orphan tool_use)
state at T::attempt-2:
complete (prose + resumed turn output)
+
+ C->>C: lookup alias map:
chat_id → T::attempt-2
+ C->>S: POST /responses
{input: [echo of full UI history, new user msg],
context.conversation_id: T::attempt-2}
+ S->>S: handler dedupes UI echo
(forwards only latest user when session has history)
+ S->>S: handler resolves session_id = T::attempt-2
+ S->>SDK: load history at T::attempt-2
+ SDK-->>S: clean state (prose + attempt-2 emissions)
+ S->>S: model receives [clean history] + [new user msg]
+ Note over S: turn succeeds normally
+
+ Note over SDK: original session T is now orphaned forever
(never read again — chatbot uses rotated alias)
+```
+
+Three things make this work without per-SDK repair code:
+
+1. **Server emits the rotated conv_id in `response.resumed`.** The chatbot reads it and updates its `Map` alias.
+2. **Per-template UI-echo dedup in the handler.** When the SDK's session/checkpointer already has prior-turn state, the agent forwards only the latest user message (not the full echoed history). This prevents `Runner.run` from sending duplicates of session items + input items to the LLM.
+3. **Always-rotate.** The rotated session was the one populated during attempt 2's run. Subsequent turns land on it. The original poisoned session is never read.
+
+### CUJ 4: Multi-pod stale-claim contention
+
+Two pods each see the same response in `in_progress` with a stale heartbeat. Only one wins the CAS.
+
+```mermaid
+sequenceDiagram
+ autonumber
+ participant B as Pod B
+ participant DB as Lakebase
+ participant C as Pod C
+
+ Note over B,C: both pods see response in_progress, heartbeat stale
+ par
+ B->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale
+ and
+ C->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale
+ end
+ Note over DB: only one row matches; the other UPDATE returns 0 rows
+ DB-->>B: RETURNING attempt_number=N+1
+ DB-->>C: RETURNING (no row)
+ Note over B: B wins, builds resume input + spawns handler
+ Note over C: C aborts cleanly, returns to its retrieve loop
+```
+
+The `claim_stale_response` function (`repository.py`) executes a single `UPDATE … WHERE attempt_number = :current AND ((heartbeat_at IS NULL) OR (heartbeat_at < now() - interval))` with `RETURNING`. Postgres serializes the writes; only the pod whose `current` value was unmodified at commit time gets the `RETURNING` row.
+
+## 3. Architecture
+
+### 3.1 Storage layout
+
+Two tables in the `agent_server` schema:
+
+```mermaid
+erDiagram
+ responses {
+ text response_id PK
+ text status "in_progress / completed / failed"
+ timestamptz created_at
+ timestamptz heartbeat_at
+ int attempt_number "CAS guard for claim atomicity"
+ text original_request "JSON of initial POST (full input)"
+ text trace_id
+ }
+ messages {
+ text response_id FK
+ int sequence_number
+ int attempt_number
+ text item "JSON of output item"
+ text stream_event "JSON of SSE frame"
+ }
+ responses ||--o{ messages : "has"
+```
+
+- `responses.attempt_number` is the CAS guard for claim atomicity. **There is no `owner_pod_id` column** — ownership is implicit. The pod that last successfully heartbeats at the current `attempt_number` is the de facto owner. A heartbeat write at attempt N stops working the moment another pod has CAS-bumped the row to N+1, so the prior owner detects it has lost the claim on its next heartbeat (rowcount=0) and shuts down its heartbeat task.
+- `responses.original_request` stores the **full untrimmed input** so the resume path can recover the entire prior-turn history when the rotated SDK session starts empty.
+- `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body).
+- Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime.
+
+### 3.2 The four key flows
+
+```mermaid
+flowchart TD
+ POST["POST /responses
background=true"] --> CR[create_response
store FULL original_request]
+ CR --> SPAWN[spawn @stream handler
handler dedupes UI echo against SDK session]
+ SPAWN --> HB[heartbeat loop
every 3s]
+ SPAWN --> EMIT[append SSE events
to messages table]
+ EMIT --> DONE{handler
exits?}
+ DONE -- yes --> COMPLETE[update status=completed]
+ DONE -- crash --> STALE[heartbeat stops
row goes stale]
+
+ GET["GET /responses/{id}
?stream=true&starting_after=N"] --> CHECK[check heartbeat age]
+ CHECK -->|fresh or terminal| READ[read events from messages
where seq > N]
+ CHECK -->|stale| CLAIM[CAS claim
attempt += 1]
+ CLAIM -->|won| BUILD[build prose recovery message
rotate conv_id
append rotated id to response.resumed]
+ CLAIM -->|lost| READ
+ BUILD --> SPAWN
+ READ --> STREAM[SSE stream to client]
+
+ KILL["POST /_debug/kill_task/{id}
(test-only)"] --> CANCEL[cancel asyncio task
without status update]
+ CANCEL --> STALE
+```
+
+### 3.3 Resume input construction
+
+When a stale-claim CAS succeeds, the new owner builds the resume input by serializing the prior attempt's stream events as JSON in a single user message:
+
+```mermaid
+flowchart LR
+ PRIOR[prior attempt's events
from messages table] --> FILTER[filter events by
attempt_number = prior_attempt]
+ FILTER --> JSON[json.dumps the events list]
+ JSON --> COMPOSE[compose single user message:
'[RECOVERY] previous attempt crashed.
Below is the raw stream-event log...
Inspect the events, figure out what is
already done versus in-progress, and continue.
Events: <json.dumps>']
+
+ ROT[_rotate_conversation_id
::attempt-N suffix] --> SUBMIT
+ COMPOSE --> SUBMIT[append to original_request.input
spawn handler with rotated request
emit response.resumed sentinel
with rotated conversation_id]
+```
+
+Why JSON-dumped events: the LLM reads them as the authoritative record of what attempt 1 did and decides what to do — re-run an interrupted tool, skip completed ones, summarize from there. No structural carry-forward, no synthetic tool events, no per-SDK pairing rules. The handler doesn't have to know any of this is durable resume — it just sees a recovery user message in `request.input`.
+
+Why rotation: the original SDK session may carry mid-turn state from the crashed attempt (orphan `tool_use`, partial checkpoint) that's hard to repair from outside the SDK. Rotating to `{base}::attempt-N` opens a fresh, empty session for the resumed attempt; the recovery message is the single source of truth for what already happened.
+
+Why the sentinel carries the rotated conv_id: cooperating chat clients capture it (via SSE) and use the rotated session for subsequent turns, so the original orphan-poisoned session is never read again.
+
+### 3.4 Per-template UI-echo dedup (NOT in the bridge)
+
+The bridge does **not** trim UI echo from `request.input`. Echo dedup is the agent handler's responsibility — it owns its SDK session/checkpointer and is the right layer to know what's already persisted vs what's a new turn.
+
+The canonical shape, per template:
+
+**OpenAI Agents SDK** (`agent-openai-advanced/agent_server/utils.py`):
+```python
+session_items = await session.get_items()
+if session_items and len(messages) > 1:
+ return [messages[-1]]
+return messages
+```
+
+**LangGraph** (`agent-langgraph-advanced/agent_server/agent.py`):
+```python
+state = await agent.aget_state(config)
+if state and state.values.get("messages") and input_state["messages"]:
+ last_user = next(
+ (m for m in reversed(input_state["messages"]) if m.get("role") == "user"),
+ None,
+ )
+ input_state["messages"] = [last_user] if last_user else []
+```
+
+Both: when the SDK store already has prior turns, forward only the latest user message and let the SDK prepend its own history. Without dedup, `Runner.run` (OpenAI) or `add_messages` (LangGraph) end up combining session+input → duplicate items → malformed assistant.tool_calls block → 400.
+
+Bridge's role here is just to pass `request.input` through untouched.
+
+### 3.5 Proactive stale-scan loop
+
+In addition to the lazy-on-GET claim path (a stale heartbeat is detected when a client GETs `/responses/{id}`), each pod runs a background **scanner** that periodically queries for in-progress responses with stale heartbeats and tries to claim+resume them. This means crashed responses get recovered even when no client is actively polling.
+
+```mermaid
+flowchart LR
+ A[every ~30s ± 50% jitter] --> B[SELECT response_id FROM responses
WHERE status='in_progress'
AND heartbeat_at < now-threshold
LIMIT 50]
+ B --> C{any rows?}
+ C -->|no| A
+ C -->|yes| D[for each id: get_response
+ _try_claim_and_resume]
+ D --> A
+```
+
+Each pod jitters its scan interval (`stale_scan_jitter_fraction = 0.5` by default) so multiple pods don't synchronize their queries. CAS-claim semantics ensure only one pod succeeds in claiming any given stale response.
+
+The scanner is a background task spawned in the FastAPI lifespan (alongside `init_db`) and cancelled on shutdown.
+
+### 3.6 Heartbeat and stale threshold
+
+Defaults are tuned for a single Lakebase deployment with low-latency writes.
+
+| Setting | Default | Rationale |
+|---|---|---|
+| `heartbeat_interval_seconds` | 3.0 | Frequent enough that short pauses (GC, tokio task waits) don't trip stale detection |
+| `heartbeat_stale_threshold_seconds` | 10.0 | Three missed heartbeats = unambiguously dead. Validated stale > interval at startup. |
+| `task_timeout_seconds` | 3600 | Hard ceiling. After this, a stuck `in_progress` row is force-failed regardless of heartbeat. |
+| `poll_interval_seconds` | 1.0 | Stream-retrieve polls the messages table at this rate while waiting for new events. |
+
+The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler.
+
+### 3.7 Claim atomicity
+
+```mermaid
+sequenceDiagram
+ autonumber
+ participant Pod as Pod B (claimer)
+ participant DB as Lakebase
+
+ Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R
+ DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress
+ Note over Pod: stale → attempt CAS
+ Pod->>DB: UPDATE responses
SET attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number
+ alt match
+ DB-->>Pod: attempt_number=2
+ Note over Pod: claim succeeded — Pod B is the new de facto owner
(prior owner's heartbeat at attempt=1 will fail next tick)
+ else no match (another pod beat us)
+ DB-->>Pod: (empty)
+ Note over Pod: claim lost — back off
+ end
+```
+
+Postgres row locking ensures only one of N concurrent UPDATEs matches the `attempt_number = N` predicate, so at most one pod ends up owning a given resume. The bumped `attempt_number` simultaneously revokes the prior owner's heartbeat: their next heartbeat write `WHERE attempt_number = N` returns rowcount=0, telling them they've lost the claim.
+
+## 4. Author-side requirements
+
+### 4.1 What's invisible to authors
+
+| Concern | Where it lives | Author-visible? |
+|---|---|---|
+| Heartbeat + claim | `LongRunningAgentServer` | No |
+| Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No |
+| Prose recovery message construction | `LongRunningAgentServer._build_prose_recovery_message` | No |
+| UI-echo dedup | per-template handler (see §3.4) | Yes — 1-2 lines in `agent.py` / `utils.py` |
+| Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No |
+| Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) |
+
+The author's `agent.py` is unchanged from a non-durable agent. They construct an `AsyncCheckpointSaver` (LangGraph) or `AsyncDatabricksSession` (OpenAI) and use it normally. Durability fires entirely above the SDK boundary — the SDK adapters themselves contain zero durability code.
+
+### 4.2 Author-visible client cooperation (chat-UI side)
+
+For the always-rotate flow to work cross-turn, a cooperating chat UI needs to:
+
+1. **Capture the rotated `conversation_id` from the SSE `response.resumed` event** when one is emitted during a streaming retrieve.
+2. **Use the rotated value as `context.conversation_id` on subsequent requests** for the same chat.
+
+The Express proxy in `e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts` does this with an in-memory `Map`. A multi-pod chatbot deployment would persist this on the chat row.
+
+Without client cooperation: the next turn lands on the original (orphan-poisoned) session and the LLM call fails on the provider's `tool_use ↔ tool_result` pairing rule. The bridge does not silently repair this — the cross-turn property requires the alias.
+
+### 4.3 Settings worth exposing to authors
+
+- `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config.
+- `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load.
+- `task_timeout_seconds` — per-attempt ceiling.
+- `stale_scan_interval_seconds` / `stale_scan_jitter_fraction` — controls how often (and with how much randomness) each pod scans the DB for stale responses to claim. Defaults to 30s with ±50% jitter.
+
+Everything else is internal.
+
+### 4.4 Test-only debug endpoint: `/_debug/kill_task/{response_id}`
+
+Cancels the in-flight asyncio task that owns the given response on this pod **without** running the `_task_scope` cleanup. The DB row stays `in_progress` with a heartbeat that's about to go stale — exactly the shape a real pod crash leaves. Used in integration tests to simulate a pod crash without restarting the container.
+
+Opt-in via env var: only registered when `LONG_RUNNING_ENABLE_DEBUG_KILL=1`. Never exposed in production.
+
+Returns 404 if no in-flight task for that response exists on this specific pod (the task may already have finished, or it may be running on a different pod).
+
+```bash
+curl -sS -X POST -H "Authorization: Bearer $TOKEN" "$APP_URL/_debug/kill_task/$RID"
+```
+
+## 5. Future direction: TaskFlow
+
+[TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery.
+
+When TaskFlow is production-ready, `LongRunningAgentServer` is expected to keep its **HTTP surface and author-visible API unchanged**, swapping only the engine internals.
+
+### Mapping today → TaskFlow
+
+```mermaid
+flowchart LR
+ subgraph TODAY[LongRunningAgentServer today]
+ T1[create_response + asyncio.create_task]
+ T2[_heartbeat async CM]
+ T3[_try_claim_and_resume CAS]
+ T4[_build_prose_recovery_message]
+ T5[/responses/{id}?stream=true]
+ T6[/_debug/kill_task]
+ end
+
+ subgraph TF[TaskFlow]
+ F1[Taskflow.start name input user_id]
+ F2[built-in executor heartbeat]
+ F3[built-in recovery worker + claim_for_recovery]
+ F4[TaskHandler.recover ctx previous_events]
+ F5[Taskflow.subscribe key last_seq]
+ F6[Taskflow.simulate_crash key]
+ end
+
+ T1 --> F1
+ T2 --> F2
+ T3 --> F3
+ T4 --> F4
+ T5 --> F5
+ T6 --> F6
+```
+
+### What stays in `LongRunningAgentServer` after the swap
+
+- `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas).
+- The MLflow `@invoke()` / `@stream()` handler convention.
+- `_build_prose_recovery_message` — recovery-message construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`.
+- `_rotate_conversation_id` — same reason.
+- Author-visible settings (db config, heartbeat tuning, task timeout, stale-scan tuning).
+
+### What gets deleted
+
+- The hand-rolled heartbeat task and CAS claim CTE — replaced by TaskFlow's executor heartbeat + `claim_for_recovery`.
+- The `_try_claim_and_resume` lazy-claim path — TaskFlow's recovery worker handles this proactively and across pods, fixing the documented "claim only fires on GET" limitation.
+- The `agent_server.responses` + `agent_server.messages` schema — TaskFlow owns its storage layer.
+- The stream-cursor logic in `_stream_retrieve` — `Taskflow.subscribe(key, last_seq)` is the cursor-based stream resume.
+
+### What requires a small TaskFlow API addition
+
+TaskFlow derives idempotency keys as `SHA256(name + canonical_input + user_id)`. The HTTP surface uses a server-generated `resp_{uuid}`. We've requested an `idempotency_key: Option` parameter on `Taskflow.submit()` so we can keep the existing HTTP contract while submitting to TaskFlow. See [`engine.rs:317`](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/experimental/taskflow/engine/src/engine.rs?L317) for the current `generate_key` call site; the override would slot in there.
+
+### Migration sequencing
+
+1. Add `LongRunningAgentServer(backend="taskflow"|"builtin")` knob, default `"builtin"`. HTTP surface unchanged on either backend.
+2. Port `agent-non-conversational` (the simplest template) to `backend="taskflow"`. Run the full crash-resume matrix.
+3. Port the advanced templates. **Zero changes to `agent.py`** — the swap is the constructor argument.
+4. Flip default to `"taskflow"`. Deprecate `"builtin"`.
+5. Delete the heartbeat / claim / repository code. Big delete PR.
+
+The point of the `LongRunningAgentServer` abstraction is exactly this kind of swap: callers should never have to care which engine is underneath.
+
+---
+
+## Quick reference
+
+- **Code:** `src/databricks_ai_bridge/long_running/`
+- **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py`
+- **Settings:** `LongRunningSettings` in `settings.py`
+- **Models:** `Response` and `Message` in `models.py`
+- **HTTP routes:** registered in `LongRunningAgentServer._setup_routes`
+- **Prose recovery:** `_build_prose_recovery_message` in `server.py`
+- **Conversation rotation:** `_rotate_conversation_id` in `server.py`
+- **Stale scanner:** `LongRunningAgentServer._stale_response_scanner_loop` in `server.py`
+- **UI-echo dedup:** in agent code, see `app-templates/agent-{openai,langgraph}-advanced/`
diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py
index 903d466f..69ca8a79 100644
--- a/src/databricks_ai_bridge/long_running/db.py
+++ b/src/databricks_ai_bridge/long_running/db.py
@@ -79,7 +79,50 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy):
await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}"))
await conn.run_sync(Base.metadata.create_all)
+ # Idempotent migration for tables created by earlier versions: add any
+ # columns introduced for durable-resume support. Each statement runs in
+ # its own transaction so an InsufficientPrivilege on one ALTER (another
+ # pod's SP owns the table but the schema is already migrated) doesn't
+ # poison the rest. A single mega-transaction would abort entirely on the
+ # first owner-check failure even with IF NOT EXISTS.
+ migration_stmts = (
+ f"ALTER TABLE {AGENT_DB_SCHEMA}.responses "
+ "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ",
+ f"ALTER TABLE {AGENT_DB_SCHEMA}.responses "
+ "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1",
+ f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT",
+ f"ALTER TABLE {AGENT_DB_SCHEMA}.messages "
+ "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1",
+ f"CREATE INDEX IF NOT EXISTS idx_responses_stale "
+ f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) "
+ "WHERE status = 'in_progress'",
+ )
+ skipped_migrations: list[str] = []
+ for stmt in migration_stmts:
+ try:
+ async with _engine.begin() as conn:
+ await conn.execute(text(stmt))
+ except Exception as exc:
+ msg = str(exc).lower()
+ if "insufficientprivilege" in msg or "must be owner" in msg:
+ skipped_migrations.append(stmt.split("\n")[0])
+ continue
+ raise
+
_initialized = True
+ if skipped_migrations:
+ # WARN-level summary: if the DB was previously migrated by another SP
+ # this is fine, but if it's genuinely a new table and our SP lacks
+ # ALTER, claim/heartbeat queries will fail later with a confusing
+ # "column does not exist" — surface it clearly at startup.
+ logger.warning(
+ "[DB] Skipped %d durability migration(s) due to insufficient "
+ "privilege — assuming table was already migrated by another "
+ "service principal. Crash-resume will fail with 'column does "
+ "not exist' if this assumption is wrong. Skipped: %s",
+ len(skipped_migrations),
+ ", ".join(skipped_migrations),
+ )
logger.info("[DB] Engine and schema ready")
diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py
index 1d876dc7..dfdcca70 100644
--- a/src/databricks_ai_bridge/long_running/models.py
+++ b/src/databricks_ai_bridge/long_running/models.py
@@ -14,7 +14,14 @@ class Base(DeclarativeBase):
class Response(Base):
- """Response status tracking for background agent tasks."""
+ """Response status tracking for background agent tasks.
+
+ Durability columns (``heartbeat_at``, ``attempt_number``,
+ ``original_request``) support crash-resume: another pod atomically
+ claims a stale in-progress row by CAS-ing on ``attempt_number`` and
+ replays the agent loop. The owning pod is implicit — it's whatever
+ pod last successfully heartbeat at the current attempt_number.
+ """
__tablename__ = "responses"
__table_args__ = {"schema": AGENT_DB_SCHEMA}
@@ -25,12 +32,22 @@ class Response(Base):
DateTime(timezone=True), nullable=False, server_default=func.now()
)
trace_id: Mapped[str | None] = mapped_column(Text, nullable=True)
+ heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
+ attempt_number: Mapped[int] = mapped_column(
+ Integer, nullable=False, server_default="1", default=1
+ )
+ original_request: Mapped[str | None] = mapped_column(Text, nullable=True)
messages = relationship("Message", back_populates="response", cascade="all, delete-orphan")
class Message(Base):
- """Stream events and output items for a response."""
+ """Stream events and output items for a response.
+
+ ``attempt_number`` tags events by which run attempt emitted them so that
+ resumed runs append to the same event log without overwriting earlier
+ (abandoned) attempts, and retrieve can filter to the latest attempt only.
+ """
__tablename__ = "messages"
__table_args__ = {"schema": AGENT_DB_SCHEMA}
@@ -44,6 +61,9 @@ class Message(Base):
sequence_number: Mapped[int] = mapped_column(
Integer, primary_key=True, nullable=False, default=0
)
+ attempt_number: Mapped[int] = mapped_column(
+ Integer, nullable=False, server_default="1", default=1
+ )
item: Mapped[str | None] = mapped_column(Text, nullable=True)
stream_event: Mapped[str | None] = mapped_column(Text, nullable=True)
diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py
index fcd86b29..2c6a1ff0 100644
--- a/src/databricks_ai_bridge/long_running/repository.py
+++ b/src/databricks_ai_bridge/long_running/repository.py
@@ -5,30 +5,64 @@
from typing import Any, NamedTuple
from sqlalchemy import select, update
+from sqlalchemy.sql import bindparam, text
from databricks_ai_bridge.long_running.db import session_scope
-from databricks_ai_bridge.long_running.models import Message, Response
+from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response
-async def create_response(response_id: str, status: str) -> None:
- """Insert a new response."""
+async def create_response(
+ response_id: str,
+ status: str,
+ *,
+ durable: bool = False,
+ original_request: dict[str, Any] | None = None,
+) -> None:
+ """Insert a new response row.
+
+ When ``durable=True``, ``heartbeat_at`` is initialized to ``now()`` so
+ the row doesn't immediately look stale. Non-durable callers (tests,
+ legacy flows) skip the heartbeat init.
+ """
async with session_scope() as session:
- session.add(Response(response_id=response_id, status=status))
+ session.add(
+ Response(
+ response_id=response_id,
+ status=status,
+ heartbeat_at=datetime.now().astimezone() if durable else None,
+ original_request=(
+ json.dumps(original_request) if original_request is not None else None
+ ),
+ )
+ )
await session.commit()
async def update_response_status(
- response_id: str, status: str, *, expected_current_status: str | None = None
+ response_id: str,
+ status: str,
+ *,
+ expected_current_status: str | None = None,
+ expected_attempt_number: int | None = None,
) -> bool:
"""Update response status. Returns True if a row was updated.
If *expected_current_status* is given the update only takes effect when the
row's current status matches, avoiding concurrent-update races.
+
+ If *expected_attempt_number* is given the update only takes effect when the
+ row's current ``attempt_number`` matches, ensuring only the pod that owns
+ the current attempt can transition the row to a terminal state. This
+ prevents a stale background task (e.g. a deferred-fail timer that fired
+ after another pod claimed the row for resume) from clobbering the new
+ owner's in-progress state.
"""
async with session_scope() as session:
stmt = update(Response).where(Response.response_id == response_id)
if expected_current_status is not None:
stmt = stmt.where(Response.status == expected_current_status)
+ if expected_attempt_number is not None:
+ stmt = stmt.where(Response.attempt_number == expected_attempt_number)
stmt = stmt.values(status=status)
result = await session.execute(stmt)
await session.commit()
@@ -43,18 +77,120 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None:
await session.commit()
+async def heartbeat_response(response_id: str, expected_attempt_number: int) -> bool:
+ """Update heartbeat_at for a response IFF the attempt is still ours.
+
+ Returns True on success. A False result means the claim has been lost —
+ another pod CAS-bumped ``attempt_number``, so this pod is no longer the
+ owner and the heartbeat task should stop. Implicit-ownership model:
+ whichever pod last successfully heartbeats at the current
+ ``attempt_number`` is the de facto owner.
+ """
+ async with session_scope() as session:
+ stmt = (
+ update(Response)
+ .where(
+ Response.response_id == response_id,
+ Response.attempt_number == expected_attempt_number,
+ )
+ .values(heartbeat_at=datetime.now().astimezone())
+ )
+ result = await session.execute(stmt)
+ await session.commit()
+ return result.rowcount > 0
+
+
+async def claim_stale_response(
+ response_id: str,
+ stale_threshold_seconds: float,
+) -> int | None:
+ """Atomically claim an in-progress response whose heartbeat has gone stale.
+
+ Uses a single conditional UPDATE so exactly one caller wins on contention:
+ claim only succeeds if status is ``in_progress`` AND
+ (``heartbeat_at IS NULL`` OR ``heartbeat_at`` is older than the threshold).
+ The new attempt_number is the previous + 1; the prior attempt's heartbeat
+ task will detect this on its next heartbeat (rowcount=0) and stop.
+
+ Returns the new ``attempt_number`` on success, or ``None`` if the row did
+ not satisfy the claim conditions (already completed, heartbeat still fresh,
+ or nonexistent).
+ """
+ # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for
+ # the incremented column as ergonomically. Using a single statement keeps the
+ # claim atomic without an explicit transaction-level lock.
+ stmt = text(
+ f"""
+ UPDATE {AGENT_DB_SCHEMA}.responses
+ SET heartbeat_at = now(),
+ attempt_number = attempt_number + 1
+ WHERE response_id = :rid
+ AND status = 'in_progress'
+ AND (heartbeat_at IS NULL
+ OR heartbeat_at < now() - make_interval(secs => :threshold))
+ RETURNING attempt_number
+ """
+ ).bindparams(
+ bindparam("rid", type_=None),
+ bindparam("threshold", type_=None),
+ )
+ async with session_scope() as session:
+ result = await session.execute(
+ stmt,
+ {"rid": response_id, "threshold": stale_threshold_seconds},
+ )
+ row = result.first()
+ await session.commit()
+ return int(row[0]) if row else None
+
+
+async def find_stale_response_ids(
+ stale_threshold_seconds: float,
+ limit: int = 50,
+) -> list[str]:
+ """Return ids of in_progress responses whose heartbeat is older than the
+ threshold. Used by the proactive scanner to find candidates for resume
+ without waiting for a client GET.
+
+ Limited to ``limit`` rows per scan to bound DB load. Ordered by
+ ``heartbeat_at`` ascending so the oldest staleness is handled first.
+ """
+ stmt = text(
+ f"""
+ SELECT response_id FROM {AGENT_DB_SCHEMA}.responses
+ WHERE status = 'in_progress'
+ AND heartbeat_at IS NOT NULL
+ AND heartbeat_at < now() - make_interval(secs => :threshold)
+ ORDER BY heartbeat_at ASC
+ LIMIT :limit
+ """
+ ).bindparams(
+ bindparam("threshold", type_=None),
+ bindparam("limit", type_=None),
+ )
+ async with session_scope() as session:
+ result = await session.execute(
+ stmt,
+ {"threshold": stale_threshold_seconds, "limit": limit},
+ )
+ return [row[0] for row in result.all()]
+
+
async def append_message(
response_id: str,
sequence_number: int,
item: str | None = None,
stream_event: dict[str, Any] | None = None,
+ *,
+ attempt_number: int = 1,
) -> None:
- """Append a message (stream event) for a response."""
+ """Append a message (stream event) for a response, tagged with attempt_number."""
async with session_scope() as session:
session.add(
Message(
response_id=response_id,
sequence_number=sequence_number,
+ attempt_number=attempt_number,
item=item,
stream_event=json.dumps(stream_event) if stream_event is not None else None,
)
@@ -65,22 +201,26 @@ async def append_message(
async def get_messages(
response_id: str,
after_sequence: int | None = None,
-) -> list[tuple[int, str | None, dict[str, Any] | None]]:
- """Fetch messages for a response, optionally after a sequence number.
+ *,
+ attempt_number: int | None = None,
+) -> list[tuple[int, str | None, dict[str, Any] | None, int]]:
+ """Fetch messages for a response, optionally filtering by sequence / attempt.
- Returns list of (sequence_number, item, stream_event_dict).
+ Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``.
"""
async with session_scope() as session:
stmt = select(Message).where(Message.response_id == response_id)
if after_sequence is not None:
stmt = stmt.where(Message.sequence_number > after_sequence)
+ if attempt_number is not None:
+ stmt = stmt.where(Message.attempt_number == attempt_number)
stmt = stmt.order_by(Message.sequence_number)
result = await session.execute(stmt)
rows = result.scalars().all()
out = []
for r in rows:
evt = json.loads(r.stream_event) if r.stream_event else None
- out.append((r.sequence_number, r.item, evt))
+ out.append((r.sequence_number, r.item, evt, r.attempt_number))
return out
@@ -89,6 +229,9 @@ class ResponseInfo(NamedTuple):
status: str
created_at: datetime
trace_id: str | None
+ heartbeat_at: datetime | None
+ attempt_number: int
+ original_request: dict[str, Any] | None
async def get_response(response_id: str) -> ResponseInfo | None:
@@ -97,5 +240,13 @@ async def get_response(response_id: str) -> ResponseInfo | None:
result = await session.execute(select(Response).where(Response.response_id == response_id))
row = result.scalar_one_or_none()
if row:
- return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id)
+ return ResponseInfo(
+ row.response_id,
+ row.status,
+ row.created_at,
+ row.trace_id,
+ row.heartbeat_at,
+ row.attempt_number,
+ json.loads(row.original_request) if row.original_request else None,
+ )
return None
diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py
index b3374d67..40007ef4 100644
--- a/src/databricks_ai_bridge/long_running/server.py
+++ b/src/databricks_ai_bridge/long_running/server.py
@@ -6,9 +6,13 @@
raise RuntimeError("The long_running module requires Python 3.11 or later.")
import asyncio
+import copy
import inspect
import json
import logging
+import os
+import random
+import socket
import time
import uuid
from collections.abc import AsyncGenerator
@@ -34,9 +38,12 @@
from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured
from databricks_ai_bridge.long_running.repository import (
append_message,
+ claim_stale_response,
create_response,
+ find_stale_response_ids,
get_messages,
get_response,
+ heartbeat_response,
update_response_status,
update_response_trace_id,
)
@@ -47,15 +54,29 @@
BACKGROUND_KEY = "background"
+# Process-local identifier for log lines. Not stored in the DB — heartbeat
+# ownership is implicit via attempt_number CAS.
+_POD_LOG_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}"
+
async def _deferred_mark_failed(
- response_id: str, delay: float = 2.0, reason: str = "Task timed out"
+ response_id: str,
+ delay: float = 2.0,
+ reason: str = "Task timed out",
+ *,
+ owning_attempt_number: int | None = None,
) -> None:
"""Mark a response as failed after a short delay.
Runs as an independent asyncio task so the caller (``_task_scope``) can
- return immediately. The delay lets the connection pool stabilise after
- a cancellation before we attempt new DB writes.
+ return immediately. The delay lets the connection pool stabilise after a
+ cancellation before we attempt new DB writes.
+
+ ``owning_attempt_number`` should be the attempt this pod was running when
+ the failure was scheduled. The terminal status update is CAS-checked
+ against it: if another pod has already claimed the row for a higher
+ attempt by the time this fires, we skip the failed-status write so we
+ don't clobber the new owner's state.
"""
try:
await asyncio.sleep(delay)
@@ -65,7 +86,17 @@ async def _deferred_mark_failed(
# or SELECT FOR UPDATE on the response row to serialise writers.
async with asyncio.timeout(delay):
existing = await get_messages(response_id, after_sequence=None)
- next_seq = max((seq for seq, _, _ in existing), default=-1) + 1
+ next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1
+ current_attempt = await _current_attempt(response_id)
+ if owning_attempt_number is not None and current_attempt != owning_attempt_number:
+ logger.info(
+ "Skipping deferred fail for %s: ownership changed "
+ "(was attempt=%d, now attempt=%d)",
+ response_id,
+ owning_attempt_number,
+ current_attempt,
+ )
+ return
error_event = {
"type": "error",
@@ -75,8 +106,18 @@ async def _deferred_mark_failed(
"code": "task_timeout",
},
}
- await append_message(response_id, next_seq, item=None, stream_event=error_event)
- await update_response_status(response_id, "failed")
+ await append_message(
+ response_id,
+ next_seq,
+ item=None,
+ stream_event=error_event,
+ attempt_number=current_attempt,
+ )
+ await update_response_status(
+ response_id,
+ "failed",
+ expected_attempt_number=owning_attempt_number,
+ )
logger.info("Marked %s as failed (reason: %s)", response_id, reason)
except TimeoutError:
@@ -91,10 +132,21 @@ async def _deferred_mark_failed(
)
+async def _current_attempt(response_id: str) -> int:
+ """Fetch the current attempt_number for a response, defaulting to 1."""
+ resp = await get_response(response_id)
+ return resp.attempt_number if resp else 1
+
+
def _sse_event(event_type: str, data: dict[str, Any] | str) -> str:
- """Format an SSE event per Open Responses spec."""
+ """Emit ``data:``-only SSE frames. Match the non-durable stream format
+ so downstream SSE parsers dispatch on the payload's ``type`` field
+ rather than a leading ``event:`` name line. Claude's multi-response
+ stream (one response.created/completed pair per tool iteration) plus
+ the event-name prefix confuses the AI SDK's Databricks provider into
+ a retry loop."""
payload = data if isinstance(data, str) else json.dumps(data)
- return f"event: {event_type}\ndata: {payload}\n\n"
+ return f"data: {payload}\n\n"
def _age_seconds(created_at: datetime) -> float:
@@ -105,9 +157,99 @@ def _age_seconds(created_at: datetime) -> float:
return (now - created_at).total_seconds()
+def _build_prose_recovery_message(
+ messages: list[tuple], prior_attempt_number: int
+) -> dict[str, Any]:
+ """Build a single user message containing the prior attempt's raw event
+ log + a directive that asks the LLM to figure out what already completed
+ and continue.
+
+ The body is `json.dumps(events)` of the prior attempt's stream events
+ wrapped in a recovery prompt. SDK-agnostic — no provider-specific pairing
+ rules, no structured carry-forward, no synthetic events. The model reads
+ the JSON, decides which tool calls succeeded, which were interrupted, and
+ continues.
+ """
+ prior_events = [
+ evt
+ for _seq, _item_json, evt, attempt_tag in messages
+ if attempt_tag == prior_attempt_number and isinstance(evt, dict)
+ ]
+ body = (
+ "[RECOVERY] The previous attempt of this agent task crashed "
+ "mid-execution. Below is the raw stream-event log from that attempt "
+ "as JSON. Some tool calls may have completed and some may have been "
+ "interrupted before returning a result. Inspect the events, figure "
+ "out what is already done versus in-progress / not completed, and "
+ "continue the task from where it left off. If a tool call was "
+ "interrupted, you may re-invoke it if its result is still needed.\n\n"
+ f"Events:\n{json.dumps(prior_events)}"
+ )
+ return {
+ "type": "message",
+ "role": "user",
+ "content": body,
+ }
+
+
+def _rotate_conversation_id(
+ request_dict: dict[str, Any],
+ new_attempt_number: int,
+ response_id: str,
+) -> dict[str, Any]:
+ """Rotate the conversation anchor to a per-attempt value.
+
+ After a crash, attempt N+1 should see a FRESH checkpointer / session so it
+ doesn't inherit mid-turn state that the SDK can't repair cleanly (most
+ notably the LangGraph stream-event attempt-boundary orphan artifact).
+ The handler's priority chain is:
+
+ 1. custom_inputs.thread_id / session_id (explicit, wins)
+ 2. context.conversation_id (fallback)
+ 3. auto-generated (last resort)
+
+ We drop (1), pick the current base anchor, and write ``{base}::attempt-N``
+ into (2). The handler then resolves to a fresh key for this attempt while
+ still being deterministic across retries of the same attempt.
+
+ The LLM sees full turn history via ``original_request.input``, which was
+ captured at the initial POST — before any attempt ran, so it's clean by
+ construction.
+ """
+ custom_inputs = request_dict.get("custom_inputs")
+ if not isinstance(custom_inputs, dict):
+ custom_inputs = {}
+
+ base_anchor = (
+ custom_inputs.get("thread_id")
+ or custom_inputs.get("session_id")
+ or (request_dict.get("context") or {}).get("conversation_id")
+ or response_id
+ )
+
+ custom_inputs.pop("thread_id", None)
+ custom_inputs.pop("session_id", None)
+ request_dict["custom_inputs"] = custom_inputs
+
+ ctx = request_dict.get("context") or {}
+ ctx = dict(ctx)
+ rotated = f"{base_anchor}::attempt-{new_attempt_number}"
+ ctx["conversation_id"] = rotated
+ request_dict["context"] = ctx
+ logger.info(
+ "[durable] rotated conversation_id for resume response_id=%s attempt=%d base=%s rotated=%s",
+ response_id,
+ new_attempt_number,
+ base_anchor,
+ rotated,
+ )
+ return request_dict
+
+
@experimental
class LongRunningAgentServer(AgentServer):
- """AgentServer subclass adding background mode and retrieve endpoints.
+ """AgentServer subclass adding background mode, retrieve endpoints, and
+ durable resume.
Only compatible with ``ResponsesAgent`` mode.
@@ -125,6 +267,16 @@ class LongRunningAgentServer(AgentServer):
``LAKEBASE_INSTANCE_NAME``, ``LAKEBASE_AUTOSCALING_ENDPOINT``, or both
``LAKEBASE_AUTOSCALING_PROJECT`` and ``LAKEBASE_AUTOSCALING_BRANCH``.
+ Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run
+ whose owning pod has stopped heartbeating for more than
+ ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims
+ the run and re-invokes the registered handler with a rotated
+ ``conversation_id`` (so the agent SDK resolves to a fresh thread/session),
+ the original request's ``input`` enriched with the prior attempt's already
+ emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic
+ output paired with any tool call that didn't finish. Completed work is
+ preserved; only the interrupted step re-runs.
+
Args:
enable_chat_proxy: Whether to enable the chat proxy endpoint.
db_instance_name: Lakebase provisioned instance name. Overrides
@@ -143,6 +295,12 @@ class LongRunningAgentServer(AgentServer):
Defaults to 5000 (5 seconds).
cleanup_timeout_seconds: Timeout for DB cleanup after task failure.
Defaults to 7.0.
+ heartbeat_interval_seconds: How often the owning pod writes
+ ``heartbeat_at`` while a run is in flight. Defaults to 3.0.
+ heartbeat_stale_threshold_seconds: Age at which a heartbeat is
+ considered stale and another pod may claim the run. Also used
+ as the grace window for a freshly-created run that hasn't
+ written its first heartbeat yet. Defaults to 10.0.
"""
_SUPPORTED_AGENT_TYPE = "ResponsesAgent"
@@ -162,6 +320,8 @@ def __init__(
poll_interval_seconds: float = 1.0,
db_statement_timeout_ms: int = 5000,
cleanup_timeout_seconds: float = 7.0,
+ heartbeat_interval_seconds: float = 3.0,
+ heartbeat_stale_threshold_seconds: float = 10.0,
):
if agent_type != self._SUPPORTED_AGENT_TYPE:
raise ValueError(
@@ -173,11 +333,18 @@ def __init__(
poll_interval_seconds=poll_interval_seconds,
db_statement_timeout_ms=db_statement_timeout_ms,
cleanup_timeout_seconds=cleanup_timeout_seconds,
+ heartbeat_interval_seconds=heartbeat_interval_seconds,
+ heartbeat_stale_threshold_seconds=heartbeat_stale_threshold_seconds,
)
self._db_instance_name = db_instance_name
self._db_autoscaling_endpoint = db_autoscaling_endpoint
self._db_project = db_project
self._db_branch = db_branch
+ # Track in-flight background tasks per response_id so the debug-kill
+ # endpoint can simulate a pod crash without tearing the whole pod
+ # down. Not load-bearing for correctness — durability still relies on
+ # DB state, this is just a test affordance.
+ self._running_tasks: dict[str, asyncio.Task] = {}
super().__init__(agent_type, enable_chat_proxy=enable_chat_proxy)
def _setup_routes(self) -> None:
@@ -195,6 +362,46 @@ async def cancel_endpoint(response_id: str):
detail="Cancellation is not yet implemented.",
)
+ # Debug endpoint for testing durable resume: cancels the in-flight
+ # asyncio task that owns the given response_id WITHOUT running the
+ # _task_scope cleanup, so the DB row stays in_progress with a
+ # going-stale heartbeat — exactly the shape a real pod crash leaves.
+ # Opt-in via env var so it's never exposed in production. Env var
+ # is checked at request time (not registration time) because some
+ # platforms inject env vars after the FastAPI app object is built.
+ @self.app.post("/_debug/kill_task/{response_id}")
+ async def _debug_kill_task(response_id: str):
+ if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") != "1":
+ raise HTTPException(
+ status_code=404,
+ detail="Debug kill endpoint is disabled.",
+ )
+ task = self._running_tasks.get(response_id)
+ if task is None:
+ logger.info(
+ "[durable] kill endpoint: no task response_id=%s on pod=%s",
+ response_id,
+ _POD_LOG_ID,
+ )
+ raise HTTPException(
+ status_code=404,
+ detail=(
+ "No in-flight task for that response_id on this pod "
+ "(may already have finished or be running on another pod)."
+ ),
+ )
+ logger.info(
+ "[durable] kill endpoint: cancelling task response_id=%s pod=%s",
+ response_id,
+ _POD_LOG_ID,
+ )
+ task.cancel()
+ return {
+ "response_id": response_id,
+ "pod_id": _POD_LOG_ID,
+ "status": "task_cancelled",
+ }
+
db_configured = is_db_configured()
@self.app.get("/responses/{response_id}")
@@ -234,8 +441,19 @@ async def _db_lifespan(app):
branch=self._db_branch,
db_statement_timeout_ms=self._settings.db_statement_timeout_ms,
)
- yield
- await dispose_db()
+ scanner_task = asyncio.create_task(
+ self._stale_response_scanner_loop(),
+ name="durable-stale-scanner",
+ )
+ try:
+ yield
+ finally:
+ scanner_task.cancel()
+ try:
+ await scanner_task
+ except (asyncio.CancelledError, Exception):
+ pass
+ await dispose_db()
self.app.router.lifespan_context = _db_lifespan
@@ -290,11 +508,32 @@ async def _handle_background_request(
) -> dict[str, Any] | StreamingResponse:
"""Start a new conversation and return response_id immediately."""
response_id = f"resp_{uuid.uuid4().hex[:24]}"
- await create_response(response_id, "in_progress")
+ # Anchor the conversation to response_id so any future replay from a
+ # different pod resolves to the same agent-SDK thread/session. We
+ # round-trip through dict + validator so the handler still receives a
+ # pydantic ResponsesAgentRequest (its declared arg type). The
+ # declared param type is ``dict`` but the runtime object is a pydantic
+ # model from ``validate_and_convert_request``; fall back to ``dict()``
+ # when tests pass a plain dict directly.
+ dump = getattr(request_data, "model_dump", None)
+ request_dict = dump() if callable(dump) else dict(request_data)
+ # Store the FULL request (untrimmed) as `original_request` so resume can
+ # recover the entire prior-turn history. Per-template handlers are
+ # responsible for deduping their own UI-echoed input against the SDK's
+ # session/checkpointer state — the bridge no longer trims input.
+ await create_response(
+ response_id,
+ "in_progress",
+ durable=True,
+ original_request=request_dict,
+ )
+ durable_request = self.validator.validate_and_convert_request(request_dict)
- logger.debug(
- "Background response created",
- extra={"response_id": response_id, "stream": is_streaming},
+ logger.info(
+ "Background response created response_id=%s stream=%s pod=%s",
+ response_id,
+ is_streaming,
+ _POD_LOG_ID,
)
response_obj: dict[str, Any] = {
@@ -309,26 +548,187 @@ async def _handle_background_request(
}
# Fire-and-forget is intentional — task status is persisted to the database.
+ # We still track the task handle so the debug-kill endpoint can simulate
+ # a crash (and so we know whether a claim target lives on this pod).
if is_streaming:
- asyncio.create_task(
- self._run_background_stream(response_id, request_data, return_trace_id)
+ task = asyncio.create_task(
+ self._run_background_stream(
+ response_id, durable_request, return_trace_id, attempt_number=1
+ )
)
+ self._track_task(response_id, task)
return await self._handle_retrieve_request(
response_id,
stream=True,
starting_after=0,
)
else:
- asyncio.create_task(
- self._run_background_invoke(response_id, request_data, return_trace_id)
+ task = asyncio.create_task(
+ self._run_background_invoke(
+ response_id, durable_request, return_trace_id, attempt_number=1
+ )
)
+ self._track_task(response_id, task)
return response_obj
+ def _track_task(self, response_id: str, task: asyncio.Task) -> None:
+ """Record a background task so the debug-kill endpoint can find it."""
+ self._running_tasks[response_id] = task
+ task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None))
+
+ async def _stale_response_scanner_loop(self) -> None:
+ """Periodically scan for in_progress responses with stale heartbeats and
+ try to claim+resume them. The proactive counterpart to the lazy claim
+ path on ``GET /responses/{id}``.
+
+ Each iteration sleeps for a jittered interval so multiple pods don't
+ synchronize their reads. Runs until cancelled (in the lifespan
+ teardown).
+ """
+ base = self._settings.stale_scan_interval_seconds
+ jitter = self._settings.stale_scan_jitter_fraction
+ threshold = self._settings.heartbeat_stale_threshold_seconds
+ logger.info(
+ "[durable] stale-scan loop start interval=%.1fs jitter=±%.0f%% threshold=%.1fs pod=%s",
+ base,
+ jitter * 100,
+ threshold,
+ _POD_LOG_ID,
+ )
+ try:
+ while True:
+ # Jittered sleep — random scaling of base interval centered on 1.0.
+ delay = base * (1.0 + random.uniform(-jitter, jitter))
+ await asyncio.sleep(delay)
+ try:
+ stale_ids = await find_stale_response_ids(threshold)
+ if not stale_ids:
+ continue
+ logger.info(
+ "[durable] stale-scan found %d candidate(s): %s",
+ len(stale_ids),
+ stale_ids,
+ )
+ for response_id in stale_ids:
+ try:
+ resp = await get_response(response_id)
+ if resp:
+ await self._try_claim_and_resume(response_id, resp)
+ except Exception:
+ logger.exception(
+ "[durable] stale-scan resume failed response_id=%s",
+ response_id,
+ )
+ except Exception:
+ # Don't let an iteration failure kill the loop.
+ logger.exception("[durable] stale-scan iteration failed")
+ except asyncio.CancelledError:
+ logger.info("[durable] stale-scan loop stopped pod=%s", _POD_LOG_ID)
+ raise
+
+ @asynccontextmanager
+ async def _heartbeat(self, response_id: str, attempt_number: int) -> AsyncGenerator[None, None]:
+ """Keep the response row's heartbeat_at fresh while the body runs.
+
+ A background task writes ``heartbeat_at = now()`` every
+ ``heartbeat_interval_seconds``, scoped to ``attempt_number``. The
+ update only matches if ``attempt_number`` still equals the value the
+ heartbeat was started with — if another pod has CAS-claimed the
+ row (bumping attempt_number), this heartbeat returns 0 rows and the
+ task knows it has lost ownership and stops.
+
+ Implicit-ownership model: there is no ``owner_pod_id`` column. The
+ last pod to successfully heartbeat at the current attempt is the
+ de facto owner.
+ """
+ interval = self._settings.heartbeat_interval_seconds
+ stop = asyncio.Event()
+
+ async def _beat():
+ beats = 0
+ logger.info(
+ "[durable] heartbeat start response_id=%s attempt=%d pod=%s interval=%.1fs",
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ interval,
+ )
+ try:
+ while not stop.is_set():
+ try:
+ ok = await heartbeat_response(response_id, attempt_number)
+ if not ok:
+ # CAS failed → attempt_number has moved past us,
+ # another pod owns this response now. Stop the
+ # heartbeat task; the handler is still running but
+ # its emissions to the message log will be tagged
+ # with this attempt and ignored on the resumed
+ # path's filter (which keys on the new attempt).
+ logger.info(
+ "[durable] heartbeat lost ownership response_id=%s "
+ "attempt=%d (another pod claimed); stopping",
+ response_id,
+ attempt_number,
+ )
+ stop.set()
+ break
+ beats += 1
+ # Sampled heartbeat log so the lifecycle is visible
+ # without spamming every interval. Every 5th (~15s
+ # at 3s interval) is a good compromise.
+ if beats % 5 == 1:
+ logger.info(
+ "[durable] heartbeat beat#%d response_id=%s attempt=%d pod=%s",
+ beats,
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ )
+ except Exception:
+ logger.warning(
+ "[durable] heartbeat write failed response_id=%s; will retry",
+ response_id,
+ exc_info=True,
+ )
+ try:
+ await asyncio.wait_for(stop.wait(), timeout=interval)
+ except TimeoutError:
+ pass
+ except asyncio.CancelledError:
+ pass
+ logger.info(
+ "[durable] heartbeat stop response_id=%s attempt=%d pod=%s total_beats=%d",
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ beats,
+ )
+
+ hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}")
+ try:
+ yield
+ finally:
+ stop.set()
+ hb_task.cancel()
+ try:
+ await hb_task
+ except (asyncio.CancelledError, Exception):
+ pass
+
@asynccontextmanager
async def _task_scope(
- self, response_id: str, state: dict[str, Any]
+ self,
+ response_id: str,
+ state: dict[str, Any],
+ *,
+ attempt_number: int = 1,
) -> AsyncGenerator[None, None]:
- """Timeout + error handling wrapper for background tasks."""
+ """Timeout + error handling wrapper for background tasks.
+
+ ``attempt_number`` is CAS-checked on terminal-status writes so a
+ deferred fail / cleanup that fires after another pod has claimed the
+ row for resume doesn't clobber the new owner's in-progress state.
+ """
try:
async with asyncio.timeout(self._settings.task_timeout_seconds):
yield
@@ -339,7 +739,11 @@ async def _task_scope(
self._settings.task_timeout_seconds,
)
asyncio.create_task(
- _deferred_mark_failed(response_id, delay=self._settings.cleanup_timeout_seconds),
+ _deferred_mark_failed(
+ response_id,
+ delay=self._settings.cleanup_timeout_seconds,
+ owning_attempt_number=attempt_number,
+ ),
name=f"deferred-fail-{response_id}",
)
except Exception as exc:
@@ -348,7 +752,7 @@ async def _task_scope(
# TODO: sequence number computation is racy (see _deferred_mark_failed).
async with asyncio.timeout(self._settings.cleanup_timeout_seconds):
existing = await get_messages(response_id, after_sequence=None)
- next_seq = max((seq for seq, _, _ in existing), default=-1) + 1
+ next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1
await append_message(
response_id,
next_seq,
@@ -361,8 +765,13 @@ async def _task_scope(
"code": "task_failed",
},
},
+ attempt_number=attempt_number,
+ )
+ await update_response_status(
+ response_id,
+ "failed",
+ expected_attempt_number=attempt_number,
)
- await update_response_status(response_id, "failed")
except Exception:
logger.exception(
"[error-cleanup] Immediate update failed for %s, deferring",
@@ -373,6 +782,7 @@ async def _task_scope(
response_id,
delay=self._settings.cleanup_timeout_seconds,
reason=str(exc),
+ owning_attempt_number=attempt_number,
),
name=f"deferred-fail-{response_id}",
)
@@ -382,11 +792,22 @@ async def _run_background_stream(
response_id: str,
request_data: dict[str, Any],
return_trace_id: bool = False,
+ *,
+ attempt_number: int = 1,
) -> None:
"""Timeout-guarded wrapper around the streaming agent loop."""
state: dict[str, Any] = {"seq": 0}
- async with self._task_scope(response_id, state):
- await self._do_background_stream(response_id, request_data, return_trace_id, state)
+ async with (
+ self._task_scope(response_id, state, attempt_number=attempt_number),
+ self._heartbeat(response_id, attempt_number),
+ ):
+ await self._do_background_stream(
+ response_id,
+ request_data,
+ return_trace_id,
+ state,
+ attempt_number=attempt_number,
+ )
def transform_stream_event(self, event: dict, response_id: str) -> dict:
"""Override to transform events before persistence (e.g. replace placeholder IDs)."""
@@ -398,16 +819,35 @@ async def _do_background_stream(
request_data: dict[str, Any],
return_trace_id: bool,
state: dict[str, Any],
+ *,
+ attempt_number: int = 1,
) -> None:
"""Run agent via stream_fn, persist each stream event as a message row."""
stream_fn = get_stream_function()
if stream_fn is None:
- await update_response_status(response_id, "failed")
+ await update_response_status(
+ response_id, "failed", expected_attempt_number=attempt_number
+ )
raise RuntimeError("No stream function registered; cannot run background stream")
func_name = stream_fn.__name__
+ logger.info(
+ "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s",
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ func_name,
+ )
all_chunks: list[dict[str, Any]] = []
- seq = 0
+ # Continue sequence numbering across attempts so the client's cursor
+ # never rewinds on resume. First attempt starts at 0 and skips the DB
+ # lookup — keeps the fast path identical to pre-resume behavior and
+ # avoids an extra query per background request.
+ if attempt_number > 1:
+ existing = await get_messages(response_id, after_sequence=None)
+ seq = max((s for s, _, _, _ in existing), default=-1) + 1
+ else:
+ seq = 0
with mlflow.start_span(name=func_name) as span:
span.set_inputs(request_data)
@@ -420,16 +860,27 @@ async def _do_background_stream(
evt_type = evt.get("type", "message")
logger.debug(
"SSE event (background)",
- extra={"response_id": response_id, "seq": seq, "type": evt_type},
+ extra={
+ "response_id": response_id,
+ "seq": seq,
+ "type": evt_type,
+ "attempt": attempt_number,
+ },
)
await append_message(
response_id,
seq,
item=json.dumps(item) if item is not None else None,
stream_event=evt,
+ attempt_number=attempt_number,
)
seq += 1
state["seq"] = seq
+ # Explicit yield so task.cancel() propagates promptly on
+ # tight event streams. The OpenAI Agents Runner's
+ # stream_events() awaits a queue that empties fast enough
+ # that cancellation can sit for tens of seconds without this.
+ await asyncio.sleep(0)
span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "openai")
span.set_outputs(ResponsesAgent.responses_agent_output_reducer(all_chunks))
@@ -439,12 +890,28 @@ async def _do_background_stream(
response_id,
seq,
stream_event={"trace_id": span.trace_id},
+ attempt_number=attempt_number,
)
- await update_response_status(response_id, "completed")
- logger.debug(
- "Background stream completed",
- extra={"response_id": response_id, "total_events": seq},
+ updated = await update_response_status(
+ response_id, "completed", expected_attempt_number=attempt_number
+ )
+ if not updated:
+ logger.info(
+ "[durable] skipped completed-status write response_id=%s attempt=%d "
+ "(another pod claimed the row mid-handler); pod=%s",
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ )
+ return
+ logger.info(
+ "[durable] background stream completed response_id=%s attempt=%d "
+ "total_events=%d pod=%s",
+ response_id,
+ attempt_number,
+ seq,
+ _POD_LOG_ID,
)
async def _run_background_invoke(
@@ -452,11 +919,22 @@ async def _run_background_invoke(
response_id: str,
request_data: dict[str, Any],
return_trace_id: bool = False,
+ *,
+ attempt_number: int = 1,
) -> None:
"""Timeout-guarded wrapper around the invoke agent loop."""
state: dict[str, Any] = {"seq": 0}
- async with self._task_scope(response_id, state):
- await self._do_background_invoke(response_id, request_data, return_trace_id, state)
+ async with (
+ self._task_scope(response_id, state, attempt_number=attempt_number),
+ self._heartbeat(response_id, attempt_number),
+ ):
+ await self._do_background_invoke(
+ response_id,
+ request_data,
+ return_trace_id,
+ state,
+ attempt_number=attempt_number,
+ )
async def _do_background_invoke(
self,
@@ -464,11 +942,15 @@ async def _do_background_invoke(
request_data: dict[str, Any],
return_trace_id: bool,
state: dict[str, Any],
+ *,
+ attempt_number: int = 1,
) -> None:
"""Run agent via invoke_fn, persist each output item as a message row."""
invoke_fn = get_invoke_function()
if invoke_fn is None:
- await update_response_status(response_id, "failed")
+ await update_response_status(
+ response_id, "failed", expected_attempt_number=attempt_number
+ )
raise RuntimeError("No invoke function registered; cannot run background invoke")
func_name = invoke_fn.__name__
@@ -485,27 +967,191 @@ async def _do_background_invoke(
span.set_outputs(result)
output = result.get("output", [])
+ # Continue sequence numbering across attempts (see _do_background_stream).
+ if attempt_number > 1:
+ existing = await get_messages(response_id, after_sequence=None)
+ base_seq = max((s for s, _, _, _ in existing), default=-1) + 1
+ else:
+ base_seq = 0
for i, item in enumerate(output):
item_dict = (
item
if isinstance(item, dict)
else (item.model_dump() if hasattr(item, "model_dump") else {"content": str(item)})
)
+ seq = base_seq + i
await append_message(
response_id,
- i,
+ seq,
item=json.dumps(item_dict),
stream_event={"type": "response.output_item.done", "item": item_dict},
+ attempt_number=attempt_number,
)
- state["seq"] = i + 1
+ state["seq"] = seq + 1
if return_trace_id:
await update_response_trace_id(response_id, span.trace_id)
- await update_response_status(response_id, "completed")
+ updated = await update_response_status(
+ response_id, "completed", expected_attempt_number=attempt_number
+ )
+ if not updated:
+ logger.info(
+ "[durable] skipped completed-status write response_id=%s attempt=%d "
+ "(another pod claimed the row mid-handler); pod=%s",
+ response_id,
+ attempt_number,
+ _POD_LOG_ID,
+ )
+ return
logger.debug(
"Background invoke completed",
extra={"response_id": response_id, "output_items": len(output)},
)
+ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None:
+ """If ``resp`` is a stale in-progress run, attempt an atomic claim.
+
+ On success, kick off a new background task that re-invokes the handler
+ on a rotated conversation anchor with the replayed input enriched by
+ the prior attempt's emitted items, and returns the new
+ ``attempt_number``. On failure (another pod won, or the run is no
+ longer stale), returns ``None``.
+
+ This is the lazy resume path: triggered by a client retrieve. Pods
+ don't poll for stale work proactively in v1 — if no client ever calls
+ ``GET /responses/{id}``, the task_timeout sweep eventually marks it
+ failed.
+ """
+ if resp.status != "in_progress":
+ return None
+ # The run may be freshly started but too young to have a heartbeat yet;
+ # respect the creation age as a grace period equal to the stale
+ # threshold. Otherwise a quick follow-up retrieve could hijack a
+ # running pod before it ever writes its first heartbeat.
+ if resp.heartbeat_at is None:
+ age = _age_seconds(resp.created_at)
+ if age < self._settings.heartbeat_stale_threshold_seconds:
+ logger.debug(
+ "[durable] claim skipped response_id=%s reason=grace_period "
+ "age=%.1fs threshold=%.1fs",
+ response_id,
+ age,
+ self._settings.heartbeat_stale_threshold_seconds,
+ )
+ return None
+ else:
+ hb_age = _age_seconds(resp.heartbeat_at)
+ if hb_age < self._settings.heartbeat_stale_threshold_seconds:
+ # Heartbeat is fresh — owner is alive. Common case, keep
+ # quiet at debug so we don't spam every poll iteration.
+ logger.debug(
+ "[durable] claim skipped response_id=%s reason=heartbeat_fresh "
+ "age=%.1fs threshold=%.1fs",
+ response_id,
+ hb_age,
+ self._settings.heartbeat_stale_threshold_seconds,
+ )
+ return None
+ logger.info(
+ "[durable] stale heartbeat detected response_id=%s "
+ "heartbeat_age=%.1fs threshold=%.1fs",
+ response_id,
+ hb_age,
+ self._settings.heartbeat_stale_threshold_seconds,
+ )
+ if resp.original_request is None:
+ # Nothing to replay from — the run predates durability metadata.
+ logger.warning(
+ "[durable] cannot resume response_id=%s reason=no_original_request",
+ response_id,
+ )
+ return None
+
+ logger.info(
+ "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s",
+ response_id,
+ resp.attempt_number,
+ _POD_LOG_ID,
+ )
+ new_attempt = await claim_stale_response(
+ response_id,
+ stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds,
+ )
+ if new_attempt is None:
+ # Someone else owns it, or the row was updated between the read and
+ # the claim. Expected under contention.
+ logger.info(
+ "[durable] claim lost response_id=%s (another pod won or row changed)",
+ response_id,
+ )
+ return None
+
+ # Build a "resume" request by REPLAYING the original POST's input on a
+ # ROTATED conversation anchor, plus a single prose user message that
+ # narrates the prior attempt's completed tool calls / outputs / narrative.
+ #
+ # Always-rotate + prose recovery design:
+ # 1. Rotation makes the handler's SDK helpers resolve to a FRESH
+ # thread_id / session_id, so the rotated session starts empty and
+ # cannot inherit orphan-poisoned mid-turn state from the crashed
+ # attempt. Subsequent turns from the client should also use the
+ # rotated anchor (templates return it via custom_outputs); the
+ # original session becomes orphaned permanently and is never read.
+ # 2. The prose user message is the single source of truth for what
+ # already ran. The LLM reads it as a recovery instruction and
+ # continues. No structural carry-forward, no synthetic outputs,
+ # no per-SDK adapter wrappers needed.
+ existing = await get_messages(response_id, after_sequence=None)
+ next_seq = max((s for s, _, _, _ in existing), default=-1) + 1
+ prose_msg = _build_prose_recovery_message(existing, prior_attempt_number=new_attempt - 1)
+
+ resume_dict = copy.deepcopy(resp.original_request)
+ resume_input = list(resume_dict.get("input") or [])
+ resume_input.append(prose_msg)
+ resume_dict["input"] = resume_input
+ logger.info(
+ "[durable] resume built prose recovery message for attempt %d response_id=%s",
+ new_attempt - 1,
+ response_id,
+ )
+ resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id)
+ resume_request = self.validator.validate_and_convert_request(resume_dict)
+ # Surface the rotated conversation_id in the sentinel so clients that
+ # cache `chat_id → conversation_id` can pick up the rotation and use
+ # the rotated session on subsequent turns. Without this the next turn
+ # lands on the original (orphan-poisoned) session.
+ rotated_conv_id = (resume_dict.get("context") or {}).get("conversation_id")
+ await append_message(
+ response_id,
+ next_seq,
+ stream_event={
+ "type": "response.resumed",
+ "attempt": new_attempt,
+ "from_seq": next_seq,
+ "conversation_id": rotated_conv_id,
+ },
+ attempt_number=new_attempt,
+ )
+
+ logger.info(
+ "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d",
+ response_id,
+ new_attempt,
+ _POD_LOG_ID,
+ next_seq,
+ )
+
+ task = asyncio.create_task(
+ self._run_background_stream(
+ response_id,
+ resume_request,
+ return_trace_id=False,
+ attempt_number=new_attempt,
+ ),
+ name=f"resume-{response_id}-{new_attempt}",
+ )
+ self._track_task(response_id, task)
+ return new_attempt
+
async def _handle_retrieve_request(
self,
response_id: str,
@@ -523,7 +1169,20 @@ async def _handle_retrieve_request(
if resp is None:
raise HTTPException(status_code=404, detail="Response not found")
- _, status, created_at, trace_id = resp
+ # Try a lazy resume before falling back to the absolute-timeout sweep.
+ # This gives us crash-recovery semantics: an idle client reconnecting
+ # after a pod died will reclaim the run and resume it here instead of
+ # just marking it failed.
+ await self._try_claim_and_resume(response_id, resp)
+
+ # Refresh after the potential resume: status / attempt_number may have changed.
+ resp = await get_response(response_id)
+ if resp is None:
+ raise HTTPException(status_code=404, detail="Response not found")
+
+ status = resp.status
+ created_at = resp.created_at
+ trace_id = resp.trace_id
if (
status == "in_progress"
@@ -542,10 +1201,9 @@ async def _handle_retrieve_request(
},
)
# TODO: sequence number computation here is racy under concurrent writers.
- # Acceptable at current scale; for high-QPS use a DB-assigned sequence or
- # SELECT FOR UPDATE on the response row to serialise writers.
existing = await get_messages(response_id, after_sequence=None)
- next_seq = max((seq for seq, _, _ in existing), default=-1) + 1
+ next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1
+ attempt = await _current_attempt(response_id)
await append_message(
response_id,
next_seq,
@@ -558,6 +1216,7 @@ async def _handle_retrieve_request(
"code": "task_timeout",
},
},
+ attempt_number=attempt,
)
status = "failed"
@@ -579,25 +1238,45 @@ async def _handle_retrieve_request(
messages = await get_messages(response_id, after_sequence=None)
if not messages and status == "in_progress":
- return {"id": response_id, "status": "in_progress"}
+ return {
+ "id": response_id,
+ "status": "in_progress",
+ "attempt_number": resp.attempt_number,
+ }
if status == "completed" and messages:
+ # Only consider items from the final (successful) attempt so that
+ # abandoned in-progress items from crashed attempts don't leak
+ # into the authoritative response body. Completed output_item.done
+ # events across attempts together make up the conversation — the
+ # agent SDK's checkpointer guarantees done-items are not re-emitted
+ # by later attempts, so this is a union with no duplicates.
output = []
- for _, _, evt in messages:
- if evt and "item" in evt:
- output.append(evt["item"])
+ for _, _, evt, _attempt in messages:
+ if evt and evt.get("type") == "response.output_item.done":
+ output.append(evt.get("item"))
result: dict[str, Any] = {
"id": response_id,
"status": "completed",
- "output": output,
+ "output": [o for o in output if o is not None],
+ "attempt_number": resp.attempt_number,
}
if trace_id:
result["metadata"] = {"trace_id": trace_id}
return result
if status == "failed" and messages:
- for _, _, evt in messages:
+ for _, _, evt, _attempt in messages:
if evt and evt.get("type") == "error":
- return {"id": response_id, "status": "failed", "error": evt.get("error")}
- return {"id": response_id, "status": status}
+ return {
+ "id": response_id,
+ "status": "failed",
+ "error": evt.get("error"),
+ "attempt_number": resp.attempt_number,
+ }
+ return {
+ "id": response_id,
+ "status": status,
+ "attempt_number": resp.attempt_number,
+ }
async def _stream_retrieve(
self,
@@ -638,15 +1317,26 @@ async def _stream_retrieve(
)
break
- _, status, _, _ = resp
+ status = resp.status
+ # Self-heal: if this response is still in_progress but its owning
+ # pod has gone silent past heartbeat_stale_threshold, try to claim
+ # + resume on this pod. A no-op if heartbeat is fresh or another
+ # pod already won. Without this, a stream opened before the crash
+ # would idle forever polling a dead run — since _try_claim_and_resume
+ # is only triggered by the outer retrieve handler on fresh GETs.
+ if status == "in_progress":
+ await self._try_claim_and_resume(response_id, resp)
+
# starting_after=0 fetches all messages (sequence numbers start at 0).
# We use after_sequence=-1 for the DB query so that seq 0 is included.
after_seq = last_seq - 1 if last_seq == 0 else last_seq
messages = await get_messages(response_id, after_sequence=after_seq)
- for seq, _, evt in messages:
+ for seq, _, evt, _attempt in messages:
if evt is not None:
- evt = {**evt, "sequence_number": seq}
+ # Tag every SSE frame with the response_id so proxies /
+ # clients can discover it without parsing nested fields.
+ evt = {**evt, "sequence_number": seq, "response_id": response_id}
event_type = evt.get("type", "message")
logger.debug(
"SSE event",
diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py
index 7b646116..8224b4d8 100644
--- a/src/databricks_ai_bridge/long_running/settings.py
+++ b/src/databricks_ai_bridge/long_running/settings.py
@@ -15,6 +15,15 @@ class LongRunningSettings:
poll_interval_seconds: float = 1.0
db_statement_timeout_ms: int = 5000
cleanup_timeout_seconds: float = 7.0
+ heartbeat_interval_seconds: float = 3.0
+ heartbeat_stale_threshold_seconds: float = 10.0
+ # Proactive stale-scan loop: how often (on average) each pod queries the
+ # responses table for stale-heartbeat rows and tries to claim+resume them.
+ # Each pod jitters this interval so multiple pods don't all hit the DB at
+ # once. The loop is the proactive counterpart to the lazy-on-GET claim
+ # path; it ensures crashed responses get recovered even if no client polls.
+ stale_scan_interval_seconds: float = 30.0
+ stale_scan_jitter_fraction: float = 0.5
def __post_init__(self) -> None:
if self.task_timeout_seconds <= 0:
@@ -25,6 +34,16 @@ def __post_init__(self) -> None:
raise ValueError("db_statement_timeout_ms must be positive")
if self.cleanup_timeout_seconds <= 0:
raise ValueError("cleanup_timeout_seconds must be positive")
+ if self.heartbeat_interval_seconds <= 0:
+ raise ValueError("heartbeat_interval_seconds must be positive")
+ if self.heartbeat_stale_threshold_seconds <= 0:
+ raise ValueError("heartbeat_stale_threshold_seconds must be positive")
+ if self.heartbeat_stale_threshold_seconds <= self.heartbeat_interval_seconds:
+ raise ValueError(
+ f"heartbeat_stale_threshold_seconds ({self.heartbeat_stale_threshold_seconds}) "
+ f"must be strictly greater than heartbeat_interval_seconds "
+ f"({self.heartbeat_interval_seconds}) to avoid false stale-run detection."
+ )
db_timeout_s = self.db_statement_timeout_ms / 1000.0
if self.cleanup_timeout_seconds <= db_timeout_s:
raise ValueError(
@@ -32,3 +51,7 @@ def __post_init__(self) -> None:
f"strictly greater than db_statement_timeout_ms converted to seconds "
f"({db_timeout_s})"
)
+ if self.stale_scan_interval_seconds <= 0:
+ raise ValueError("stale_scan_interval_seconds must be positive")
+ if not 0 <= self.stale_scan_jitter_fraction < 1:
+ raise ValueError("stale_scan_jitter_fraction must be in [0, 1)")
diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py
index a1290ba1..2565a1e0 100644
--- a/tests/databricks_ai_bridge/test_long_running_db.py
+++ b/tests/databricks_ai_bridge/test_long_running_db.py
@@ -160,10 +160,12 @@ async def test_get_messages(mock_session):
result_mock.scalars.return_value.all.return_value = [msg1, msg2]
mock_session.execute.return_value = result_mock
+ msg1.attempt_number = 1
+ msg2.attempt_number = 1
messages = await get_messages("resp_abc123", after_sequence=None)
assert len(messages) == 2
- assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"})
- assert messages[1] == (1, None, None)
+ assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}, 1)
+ assert messages[1] == (1, None, None, 1)
@pytest.mark.asyncio
@@ -175,6 +177,9 @@ async def test_get_response(mock_session):
row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc)
row.trace_id = "trace_xyz"
+ row.heartbeat_at = None
+ row.attempt_number = 1
+ row.original_request = None
result_mock = MagicMock()
result_mock.scalar_one_or_none.return_value = row
mock_session.execute.return_value = result_mock
@@ -185,6 +190,9 @@ async def test_get_response(mock_session):
"completed",
datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc),
"trace_xyz",
+ None, # heartbeat_at
+ 1, # attempt_number
+ None, # original_request
)
@@ -283,9 +291,15 @@ async def test_creates_schema_and_tables(self, reset_db_globals):
with patch(f"{DB_MODULE}.AsyncLakebaseSQLAlchemy", mock_cls), patch(f"{DB_MODULE}.event"):
await init_db(autoscaling_endpoint="ep")
- mock_conn.execute.assert_awaited_once()
- sql_arg = str(mock_conn.execute.call_args[0][0])
- assert "CREATE SCHEMA IF NOT EXISTS" in sql_arg
+ # init_db runs: CREATE SCHEMA + run_sync(create_all) + a series of
+ # ADD COLUMN IF NOT EXISTS / CREATE INDEX IF NOT EXISTS to migrate
+ # the durability columns onto pre-existing tables.
+ all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list)
+ assert "CREATE SCHEMA IF NOT EXISTS" in all_sql
+ assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql
+ assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql
+ assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql
+ assert "idx_responses_stale" in all_sql
mock_conn.run_sync.assert_awaited_once()
@pytest.mark.asyncio
@@ -346,3 +360,104 @@ async def fake_factory():
monkeypatch.setattr(db_mod, "_session_factory", fake_factory)
async with session_scope() as session:
assert session is mock_session
+
+
+# ---------------------------------------------------------------------------
+# Durability metadata: heartbeat (CAS on attempt), claim, attempt_number
+# ---------------------------------------------------------------------------
+
+
+@pytest.mark.asyncio
+async def test_create_response_durable_stamps_heartbeat_and_original_request(mock_session):
+ """Durable callers stamp heartbeat_at + serialized request on creation —
+ without these, a resumed pod can't re-invoke the handler."""
+ from databricks_ai_bridge.long_running.repository import create_response
+
+ await create_response(
+ "resp_abc",
+ "in_progress",
+ durable=True,
+ original_request={"input": [{"role": "user", "content": "hi"}]},
+ )
+ added = mock_session.add.call_args[0][0]
+ assert added.heartbeat_at is not None
+ # original_request is JSON-encoded for Text storage.
+ assert '"role": "user"' in added.original_request
+
+
+@pytest.mark.asyncio
+async def test_create_response_without_durability_metadata(mock_session):
+ """Non-durable callers (tests, legacy flows) write no heartbeat so the
+ stale sweep can't accidentally claim them."""
+ from databricks_ai_bridge.long_running.repository import create_response
+
+ await create_response("resp_x", "in_progress")
+ added = mock_session.add.call_args[0][0]
+ assert added.heartbeat_at is None
+ assert added.original_request is None
+
+
+@pytest.mark.asyncio
+async def test_heartbeat_response_updates_when_attempt_matches(mock_session):
+ from databricks_ai_bridge.long_running.repository import heartbeat_response
+
+ result_mock = MagicMock()
+ result_mock.rowcount = 1
+ mock_session.execute.return_value = result_mock
+
+ ok = await heartbeat_response("resp_abc", expected_attempt_number=1)
+ assert ok is True
+ mock_session.commit.assert_awaited_once()
+
+
+@pytest.mark.asyncio
+async def test_heartbeat_response_fails_when_attempt_changed(mock_session):
+ """If the CAS misses (attempt_number bumped by another pod's claim),
+ heartbeat reports failure so the caller can stop looping."""
+ from databricks_ai_bridge.long_running.repository import heartbeat_response
+
+ result_mock = MagicMock()
+ result_mock.rowcount = 0
+ mock_session.execute.return_value = result_mock
+
+ ok = await heartbeat_response("resp_abc", expected_attempt_number=1)
+ assert ok is False
+
+
+@pytest.mark.asyncio
+async def test_claim_stale_response_returns_attempt_number(mock_session):
+ from databricks_ai_bridge.long_running.repository import claim_stale_response
+
+ row = MagicMock()
+ row.__iter__ = lambda self: iter([2])
+ row.__getitem__ = lambda self, i: 2
+ result_mock = MagicMock()
+ result_mock.first.return_value = row
+ mock_session.execute.return_value = result_mock
+
+ attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0)
+ assert attempt == 2
+
+
+@pytest.mark.asyncio
+async def test_claim_stale_response_returns_none_when_not_eligible(mock_session):
+ from databricks_ai_bridge.long_running.repository import claim_stale_response
+
+ result_mock = MagicMock()
+ result_mock.first.return_value = None
+ mock_session.execute.return_value = result_mock
+
+ attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0)
+ assert attempt is None
+
+
+@pytest.mark.asyncio
+async def test_append_message_with_attempt_number(mock_session):
+ """Resumed events must be tagged with the resume attempt so retrieve can
+ filter or the client can render the response.resumed boundary cleanly."""
+ from databricks_ai_bridge.long_running.repository import append_message
+
+ await append_message("resp_abc", 5, stream_event={"x": 1}, attempt_number=3)
+ added = mock_session.add.call_args[0][0]
+ assert added.attempt_number == 3
+ assert added.sequence_number == 5
diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py
index 27b6eaf5..ae33b05c 100644
--- a/tests/databricks_ai_bridge/test_long_running_server.py
+++ b/tests/databricks_ai_bridge/test_long_running_server.py
@@ -14,9 +14,12 @@
pytest.importorskip("fastapi")
pytest.importorskip("psycopg")
+from databricks_ai_bridge.long_running.repository import ResponseInfo
from databricks_ai_bridge.long_running.server import (
LongRunningAgentServer,
+ _build_prose_recovery_message,
_deferred_mark_failed,
+ _rotate_conversation_id,
_sse_event,
)
from databricks_ai_bridge.long_running.settings import LongRunningSettings
@@ -34,6 +37,38 @@ def _make_server(**kwargs):
return LongRunningAgentServer("ResponsesAgent", **kwargs)
+def _resp_info(
+ response_id: str = "resp_123",
+ status: str = "in_progress",
+ created_at=None,
+ trace_id: str | None = None,
+ heartbeat_at=None,
+ attempt_number: int = 1,
+ original_request: dict | None = None,
+) -> ResponseInfo:
+ """Build a ResponseInfo with sensible defaults for tests.
+
+ Mirrors the server's repository model so test setups stay terse even as
+ durability columns grow over time.
+ """
+ if created_at is None:
+ created_at = datetime.now(timezone.utc)
+ return ResponseInfo(
+ response_id=response_id,
+ status=status,
+ created_at=created_at,
+ trace_id=trace_id,
+ heartbeat_at=heartbeat_at,
+ attempt_number=attempt_number,
+ original_request=original_request,
+ )
+
+
+def _msg(seq: int, item=None, evt=None, attempt: int = 1):
+ """Build a (seq, item, stream_event, attempt_number) tuple for get_messages mocks."""
+ return (seq, item, evt, attempt)
+
+
def _mock_span():
"""Return a mock MLflow span with the attributes the server uses."""
span = MagicMock()
@@ -55,8 +90,8 @@ def _mock_validator(server):
class TestSSEEvent:
def test_dict_data(self):
result = _sse_event("response.created", {"id": "resp_123", "status": "in_progress"})
- assert result.startswith("event: response.created\n")
- assert "data: " in result
+ assert result.startswith("data: ")
+ assert "event:" not in result
assert result.endswith("\n\n")
data_line = result.split("data: ")[1].strip()
parsed = json.loads(data_line)
@@ -64,8 +99,8 @@ def test_dict_data(self):
def test_string_data(self):
result = _sse_event("error", "something went wrong")
- assert "event: error\n" in result
- assert "data: something went wrong\n\n" in result
+ assert "event:" not in result
+ assert result == "data: something went wrong\n\n"
class TestLongRunningSettings:
@@ -189,7 +224,7 @@ def test_starting_after_zero_without_stream_is_allowed(self):
patch(
f"{MODULE}.get_response",
new_callable=AsyncMock,
- return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None),
+ return_value=_resp_info("resp_123", "in_progress"),
),
patch(
f"{MODULE}.get_messages",
@@ -209,8 +244,13 @@ async def test_marks_response_failed(self):
patch(
"databricks_ai_bridge.long_running.server.get_messages",
new_callable=AsyncMock,
- return_value=[(0, None, {"type": "response.created"})],
+ return_value=[_msg(0, None, {"type": "response.created"})],
) as mock_get,
+ patch(
+ "databricks_ai_bridge.long_running.server.get_response",
+ new_callable=AsyncMock,
+ return_value=_resp_info(),
+ ),
patch(
"databricks_ai_bridge.long_running.server.append_message",
new_callable=AsyncMock,
@@ -230,7 +270,7 @@ async def test_marks_response_failed(self):
stream_event = args[1]["stream_event"]
assert stream_event["type"] == "error"
assert stream_event["error"]["code"] == "task_timeout"
- mock_update.assert_awaited_once_with("resp_123", "failed")
+ mock_update.assert_awaited_once_with("resp_123", "failed", expected_attempt_number=None)
@pytest.mark.asyncio
async def test_handles_db_error_gracefully(self):
@@ -242,6 +282,36 @@ async def test_handles_db_error_gracefully(self):
# Should not raise
await _deferred_mark_failed("resp_123", delay=0.01)
+ @pytest.mark.asyncio
+ async def test_skips_status_write_when_attempt_changed(self):
+ # The pod that scheduled this fail was running attempt=1; by the
+ # time this fires, another pod has bumped to attempt=2. We must NOT
+ # write terminal status.
+ with (
+ patch(
+ "databricks_ai_bridge.long_running.server.get_messages",
+ new_callable=AsyncMock,
+ return_value=[_msg(0, None, {"type": "response.created"})],
+ ),
+ patch(
+ "databricks_ai_bridge.long_running.server.get_response",
+ new_callable=AsyncMock,
+ return_value=_resp_info(attempt_number=2),
+ ),
+ patch(
+ "databricks_ai_bridge.long_running.server.append_message",
+ new_callable=AsyncMock,
+ ) as mock_append,
+ patch(
+ "databricks_ai_bridge.long_running.server.update_response_status",
+ new_callable=AsyncMock,
+ ) as mock_update,
+ ):
+ await _deferred_mark_failed("resp_123", delay=0.01, owning_attempt_number=1)
+ # Neither append nor status-write fires when we've lost ownership.
+ mock_append.assert_not_awaited()
+ mock_update.assert_not_awaited()
+
class TestRetrieveRequest:
@pytest.mark.asyncio
@@ -271,13 +341,13 @@ async def test_completed_returns_output(self):
patch(
"databricks_ai_bridge.long_running.server.get_response",
new_callable=AsyncMock,
- return_value=("resp_123", "completed", datetime.now(timezone.utc), "trace_abc"),
+ return_value=_resp_info("resp_123", "completed", trace_id="trace_abc"),
),
patch(
"databricks_ai_bridge.long_running.server.get_messages",
new_callable=AsyncMock,
return_value=[
- (
+ _msg(
0,
'{"text": "hi"}',
{"type": "response.output_item.done", "item": {"text": "hi"}},
@@ -305,7 +375,7 @@ async def test_stale_run_detection(self):
patch(
"databricks_ai_bridge.long_running.server.get_response",
new_callable=AsyncMock,
- return_value=("resp_stale", "in_progress", old_time, None),
+ return_value=_resp_info("resp_stale", "in_progress", created_at=old_time),
),
patch(
"databricks_ai_bridge.long_running.server.get_messages",
@@ -336,7 +406,7 @@ async def test_in_progress_returns_status(self):
patch(
"databricks_ai_bridge.long_running.server.get_response",
new_callable=AsyncMock,
- return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None),
+ return_value=_resp_info("resp_123", "in_progress"),
),
patch(
"databricks_ai_bridge.long_running.server.get_messages",
@@ -347,7 +417,11 @@ async def test_in_progress_returns_status(self):
result = await server._handle_retrieve_request(
"resp_123", stream=False, starting_after=0
)
- assert result == {"id": "resp_123", "status": "in_progress"}
+ assert result == {
+ "id": "resp_123",
+ "status": "in_progress",
+ "attempt_number": 1,
+ }
class TestStreamRetrieve:
@@ -360,14 +434,14 @@ async def test_completed_stream(self):
patch(
"databricks_ai_bridge.long_running.server.get_response",
new_callable=AsyncMock,
- return_value=("resp_123", "completed", datetime.now(timezone.utc), None),
+ return_value=_resp_info("resp_123", "completed"),
),
patch(
"databricks_ai_bridge.long_running.server.get_messages",
new_callable=AsyncMock,
return_value=[
- (0, None, {"type": "response.created", "id": "resp_123"}),
- (
+ _msg(0, None, {"type": "response.created", "id": "resp_123"}),
+ _msg(
1,
'{"text": "hi"}',
{"type": "response.output_item.done", "item": {"text": "hi"}},
@@ -394,13 +468,13 @@ async def test_failed_stream_stops(self):
patch(
"databricks_ai_bridge.long_running.server.get_response",
new_callable=AsyncMock,
- return_value=("resp_123", "failed", datetime.now(timezone.utc), None),
+ return_value=_resp_info("resp_123", "failed"),
),
patch(
"databricks_ai_bridge.long_running.server.get_messages",
new_callable=AsyncMock,
return_value=[
- (0, None, {"type": "error", "error": {"message": "boom"}}),
+ _msg(0, None, {"type": "error", "error": {"message": "boom"}}),
],
),
):
@@ -433,7 +507,9 @@ async def fake_stream(request_data):
patch(f"{MODULE}.get_stream_function", return_value=fake_stream),
patch(f"{MODULE}.mlflow") as mock_mlflow,
patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append,
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
patch(f"{MODULE}.ResponsesAgent") as mock_ra,
):
mock_mlflow.start_span.return_value = span
@@ -448,7 +524,7 @@ async def fake_stream(request_data):
assert seqs == [0, 1, 2]
# Verify state tracks final seq
assert state["seq"] == 3
- mock_update.assert_awaited_once_with("resp_1", "completed")
+ mock_update.assert_awaited_once_with("resp_1", "completed", expected_attempt_number=1)
@pytest.mark.asyncio
async def test_calls_transform_stream_event(self):
@@ -500,12 +576,14 @@ async def test_no_stream_fn_marks_failed(self):
with (
patch(f"{MODULE}.get_stream_function", return_value=None),
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
):
state = {"seq": 0}
with pytest.raises(RuntimeError, match="No stream function registered"):
await server._do_background_stream("resp_x", {}, False, state)
- mock_update.assert_awaited_once_with("resp_x", "failed")
+ mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1)
@pytest.mark.asyncio
async def test_persists_trace_id_when_requested(self):
@@ -554,7 +632,9 @@ async def fake_invoke(request_data):
patch(f"{MODULE}.get_invoke_function", return_value=fake_invoke),
patch(f"{MODULE}.mlflow") as mock_mlflow,
patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append,
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock),
):
mock_mlflow.start_span.return_value = span
@@ -572,7 +652,7 @@ async def fake_invoke(request_data):
assert evt["type"] == "response.output_item.done"
assert "item" in evt
assert state["seq"] == 2
- mock_update.assert_awaited_once_with("resp_inv", "completed")
+ mock_update.assert_awaited_once_with("resp_inv", "completed", expected_attempt_number=1)
@pytest.mark.asyncio
async def test_trace_id_persisted_when_requested(self):
@@ -603,12 +683,14 @@ async def test_no_invoke_fn_marks_failed(self):
with (
patch(f"{MODULE}.get_invoke_function", return_value=None),
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
):
state = {"seq": 0}
with pytest.raises(RuntimeError, match="No invoke function registered"):
await server._do_background_invoke("resp_x", {}, False, state)
- mock_update.assert_awaited_once_with("resp_x", "failed")
+ mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1)
@pytest.mark.asyncio
async def test_sync_invoke_fn_supported(self):
@@ -623,7 +705,9 @@ def sync_invoke(request_data):
patch(f"{MODULE}.get_invoke_function", return_value=sync_invoke),
patch(f"{MODULE}.mlflow") as mock_mlflow,
patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append,
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock),
):
mock_mlflow.start_span.return_value = span
@@ -632,7 +716,9 @@ def sync_invoke(request_data):
await server._do_background_invoke("resp_sync", {"input": "hi"}, False, state)
assert mock_append.await_count == 1
- mock_update.assert_awaited_once_with("resp_sync", "completed")
+ mock_update.assert_awaited_once_with(
+ "resp_sync", "completed", expected_attempt_number=1
+ )
# ---------------------------------------------------------------------------
@@ -668,12 +754,19 @@ async def test_exception_writes_error_event_inline(self):
f"{MODULE}.get_messages",
new_callable=AsyncMock,
return_value=[
- (0, None, {"type": "response.created"}),
- (1, None, {"type": "response.output_text.delta"}),
+ _msg(0, None, {"type": "response.created"}),
+ _msg(1, None, {"type": "response.output_text.delta"}),
],
),
+ patch(
+ f"{MODULE}.get_response",
+ new_callable=AsyncMock,
+ return_value=_resp_info(),
+ ),
patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append,
- patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update,
+ patch(
+ f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True
+ ) as mock_update,
):
state = {"seq": 2}
async with server._task_scope("resp_err", state):
@@ -686,7 +779,7 @@ async def test_exception_writes_error_event_inline(self):
assert evt["error"]["message"] == "something broke"
assert evt["error"]["code"] == "task_failed"
assert mock_append.call_args.args[1] == 2 # next_seq
- mock_update.assert_awaited_once_with("resp_err", "failed")
+ mock_update.assert_awaited_once_with("resp_err", "failed", expected_attempt_number=1)
@pytest.mark.asyncio
async def test_exception_falls_back_to_deferred_on_db_failure(self):
@@ -875,3 +968,544 @@ async def test_lifespan_not_set_when_db_not_configured(self):
routes = [r.path for r in server.app.routes if hasattr(r, "path")]
assert "/responses/{response_id}" in routes
mock_init.assert_not_awaited()
+
+
+# ---------------------------------------------------------------------------
+# Durable resume: claim/heartbeat/attempt_number/sentinel
+# ---------------------------------------------------------------------------
+
+
+class TestBuildProseRecoveryMessage:
+ """Prose recovery serializer: produce a single Responses-API user-message
+ item containing the prior attempt's stream events as JSON, plus a
+ directive that asks the LLM to figure out what's done vs interrupted."""
+
+ def _done(self, seq, attempt, item):
+ return (seq, None, {"type": "response.output_item.done", "item": item}, attempt)
+
+ def test_returns_user_message_shape(self):
+ out = _build_prose_recovery_message([], prior_attempt_number=1)
+ assert out["type"] == "message"
+ assert out["role"] == "user"
+ assert isinstance(out["content"], str)
+ assert "[RECOVERY]" in out["content"]
+
+ def test_includes_events_json(self):
+ messages = [
+ self._done(
+ 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}
+ ),
+ self._done(1, 1, {"type": "function_call_output", "call_id": "c1", "output": "ok"}),
+ ]
+ out = _build_prose_recovery_message(messages, prior_attempt_number=1)
+ body = out["content"]
+ # Body should contain the raw events JSON-serialized.
+ assert '"call_id": "c1"' in body
+ assert '"output": "ok"' in body
+ assert '"name": "f"' in body
+
+ def test_filters_other_attempts(self):
+ messages = [
+ self._done(
+ 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}
+ ),
+ self._done(
+ 1, 2, {"type": "function_call", "call_id": "c2", "name": "g", "arguments": "{}"}
+ ),
+ ]
+ out = _build_prose_recovery_message(messages, prior_attempt_number=1)
+ body = out["content"]
+ assert '"call_id": "c1"' in body
+ # attempt 2 events excluded
+ assert '"call_id": "c2"' not in body
+
+ def test_empty_attempt_emits_empty_events_array(self):
+ out = _build_prose_recovery_message([], prior_attempt_number=1)
+ # Body still contains the recovery directive and an empty events array.
+ assert "[RECOVERY]" in out["content"]
+ assert "Events:\n[]" in out["content"]
+
+
+class TestRotateConversationId:
+ def test_rotate_drops_thread_id_and_sets_rotated_context(self):
+ r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}}
+ out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x")
+ assert "thread_id" not in out["custom_inputs"]
+ assert out["custom_inputs"]["user_id"] == "u"
+ assert out["context"]["conversation_id"] == "t1::attempt-2"
+
+ def test_rotate_drops_session_id(self):
+ r = {"custom_inputs": {"session_id": "s1"}, "context": {}}
+ out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x")
+ assert "session_id" not in out["custom_inputs"]
+ assert out["context"]["conversation_id"] == "s1::attempt-2"
+
+ def test_rotate_falls_back_to_context_conversation_id(self):
+ r = {"custom_inputs": {}, "context": {"conversation_id": "c-abc"}}
+ out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x")
+ assert out["context"]["conversation_id"] == "c-abc::attempt-3"
+
+ def test_rotate_falls_back_to_response_id_as_last_resort(self):
+ r = {"custom_inputs": {}, "context": {}}
+ out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x")
+ assert out["context"]["conversation_id"] == "resp_x::attempt-2"
+
+ def test_rotate_handles_missing_custom_inputs_key(self):
+ r = {"context": {"conversation_id": "c-abc"}}
+ out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x")
+ assert out["context"]["conversation_id"] == "c-abc::attempt-2"
+ assert out["custom_inputs"] == {}
+
+
+class TestHandleBackgroundRequestPersistsDurabilityState:
+ """Background request entry point should stamp the response row with the
+ full original_request body so resume can recover full prior-turn history."""
+
+ @pytest.mark.asyncio
+ async def test_persists_durable_flag_and_original_request(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=True):
+ server = LongRunningAgentServer("ResponsesAgent")
+ _mock_validator(server)
+
+ captured: dict = {}
+
+ async def fake_create_response(
+ response_id, status, *, durable=False, original_request=None
+ ):
+ captured["response_id"] = response_id
+ captured["status"] = status
+ captured["durable"] = durable
+ captured["original_request"] = original_request
+
+ with (
+ patch(f"{MODULE}.create_response", side_effect=fake_create_response),
+ patch("asyncio.create_task") as mock_create_task,
+ ):
+ result = await server._handle_background_request(
+ {"input": [{"role": "user", "content": "hi"}]},
+ is_streaming=False,
+ return_trace_id=False,
+ )
+
+ assert captured["status"] == "in_progress"
+ assert captured["durable"] is True
+ # original_request preserves the input the client sent (no
+ # conversation_id injection — the client owns that decision).
+ orig = captured["original_request"]
+ assert orig["input"] == [{"role": "user", "content": "hi"}]
+ # Return shape: immediate response_obj, not a stream.
+ assert result["id"] == captured["response_id"]
+ assert result["status"] == "in_progress"
+ mock_create_task.assert_called_once()
+
+
+class TestTryClaimAndResume:
+ @pytest.mark.asyncio
+ async def test_no_op_when_completed(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ resp = _resp_info(status="completed")
+ with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim:
+ result = await server._try_claim_and_resume("resp_x", resp)
+ assert result is None
+ mock_claim.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_grace_period_for_fresh_run(self):
+ """Just-started runs get a grace window before they're claim-eligible."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer(
+ "ResponsesAgent", heartbeat_stale_threshold_seconds=15.0
+ )
+ # created 2s ago, no heartbeat yet → should NOT be claimed.
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=2),
+ heartbeat_at=None,
+ original_request={"input": []},
+ )
+ with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim:
+ result = await server._try_claim_and_resume("resp_x", resp)
+ assert result is None
+ mock_claim.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_no_op_without_original_request(self):
+ """Legacy rows created before durability metadata can't be resumed."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ heartbeat_at=None,
+ original_request=None,
+ )
+ with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim:
+ result = await server._try_claim_and_resume("resp_x", resp)
+ assert result is None
+ mock_claim.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_claim_fails_returns_none(self):
+ """Another pod won the race — we quietly step aside."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ original_request={"input": [{"role": "user"}]},
+ )
+ with (
+ patch(
+ f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=None
+ ) as mock_claim,
+ patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append,
+ ):
+ result = await server._try_claim_and_resume("resp_x", resp)
+ assert result is None
+ mock_claim.assert_awaited_once()
+ mock_append.assert_not_awaited()
+
+ @pytest.mark.asyncio
+ async def test_successful_claim_spawns_resume_and_emits_sentinel(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100),
+ original_request={
+ "input": [{"role": "user", "content": "hi"}],
+ "custom_inputs": {"user_id": "u"},
+ "context": {"conversation_id": "resp_x"},
+ },
+ )
+ captured: dict = {}
+
+ async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt_number=1):
+ captured["seq"] = seq
+ captured["event"] = stream_event
+ captured["attempt_tag"] = attempt_number
+
+ with (
+ patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2),
+ patch(
+ f"{MODULE}.get_messages",
+ new_callable=AsyncMock,
+ return_value=[_msg(0, None, {}), _msg(1, None, {})],
+ ),
+ patch(f"{MODULE}.append_message", side_effect=fake_append),
+ patch("asyncio.create_task") as mock_create_task,
+ ):
+ attempt = await server._try_claim_and_resume("resp_x", resp)
+
+ assert attempt == 2
+ # Sentinel is written at next_seq (existing seqs were 0 and 1).
+ assert captured["seq"] == 2
+ assert captured["event"]["type"] == "response.resumed"
+ assert captured["event"]["attempt"] == 2
+ assert captured["attempt_tag"] == 2
+ # A resume task is spawned; it was not awaited synchronously.
+ mock_create_task.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_resume_replays_input_and_rotates_conversation_id(self):
+ """Resume must replay original_request.input (not blank it) and rotate
+ the conversation anchor so the handler resolves to a fresh thread /
+ session for the new attempt. Prevents the LangGraph stream-event
+ attempt-boundary orphan artifact (rotation-findings.md)."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100),
+ original_request={
+ "input": [{"role": "user", "content": "hi"}],
+ "custom_inputs": {"thread_id": "t1", "user_id": "u"},
+ "context": {},
+ },
+ )
+
+ captured_tasks = []
+
+ def capture_task(coro, *, name=None):
+ captured_tasks.append((coro, name))
+
+ class _Fake:
+ def cancel(self):
+ pass
+
+ def add_done_callback(self, cb):
+ pass
+
+ return _Fake()
+
+ with (
+ patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2),
+ patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]),
+ patch(f"{MODULE}.append_message", new_callable=AsyncMock),
+ patch("asyncio.create_task", side_effect=capture_task),
+ patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run,
+ ):
+ await server._try_claim_and_resume("resp_x", resp)
+
+ assert len(captured_tasks) == 1
+ coro, _name = captured_tasks[0]
+ await coro
+ mock_run.assert_awaited_once()
+ args, kwargs = mock_run.call_args
+ resume_request = args[1] if len(args) > 1 else kwargs["request_data"]
+ dumped = (
+ resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request
+ )
+ # Input is REPLAYED (not blanked) and a prose-recovery user message is
+ # appended so attempt N+1's LLM sees the original request plus a
+ # narrative of what happened. The MLflow validator normalizes the shape.
+ assert len(dumped["input"]) == 2
+ assert dumped["input"][0]["role"] == "user"
+ assert dumped["input"][0]["content"] == "hi"
+ assert dumped["input"][1]["role"] == "user"
+ assert "[RECOVERY]" in dumped["input"][1]["content"]
+ # thread_id was dropped so the handler's priority-2 fallback wins.
+ assert "thread_id" not in (dumped["custom_inputs"] or {})
+ # Other custom_inputs keys are preserved.
+ assert dumped["custom_inputs"]["user_id"] == "u"
+ # conversation_id is rotated to a per-attempt value anchored on t1.
+ assert dumped["context"]["conversation_id"] == "t1::attempt-2"
+ assert kwargs.get("attempt_number") == 2
+
+ @pytest.mark.asyncio
+ async def test_resume_rotation_anchors_on_context_conversation_id(self):
+ """When the client didn't pin a thread_id/session_id, rotation uses
+ the injected context.conversation_id as the base anchor."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+ from datetime import timedelta
+
+ resp = _resp_info(
+ status="in_progress",
+ created_at=datetime.now(timezone.utc) - timedelta(seconds=300),
+ heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100),
+ original_request={
+ "input": [{"role": "user", "content": "hi"}],
+ "custom_inputs": {},
+ "context": {"conversation_id": "resp_x"},
+ },
+ )
+
+ captured_tasks = []
+
+ def capture_task(coro, *, name=None):
+ captured_tasks.append((coro, name))
+
+ class _Fake:
+ def cancel(self):
+ pass
+
+ def add_done_callback(self, cb):
+ pass
+
+ return _Fake()
+
+ with (
+ patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=3),
+ patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]),
+ patch(f"{MODULE}.append_message", new_callable=AsyncMock),
+ patch("asyncio.create_task", side_effect=capture_task),
+ patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run,
+ ):
+ await server._try_claim_and_resume("resp_x", resp)
+
+ assert len(captured_tasks) == 1
+ coro, _name = captured_tasks[0]
+ await coro
+ mock_run.assert_awaited_once()
+ args, kwargs = mock_run.call_args
+ resume_request = args[1] if len(args) > 1 else kwargs["request_data"]
+ dumped = (
+ resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request
+ )
+ # Rotation anchors on the stored context.conversation_id (priority 2).
+ # Note: re-rotating in a subsequent attempt would re-anchor on the
+ # ORIGINAL stored value, not the previous rotation — no stacking.
+ assert dumped["context"]["conversation_id"] == "resp_x::attempt-3"
+
+
+class TestRetrieveTriggersLazyClaim:
+ @pytest.mark.asyncio
+ async def test_retrieve_calls_try_claim(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer("ResponsesAgent")
+
+ resp = _resp_info("resp_x", "in_progress")
+ with (
+ patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp),
+ patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]),
+ patch.object(
+ server, "_try_claim_and_resume", new_callable=AsyncMock, return_value=None
+ ) as mock_claim,
+ ):
+ await server._handle_retrieve_request("resp_x", stream=False, starting_after=0)
+
+ mock_claim.assert_awaited_once()
+
+
+class TestHeartbeatContextManager:
+ @pytest.mark.asyncio
+ async def test_writes_heartbeat_periodically(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer(
+ "ResponsesAgent",
+ heartbeat_interval_seconds=0.05,
+ heartbeat_stale_threshold_seconds=1.0,
+ )
+
+ with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb:
+ async with server._heartbeat("resp_x", attempt_number=1):
+ await asyncio.sleep(0.2) # enough time for 2+ heartbeats
+
+ # Heartbeat interval is 0.05s so we should see at least 2 writes.
+ assert mock_hb.await_count >= 2
+ for call in mock_hb.await_args_list:
+ assert call.args[0] == "resp_x"
+
+ @pytest.mark.asyncio
+ async def test_stops_cleanly_on_exit(self):
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer(
+ "ResponsesAgent",
+ heartbeat_interval_seconds=0.05,
+ heartbeat_stale_threshold_seconds=1.0,
+ )
+
+ with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb:
+ async with server._heartbeat("resp_x", attempt_number=1):
+ pass # immediate exit
+
+ # Give the heartbeat loop a chance to observe the stop signal.
+ await asyncio.sleep(0.1)
+ writes_after_exit = mock_hb.await_count
+
+ await asyncio.sleep(0.15)
+ # No new writes after the scope closed.
+ assert mock_hb.await_count == writes_after_exit
+
+ @pytest.mark.asyncio
+ async def test_db_error_does_not_interrupt_body(self):
+ """Heartbeat failures are logged, not raised — the stale check catches
+ real death, so a transient write miss must not kill a live run."""
+ with patch(f"{MODULE}.is_db_configured", return_value=False):
+ server = LongRunningAgentServer(
+ "ResponsesAgent",
+ heartbeat_interval_seconds=0.05,
+ heartbeat_stale_threshold_seconds=1.0,
+ )
+
+ body_ran = False
+ with patch(
+ f"{MODULE}.heartbeat_response",
+ new_callable=AsyncMock,
+ side_effect=RuntimeError("db down"),
+ ):
+ async with server._heartbeat("resp_x", attempt_number=1):
+ await asyncio.sleep(0.1)
+ body_ran = True
+ assert body_ran
+
+
+class TestSettingsHeartbeatValidation:
+ def test_stale_must_exceed_interval(self):
+ with pytest.raises(ValueError, match="heartbeat_stale_threshold_seconds"):
+ LongRunningSettings(
+ heartbeat_interval_seconds=5.0,
+ heartbeat_stale_threshold_seconds=5.0,
+ )
+
+ def test_interval_must_be_positive(self):
+ with pytest.raises(ValueError, match="heartbeat_interval_seconds must be positive"):
+ LongRunningSettings(heartbeat_interval_seconds=0)
+
+ def test_defaults_match_chat_ux(self):
+ # 3s interval + 15s stale gives ~5 heartbeats before a pod is considered
+ # dead — snug enough to recover conversations within a user's
+ # "reconnecting..." patience window.
+ s = LongRunningSettings()
+ assert s.heartbeat_interval_seconds == 3.0
+ assert s.heartbeat_stale_threshold_seconds == 10.0
+
+
+class TestDebugKillTask:
+ """The opt-in debug-kill endpoint lets integration tests simulate a crash
+ against a deployed pod without restarting the whole app. Off by default
+ because exposing task cancellation bypasses the normal cleanup path."""
+
+ def test_endpoint_absent_by_default(self):
+ from starlette.testclient import TestClient
+
+ with patch(f"{MODULE}.is_db_configured", return_value=True):
+ server = LongRunningAgentServer("ResponsesAgent")
+ client = TestClient(server.app, raise_server_exceptions=False)
+ resp = client.post("/_debug/kill_task/resp_x")
+ assert resp.status_code == 404 # route not registered
+
+ def test_endpoint_registered_when_env_set(self, monkeypatch):
+ from starlette.testclient import TestClient
+
+ monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1")
+ with patch(f"{MODULE}.is_db_configured", return_value=True):
+ server = LongRunningAgentServer("ResponsesAgent")
+ client = TestClient(server.app, raise_server_exceptions=False)
+ # No in-flight task for this response_id on this pod → 404, not 405.
+ resp = client.post("/_debug/kill_task/resp_missing")
+ assert resp.status_code == 404
+ assert "No in-flight task" in resp.json()["detail"]
+
+ @pytest.mark.asyncio
+ async def test_cancels_tracked_task(self, monkeypatch):
+ """Direct-call variant: skip the TestClient (which is sync and blocks
+ the loop) and call the handler logic through _running_tasks directly.
+ Covers the important behavior: cancelling a tracked task propagates
+ CancelledError and the tracking dict is cleared by the done-callback.
+ """
+ monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1")
+ with patch(f"{MODULE}.is_db_configured", return_value=True):
+ server = LongRunningAgentServer("ResponsesAgent")
+
+ cancel_event = asyncio.Event()
+
+ async def long_running():
+ try:
+ await asyncio.sleep(60)
+ except asyncio.CancelledError:
+ cancel_event.set()
+ raise
+
+ task = asyncio.create_task(long_running())
+ server._track_task("resp_tracked", task)
+
+ # Yield once so the new task can start waiting on sleep(60).
+ await asyncio.sleep(0)
+ assert "resp_tracked" in server._running_tasks
+
+ task.cancel()
+ # Expect CancelledError from awaiting the task itself, and the cancel
+ # event set inside the except handler before the re-raise.
+ with pytest.raises(asyncio.CancelledError):
+ await task
+ assert cancel_event.is_set()
+ # done-callback (scheduled on loop) clears the registration after the
+ # task completes — give it one more tick.
+ await asyncio.sleep(0)
+ assert "resp_tracked" not in server._running_tasks