From 87537724535e2f6e3071f8d798a9dca473fe3337 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 28 Apr 2026 01:21:58 +0000 Subject: [PATCH 1/9] LongRunningAgentServer: durable resume via prose recovery + always-rotate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Alternative to PR #416 — same durable-execution capability (heartbeat, CAS claim, retrieve endpoints, conversation_id rotation) but with a prose-recovery resume mechanism instead of structured carry-forward. Forked from main, scoped tightly to what prose recovery actually requires. All defenses live in ONE place (the bridge), SDK-agnostic. What's in this PR (vs main) =========================== Durable execution infrastructure (mirrors #416): - long_running/db.py — durability column migrations - long_running/models.py — durability columns on Response / Message - long_running/repository.py — claim_stale_response (CAS), heartbeat, ResponseInfo, get_messages tagged with attempt_number - long_running/settings.py — heartbeat interval/threshold - long_running/server.py — heartbeat, _try_claim_and_resume, _rotate_conversation_id, _inject_conversation_id, retrieve endpoint with stream resume, /_debug/kill_task, [durable] lifecycle logs Resume mechanism (DIFFERS from #416): - _build_prose_recovery_message: walks prior attempt's events, returns one Responses-API user message containing a flat prose narrative ("Called f(...) and got: ...", "Called g(...) — interrupted before result", "Was generating: ..."). Replaces #416's structured walker (~600 LOC → ~80 LOC). - _try_claim_and_resume appends the prose message to original_request. input[] instead of carrying forward structured tool pairs + synthetic [INTERRUPTED] outputs. - response.resumed SSE sentinel includes the rotated conversation_id so cooperating clients can use the rotated session for subsequent turns. SDK-agnostic UI-echo dedup (replaces #416's per-SDK adapter wrappers): - _trim_echoed_history: if request.input contains an assistant message, the client is echoing prior conversation history → trim to the latest user message. The SDK's own session/checkpointer storage is the source of truth for prior turns; the echo is redundant. - Wired into _handle_invocations_request, runs on every fresh POST. Resume-built input has only role:user items so the trim is a no-op. - Single ~25 LOC function in the bridge replaces the per-SDK dedups templates used to carry (OpenAI session.get_items() comparison; LangGraph agent.aget_state() comparison). What's NOT in this PR (vs #416) ================================ - integrations/openai/.../session.py: identical to main. No AsyncDatabricksSession.get_items wrap. - integrations/langchain/.../checkpoint.py: identical to main. No _build_tool_resume_repair, no _repair_loaded_checkpoint_tuple, no aget_tuple/get_tuple overrides. - src/databricks_ai_bridge/tool_repair.py: not added. No sanitize_tool_items, no _sanitize_request_input, no auto_sanitize_input setting. Always-rotate makes the SDK wrappers unnecessary (rotated session is clean). The trim hook makes per-SDK dedup unnecessary (handles UI echo at the request boundary, before the handler runs). All dedup logic now lives in one SDK-agnostic place. The trade ========= | Axis | #416 (structured) | This PR (prose) | |---|---|---| | Files vs main | 12 | 7 | | Total LOC vs main | +2293 | +1709 | | Resume builder LOC | ~600 | ~80 | | SDK adapter wrappers | required (~200 LOC) | none | | Input sanitizer | yes (~150 LOC) | none | | Per-template dedup hooks | yes (per SDK) | none (SDK-agnostic in bridge) | | Resume input shape | structured pairs + [INTERRUPTED] | single prose user message | | Cache prefix on resume | ~95% hit | ~0% hit beyond [system] | | Cache prefix on subsequent turns | natural structured prefix | [system]+[prose] for conversation lifetime | | Multi-turn fix | read-time repair on aget_tuple/get_items | always-rotate redirect via chatbot alias | | New infra needed | none | chatbot alias capture from response.resumed | Trade-offs detailed in app-templates/durable-recovery-recommendation.md. Tests: 118 pass (91 long_running_server + 27 long_running_db). Status ====== POC for review against #416. Not intended to merge unless empirical data on cache cost and tool-use quality justifies the trade vs #416. Companion PR: app-templates#204. Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/db.py | 44 + .../long_running/models.py | 23 +- .../long_running/repository.py | 125 ++- .../long_running/server.py | 788 +++++++++++++++++- .../long_running/settings.py | 12 + .../test_long_running_db.py | 134 ++- .../test_long_running_server.py | 756 ++++++++++++++++- 7 files changed, 1806 insertions(+), 76 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 903d466f..aed2c903 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -79,7 +79,51 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}")) await conn.run_sync(Base.metadata.create_all) + # Idempotent migration for tables created by earlier versions: add any + # columns introduced for durable-resume support. Each statement runs in + # its own transaction so an InsufficientPrivilege on one ALTER (another + # pod's SP owns the table but the schema is already migrated) doesn't + # poison the rest. A single mega-transaction would abort entirely on the + # first owner-check failure even with IF NOT EXISTS. + migration_stmts = ( + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"CREATE INDEX IF NOT EXISTS idx_responses_stale " + f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " + "WHERE status = 'in_progress'", + ) + skipped_migrations: list[str] = [] + for stmt in migration_stmts: + try: + async with _engine.begin() as conn: + await conn.execute(text(stmt)) + except Exception as exc: + msg = str(exc).lower() + if "insufficientprivilege" in msg or "must be owner" in msg: + skipped_migrations.append(stmt.split("\n")[0]) + continue + raise + _initialized = True + if skipped_migrations: + # WARN-level summary: if the DB was previously migrated by another SP + # this is fine, but if it's genuinely a new table and our SP lacks + # ALTER, claim/heartbeat queries will fail later with a confusing + # "column does not exist" — surface it clearly at startup. + logger.warning( + "[DB] Skipped %d durability migration(s) due to insufficient " + "privilege — assuming table was already migrated by another " + "service principal. Crash-resume will fail with 'column does " + "not exist' if this assumption is wrong. Skipped: %s", + len(skipped_migrations), + ", ".join(skipped_migrations), + ) logger.info("[DB] Engine and schema ready") diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 1d876dc7..7014a7db 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -14,7 +14,12 @@ class Base(DeclarativeBase): class Response(Base): - """Response status tracking for background agent tasks.""" + """Response status tracking for background agent tasks. + + Durability columns (``owner_pod_id``, ``heartbeat_at``, ``attempt_number``, + ``original_request``) support crash-resume: another pod can atomically + claim a stale in-progress row and replay the agent loop. + """ __tablename__ = "responses" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -25,12 +30,23 @@ class Response(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) + owner_pod_id: Mapped[str | None] = mapped_column(Text, nullable=True) + heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) + original_request: Mapped[str | None] = mapped_column(Text, nullable=True) messages = relationship("Message", back_populates="response", cascade="all, delete-orphan") class Message(Base): - """Stream events and output items for a response.""" + """Stream events and output items for a response. + + ``attempt_number`` tags events by which run attempt emitted them so that + resumed runs append to the same event log without overwriting earlier + (abandoned) attempts, and retrieve can filter to the latest attempt only. + """ __tablename__ = "messages" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -44,6 +60,9 @@ class Message(Base): sequence_number: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False, default=0 ) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) item: Mapped[str | None] = mapped_column(Text, nullable=True) stream_event: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index fcd86b29..06d30edb 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -5,15 +5,37 @@ from typing import Any, NamedTuple from sqlalchemy import select, update +from sqlalchemy.sql import bindparam, text from databricks_ai_bridge.long_running.db import session_scope -from databricks_ai_bridge.long_running.models import Message, Response +from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response -async def create_response(response_id: str, status: str) -> None: - """Insert a new response.""" +async def create_response( + response_id: str, + status: str, + *, + owner_pod_id: str | None = None, + original_request: dict[str, Any] | None = None, +) -> None: + """Insert a new response row. + + ``owner_pod_id`` and ``original_request`` are optional so that non-durable + callers (tests, legacy flows) can still create rows without durability + metadata. When present, they enable heartbeat + crash-resume semantics. + """ async with session_scope() as session: - session.add(Response(response_id=response_id, status=status)) + session.add( + Response( + response_id=response_id, + status=status, + owner_pod_id=owner_pod_id, + heartbeat_at=datetime.now().astimezone() if owner_pod_id else None, + original_request=( + json.dumps(original_request) if original_request is not None else None + ), + ) + ) await session.commit() @@ -43,18 +65,84 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None: await session.commit() +async def heartbeat_response(response_id: str, pod_id: str) -> bool: + """Update heartbeat_at for a response IFF this pod owns it. + + Returns True on success. A False result means the claim has been lost + (another pod took over, or the run finished and heartbeat should stop). + """ + async with session_scope() as session: + stmt = ( + update(Response) + .where(Response.response_id == response_id, Response.owner_pod_id == pod_id) + .values(heartbeat_at=datetime.now().astimezone()) + ) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + + +async def claim_stale_response( + response_id: str, + new_owner_pod_id: str, + stale_threshold_seconds: float, +) -> int | None: + """Atomically claim an in-progress response whose heartbeat has gone stale. + + Uses a single conditional UPDATE so exactly one caller wins on contention: + claim only succeeds if status is ``in_progress`` AND + (``owner_pod_id IS NULL`` OR ``heartbeat_at`` is older than the threshold). + + Returns the new ``attempt_number`` on success, or ``None`` if the row did + not satisfy the claim conditions (already completed, already claimed by a + live pod, or nonexistent). + """ + # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for + # the incremented column as ergonomically. Using a single statement keeps the + # claim atomic without an explicit transaction-level lock. + stmt = text( + f""" + UPDATE {AGENT_DB_SCHEMA}.responses + SET owner_pod_id = :pod, + heartbeat_at = now(), + attempt_number = attempt_number + 1 + WHERE response_id = :rid + AND status = 'in_progress' + AND (owner_pod_id IS NULL + OR heartbeat_at IS NULL + OR heartbeat_at < now() - make_interval(secs => :threshold)) + RETURNING attempt_number + """ + ).bindparams( + bindparam("pod", type_=None), + bindparam("rid", type_=None), + bindparam("threshold", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"pod": new_owner_pod_id, "rid": response_id, "threshold": stale_threshold_seconds}, + ) + row = result.first() + await session.commit() + return int(row[0]) if row else None + + async def append_message( response_id: str, sequence_number: int, item: str | None = None, stream_event: dict[str, Any] | None = None, + *, + attempt_number: int = 1, ) -> None: - """Append a message (stream event) for a response.""" + """Append a message (stream event) for a response, tagged with attempt_number.""" async with session_scope() as session: session.add( Message( response_id=response_id, sequence_number=sequence_number, + attempt_number=attempt_number, item=item, stream_event=json.dumps(stream_event) if stream_event is not None else None, ) @@ -65,22 +153,26 @@ async def append_message( async def get_messages( response_id: str, after_sequence: int | None = None, -) -> list[tuple[int, str | None, dict[str, Any] | None]]: - """Fetch messages for a response, optionally after a sequence number. + *, + attempt_number: int | None = None, +) -> list[tuple[int, str | None, dict[str, Any] | None, int]]: + """Fetch messages for a response, optionally filtering by sequence / attempt. - Returns list of (sequence_number, item, stream_event_dict). + Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``. """ async with session_scope() as session: stmt = select(Message).where(Message.response_id == response_id) if after_sequence is not None: stmt = stmt.where(Message.sequence_number > after_sequence) + if attempt_number is not None: + stmt = stmt.where(Message.attempt_number == attempt_number) stmt = stmt.order_by(Message.sequence_number) result = await session.execute(stmt) rows = result.scalars().all() out = [] for r in rows: evt = json.loads(r.stream_event) if r.stream_event else None - out.append((r.sequence_number, r.item, evt)) + out.append((r.sequence_number, r.item, evt, r.attempt_number)) return out @@ -89,6 +181,10 @@ class ResponseInfo(NamedTuple): status: str created_at: datetime trace_id: str | None + owner_pod_id: str | None + heartbeat_at: datetime | None + attempt_number: int + original_request: dict[str, Any] | None async def get_response(response_id: str) -> ResponseInfo | None: @@ -97,5 +193,14 @@ async def get_response(response_id: str) -> ResponseInfo | None: result = await session.execute(select(Response).where(Response.response_id == response_id)) row = result.scalar_one_or_none() if row: - return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id) + return ResponseInfo( + row.response_id, + row.status, + row.created_at, + row.trace_id, + row.owner_pod_id, + row.heartbeat_at, + row.attempt_number, + json.loads(row.original_request) if row.original_request else None, + ) return None diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index b3374d67..4672c365 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -6,9 +6,12 @@ raise RuntimeError("The long_running module requires Python 3.11 or later.") import asyncio +import copy import inspect import json import logging +import os +import socket import time import uuid from collections.abc import AsyncGenerator @@ -34,9 +37,11 @@ from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured from databricks_ai_bridge.long_running.repository import ( append_message, + claim_stale_response, create_response, get_messages, get_response, + heartbeat_response, update_response_status, update_response_trace_id, ) @@ -47,6 +52,9 @@ BACKGROUND_KEY = "background" +# One ID per process so heartbeats + claims have a stable owner identity. +_POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" + async def _deferred_mark_failed( response_id: str, delay: float = 2.0, reason: str = "Task timed out" @@ -65,7 +73,8 @@ async def _deferred_mark_failed( # or SELECT FOR UPDATE on the response row to serialise writers. async with asyncio.timeout(delay): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) error_event = { "type": "error", @@ -75,7 +84,13 @@ async def _deferred_mark_failed( "code": "task_timeout", }, } - await append_message(response_id, next_seq, item=None, stream_event=error_event) + await append_message( + response_id, + next_seq, + item=None, + stream_event=error_event, + attempt_number=attempt, + ) await update_response_status(response_id, "failed") logger.info("Marked %s as failed (reason: %s)", response_id, reason) @@ -91,10 +106,21 @@ async def _deferred_mark_failed( ) +async def _current_attempt(response_id: str) -> int: + """Fetch the current attempt_number for a response, defaulting to 1.""" + resp = await get_response(response_id) + return resp.attempt_number if resp else 1 + + def _sse_event(event_type: str, data: dict[str, Any] | str) -> str: - """Format an SSE event per Open Responses spec.""" + """Emit ``data:``-only SSE frames. Match the non-durable stream format + so downstream SSE parsers dispatch on the payload's ``type`` field + rather than a leading ``event:`` name line. Claude's multi-response + stream (one response.created/completed pair per tool iteration) plus + the event-name prefix confuses the AI SDK's Databricks provider into + a retry loop.""" payload = data if isinstance(data, str) else json.dumps(data) - return f"event: {event_type}\ndata: {payload}\n\n" + return f"data: {payload}\n\n" def _age_seconds(created_at: datetime) -> float: @@ -105,9 +131,256 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +def _extract_text_from_message(item: dict) -> str: + """Pull text out of a Responses-API message item's content array.""" + content = item.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for c in content: + if isinstance(c, dict): + t = c.get("text") + if isinstance(t, str): + parts.append(t) + return "".join(parts) + return "" + + +def _build_prose_recovery_message( + messages: list[tuple], prior_attempt_number: int +) -> dict[str, Any]: + """Narrate the prior attempt's events as a single user message. + + Walks the repository's ``(seq, item, stream_event, attempt)`` tuples + for the given prior attempt and produces a single Responses-API user + message item whose content is a flat prose summary of completed tool + calls, their outputs, narrative messages, and any partial in-flight + assistant text. + + This replaces the structured carry-forward of completed tool pairs + + synthetic ``[INTERRUPTED]`` outputs with a single recovery prompt the + LLM reads as: "the prior attempt crashed, here's what had happened, + please continue." Trades structural protocol fidelity for SDK-agnostic + simplicity — no provider-specific pairing rules, no synthetic events. + """ + completed_calls: dict[str, dict[str, Any]] = {} + call_order: list[str] = [] + narrative_texts: list[str] = [] + in_progress_text: dict[str, list[str]] = {} + in_progress_order: list[str] = [] + + for _seq, _item_json, evt, attempt_tag in messages: + if attempt_tag != prior_attempt_number or not isinstance(evt, dict): + continue + t = evt.get("type") + item = evt.get("item") + + if t == "response.output_item.added" and isinstance(item, dict): + if item.get("type") == "message" and (iid := item.get("id")): + in_progress_text.setdefault(iid, []) + if iid not in in_progress_order: + in_progress_order.append(iid) + + elif t == "response.output_item.done" and isinstance(item, dict): + itype = item.get("type") + if itype == "function_call": + cid = item.get("call_id") + if cid: + slot = completed_calls.setdefault(cid, {}) + slot["name"] = item.get("name") + slot["args"] = item.get("arguments") + if cid not in call_order: + call_order.append(cid) + elif itype == "function_call_output": + cid = item.get("call_id") + if cid: + slot = completed_calls.setdefault(cid, {}) + slot["output"] = item.get("output") + if cid not in call_order: + call_order.append(cid) + elif itype == "message": + text = _extract_text_from_message(item) + if text: + narrative_texts.append(text) + # The done event makes prior added/delta tracking moot. + iid = item.get("id") + if iid in in_progress_text: + in_progress_text.pop(iid, None) + if iid in in_progress_order: + in_progress_order.remove(iid) + + elif t == "response.output_text.delta": + iid = evt.get("item_id") + delta = evt.get("delta") + if iid and isinstance(delta, str) and iid in in_progress_text: + in_progress_text[iid].append(delta) + + lines: list[str] = [] + for cid in call_order: + info = completed_calls[cid] + name = info.get("name", "") + args = info.get("args", "") + if "output" in info: + lines.append(f"- Called `{name}({args})` and got result: {info['output']}") + else: + lines.append(f"- Called `{name}({args})` — interrupted before a result was returned.") + for text in narrative_texts: + lines.append(f"- Said: {text}") + for iid in in_progress_order: + chunks = in_progress_text.get(iid) or [] + partial = "".join(chunks).strip() + if partial: + lines.append(f"- Was generating: {partial!r} when the run was interrupted.") + + if lines: + body = ( + "[RECOVERY] The previous attempt of this agent task crashed mid-execution. " + "Here is what had completed before the crash:\n\n" + + "\n".join(lines) + + "\n\nPlease continue from where the prior attempt left off. " + "If a tool call was interrupted, you may re-invoke it if its result " + "is still needed." + ) + else: + body = ( + "[RECOVERY] The previous attempt of this agent task was interrupted " + "before any tool calls or assistant output completed. Please proceed " + "with the original user request." + ) + + return { + "type": "message", + "role": "user", + "content": body, + } + + +def _trim_echoed_history(items: list[Any]) -> list[Any]: + """If request.input contains an assistant message, the client is echoing + prior conversation history — trust the SDK's own session/checkpointer + storage as authoritative for prior turns and forward only the latest + user message. + + SDK-agnostic equivalent of the per-SDK dedup hooks templates used to + carry (OpenAI Session.get_items() comparison; LangGraph + agent.aget_state() comparison). The presence of any ``role:assistant`` + item is a reliable proxy for "this request is a continuation echo from + the chat UI" — first-turn POSTs have a single user message, while + continuations include the prior turns' assistant replies. + + Resume-path inputs built by ``_try_claim_and_resume`` are + [original_input + prose_msg] (both ``role:user``), so the trim is a + correct no-op there. + """ + if not items: + return items + has_assistant = any( + isinstance(item, dict) and item.get("role") == "assistant" for item in items + ) + if not has_assistant: + return items + user_idxs = [ + i for i, item in enumerate(items) if isinstance(item, dict) and item.get("role") == "user" + ] + if len(user_idxs) <= 1: + return items + trimmed = items[user_idxs[-1] :] + logger.info( + "[durable] trimmed echoed history: original=%d final=%d (kept latest user turn)", + len(items), + len(trimmed), + ) + return trimmed + + +def _rotate_conversation_id( + request_dict: dict[str, Any], + new_attempt_number: int, + response_id: str, +) -> dict[str, Any]: + """Rotate the conversation anchor to a per-attempt value. + + After a crash, attempt N+1 should see a FRESH checkpointer / session so it + doesn't inherit mid-turn state that the SDK can't repair cleanly (most + notably the LangGraph stream-event attempt-boundary orphan artifact). + The handler's priority chain is: + + 1. custom_inputs.thread_id / session_id (explicit, wins) + 2. context.conversation_id (fallback) + 3. auto-generated (last resort) + + We drop (1), pick the current base anchor, and write ``{base}::attempt-N`` + into (2). The handler then resolves to a fresh key for this attempt while + still being deterministic across retries of the same attempt. + + The LLM sees full turn history via ``original_request.input``, which was + captured at the initial POST — before any attempt ran, so it's clean by + construction. + """ + custom_inputs = request_dict.get("custom_inputs") + if not isinstance(custom_inputs, dict): + custom_inputs = {} + + base_anchor = ( + custom_inputs.get("thread_id") + or custom_inputs.get("session_id") + or (request_dict.get("context") or {}).get("conversation_id") + or response_id + ) + + custom_inputs.pop("thread_id", None) + custom_inputs.pop("session_id", None) + request_dict["custom_inputs"] = custom_inputs + + ctx = request_dict.get("context") or {} + ctx = dict(ctx) + rotated = f"{base_anchor}::attempt-{new_attempt_number}" + ctx["conversation_id"] = rotated + request_dict["context"] = ctx + logger.info( + "[durable] rotated conversation_id for resume response_id=%s attempt=%d base=%s rotated=%s", + response_id, + new_attempt_number, + base_anchor, + rotated, + ) + return request_dict + + +def _inject_conversation_id(request_dict: dict[str, Any], response_id: str) -> dict[str, Any]: + """Anchor the request to ``response_id`` as its conversation. + + Operates on a plain dict — the caller is responsible for converting to/from + pydantic via ``model_dump()`` and the server's validator. + + Templates that back this server use ``context.conversation_id`` (and + ``custom_inputs.thread_id`` / ``custom_inputs.session_id``) as priority-2 + fallbacks to derive their stateful thread/session key. If neither is + provided by the client, a resumed invocation from another pod would + generate a *fresh* ID and miss the checkpoint entirely — so we stamp the + conversation_id here before persisting the request, guaranteeing that + every replay hits the same memory store. + + Client-supplied values take precedence and are left untouched. + """ + out = copy.deepcopy(request_dict) if request_dict else {} + custom_inputs = out.get("custom_inputs") or {} + if custom_inputs.get("thread_id") or custom_inputs.get("session_id"): + return out + ctx = out.get("context") or {} + if ctx.get("conversation_id"): + return out + ctx = dict(ctx) + ctx["conversation_id"] = response_id + out["context"] = ctx + return out + + @experimental class LongRunningAgentServer(AgentServer): - """AgentServer subclass adding background mode and retrieve endpoints. + """AgentServer subclass adding background mode, retrieve endpoints, and + durable resume. Only compatible with ``ResponsesAgent`` mode. @@ -125,6 +398,16 @@ class LongRunningAgentServer(AgentServer): ``LAKEBASE_INSTANCE_NAME``, ``LAKEBASE_AUTOSCALING_ENDPOINT``, or both ``LAKEBASE_AUTOSCALING_PROJECT`` and ``LAKEBASE_AUTOSCALING_BRANCH``. + Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run + whose owning pod has stopped heartbeating for more than + ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims + the run and re-invokes the registered handler with a rotated + ``conversation_id`` (so the agent SDK resolves to a fresh thread/session), + the original request's ``input`` enriched with the prior attempt's already + emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic + output paired with any tool call that didn't finish. Completed work is + preserved; only the interrupted step re-runs. + Args: enable_chat_proxy: Whether to enable the chat proxy endpoint. db_instance_name: Lakebase provisioned instance name. Overrides @@ -143,6 +426,12 @@ class LongRunningAgentServer(AgentServer): Defaults to 5000 (5 seconds). cleanup_timeout_seconds: Timeout for DB cleanup after task failure. Defaults to 7.0. + heartbeat_interval_seconds: How often the owning pod writes + ``heartbeat_at`` while a run is in flight. Defaults to 3.0. + heartbeat_stale_threshold_seconds: Age at which a heartbeat is + considered stale and another pod may claim the run. Also used + as the grace window for a freshly-created run that hasn't + written its first heartbeat yet. Defaults to 10.0. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -162,6 +451,8 @@ def __init__( poll_interval_seconds: float = 1.0, db_statement_timeout_ms: int = 5000, cleanup_timeout_seconds: float = 7.0, + heartbeat_interval_seconds: float = 3.0, + heartbeat_stale_threshold_seconds: float = 10.0, ): if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( @@ -173,11 +464,18 @@ def __init__( poll_interval_seconds=poll_interval_seconds, db_statement_timeout_ms=db_statement_timeout_ms, cleanup_timeout_seconds=cleanup_timeout_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + heartbeat_stale_threshold_seconds=heartbeat_stale_threshold_seconds, ) self._db_instance_name = db_instance_name self._db_autoscaling_endpoint = db_autoscaling_endpoint self._db_project = db_project self._db_branch = db_branch + # Track in-flight background tasks per response_id so the debug-kill + # endpoint can simulate a pod crash without tearing the whole pod + # down. Not load-bearing for correctness — durability still relies on + # DB state, this is just a test affordance. + self._running_tasks: dict[str, asyncio.Task] = {} super().__init__(agent_type, enable_chat_proxy=enable_chat_proxy) def _setup_routes(self) -> None: @@ -195,6 +493,41 @@ async def cancel_endpoint(response_id: str): detail="Cancellation is not yet implemented.", ) + # Debug endpoint for testing durable resume: cancels the in-flight + # asyncio task that owns the given response_id WITHOUT running the + # _task_scope cleanup, so the DB row stays in_progress with a + # going-stale heartbeat — exactly the shape a real pod crash leaves. + # Opt-in via env var so it's never exposed in production. + if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") == "1": + + @self.app.post("/_debug/kill_task/{response_id}") + async def _debug_kill_task(response_id: str): + task = self._running_tasks.get(response_id) + if task is None: + logger.info( + "[durable] kill endpoint: no task response_id=%s on pod=%s", + response_id, + _POD_ID, + ) + raise HTTPException( + status_code=404, + detail=( + "No in-flight task for that response_id on this pod " + "(may already have finished or be running on another pod)." + ), + ) + logger.info( + "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + response_id, + _POD_ID, + ) + task.cancel() + return { + "response_id": response_id, + "pod_id": _POD_ID, + "status": "task_cancelled", + } + db_configured = is_db_configured() @self.app.get("/responses/{response_id}") @@ -265,6 +598,17 @@ async def _handle_invocations_request( data = {k: v for k, v in data.items() if k not in (BACKGROUND_KEY, MLFLOW_STREAM_KEY)} return_trace_id = (get_request_headers().get(RETURN_TRACE_HEADER) or "").lower() == "true" + # For background+DB requests, trim happens INSIDE _handle_background_request + # AFTER storing the untrimmed input as original_request — so resume can + # recover the full prior-turn history. For non-background requests we + # trim here in place since there's no persistence step that needs the + # untrimmed copy. + is_background_with_db = is_background and is_db_configured() + if not is_background_with_db: + items = data.get("input") + if isinstance(items, list): + data["input"] = _trim_echoed_history(items) + try: request_data = self.validator.validate_and_convert_request(data) except ValueError as e: @@ -273,7 +617,7 @@ async def _handle_invocations_request( detail=f"Invalid parameters for {self.agent_type}: {e}", ) from None - if is_background and is_db_configured(): + if is_background_with_db: return await self._handle_background_request( request_data, is_streaming, return_trace_id ) @@ -290,11 +634,43 @@ async def _handle_background_request( ) -> dict[str, Any] | StreamingResponse: """Start a new conversation and return response_id immediately.""" response_id = f"resp_{uuid.uuid4().hex[:24]}" - await create_response(response_id, "in_progress") + # Anchor the conversation to response_id so any future replay from a + # different pod resolves to the same agent-SDK thread/session. We + # round-trip through dict + validator so the handler still receives a + # pydantic ResponsesAgentRequest (its declared arg type). The + # declared param type is ``dict`` but the runtime object is a pydantic + # model from ``validate_and_convert_request``; fall back to ``dict()`` + # when tests pass a plain dict directly. + dump = getattr(request_data, "model_dump", None) + request_dict = dump() if callable(dump) else dict(request_data) + durable_dict = _inject_conversation_id(request_dict, response_id) + # Store the FULL request (untrimmed) as `original_request` so resume can + # recover the entire prior-turn history. The handler invocation below + # uses a trimmed copy to avoid duplicating turns the SDK's session has + # already persisted, but on resume the rotated SDK session is empty — + # only the full conversation in `original_request.input` lets the model + # reconstruct what came before the crashed turn. + await create_response( + response_id, + "in_progress", + owner_pod_id=_POD_ID, + original_request=durable_dict, + ) - logger.debug( - "Background response created", - extra={"response_id": response_id, "stream": is_streaming}, + # Build a TRIMMED handler request from the same durable dict — drops + # echoed history that the SDK's session already has. Original_request + # above stays untrimmed for resume. + handler_dict = copy.deepcopy(durable_dict) + handler_items = handler_dict.get("input") + if isinstance(handler_items, list): + handler_dict["input"] = _trim_echoed_history(handler_items) + durable_request = self.validator.validate_and_convert_request(handler_dict) + + logger.info( + "Background response created response_id=%s stream=%s pod=%s", + response_id, + is_streaming, + _POD_ID, ) response_obj: dict[str, Any] = { @@ -309,21 +685,99 @@ async def _handle_background_request( } # Fire-and-forget is intentional — task status is persisted to the database. + # We still track the task handle so the debug-kill endpoint can simulate + # a crash (and so we know whether a claim target lives on this pod). if is_streaming: - asyncio.create_task( - self._run_background_stream(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_stream( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return await self._handle_retrieve_request( response_id, stream=True, starting_after=0, ) else: - asyncio.create_task( - self._run_background_invoke(response_id, request_data, return_trace_id) + task = asyncio.create_task( + self._run_background_invoke( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) + self._track_task(response_id, task) return response_obj + def _track_task(self, response_id: str, task: asyncio.Task) -> None: + """Record a background task so the debug-kill endpoint can find it.""" + self._running_tasks[response_id] = task + task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None)) + + @asynccontextmanager + async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: + """Keep the response row's heartbeat_at fresh while the body runs. + + A background task writes ``heartbeat_at = now()`` every + ``heartbeat_interval_seconds`` for the owning pod. It stops when the + body returns/raises. Heartbeat write failures are logged but do not + interrupt the agent run — the stale-run check will detect a dead pod. + """ + interval = self._settings.heartbeat_interval_seconds + stop = asyncio.Event() + + async def _beat(): + beats = 0 + logger.info( + "[durable] heartbeat start response_id=%s pod=%s interval=%.1fs", + response_id, + _POD_ID, + interval, + ) + try: + while not stop.is_set(): + try: + await heartbeat_response(response_id, _POD_ID) + beats += 1 + # Sampled heartbeat log so the lifecycle is visible + # without spamming every interval. Every 5th (~15s + # at 3s interval) is a good compromise. + if beats % 5 == 1: + logger.info( + "[durable] heartbeat beat#%d response_id=%s pod=%s", + beats, + response_id, + _POD_ID, + ) + except Exception: + logger.warning( + "[durable] heartbeat write failed response_id=%s; will retry", + response_id, + exc_info=True, + ) + try: + await asyncio.wait_for(stop.wait(), timeout=interval) + except TimeoutError: + pass + except asyncio.CancelledError: + pass + logger.info( + "[durable] heartbeat stop response_id=%s pod=%s total_beats=%d", + response_id, + _POD_ID, + beats, + ) + + hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}") + try: + yield + finally: + stop.set() + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + @asynccontextmanager async def _task_scope( self, response_id: str, state: dict[str, Any] @@ -348,7 +802,8 @@ async def _task_scope( # TODO: sequence number computation is racy (see _deferred_mark_failed). async with asyncio.timeout(self._settings.cleanup_timeout_seconds): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -361,6 +816,7 @@ async def _task_scope( "code": "task_failed", }, }, + attempt_number=attempt, ) await update_response_status(response_id, "failed") except Exception: @@ -382,11 +838,19 @@ async def _run_background_stream( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_stream(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_stream( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) def transform_stream_event(self, event: dict, response_id: str) -> dict: """Override to transform events before persistence (e.g. replace placeholder IDs).""" @@ -398,6 +862,8 @@ async def _do_background_stream( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via stream_fn, persist each stream event as a message row.""" stream_fn = get_stream_function() @@ -406,8 +872,23 @@ async def _do_background_stream( raise RuntimeError("No stream function registered; cannot run background stream") func_name = stream_fn.__name__ + logger.info( + "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s", + response_id, + attempt_number, + _POD_ID, + func_name, + ) all_chunks: list[dict[str, Any]] = [] - seq = 0 + # Continue sequence numbering across attempts so the client's cursor + # never rewinds on resume. First attempt starts at 0 and skips the DB + # lookup — keeps the fast path identical to pre-resume behavior and + # avoids an extra query per background request. + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + seq = 0 with mlflow.start_span(name=func_name) as span: span.set_inputs(request_data) @@ -420,16 +901,27 @@ async def _do_background_stream( evt_type = evt.get("type", "message") logger.debug( "SSE event (background)", - extra={"response_id": response_id, "seq": seq, "type": evt_type}, + extra={ + "response_id": response_id, + "seq": seq, + "type": evt_type, + "attempt": attempt_number, + }, ) await append_message( response_id, seq, item=json.dumps(item) if item is not None else None, stream_event=evt, + attempt_number=attempt_number, ) seq += 1 state["seq"] = seq + # Explicit yield so task.cancel() propagates promptly on + # tight event streams. The OpenAI Agents Runner's + # stream_events() awaits a queue that empties fast enough + # that cancellation can sit for tens of seconds without this. + await asyncio.sleep(0) span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "openai") span.set_outputs(ResponsesAgent.responses_agent_output_reducer(all_chunks)) @@ -439,12 +931,17 @@ async def _do_background_stream( response_id, seq, stream_event={"trace_id": span.trace_id}, + attempt_number=attempt_number, ) await update_response_status(response_id, "completed") - logger.debug( - "Background stream completed", - extra={"response_id": response_id, "total_events": seq}, + logger.info( + "[durable] background stream completed response_id=%s attempt=%d " + "total_events=%d pod=%s", + response_id, + attempt_number, + seq, + _POD_ID, ) async def _run_background_invoke( @@ -452,11 +949,19 @@ async def _run_background_invoke( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_invoke(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_invoke( + response_id, + request_data, + return_trace_id, + state, + attempt_number=attempt_number, + ) async def _do_background_invoke( self, @@ -464,6 +969,8 @@ async def _do_background_invoke( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via invoke_fn, persist each output item as a message row.""" invoke_fn = get_invoke_function() @@ -485,19 +992,27 @@ async def _do_background_invoke( span.set_outputs(result) output = result.get("output", []) + # Continue sequence numbering across attempts (see _do_background_stream). + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + base_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + base_seq = 0 for i, item in enumerate(output): item_dict = ( item if isinstance(item, dict) else (item.model_dump() if hasattr(item, "model_dump") else {"content": str(item)}) ) + seq = base_seq + i await append_message( response_id, - i, + seq, item=json.dumps(item_dict), stream_event={"type": "response.output_item.done", "item": item_dict}, + attempt_number=attempt_number, ) - state["seq"] = i + 1 + state["seq"] = seq + 1 if return_trace_id: await update_response_trace_id(response_id, span.trace_id) await update_response_status(response_id, "completed") @@ -506,6 +1021,153 @@ async def _do_background_invoke( extra={"response_id": response_id, "output_items": len(output)}, ) + async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: + """If ``resp`` is a stale in-progress run, attempt an atomic claim. + + On success, kick off a new background task that re-invokes the handler + on a rotated conversation anchor with the replayed input enriched by + the prior attempt's emitted items, and returns the new + ``attempt_number``. On failure (another pod won, or the run is no + longer stale), returns ``None``. + + This is the lazy resume path: triggered by a client retrieve. Pods + don't poll for stale work proactively in v1 — if no client ever calls + ``GET /responses/{id}``, the task_timeout sweep eventually marks it + failed. + """ + if resp.status != "in_progress": + return None + # The run may be freshly started but too young to have a heartbeat yet; + # respect the creation age as a grace period equal to the stale + # threshold. Otherwise a quick follow-up retrieve could hijack a + # running pod before it ever writes its first heartbeat. + if resp.heartbeat_at is None: + age = _age_seconds(resp.created_at) + if age < self._settings.heartbeat_stale_threshold_seconds: + logger.debug( + "[durable] claim skipped response_id=%s reason=grace_period " + "age=%.1fs threshold=%.1fs", + response_id, + age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + else: + hb_age = _age_seconds(resp.heartbeat_at) + if hb_age < self._settings.heartbeat_stale_threshold_seconds: + # Heartbeat is fresh — owner is alive. Common case, keep + # quiet at debug so we don't spam every poll iteration. + logger.debug( + "[durable] claim skipped response_id=%s reason=heartbeat_fresh " + "age=%.1fs threshold=%.1fs", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + logger.info( + "[durable] stale heartbeat detected response_id=%s " + "heartbeat_age=%.1fs threshold=%.1fs current_owner=%s", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + resp.owner_pod_id, + ) + if resp.original_request is None: + # Nothing to replay from — the run predates durability metadata. + logger.warning( + "[durable] cannot resume response_id=%s reason=no_original_request", + response_id, + ) + return None + + logger.info( + "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s", + response_id, + resp.attempt_number, + _POD_ID, + ) + new_attempt = await claim_stale_response( + response_id, + new_owner_pod_id=_POD_ID, + stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds, + ) + if new_attempt is None: + # Someone else owns it, or the row was updated between the read and + # the claim. Expected under contention. + logger.info( + "[durable] claim lost response_id=%s (another pod won or row changed)", + response_id, + ) + return None + + # Build a "resume" request by REPLAYING the original POST's input on a + # ROTATED conversation anchor, plus a single prose user message that + # narrates the prior attempt's completed tool calls / outputs / narrative. + # + # Always-rotate + prose recovery design: + # 1. Rotation makes the handler's SDK helpers resolve to a FRESH + # thread_id / session_id, so the rotated session starts empty and + # cannot inherit orphan-poisoned mid-turn state from the crashed + # attempt. Subsequent turns from the client should also use the + # rotated anchor (templates return it via custom_outputs); the + # original session becomes orphaned permanently and is never read. + # 2. The prose user message is the single source of truth for what + # already ran. The LLM reads it as a recovery instruction and + # continues. No structural carry-forward, no synthetic outputs, + # no per-SDK adapter wrappers needed. + existing = await get_messages(response_id, after_sequence=None) + next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + prose_msg = _build_prose_recovery_message(existing, prior_attempt_number=new_attempt - 1) + + resume_dict = copy.deepcopy(resp.original_request) + resume_input = list(resume_dict.get("input") or []) + resume_input.append(prose_msg) + resume_dict["input"] = resume_input + logger.info( + "[durable] resume built prose recovery message for attempt %d response_id=%s", + new_attempt - 1, + response_id, + ) + resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id) + resume_request = self.validator.validate_and_convert_request(resume_dict) + # Surface the rotated conversation_id in the sentinel so clients that + # cache `chat_id → conversation_id` can pick up the rotation and use + # the rotated session on subsequent turns. Without this the next turn + # lands on the original (orphan-poisoned) session. + rotated_conv_id = (resume_dict.get("context") or {}).get("conversation_id") + await append_message( + response_id, + next_seq, + stream_event={ + "type": "response.resumed", + "attempt": new_attempt, + "from_seq": next_seq, + "conversation_id": rotated_conv_id, + }, + attempt_number=new_attempt, + ) + + logger.info( + "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d", + response_id, + new_attempt, + _POD_ID, + next_seq, + ) + + task = asyncio.create_task( + self._run_background_stream( + response_id, + resume_request, + return_trace_id=False, + attempt_number=new_attempt, + ), + name=f"resume-{response_id}-{new_attempt}", + ) + self._track_task(response_id, task) + return new_attempt + async def _handle_retrieve_request( self, response_id: str, @@ -523,7 +1185,20 @@ async def _handle_retrieve_request( if resp is None: raise HTTPException(status_code=404, detail="Response not found") - _, status, created_at, trace_id = resp + # Try a lazy resume before falling back to the absolute-timeout sweep. + # This gives us crash-recovery semantics: an idle client reconnecting + # after a pod died will reclaim the run and resume it here instead of + # just marking it failed. + await self._try_claim_and_resume(response_id, resp) + + # Refresh after the potential resume: status / attempt_number may have changed. + resp = await get_response(response_id) + if resp is None: + raise HTTPException(status_code=404, detail="Response not found") + + status = resp.status + created_at = resp.created_at + trace_id = resp.trace_id if ( status == "in_progress" @@ -542,10 +1217,9 @@ async def _handle_retrieve_request( }, ) # TODO: sequence number computation here is racy under concurrent writers. - # Acceptable at current scale; for high-QPS use a DB-assigned sequence or - # SELECT FOR UPDATE on the response row to serialise writers. existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -558,6 +1232,7 @@ async def _handle_retrieve_request( "code": "task_timeout", }, }, + attempt_number=attempt, ) status = "failed" @@ -579,25 +1254,45 @@ async def _handle_retrieve_request( messages = await get_messages(response_id, after_sequence=None) if not messages and status == "in_progress": - return {"id": response_id, "status": "in_progress"} + return { + "id": response_id, + "status": "in_progress", + "attempt_number": resp.attempt_number, + } if status == "completed" and messages: + # Only consider items from the final (successful) attempt so that + # abandoned in-progress items from crashed attempts don't leak + # into the authoritative response body. Completed output_item.done + # events across attempts together make up the conversation — the + # agent SDK's checkpointer guarantees done-items are not re-emitted + # by later attempts, so this is a union with no duplicates. output = [] - for _, _, evt in messages: - if evt and "item" in evt: - output.append(evt["item"]) + for _, _, evt, _attempt in messages: + if evt and evt.get("type") == "response.output_item.done": + output.append(evt.get("item")) result: dict[str, Any] = { "id": response_id, "status": "completed", - "output": output, + "output": [o for o in output if o is not None], + "attempt_number": resp.attempt_number, } if trace_id: result["metadata"] = {"trace_id": trace_id} return result if status == "failed" and messages: - for _, _, evt in messages: + for _, _, evt, _attempt in messages: if evt and evt.get("type") == "error": - return {"id": response_id, "status": "failed", "error": evt.get("error")} - return {"id": response_id, "status": status} + return { + "id": response_id, + "status": "failed", + "error": evt.get("error"), + "attempt_number": resp.attempt_number, + } + return { + "id": response_id, + "status": status, + "attempt_number": resp.attempt_number, + } async def _stream_retrieve( self, @@ -638,15 +1333,26 @@ async def _stream_retrieve( ) break - _, status, _, _ = resp + status = resp.status + # Self-heal: if this response is still in_progress but its owning + # pod has gone silent past heartbeat_stale_threshold, try to claim + # + resume on this pod. A no-op if heartbeat is fresh or another + # pod already won. Without this, a stream opened before the crash + # would idle forever polling a dead run — since _try_claim_and_resume + # is only triggered by the outer retrieve handler on fresh GETs. + if status == "in_progress": + await self._try_claim_and_resume(response_id, resp) + # starting_after=0 fetches all messages (sequence numbers start at 0). # We use after_sequence=-1 for the DB query so that seq 0 is included. after_seq = last_seq - 1 if last_seq == 0 else last_seq messages = await get_messages(response_id, after_sequence=after_seq) - for seq, _, evt in messages: + for seq, _, evt, _attempt in messages: if evt is not None: - evt = {**evt, "sequence_number": seq} + # Tag every SSE frame with the response_id so proxies / + # clients can discover it without parsing nested fields. + evt = {**evt, "sequence_number": seq, "response_id": response_id} event_type = evt.get("type", "message") logger.debug( "SSE event", diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 7b646116..30f1ad02 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -15,6 +15,8 @@ class LongRunningSettings: poll_interval_seconds: float = 1.0 db_statement_timeout_ms: int = 5000 cleanup_timeout_seconds: float = 7.0 + heartbeat_interval_seconds: float = 3.0 + heartbeat_stale_threshold_seconds: float = 10.0 def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: @@ -25,6 +27,16 @@ def __post_init__(self) -> None: raise ValueError("db_statement_timeout_ms must be positive") if self.cleanup_timeout_seconds <= 0: raise ValueError("cleanup_timeout_seconds must be positive") + if self.heartbeat_interval_seconds <= 0: + raise ValueError("heartbeat_interval_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= 0: + raise ValueError("heartbeat_stale_threshold_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= self.heartbeat_interval_seconds: + raise ValueError( + f"heartbeat_stale_threshold_seconds ({self.heartbeat_stale_threshold_seconds}) " + f"must be strictly greater than heartbeat_interval_seconds " + f"({self.heartbeat_interval_seconds}) to avoid false stale-run detection." + ) db_timeout_s = self.db_statement_timeout_ms / 1000.0 if self.cleanup_timeout_seconds <= db_timeout_s: raise ValueError( diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index a1290ba1..d425da44 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -160,10 +160,12 @@ async def test_get_messages(mock_session): result_mock.scalars.return_value.all.return_value = [msg1, msg2] mock_session.execute.return_value = result_mock + msg1.attempt_number = 1 + msg2.attempt_number = 1 messages = await get_messages("resp_abc123", after_sequence=None) assert len(messages) == 2 - assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}) - assert messages[1] == (1, None, None) + assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}, 1) + assert messages[1] == (1, None, None, 1) @pytest.mark.asyncio @@ -175,6 +177,10 @@ async def test_get_response(mock_session): row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) row.trace_id = "trace_xyz" + row.owner_pod_id = None + row.heartbeat_at = None + row.attempt_number = 1 + row.original_request = None result_mock = MagicMock() result_mock.scalar_one_or_none.return_value = row mock_session.execute.return_value = result_mock @@ -185,6 +191,10 @@ async def test_get_response(mock_session): "completed", datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc), "trace_xyz", + None, # owner_pod_id + None, # heartbeat_at + 1, # attempt_number + None, # original_request ) @@ -283,9 +293,16 @@ async def test_creates_schema_and_tables(self, reset_db_globals): with patch(f"{DB_MODULE}.AsyncLakebaseSQLAlchemy", mock_cls), patch(f"{DB_MODULE}.event"): await init_db(autoscaling_endpoint="ep") - mock_conn.execute.assert_awaited_once() - sql_arg = str(mock_conn.execute.call_args[0][0]) - assert "CREATE SCHEMA IF NOT EXISTS" in sql_arg + # init_db runs: CREATE SCHEMA + run_sync(create_all) + a series of + # ADD COLUMN IF NOT EXISTS / CREATE INDEX IF NOT EXISTS to migrate + # the durability columns onto pre-existing tables. + all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list) + assert "CREATE SCHEMA IF NOT EXISTS" in all_sql + assert "ADD COLUMN IF NOT EXISTS owner_pod_id" in all_sql + assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql + assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql + assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql + assert "idx_responses_stale" in all_sql mock_conn.run_sync.assert_awaited_once() @pytest.mark.asyncio @@ -346,3 +363,110 @@ async def fake_factory(): monkeypatch.setattr(db_mod, "_session_factory", fake_factory) async with session_scope() as session: assert session is mock_session + + +# --------------------------------------------------------------------------- +# Durability metadata: owner_pod_id, heartbeat, claim, attempt_number +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_response_with_owner_and_original_request(mock_session): + """New background callers stamp pod id + serialized request on creation — + without these, a resumed pod can't re-invoke the handler.""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response( + "resp_abc", + "in_progress", + owner_pod_id="pod-1", + original_request={"input": [{"role": "user", "content": "hi"}]}, + ) + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id == "pod-1" + assert added.heartbeat_at is not None + # original_request is JSON-encoded for Text storage. + assert '"role": "user"' in added.original_request + + +@pytest.mark.asyncio +async def test_create_response_without_durability_metadata(mock_session): + """Legacy/no-durability callers should still work and write no + owner/heartbeat (so the stale sweep can't accidentally claim them).""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response("resp_x", "in_progress") + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id is None + assert added.heartbeat_at is None + assert added.original_request is None + + +@pytest.mark.asyncio +async def test_heartbeat_response_updates_timestamp(mock_session): + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 1 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is True + mock_session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_heartbeat_response_fails_when_not_owner(mock_session): + """If the CAS misses (owner changed / row deleted), heartbeat reports + failure so the caller can stop looping.""" + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 0 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is False + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_attempt_number(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + row = MagicMock() + row.__iter__ = lambda self: iter([2]) + row.__getitem__ = lambda self, i: 2 + result_mock = MagicMock() + result_mock.first.return_value = row + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt == 2 + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_none_when_not_eligible(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + result_mock = MagicMock() + result_mock.first.return_value = None + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt is None + + +@pytest.mark.asyncio +async def test_append_message_with_attempt_number(mock_session): + """Resumed events must be tagged with the resume attempt so retrieve can + filter or the client can render the response.resumed boundary cleanly.""" + from databricks_ai_bridge.long_running.repository import append_message + + await append_message("resp_abc", 5, stream_event={"x": 1}, attempt_number=3) + added = mock_session.add.call_args[0][0] + assert added.attempt_number == 3 + assert added.sequence_number == 5 diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 27b6eaf5..0cf00c53 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,10 +14,15 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") +from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( LongRunningAgentServer, + _build_prose_recovery_message, _deferred_mark_failed, + _inject_conversation_id, + _rotate_conversation_id, _sse_event, + _trim_echoed_history, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings @@ -34,6 +39,40 @@ def _make_server(**kwargs): return LongRunningAgentServer("ResponsesAgent", **kwargs) +def _resp_info( + response_id: str = "resp_123", + status: str = "in_progress", + created_at=None, + trace_id: str | None = None, + owner_pod_id: str | None = None, + heartbeat_at=None, + attempt_number: int = 1, + original_request: dict | None = None, +) -> ResponseInfo: + """Build a ResponseInfo with sensible defaults for tests. + + Mirrors the server's repository model so test setups stay terse even as + durability columns grow over time. + """ + if created_at is None: + created_at = datetime.now(timezone.utc) + return ResponseInfo( + response_id=response_id, + status=status, + created_at=created_at, + trace_id=trace_id, + owner_pod_id=owner_pod_id, + heartbeat_at=heartbeat_at, + attempt_number=attempt_number, + original_request=original_request, + ) + + +def _msg(seq: int, item=None, evt=None, attempt: int = 1): + """Build a (seq, item, stream_event, attempt_number) tuple for get_messages mocks.""" + return (seq, item, evt, attempt) + + def _mock_span(): """Return a mock MLflow span with the attributes the server uses.""" span = MagicMock() @@ -55,8 +94,8 @@ def _mock_validator(server): class TestSSEEvent: def test_dict_data(self): result = _sse_event("response.created", {"id": "resp_123", "status": "in_progress"}) - assert result.startswith("event: response.created\n") - assert "data: " in result + assert result.startswith("data: ") + assert "event:" not in result assert result.endswith("\n\n") data_line = result.split("data: ")[1].strip() parsed = json.loads(data_line) @@ -64,8 +103,8 @@ def test_dict_data(self): def test_string_data(self): result = _sse_event("error", "something went wrong") - assert "event: error\n" in result - assert "data: something went wrong\n\n" in result + assert "event:" not in result + assert result == "data: something went wrong\n\n" class TestLongRunningSettings: @@ -189,7 +228,7 @@ def test_starting_after_zero_without_stream_is_allowed(self): patch( f"{MODULE}.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( f"{MODULE}.get_messages", @@ -209,8 +248,13 @@ async def test_marks_response_failed(self): patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, - return_value=[(0, None, {"type": "response.created"})], + return_value=[_msg(0, None, {"type": "response.created"})], ) as mock_get, + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch( "databricks_ai_bridge.long_running.server.append_message", new_callable=AsyncMock, @@ -271,13 +315,13 @@ async def test_completed_returns_output(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), "trace_abc"), + return_value=_resp_info("resp_123", "completed", trace_id="trace_abc"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - ( + _msg( 0, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -305,7 +349,7 @@ async def test_stale_run_detection(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_stale", "in_progress", old_time, None), + return_value=_resp_info("resp_stale", "in_progress", created_at=old_time), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -336,7 +380,7 @@ async def test_in_progress_returns_status(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -347,7 +391,11 @@ async def test_in_progress_returns_status(self): result = await server._handle_retrieve_request( "resp_123", stream=False, starting_after=0 ) - assert result == {"id": "resp_123", "status": "in_progress"} + assert result == { + "id": "resp_123", + "status": "in_progress", + "attempt_number": 1, + } class TestStreamRetrieve: @@ -360,14 +408,14 @@ async def test_completed_stream(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "completed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created", "id": "resp_123"}), - ( + _msg(0, None, {"type": "response.created", "id": "resp_123"}), + _msg( 1, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -394,13 +442,13 @@ async def test_failed_stream_stops(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "failed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "failed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "error", "error": {"message": "boom"}}), + _msg(0, None, {"type": "error", "error": {"message": "boom"}}), ], ), ): @@ -668,10 +716,15 @@ async def test_exception_writes_error_event_inline(self): f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created"}), - (1, None, {"type": "response.output_text.delta"}), + _msg(0, None, {"type": "response.created"}), + _msg(1, None, {"type": "response.output_text.delta"}), ], ), + patch( + f"{MODULE}.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, ): @@ -875,3 +928,670 @@ async def test_lifespan_not_set_when_db_not_configured(self): routes = [r.path for r in server.app.routes if hasattr(r, "path")] assert "/responses/{response_id}" in routes mock_init.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Durable resume: claim/heartbeat/attempt_number/sentinel +# --------------------------------------------------------------------------- + + +class TestBuildProseRecoveryMessage: + """Prose recovery serializer: walk a prior attempt's events and produce a + single Responses-API user-message item narrating what happened, for the + next attempt's LLM to read as a recovery instruction.""" + + def _done(self, seq, attempt, item): + return (seq, None, {"type": "response.output_item.done", "item": item}, attempt) + + def test_returns_user_message_shape(self): + out = _build_prose_recovery_message([], prior_attempt_number=1) + assert out["type"] == "message" + assert out["role"] == "user" + assert isinstance(out["content"], str) + assert "[RECOVERY]" in out["content"] + + def test_completed_call_with_output(self): + messages = [ + self._done( + 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"} + ), + self._done(1, 1, {"type": "function_call_output", "call_id": "c1", "output": "ok"}), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + body = out["content"] + assert "Called `f({})`" in body + assert "got result: ok" in body + + def test_call_without_output_marked_interrupted(self): + messages = [ + self._done( + 0, + 1, + { + "type": "function_call", + "call_id": "c1", + "name": "deep_research", + "arguments": "", + }, + ), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + assert "interrupted before a result was returned" in out["content"] + + def test_filters_other_attempts(self): + messages = [ + self._done( + 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"} + ), + self._done( + 1, 2, {"type": "function_call", "call_id": "c2", "name": "g", "arguments": "{}"} + ), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + assert "f({})" in out["content"] + assert "g(" not in out["content"] + + def test_partial_text_from_deltas(self): + messages = [ + ( + 0, + None, + {"type": "response.output_item.added", "item": {"type": "message", "id": "m1"}}, + 1, + ), + ( + 1, + None, + {"type": "response.output_text.delta", "item_id": "m1", "delta": "Hello, "}, + 1, + ), + (2, None, {"type": "response.output_text.delta", "item_id": "m1", "delta": "world"}, 1), + ] + out = _build_prose_recovery_message(messages, prior_attempt_number=1) + assert "Was generating" in out["content"] + assert "Hello, world" in out["content"] + + def test_empty_attempt_falls_back_to_default_recovery(self): + out = _build_prose_recovery_message([], prior_attempt_number=1) + assert "interrupted before any tool calls" in out["content"] + + +class TestTrimEchoedHistory: + """SDK-agnostic dedup: when request input contains an assistant message, + the client is echoing prior conversation history; trim to the latest + user message and trust the SDK's session/checkpointer storage as + authoritative for prior turns.""" + + def test_first_turn_passthrough(self): + items = [{"role": "user", "content": "hi"}] + assert _trim_echoed_history(items) is items + + def test_first_turn_with_no_assistant_passthrough(self): + # Multi-message input but no assistant role yet — first turn from a + # client preserving its own prior user turns. Pass through. + items = [ + {"role": "user", "content": "u1"}, + {"role": "user", "content": "u2"}, + ] + assert _trim_echoed_history(items) is items + + def test_continuation_trims_to_last_user(self): + items = [ + {"role": "user", "content": "u1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + ] + out = _trim_echoed_history(items) + assert len(out) == 1 + assert out[0]["role"] == "user" + assert out[0]["content"] == "u2" + + def test_continuation_with_tool_history_trims_correctly(self): + items = [ + {"role": "user", "content": "u1"}, + {"role": "assistant", "content": "a1"}, + {"role": "user", "content": "u2"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + {"role": "assistant", "content": "a2"}, + {"role": "user", "content": "u3"}, + ] + out = _trim_echoed_history(items) + assert len(out) == 1 + assert out[0]["content"] == "u3" + + def test_resume_path_passthrough(self): + # Resume-built input: original input + prose user message, no + # assistant role. Trim is a no-op so the prose recovery payload is + # preserved. + items = [ + {"role": "user", "content": "original"}, + {"type": "message", "role": "user", "content": "[RECOVERY] ..."}, + ] + out = _trim_echoed_history(items) + assert out is items + + def test_empty_input(self): + assert _trim_echoed_history([]) == [] + + +class TestRotateConversationId: + def test_rotate_drops_thread_id_and_sets_rotated_context(self): + r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "thread_id" not in out["custom_inputs"] + assert out["custom_inputs"]["user_id"] == "u" + assert out["context"]["conversation_id"] == "t1::attempt-2" + + def test_rotate_drops_session_id(self): + r = {"custom_inputs": {"session_id": "s1"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "session_id" not in out["custom_inputs"] + assert out["context"]["conversation_id"] == "s1::attempt-2" + + def test_rotate_falls_back_to_context_conversation_id(self): + r = {"custom_inputs": {}, "context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-3" + + def test_rotate_falls_back_to_response_id_as_last_resort(self): + r = {"custom_inputs": {}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "resp_x::attempt-2" + + def test_rotate_handles_missing_custom_inputs_key(self): + r = {"context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-2" + assert out["custom_inputs"] == {} + + +class TestInjectConversationId: + """Anchoring an otherwise-anonymous request to a response_id guarantees a + resumed run on a new pod resolves to the same agent-SDK thread/session.""" + + def test_injects_when_nothing_set(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_respects_existing_conversation_id(self): + r = {"input": [], "context": {"conversation_id": "user-set"}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "user-set" + + def test_respects_thread_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"thread_id": "t-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + # When the client already pinned a thread, we don't overwrite — the + # template's _get_or_create_thread_id picks up custom_inputs first. + assert "conversation_id" not in (out["context"] or {}) + + def test_respects_session_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"session_id": "s-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert "conversation_id" not in (out["context"] or {}) + + def test_handles_missing_context_key(self): + r = {"input": [], "custom_inputs": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_does_not_mutate_input(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + _inject_conversation_id(r, "resp_abc") + assert r["context"] == {} # original untouched + + +class TestHandleBackgroundRequestPersistsDurabilityState: + """Background request entry point should now stamp the response row with + the caller's pod, the original request body, and a conversation anchor.""" + + @pytest.mark.asyncio + async def test_persists_owner_and_original_request(self): + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + _mock_validator(server) + + captured: dict = {} + + async def fake_create_response( + response_id, status, *, owner_pod_id=None, original_request=None + ): + captured["response_id"] = response_id + captured["status"] = status + captured["owner_pod_id"] = owner_pod_id + captured["original_request"] = original_request + + with ( + patch(f"{MODULE}.create_response", side_effect=fake_create_response), + patch("asyncio.create_task") as mock_create_task, + ): + result = await server._handle_background_request( + {"input": [{"role": "user", "content": "hi"}]}, + is_streaming=False, + return_trace_id=False, + ) + + assert captured["status"] == "in_progress" + assert captured["owner_pod_id"] # non-empty + # original_request should include input + injected conversation_id. + orig = captured["original_request"] + assert orig["input"] == [{"role": "user", "content": "hi"}] + assert orig["context"]["conversation_id"] == captured["response_id"] + # Return shape: immediate response_obj, not a stream. + assert result["id"] == captured["response_id"] + assert result["status"] == "in_progress" + mock_create_task.assert_called_once() + + +class TestTryClaimAndResume: + @pytest.mark.asyncio + async def test_no_op_when_completed(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + resp = _resp_info(status="completed") + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_grace_period_for_fresh_run(self): + """Just-started runs get a grace window before they're claim-eligible.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_stale_threshold_seconds=15.0 + ) + # created 2s ago, no heartbeat yet → should NOT be claimed. + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=2), + heartbeat_at=None, + original_request={"input": []}, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_op_without_original_request(self): + """Legacy rows created before durability metadata can't be resumed.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=None, + original_request=None, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_claim_fails_returns_none(self): + """Another pod won the race — we quietly step aside.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=300), + original_request={"input": [{"role": "user"}]}, + ) + with ( + patch( + f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=None + ) as mock_claim, + patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, + ): + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_awaited_once() + mock_append.assert_not_awaited() + + @pytest.mark.asyncio + async def test_successful_claim_spawns_resume_and_emits_sentinel(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"user_id": "u"}, + "context": {"conversation_id": "resp_x"}, + }, + ) + captured: dict = {} + + async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt_number=1): + captured["seq"] = seq + captured["event"] = stream_event + captured["attempt_tag"] = attempt_number + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch( + f"{MODULE}.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {}), _msg(1, None, {})], + ), + patch(f"{MODULE}.append_message", side_effect=fake_append), + patch("asyncio.create_task") as mock_create_task, + ): + attempt = await server._try_claim_and_resume("resp_x", resp) + + assert attempt == 2 + # Sentinel is written at next_seq (existing seqs were 0 and 1). + assert captured["seq"] == 2 + assert captured["event"]["type"] == "response.resumed" + assert captured["event"]["attempt"] == 2 + assert captured["attempt_tag"] == 2 + # A resume task is spawned; it was not awaited synchronously. + mock_create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_resume_replays_input_and_rotates_conversation_id(self): + """Resume must replay original_request.input (not blank it) and rotate + the conversation anchor so the handler resolves to a fresh thread / + session for the new attempt. Prevents the LangGraph stream-event + attempt-boundary orphan artifact (rotation-findings.md).""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"thread_id": "t1", "user_id": "u"}, + "context": {}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Input is REPLAYED (not blanked) and a prose-recovery user message is + # appended so attempt N+1's LLM sees the original request plus a + # narrative of what happened. The MLflow validator normalizes the shape. + assert len(dumped["input"]) == 2 + assert dumped["input"][0]["role"] == "user" + assert dumped["input"][0]["content"] == "hi" + assert dumped["input"][1]["role"] == "user" + assert "[RECOVERY]" in dumped["input"][1]["content"] + # thread_id was dropped so the handler's priority-2 fallback wins. + assert "thread_id" not in (dumped["custom_inputs"] or {}) + # Other custom_inputs keys are preserved. + assert dumped["custom_inputs"]["user_id"] == "u" + # conversation_id is rotated to a per-attempt value anchored on t1. + assert dumped["context"]["conversation_id"] == "t1::attempt-2" + assert kwargs.get("attempt_number") == 2 + + @pytest.mark.asyncio + async def test_resume_rotation_anchors_on_context_conversation_id(self): + """When the client didn't pin a thread_id/session_id, rotation uses + the injected context.conversation_id as the base anchor.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {}, + "context": {"conversation_id": "resp_x"}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=3), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Rotation anchors on the stored context.conversation_id (priority 2). + # Note: re-rotating in a subsequent attempt would re-anchor on the + # ORIGINAL stored value, not the previous rotation — no stacking. + assert dumped["context"]["conversation_id"] == "resp_x::attempt-3" + + +class TestRetrieveTriggersLazyClaim: + @pytest.mark.asyncio + async def test_retrieve_calls_try_claim(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + + resp = _resp_info("resp_x", "in_progress") + with ( + patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch.object( + server, "_try_claim_and_resume", new_callable=AsyncMock, return_value=None + ) as mock_claim, + ): + await server._handle_retrieve_request("resp_x", stream=False, starting_after=0) + + mock_claim.assert_awaited_once() + + +class TestHeartbeatContextManager: + @pytest.mark.asyncio + async def test_writes_heartbeat_periodically(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.2) # enough time for 2+ heartbeats + + # Heartbeat interval is 0.05s so we should see at least 2 writes. + assert mock_hb.await_count >= 2 + for call in mock_hb.await_args_list: + assert call.args[0] == "resp_x" + + @pytest.mark.asyncio + async def test_stops_cleanly_on_exit(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + pass # immediate exit + + # Give the heartbeat loop a chance to observe the stop signal. + await asyncio.sleep(0.1) + writes_after_exit = mock_hb.await_count + + await asyncio.sleep(0.15) + # No new writes after the scope closed. + assert mock_hb.await_count == writes_after_exit + + @pytest.mark.asyncio + async def test_db_error_does_not_interrupt_body(self): + """Heartbeat failures are logged, not raised — the stale check catches + real death, so a transient write miss must not kill a live run.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", + heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + body_ran = False + with patch( + f"{MODULE}.heartbeat_response", + new_callable=AsyncMock, + side_effect=RuntimeError("db down"), + ): + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.1) + body_ran = True + assert body_ran + + +class TestSettingsHeartbeatValidation: + def test_stale_must_exceed_interval(self): + with pytest.raises(ValueError, match="heartbeat_stale_threshold_seconds"): + LongRunningSettings( + heartbeat_interval_seconds=5.0, + heartbeat_stale_threshold_seconds=5.0, + ) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError, match="heartbeat_interval_seconds must be positive"): + LongRunningSettings(heartbeat_interval_seconds=0) + + def test_defaults_match_chat_ux(self): + # 3s interval + 15s stale gives ~5 heartbeats before a pod is considered + # dead — snug enough to recover conversations within a user's + # "reconnecting..." patience window. + s = LongRunningSettings() + assert s.heartbeat_interval_seconds == 3.0 + assert s.heartbeat_stale_threshold_seconds == 10.0 + + +class TestDebugKillTask: + """The opt-in debug-kill endpoint lets integration tests simulate a crash + against a deployed pod without restarting the whole app. Off by default + because exposing task cancellation bypasses the normal cleanup path.""" + + def test_endpoint_absent_by_default(self): + from starlette.testclient import TestClient + + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + resp = client.post("/_debug/kill_task/resp_x") + assert resp.status_code == 404 # route not registered + + def test_endpoint_registered_when_env_set(self, monkeypatch): + from starlette.testclient import TestClient + + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + # No in-flight task for this response_id on this pod → 404, not 405. + resp = client.post("/_debug/kill_task/resp_missing") + assert resp.status_code == 404 + assert "No in-flight task" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_cancels_tracked_task(self, monkeypatch): + """Direct-call variant: skip the TestClient (which is sync and blocks + the loop) and call the handler logic through _running_tasks directly. + Covers the important behavior: cancelling a tracked task propagates + CancelledError and the tracking dict is cleared by the done-callback. + """ + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + + cancel_event = asyncio.Event() + + async def long_running(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_event.set() + raise + + task = asyncio.create_task(long_running()) + server._track_task("resp_tracked", task) + + # Yield once so the new task can start waiting on sleep(60). + await asyncio.sleep(0) + assert "resp_tracked" in server._running_tasks + + task.cancel() + # Expect CancelledError from awaiting the task itself, and the cancel + # event set inside the except handler before the re-raise. + with pytest.raises(asyncio.CancelledError): + await task + assert cancel_event.is_set() + # done-callback (scheduled on loop) clears the registration after the + # task completes — give it one more tick. + await asyncio.sleep(0) + assert "resp_tracked" not in server._running_tasks From e0373caf2da81f3e152e7740cf8ab92f5673ac7c Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 28 Apr 2026 17:46:23 +0000 Subject: [PATCH 2/9] Add AGENTS.md describing LongRunningAgentServer's design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Standalone documentation of the durable-resume design — for future maintainers to understand the system without reading the PR diff. Sections: - §1 Module purpose, capabilities, guarantees, non-goals - §2 Four customer journeys with sequence diagrams (author writes agent, pod crashes mid-tool, subsequent turn after crash, multi-pod contention) - §3 Architecture: storage layout (ER diagram), key flows (flowchart), prose-recovery construction (flowchart), UI-echo dedup (flowchart), heartbeat tuning, CAS atomicity (sequence diagram) - §4 Author-side requirements: what's invisible, chat-UI alias cooperation, exposed settings - §5 TaskFlow migration mapping (today → TaskFlow primitives, what stays, what gets deleted, sequencing) Six mermaid diagrams: 4 sequence diagrams, 1 ER, 4 flowcharts. Co-authored-by: Isaac --- .../long_running/AGENTS.md | 423 ++++++++++++++++++ 1 file changed, 423 insertions(+) create mode 100644 src/databricks_ai_bridge/long_running/AGENTS.md diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md new file mode 100644 index 00000000..0fd06ac9 --- /dev/null +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -0,0 +1,423 @@ +# LongRunningAgentServer + +Durable, crash-resumable agent execution for MLflow `ResponsesAgent` handlers. + +This document describes: +1. What `LongRunningAgentServer` does and the guarantees it gives callers ([§1](#1-what-this-module-does)). +2. The four customer journeys it covers, with sequence diagrams ([§2](#2-customer-journeys)). +3. The architecture: storage layout, claim mechanism, recovery, and stream resume ([§3](#3-architecture)). +4. Author-side requirements: what changes (and doesn't) when a handler opts into durable mode ([§4](#4-author-side-requirements)). +5. The interface today and how it's expected to evolve when [TaskFlow](https://github.com/databricks-eng/universe/tree/master/experimental/taskflow) lands ([§5](#5-future-direction-taskflow)). + +## 1. What this module does + +`LongRunningAgentServer` extends MLflow's `AgentServer` for `ResponsesAgent` handlers with three capabilities: + +1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres. +2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events. +3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run and finishes the work via **prose recovery**: the new attempt receives a single user message narrating the crashed attempt's completed tool calls, results, and partial output, and continues from there on a freshly-rotated SDK session. Tool results that completed before the crash are preserved through the prose narrative. + +Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server. + +### Guarantees + +- **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`. +- **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`. +- **SDK-agnostic recovery.** The resumed attempt receives a flat prose narrative — no provider-specific tool-pair structure, no synthetic tool events, no per-SDK adapter code. +- **SDK-agnostic UI-echo dedup.** When the chatbot client echoes prior conversation history in `request.input`, the server detects this (presence of an `assistant`-role message) and trims to the latest user message. The SDK's session/checkpointer storage is treated as authoritative for prior turns. +- **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility. +- **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. All durability lives below the handler boundary. + +### Non-goals + +- Cross-region failover. Pods are assumed to share one Lakebase. +- Tool-level checkpointing / exactly-once tool execution. +- A workflow DSL. Handlers are ordinary async generators / coroutines. + +## 2. Customer journeys + +### CUJ 1: Author writes a long-running agent + +The author subclasses `LongRunningAgentServer` and registers `@invoke()` / `@stream()` handlers like a regular MLflow agent server. **No durability code in `agent.py`.** + +```python +from databricks_ai_bridge.long_running import LongRunningAgentServer +from mlflow.genai.agent_server import invoke, stream + +agent_server = LongRunningAgentServer( + "ResponsesAgent", + db_instance_name="my-lakebase-instance", +) + +@stream() +async def stream_handler(request): + # ordinary agent code: build messages, call SDK, yield events + ... + +@invoke() +async def invoke_handler(request): + ... + +app = agent_server.app +``` + +The agent author writes their handler exactly the same way they would for the non-durable `AgentServer`. `LongRunningAgentServer` adds the durable wiring transparently. + +### CUJ 2: Pod crashes mid-tool, client polls + +A client posts a long-running request, the owning pod dies mid-tool, another pod takes over via prose recovery, and the client gets the final output without restarting. + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant A as Pod A (owner) + participant DB as Lakebase
(agent_server.*) + participant B as Pod B + participant SDK as SDK store
(rotated session) + + C->>A: POST /responses
{input, background:true, stream:true} + A->>DB: INSERT response_id, owner=A,
original_request (full input), attempt=1 + A-->>C: 200 {id: resp_xxx, status: in_progress} + activate A + Note over A: heartbeat loop (every 3s) + A->>DB: UPDATE heartbeat_at WHERE owner=A + + par Streaming + A->>SDK: write to session T (original conv_id) + A->>DB: append events seq=0..N (attempt=1) + and Polling + C->>A: GET /responses/{id}?stream=true&starting_after=N + A-->>C: SSE events + end + + Note over A: 💥 pod crashes mid-tool
(after tool_use, before tool_result) + deactivate A + Note over DB: heartbeat_at goes stale (>10s old) + + C->>B: GET /responses/{id}?stream=true&starting_after=K + B->>DB: SELECT heartbeat_at, attempt + Note over B: heartbeat stale → claim + B->>DB: CAS UPDATE owner=B, attempt=2
WHERE attempt=1 (atomic) + B->>DB: SELECT prior events WHERE attempt=1 + Note over B: build resume input:
· original_request.input (full prior turns)
· + prose recovery message narrating attempt 1
· rotate conv_id → ::attempt-2 + B->>DB: append response.resumed sentinel
{conversation_id: rotated value} + B->>B: re-invoke @stream() handler
(fresh rotated SDK session) + activate B + B->>SDK: write to T::attempt-2 (clean) + B->>DB: append events seq=K+1..M (attempt=2) + B-->>C: SSE events (response.resumed, then attempt-2 events) + deactivate B +``` + +**What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the attempt boundary and carries the rotated `conversation_id` so the chatbot can use the rotated session for subsequent turns. + +**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus a single prose user message describing what happened in attempt 1. The handler doesn't have to know any of this is durable resume — to the model, it just looks like a user said "the prior attempt crashed, here's what had happened, please continue." + +### CUJ 3: Subsequent turn after a crashed turn + +After a successful crash + resume, the next turn from the client lands on a fresh `POST /responses`. The chatbot uses the **rotated** `conversation_id` (captured from the `response.resumed` sentinel) so the handler resolves to the rotated SDK session — which was populated cleanly during attempt 2's prose-recovery run. + +```mermaid +sequenceDiagram + autonumber + participant C as Chatbot + participant S as Server (any pod) + participant SDK as SDK store + + Note over SDK: state at original conv_id T:
incomplete (orphan tool_use)
state at T::attempt-2:
complete (prose + resumed turn output) + + C->>C: lookup alias map:
chat_id → T::attempt-2 + C->>S: POST /responses
{input: [echo of full UI history, new user msg],
context.conversation_id: T::attempt-2} + S->>S: _trim_echoed_history
(detects assistant in input → trim to last user) + S->>S: handler resolves session_id = T::attempt-2 + S->>SDK: load history at T::attempt-2 + SDK-->>S: clean state (prose + attempt-2 emissions) + S->>S: model receives [clean history] + [new user msg] + Note over S: turn succeeds normally + + Note over SDK: original session T is now orphaned forever
(never read again — chatbot uses rotated alias) +``` + +Three things make this work without per-SDK repair code: + +1. **Server emits the rotated conv_id in `response.resumed`.** The chatbot reads it and updates its `Map` alias. +2. **Server-side trim of UI-echoed history.** The presence of an `assistant`-role item in `request.input` is a reliable proxy for "client is echoing prior history." Trim to the latest user message; trust the SDK's session as authoritative for prior turns. +3. **Always-rotate.** The rotated session was the one populated during attempt 2's run. Subsequent turns land on it. The original poisoned session is never read. + +### CUJ 4: Multi-pod stale-claim contention + +Two pods each see the same response in `in_progress` with a stale heartbeat. Only one wins the CAS. + +```mermaid +sequenceDiagram + autonumber + participant B as Pod B + participant DB as Lakebase + participant C as Pod C + + Note over B,C: both pods see response in_progress, heartbeat stale + par + B->>DB: UPDATE responses SET owner=B, attempt=N+1
WHERE response_id=R AND attempt=N + and + C->>DB: UPDATE responses SET owner=C, attempt=N+1
WHERE response_id=R AND attempt=N + end + Note over DB: only one row matches; the other UPDATE returns 0 rows + DB-->>B: RETURNING attempt_number=N+1 + DB-->>C: RETURNING (no row) + Note over B: B wins, builds resume input + spawns handler + Note over C: C aborts cleanly, returns to its retrieve loop +``` + +The `claim_stale_response` function (`repository.py`) executes a single `UPDATE … WHERE attempt_number = :current AND ((heartbeat_at IS NULL) OR (heartbeat_at < now() - interval))` with `RETURNING`. Postgres serializes the writes; only the pod whose `current` value was unmodified at commit time gets the `RETURNING` row. + +## 3. Architecture + +### 3.1 Storage layout + +Two tables in the `agent_server` schema: + +```mermaid +erDiagram + responses { + text response_id PK + text status "in_progress / completed / failed" + timestamptz created_at + text owner_pod_id + timestamptz heartbeat_at + int attempt_number + text original_request "JSON of initial POST (full input)" + text trace_id + } + messages { + text response_id FK + int sequence_number + int attempt_number + text item "JSON of output item" + text stream_event "JSON of SSE frame" + } + responses ||--o{ messages : "has" +``` + +- `responses.attempt_number` is the CAS guard for claim atomicity. +- `responses.original_request` stores the **full untrimmed input** so the resume path can recover the entire prior-turn history when the rotated SDK session starts empty. +- `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body). +- Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime. + +### 3.2 The four key flows + +```mermaid +flowchart TD + POST["POST /responses
background=true"] --> CR[create_response
store FULL original_request] + CR --> TRIM[deepcopy + _trim_echoed_history
for handler invocation] + TRIM --> SPAWN[spawn @stream handler] + SPAWN --> HB[heartbeat loop
every 3s] + SPAWN --> EMIT[append SSE events
to messages table] + EMIT --> DONE{handler
exits?} + DONE -- yes --> COMPLETE[update status=completed] + DONE -- crash --> STALE[heartbeat stops
row goes stale] + + GET["GET /responses/{id}
?stream=true&starting_after=N"] --> CHECK[check heartbeat age] + CHECK -->|fresh or terminal| READ[read events from messages
where seq > N] + CHECK -->|stale| CLAIM[CAS claim
attempt += 1] + CLAIM -->|won| BUILD[build prose recovery message
rotate conv_id
append rotated id to response.resumed] + CLAIM -->|lost| READ + BUILD --> SPAWN + READ --> STREAM[SSE stream to client] + + KILL["POST /_debug/kill_task/{id}
(test-only)"] --> CANCEL[cancel asyncio task
without status update] + CANCEL --> STALE +``` + +### 3.3 Resume input construction + +When a stale-claim CAS succeeds, the new owner builds the resume input from the prior attempt's emitted events as a flat prose narrative: + +```mermaid +flowchart LR + PRIOR[prior attempt's events
from messages table] --> WALK[_build_prose_recovery_message] + WALK --> CALLS[walk completed function_call /
function_call_output items by call_id] + WALK --> NARR[walk completed message items
extract text] + WALK --> PARTIAL[walk in-progress message deltas
concat as 'was generating: ...'] + + CALLS --> COMPOSE + NARR --> COMPOSE + PARTIAL --> COMPOSE[compose single user message:
'[RECOVERY] previous attempt crashed.
Here is what had completed:
- Called f(...) and got result: ...
- Called g(...) — interrupted before result
- Was generating: ...'] + + ROT[_rotate_conversation_id
::attempt-N suffix] --> SUBMIT + COMPOSE --> SUBMIT[append to original_request.input
spawn handler with rotated request
emit response.resumed sentinel] +``` + +Why prose: the LLM reads it as a recovery instruction and decides what to do (re-run the interrupted tool, skip completed ones, summarize from there). No structural carry-forward, no synthetic tool events, no per-SDK pairing rules. Trades cache hit rate (the prose is a fresh user message, not a stable structural prefix) for SDK-agnostic simplicity. + +Why rotation: the original SDK session may carry mid-turn state from the crashed attempt (orphan `tool_use`, partial checkpoint) that's hard to repair from outside the SDK. Rotating to `{base}::attempt-N` opens a fresh, empty session for the resumed attempt; the prose narrative is the single source of truth for what already happened. + +Why the sentinel carries the rotated conv_id: cooperating chat clients capture it (via SSE) and use the rotated session for subsequent turns, so the original orphan-poisoned session is never read again. + +### 3.4 SDK-agnostic UI-echo dedup + +```mermaid +flowchart LR + REQ[POST /responses
request.input from client] --> CHECK{any item
has role: assistant?} + CHECK -->|no| PASS[pass through unchanged
typical first-turn POST] + CHECK -->|yes| TRIM[client is echoing history
find last role: user index
slice items from there] + TRIM --> RESULT[result: latest user message only] + PASS --> HANDLER + RESULT --> HANDLER[handler invocation] +``` + +The trim is the SDK-agnostic equivalent of a per-SDK "is this a continuation turn?" probe. By inspecting input shape, the server avoids: +- Querying the SDK's storage to ask "do you have prior turns?" (per-SDK call) +- Walking provider-specific tool-pair structure +- Any SDK adapter code + +The trim runs on every `POST /responses` and is a no-op when input has no assistant message (first-turn POSTs or the resume path's `[original_input + prose_msg]` payload, which contains only `role:user` items by construction). + +### 3.5 Heartbeat and stale threshold + +Defaults are tuned for a single Lakebase deployment with low-latency writes. + +| Setting | Default | Rationale | +|---|---|---| +| `heartbeat_interval_seconds` | 3.0 | Frequent enough that short pauses (GC, tokio task waits) don't trip stale detection | +| `heartbeat_stale_threshold_seconds` | 10.0 | Three missed heartbeats = unambiguously dead. Validated stale > interval at startup. | +| `task_timeout_seconds` | 3600 | Hard ceiling. After this, a stuck `in_progress` row is force-failed regardless of heartbeat. | +| `poll_interval_seconds` | 1.0 | Stream-retrieve polls the messages table at this rate while waiting for new events. | + +The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler. + +### 3.6 Claim atomicity + +```mermaid +sequenceDiagram + autonumber + participant Pod as Pod B (claimer) + participant DB as Lakebase + + Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R + DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress + Note over Pod: stale → attempt CAS + Pod->>DB: UPDATE responses
SET owner_pod_id=B, attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number + alt match + DB-->>Pod: attempt_number=2 + Note over Pod: claim succeeded + else no match (another pod beat us) + DB-->>Pod: (empty) + Note over Pod: claim lost — back off + end +``` + +Postgres row locking ensures only one of N concurrent UPDATEs matches, so at most one pod ends up owning a given resume. + +## 4. Author-side requirements + +### 4.1 What's invisible to authors + +| Concern | Where it lives | Author-visible? | +|---|---|---| +| Heartbeat + claim | `LongRunningAgentServer` | No | +| Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No | +| Prose recovery message construction | `LongRunningAgentServer._build_prose_recovery_message` | No | +| UI-echo dedup | `LongRunningAgentServer._trim_echoed_history` | No | +| Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No | +| Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) | + +The author's `agent.py` is unchanged from a non-durable agent. They construct an `AsyncCheckpointSaver` (LangGraph) or `AsyncDatabricksSession` (OpenAI) and use it normally. Durability fires entirely above the SDK boundary — the SDK adapters themselves contain zero durability code. + +### 4.2 Author-visible client cooperation (chat-UI side) + +For the always-rotate flow to work cross-turn, a cooperating chat UI needs to: + +1. **Capture the rotated `conversation_id` from the SSE `response.resumed` event** when one is emitted during a streaming retrieve. +2. **Use the rotated value as `context.conversation_id` on subsequent requests** for the same chat. + +The Express proxy in `e2e-chatbot-app-next/packages/ai-sdk-providers/src/providers-server.ts` does this with an in-memory `Map`. A multi-pod chatbot deployment would persist this on the chat row. + +Without client cooperation: the next turn lands on the original (orphan-poisoned) session and the LLM call fails on the provider's `tool_use ↔ tool_result` pairing rule. The bridge does not silently repair this — the cross-turn property requires the alias. + +### 4.3 Settings worth exposing to authors + +- `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config. +- `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load. +- `task_timeout_seconds` — per-attempt ceiling. + +Everything else is internal. + +## 5. Future direction: TaskFlow + +[TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery. + +When TaskFlow is production-ready, `LongRunningAgentServer` is expected to keep its **HTTP surface and author-visible API unchanged**, swapping only the engine internals. + +### Mapping today → TaskFlow + +```mermaid +flowchart LR + subgraph TODAY[LongRunningAgentServer today] + T1[create_response + asyncio.create_task] + T2[_heartbeat async CM] + T3[_try_claim_and_resume CAS] + T4[_build_prose_recovery_message] + T5[/responses/{id}?stream=true] + T6[/_debug/kill_task] + end + + subgraph TF[TaskFlow] + F1[Taskflow.start name input user_id] + F2[built-in executor heartbeat] + F3[built-in recovery worker + claim_for_recovery] + F4[TaskHandler.recover ctx previous_events] + F5[Taskflow.subscribe key last_seq] + F6[Taskflow.simulate_crash key] + end + + T1 --> F1 + T2 --> F2 + T3 --> F3 + T4 --> F4 + T5 --> F5 + T6 --> F6 +``` + +### What stays in `LongRunningAgentServer` after the swap + +- `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas). +- The MLflow `@invoke()` / `@stream()` handler convention. +- `_build_prose_recovery_message` — prose construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. +- `_rotate_conversation_id` and `_inject_conversation_id` — same reason. +- `_trim_echoed_history` — input cleanup happens at the HTTP boundary regardless of engine. +- Author-visible settings (db config, heartbeat tuning, task timeout). + +### What gets deleted + +- The hand-rolled heartbeat task and CAS claim CTE — replaced by TaskFlow's executor heartbeat + `claim_for_recovery`. +- The `_try_claim_and_resume` lazy-claim path — TaskFlow's recovery worker handles this proactively and across pods, fixing the documented "claim only fires on GET" limitation. +- The `agent_server.responses` + `agent_server.messages` schema — TaskFlow owns its storage layer. +- The stream-cursor logic in `_stream_retrieve` — `Taskflow.subscribe(key, last_seq)` is the cursor-based stream resume. + +### What requires a small TaskFlow API addition + +TaskFlow derives idempotency keys as `SHA256(name + canonical_input + user_id)`. The HTTP surface uses a server-generated `resp_{uuid}`. We've requested an `idempotency_key: Option` parameter on `Taskflow.submit()` so we can keep the existing HTTP contract while submitting to TaskFlow. See [`engine.rs:317`](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/experimental/taskflow/engine/src/engine.rs?L317) for the current `generate_key` call site; the override would slot in there. + +### Migration sequencing + +1. Add `LongRunningAgentServer(backend="taskflow"|"builtin")` knob, default `"builtin"`. HTTP surface unchanged on either backend. +2. Port `agent-non-conversational` (the simplest template) to `backend="taskflow"`. Run the full crash-resume matrix. +3. Port the advanced templates. **Zero changes to `agent.py`** — the swap is the constructor argument. +4. Flip default to `"taskflow"`. Deprecate `"builtin"`. +5. Delete the heartbeat / claim / repository code. Big delete PR. + +The point of the `LongRunningAgentServer` abstraction is exactly this kind of swap: callers should never have to care which engine is underneath. + +--- + +## Quick reference + +- **Code:** `src/databricks_ai_bridge/long_running/` +- **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py` +- **Settings:** `LongRunningSettings` in `settings.py` +- **Model:** `Response` and `Message` in `models.py` +- **HTTP routes:** registered in `LongRunningAgentServer._setup_routes` +- **Prose recovery:** `_build_prose_recovery_message` in `server.py` +- **UI-echo dedup:** `_trim_echoed_history` in `server.py` +- **Conversation rotation:** `_rotate_conversation_id` / `_inject_conversation_id` in `server.py` From 1b1fd7034d6f77fbac576b53c07c14021bc9632d Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 29 Apr 2026 21:47:00 +0000 Subject: [PATCH 3/9] Address Bryan's review feedback (#1, #2, #5, #6 + scan loop) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per review on PR #425. Skipping #3 (trim) and #4 (conv_id naming) — addressed separately. #1 Drop owner_pod_id; ownership via heartbeat CAS on attempt_number - Remove owner_pod_id column from Response model - heartbeat_response(response_id, expected_attempt_number) CAS-checks the attempt_number column. If a heartbeat write returns 0 rows, the prior owner has been bumped by another pod's claim and the heartbeat task knows to stop. - _heartbeat() context manager takes attempt_number; passes it through from _run_background_stream / _run_background_invoke. - claim_stale_response() no longer takes a pod parameter. - _POD_LOG_ID retained for log-line identity only (not stored in DB). #2 Simplify prose recovery to json.dumps the events array - _build_prose_recovery_message: was ~110 LOC structural walker (function_call/output pairs, narrative messages, partial-text reassembly). Now ~15 LOC: filter events by prior_attempt_number, json.dumps them, wrap in a directive prompt asking the model to figure out what's done vs interrupted. #5 Drop _inject_conversation_id - The function was defensive injection of response_id into context.conversation_id when no client anchor was supplied. With rotation handling resume, and templates / chatbot consistently setting conv_id, the injection was redundant. Top-level review: proactive stale-scan loop with jitter - New _stale_response_scanner_loop: every ~30s ± 50% jitter, queries responses for in_progress rows with stale heartbeats and tries to claim+resume them. The proactive counterpart to lazy-on-GET claim; ensures crashed responses get recovered even if no client polls. - find_stale_response_ids repository function with LIMIT 50. - Spawned in the FastAPI lifespan alongside init_db; cancelled on shutdown. - Settings: stale_scan_interval_seconds=30.0, stale_scan_jitter_fraction=0.5. #6 Document /_debug/kill_task in AGENTS.md - New §4.4 explaining the test-only debug endpoint, env-var gating, what state it leaves the row in. AGENTS.md updates: - ER diagram: drop owner_pod_id, annotate attempt_number as CAS guard. - New §3.5 documenting the proactive scanner with mermaid flowchart. - §4.3 includes new scanner settings. Tests: 110 pass. Ruff/format/ty all clean. Co-authored-by: Isaac --- .../long_running/AGENTS.md | 37 +- src/databricks_ai_bridge/long_running/db.py | 1 - .../long_running/models.py | 9 +- .../long_running/repository.py | 79 +++-- .../long_running/server.py | 328 ++++++++---------- .../long_running/settings.py | 11 + .../test_long_running_db.py | 37 +- .../test_long_running_server.py | 124 ++----- 8 files changed, 296 insertions(+), 330 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md index 0fd06ac9..c7331409 100644 --- a/src/databricks_ai_bridge/long_running/AGENTS.md +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -183,9 +183,8 @@ erDiagram text response_id PK text status "in_progress / completed / failed" timestamptz created_at - text owner_pod_id timestamptz heartbeat_at - int attempt_number + int attempt_number "CAS guard for claim atomicity" text original_request "JSON of initial POST (full input)" text trace_id } @@ -199,7 +198,7 @@ erDiagram responses ||--o{ messages : "has" ``` -- `responses.attempt_number` is the CAS guard for claim atomicity. +- `responses.attempt_number` is the CAS guard for claim atomicity. **There is no `owner_pod_id` column** — ownership is implicit. The pod that last successfully heartbeats at the current `attempt_number` is the de facto owner. A heartbeat write at attempt N stops working the moment another pod has CAS-bumped the row to N+1, so the prior owner detects it has lost the claim on its next heartbeat (rowcount=0) and shuts down its heartbeat task. - `responses.original_request` stores the **full untrimmed input** so the resume path can recover the entire prior-turn history when the rotated SDK session starts empty. - `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body). - Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime. @@ -273,7 +272,24 @@ The trim is the SDK-agnostic equivalent of a per-SDK "is this a continuation tur The trim runs on every `POST /responses` and is a no-op when input has no assistant message (first-turn POSTs or the resume path's `[original_input + prose_msg]` payload, which contains only `role:user` items by construction). -### 3.5 Heartbeat and stale threshold +### 3.5 Proactive stale-scan loop + +In addition to the lazy-on-GET claim path (a stale heartbeat is detected when a client GETs `/responses/{id}`), each pod runs a background **scanner** that periodically queries for in-progress responses with stale heartbeats and tries to claim+resume them. This means crashed responses get recovered even when no client is actively polling. + +```mermaid +flowchart LR + A[every ~30s ± 50% jitter] --> B[SELECT response_id FROM responses
WHERE status='in_progress'
AND heartbeat_at < now-threshold
LIMIT 50] + B --> C{any rows?} + C -->|no| A + C -->|yes| D[for each id: get_response
+ _try_claim_and_resume] + D --> A +``` + +Each pod jitters its scan interval (`stale_scan_jitter_fraction = 0.5` by default) so multiple pods don't synchronize their queries. CAS-claim semantics ensure only one pod succeeds in claiming any given stale response. + +The scanner is a background task spawned in the FastAPI lifespan (alongside `init_db`) and cancelled on shutdown. + +### 3.6 Heartbeat and stale threshold Defaults are tuned for a single Lakebase deployment with low-latency writes. @@ -340,9 +356,22 @@ Without client cooperation: the next turn lands on the original (orphan-poisoned - `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config. - `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load. - `task_timeout_seconds` — per-attempt ceiling. +- `stale_scan_interval_seconds` / `stale_scan_jitter_fraction` — controls how often (and with how much randomness) each pod scans the DB for stale responses to claim. Defaults to 30s with ±50% jitter. Everything else is internal. +### 4.4 Test-only debug endpoint: `/_debug/kill_task/{response_id}` + +Cancels the in-flight asyncio task that owns the given response on this pod **without** running the `_task_scope` cleanup. The DB row stays `in_progress` with a heartbeat that's about to go stale — exactly the shape a real pod crash leaves. Used in integration tests to simulate a pod crash without restarting the container. + +Opt-in via env var: only registered when `LONG_RUNNING_ENABLE_DEBUG_KILL=1`. Never exposed in production. + +Returns 404 if no in-flight task for that response exists on this specific pod (the task may already have finished, or it may be running on a different pod). + +```bash +curl -sS -X POST -H "Authorization: Bearer $TOKEN" "$APP_URL/_debug/kill_task/$RID" +``` + ## 5. Future direction: TaskFlow [TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery. diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index aed2c903..69ca8a79 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -86,7 +86,6 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): # poison the rest. A single mega-transaction would abort entirely on the # first owner-check failure even with IF NOT EXISTS. migration_stmts = ( - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 7014a7db..dfdcca70 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -16,9 +16,11 @@ class Base(DeclarativeBase): class Response(Base): """Response status tracking for background agent tasks. - Durability columns (``owner_pod_id``, ``heartbeat_at``, ``attempt_number``, - ``original_request``) support crash-resume: another pod can atomically - claim a stale in-progress row and replay the agent loop. + Durability columns (``heartbeat_at``, ``attempt_number``, + ``original_request``) support crash-resume: another pod atomically + claims a stale in-progress row by CAS-ing on ``attempt_number`` and + replays the agent loop. The owning pod is implicit — it's whatever + pod last successfully heartbeat at the current attempt_number. """ __tablename__ = "responses" @@ -30,7 +32,6 @@ class Response(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) - owner_pod_id: Mapped[str | None] = mapped_column(Text, nullable=True) heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) attempt_number: Mapped[int] = mapped_column( Integer, nullable=False, server_default="1", default=1 diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index 06d30edb..a3621ef2 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -15,22 +15,21 @@ async def create_response( response_id: str, status: str, *, - owner_pod_id: str | None = None, + durable: bool = False, original_request: dict[str, Any] | None = None, ) -> None: """Insert a new response row. - ``owner_pod_id`` and ``original_request`` are optional so that non-durable - callers (tests, legacy flows) can still create rows without durability - metadata. When present, they enable heartbeat + crash-resume semantics. + When ``durable=True``, ``heartbeat_at`` is initialized to ``now()`` so + the row doesn't immediately look stale. Non-durable callers (tests, + legacy flows) skip the heartbeat init. """ async with session_scope() as session: session.add( Response( response_id=response_id, status=status, - owner_pod_id=owner_pod_id, - heartbeat_at=datetime.now().astimezone() if owner_pod_id else None, + heartbeat_at=datetime.now().astimezone() if durable else None, original_request=( json.dumps(original_request) if original_request is not None else None ), @@ -65,16 +64,22 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None: await session.commit() -async def heartbeat_response(response_id: str, pod_id: str) -> bool: - """Update heartbeat_at for a response IFF this pod owns it. +async def heartbeat_response(response_id: str, expected_attempt_number: int) -> bool: + """Update heartbeat_at for a response IFF the attempt is still ours. - Returns True on success. A False result means the claim has been lost - (another pod took over, or the run finished and heartbeat should stop). + Returns True on success. A False result means the claim has been lost — + another pod CAS-bumped ``attempt_number``, so this pod is no longer the + owner and the heartbeat task should stop. Implicit-ownership model: + whichever pod last successfully heartbeats at the current + ``attempt_number`` is the de facto owner. """ async with session_scope() as session: stmt = ( update(Response) - .where(Response.response_id == response_id, Response.owner_pod_id == pod_id) + .where( + Response.response_id == response_id, + Response.attempt_number == expected_attempt_number, + ) .values(heartbeat_at=datetime.now().astimezone()) ) result = await session.execute(stmt) @@ -84,18 +89,19 @@ async def heartbeat_response(response_id: str, pod_id: str) -> bool: async def claim_stale_response( response_id: str, - new_owner_pod_id: str, stale_threshold_seconds: float, ) -> int | None: """Atomically claim an in-progress response whose heartbeat has gone stale. Uses a single conditional UPDATE so exactly one caller wins on contention: claim only succeeds if status is ``in_progress`` AND - (``owner_pod_id IS NULL`` OR ``heartbeat_at`` is older than the threshold). + (``heartbeat_at IS NULL`` OR ``heartbeat_at`` is older than the threshold). + The new attempt_number is the previous + 1; the prior attempt's heartbeat + task will detect this on its next heartbeat (rowcount=0) and stop. Returns the new ``attempt_number`` on success, or ``None`` if the row did - not satisfy the claim conditions (already completed, already claimed by a - live pod, or nonexistent). + not satisfy the claim conditions (already completed, heartbeat still fresh, + or nonexistent). """ # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for # the incremented column as ergonomically. Using a single statement keeps the @@ -103,31 +109,60 @@ async def claim_stale_response( stmt = text( f""" UPDATE {AGENT_DB_SCHEMA}.responses - SET owner_pod_id = :pod, - heartbeat_at = now(), + SET heartbeat_at = now(), attempt_number = attempt_number + 1 WHERE response_id = :rid AND status = 'in_progress' - AND (owner_pod_id IS NULL - OR heartbeat_at IS NULL + AND (heartbeat_at IS NULL OR heartbeat_at < now() - make_interval(secs => :threshold)) RETURNING attempt_number """ ).bindparams( - bindparam("pod", type_=None), bindparam("rid", type_=None), bindparam("threshold", type_=None), ) async with session_scope() as session: result = await session.execute( stmt, - {"pod": new_owner_pod_id, "rid": response_id, "threshold": stale_threshold_seconds}, + {"rid": response_id, "threshold": stale_threshold_seconds}, ) row = result.first() await session.commit() return int(row[0]) if row else None +async def find_stale_response_ids( + stale_threshold_seconds: float, + limit: int = 50, +) -> list[str]: + """Return ids of in_progress responses whose heartbeat is older than the + threshold. Used by the proactive scanner to find candidates for resume + without waiting for a client GET. + + Limited to ``limit`` rows per scan to bound DB load. Ordered by + ``heartbeat_at`` ascending so the oldest staleness is handled first. + """ + stmt = text( + f""" + SELECT response_id FROM {AGENT_DB_SCHEMA}.responses + WHERE status = 'in_progress' + AND heartbeat_at IS NOT NULL + AND heartbeat_at < now() - make_interval(secs => :threshold) + ORDER BY heartbeat_at ASC + LIMIT :limit + """ + ).bindparams( + bindparam("threshold", type_=None), + bindparam("limit", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"threshold": stale_threshold_seconds, "limit": limit}, + ) + return [row[0] for row in result.all()] + + async def append_message( response_id: str, sequence_number: int, @@ -181,7 +216,6 @@ class ResponseInfo(NamedTuple): status: str created_at: datetime trace_id: str | None - owner_pod_id: str | None heartbeat_at: datetime | None attempt_number: int original_request: dict[str, Any] | None @@ -198,7 +232,6 @@ async def get_response(response_id: str) -> ResponseInfo | None: row.status, row.created_at, row.trace_id, - row.owner_pod_id, row.heartbeat_at, row.attempt_number, json.loads(row.original_request) if row.original_request else None, diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 4672c365..bcf4e835 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -11,6 +11,7 @@ import json import logging import os +import random import socket import time import uuid @@ -39,6 +40,7 @@ append_message, claim_stale_response, create_response, + find_stale_response_ids, get_messages, get_response, heartbeat_response, @@ -52,8 +54,9 @@ BACKGROUND_KEY = "background" -# One ID per process so heartbeats + claims have a stable owner identity. -_POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" +# Process-local identifier for log lines. Not stored in the DB — heartbeat +# ownership is implicit via attempt_number CAS. +_POD_LOG_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" async def _deferred_mark_failed( @@ -131,124 +134,34 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -def _extract_text_from_message(item: dict) -> str: - """Pull text out of a Responses-API message item's content array.""" - content = item.get("content") - if isinstance(content, str): - return content - if isinstance(content, list): - parts: list[str] = [] - for c in content: - if isinstance(c, dict): - t = c.get("text") - if isinstance(t, str): - parts.append(t) - return "".join(parts) - return "" - - def _build_prose_recovery_message( messages: list[tuple], prior_attempt_number: int ) -> dict[str, Any]: - """Narrate the prior attempt's events as a single user message. - - Walks the repository's ``(seq, item, stream_event, attempt)`` tuples - for the given prior attempt and produces a single Responses-API user - message item whose content is a flat prose summary of completed tool - calls, their outputs, narrative messages, and any partial in-flight - assistant text. - - This replaces the structured carry-forward of completed tool pairs + - synthetic ``[INTERRUPTED]`` outputs with a single recovery prompt the - LLM reads as: "the prior attempt crashed, here's what had happened, - please continue." Trades structural protocol fidelity for SDK-agnostic - simplicity — no provider-specific pairing rules, no synthetic events. + """Build a single user message containing the prior attempt's raw event + log + a directive that asks the LLM to figure out what already completed + and continue. + + The body is `json.dumps(events)` of the prior attempt's stream events + wrapped in a recovery prompt. SDK-agnostic — no provider-specific pairing + rules, no structured carry-forward, no synthetic events. The model reads + the JSON, decides which tool calls succeeded, which were interrupted, and + continues. """ - completed_calls: dict[str, dict[str, Any]] = {} - call_order: list[str] = [] - narrative_texts: list[str] = [] - in_progress_text: dict[str, list[str]] = {} - in_progress_order: list[str] = [] - - for _seq, _item_json, evt, attempt_tag in messages: - if attempt_tag != prior_attempt_number or not isinstance(evt, dict): - continue - t = evt.get("type") - item = evt.get("item") - - if t == "response.output_item.added" and isinstance(item, dict): - if item.get("type") == "message" and (iid := item.get("id")): - in_progress_text.setdefault(iid, []) - if iid not in in_progress_order: - in_progress_order.append(iid) - - elif t == "response.output_item.done" and isinstance(item, dict): - itype = item.get("type") - if itype == "function_call": - cid = item.get("call_id") - if cid: - slot = completed_calls.setdefault(cid, {}) - slot["name"] = item.get("name") - slot["args"] = item.get("arguments") - if cid not in call_order: - call_order.append(cid) - elif itype == "function_call_output": - cid = item.get("call_id") - if cid: - slot = completed_calls.setdefault(cid, {}) - slot["output"] = item.get("output") - if cid not in call_order: - call_order.append(cid) - elif itype == "message": - text = _extract_text_from_message(item) - if text: - narrative_texts.append(text) - # The done event makes prior added/delta tracking moot. - iid = item.get("id") - if iid in in_progress_text: - in_progress_text.pop(iid, None) - if iid in in_progress_order: - in_progress_order.remove(iid) - - elif t == "response.output_text.delta": - iid = evt.get("item_id") - delta = evt.get("delta") - if iid and isinstance(delta, str) and iid in in_progress_text: - in_progress_text[iid].append(delta) - - lines: list[str] = [] - for cid in call_order: - info = completed_calls[cid] - name = info.get("name", "") - args = info.get("args", "") - if "output" in info: - lines.append(f"- Called `{name}({args})` and got result: {info['output']}") - else: - lines.append(f"- Called `{name}({args})` — interrupted before a result was returned.") - for text in narrative_texts: - lines.append(f"- Said: {text}") - for iid in in_progress_order: - chunks = in_progress_text.get(iid) or [] - partial = "".join(chunks).strip() - if partial: - lines.append(f"- Was generating: {partial!r} when the run was interrupted.") - - if lines: - body = ( - "[RECOVERY] The previous attempt of this agent task crashed mid-execution. " - "Here is what had completed before the crash:\n\n" - + "\n".join(lines) - + "\n\nPlease continue from where the prior attempt left off. " - "If a tool call was interrupted, you may re-invoke it if its result " - "is still needed." - ) - else: - body = ( - "[RECOVERY] The previous attempt of this agent task was interrupted " - "before any tool calls or assistant output completed. Please proceed " - "with the original user request." - ) - + prior_events = [ + evt + for _seq, _item_json, evt, attempt_tag in messages + if attempt_tag == prior_attempt_number and isinstance(evt, dict) + ] + body = ( + "[RECOVERY] The previous attempt of this agent task crashed " + "mid-execution. Below is the raw stream-event log from that attempt " + "as JSON. Some tool calls may have completed and some may have been " + "interrupted before returning a result. Inspect the events, figure " + "out what is already done versus in-progress / not completed, and " + "continue the task from where it left off. If a tool call was " + "interrupted, you may re-invoke it if its result is still needed.\n\n" + f"Events:\n{json.dumps(prior_events)}" + ) return { "type": "message", "role": "user", @@ -348,35 +261,6 @@ def _rotate_conversation_id( return request_dict -def _inject_conversation_id(request_dict: dict[str, Any], response_id: str) -> dict[str, Any]: - """Anchor the request to ``response_id`` as its conversation. - - Operates on a plain dict — the caller is responsible for converting to/from - pydantic via ``model_dump()`` and the server's validator. - - Templates that back this server use ``context.conversation_id`` (and - ``custom_inputs.thread_id`` / ``custom_inputs.session_id``) as priority-2 - fallbacks to derive their stateful thread/session key. If neither is - provided by the client, a resumed invocation from another pod would - generate a *fresh* ID and miss the checkpoint entirely — so we stamp the - conversation_id here before persisting the request, guaranteeing that - every replay hits the same memory store. - - Client-supplied values take precedence and are left untouched. - """ - out = copy.deepcopy(request_dict) if request_dict else {} - custom_inputs = out.get("custom_inputs") or {} - if custom_inputs.get("thread_id") or custom_inputs.get("session_id"): - return out - ctx = out.get("context") or {} - if ctx.get("conversation_id"): - return out - ctx = dict(ctx) - ctx["conversation_id"] = response_id - out["context"] = ctx - return out - - @experimental class LongRunningAgentServer(AgentServer): """AgentServer subclass adding background mode, retrieve endpoints, and @@ -507,7 +391,7 @@ async def _debug_kill_task(response_id: str): logger.info( "[durable] kill endpoint: no task response_id=%s on pod=%s", response_id, - _POD_ID, + _POD_LOG_ID, ) raise HTTPException( status_code=404, @@ -519,12 +403,12 @@ async def _debug_kill_task(response_id: str): logger.info( "[durable] kill endpoint: cancelling task response_id=%s pod=%s", response_id, - _POD_ID, + _POD_LOG_ID, ) task.cancel() return { "response_id": response_id, - "pod_id": _POD_ID, + "pod_id": _POD_LOG_ID, "status": "task_cancelled", } @@ -567,8 +451,19 @@ async def _db_lifespan(app): branch=self._db_branch, db_statement_timeout_ms=self._settings.db_statement_timeout_ms, ) - yield - await dispose_db() + scanner_task = asyncio.create_task( + self._stale_response_scanner_loop(), + name="durable-stale-scanner", + ) + try: + yield + finally: + scanner_task.cancel() + try: + await scanner_task + except (asyncio.CancelledError, Exception): + pass + await dispose_db() self.app.router.lifespan_context = _db_lifespan @@ -643,7 +538,6 @@ async def _handle_background_request( # when tests pass a plain dict directly. dump = getattr(request_data, "model_dump", None) request_dict = dump() if callable(dump) else dict(request_data) - durable_dict = _inject_conversation_id(request_dict, response_id) # Store the FULL request (untrimmed) as `original_request` so resume can # recover the entire prior-turn history. The handler invocation below # uses a trimmed copy to avoid duplicating turns the SDK's session has @@ -653,14 +547,14 @@ async def _handle_background_request( await create_response( response_id, "in_progress", - owner_pod_id=_POD_ID, - original_request=durable_dict, + durable=True, + original_request=request_dict, ) - # Build a TRIMMED handler request from the same durable dict — drops - # echoed history that the SDK's session already has. Original_request - # above stays untrimmed for resume. - handler_dict = copy.deepcopy(durable_dict) + # Build a TRIMMED handler request from the same dict — drops echoed + # history that the SDK's session already has. Original_request above + # stays untrimmed for resume. + handler_dict = copy.deepcopy(request_dict) handler_items = handler_dict.get("input") if isinstance(handler_items, list): handler_dict["input"] = _trim_echoed_history(handler_items) @@ -670,7 +564,7 @@ async def _handle_background_request( "Background response created response_id=%s stream=%s pod=%s", response_id, is_streaming, - _POD_ID, + _POD_LOG_ID, ) response_obj: dict[str, Any] = { @@ -713,14 +607,70 @@ def _track_task(self, response_id: str, task: asyncio.Task) -> None: self._running_tasks[response_id] = task task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None)) + async def _stale_response_scanner_loop(self) -> None: + """Periodically scan for in_progress responses with stale heartbeats and + try to claim+resume them. The proactive counterpart to the lazy claim + path on ``GET /responses/{id}``. + + Each iteration sleeps for a jittered interval so multiple pods don't + synchronize their reads. Runs until cancelled (in the lifespan + teardown). + """ + base = self._settings.stale_scan_interval_seconds + jitter = self._settings.stale_scan_jitter_fraction + threshold = self._settings.heartbeat_stale_threshold_seconds + logger.info( + "[durable] stale-scan loop start interval=%.1fs jitter=±%.0f%% threshold=%.1fs pod=%s", + base, + jitter * 100, + threshold, + _POD_LOG_ID, + ) + try: + while True: + # Jittered sleep — random scaling of base interval centered on 1.0. + delay = base * (1.0 + random.uniform(-jitter, jitter)) + await asyncio.sleep(delay) + try: + stale_ids = await find_stale_response_ids(threshold) + if not stale_ids: + continue + logger.info( + "[durable] stale-scan found %d candidate(s): %s", + len(stale_ids), + stale_ids, + ) + for response_id in stale_ids: + try: + resp = await get_response(response_id) + if resp: + await self._try_claim_and_resume(response_id, resp) + except Exception: + logger.exception( + "[durable] stale-scan resume failed response_id=%s", + response_id, + ) + except Exception: + # Don't let an iteration failure kill the loop. + logger.exception("[durable] stale-scan iteration failed") + except asyncio.CancelledError: + logger.info("[durable] stale-scan loop stopped pod=%s", _POD_LOG_ID) + raise + @asynccontextmanager - async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: + async def _heartbeat(self, response_id: str, attempt_number: int) -> AsyncGenerator[None, None]: """Keep the response row's heartbeat_at fresh while the body runs. A background task writes ``heartbeat_at = now()`` every - ``heartbeat_interval_seconds`` for the owning pod. It stops when the - body returns/raises. Heartbeat write failures are logged but do not - interrupt the agent run — the stale-run check will detect a dead pod. + ``heartbeat_interval_seconds``, scoped to ``attempt_number``. The + update only matches if ``attempt_number`` still equals the value the + heartbeat was started with — if another pod has CAS-claimed the + row (bumping attempt_number), this heartbeat returns 0 rows and the + task knows it has lost ownership and stops. + + Implicit-ownership model: there is no ``owner_pod_id`` column. The + last pod to successfully heartbeat at the current attempt is the + de facto owner. """ interval = self._settings.heartbeat_interval_seconds stop = asyncio.Event() @@ -728,25 +678,42 @@ async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: async def _beat(): beats = 0 logger.info( - "[durable] heartbeat start response_id=%s pod=%s interval=%.1fs", + "[durable] heartbeat start response_id=%s attempt=%d pod=%s interval=%.1fs", response_id, - _POD_ID, + attempt_number, + _POD_LOG_ID, interval, ) try: while not stop.is_set(): try: - await heartbeat_response(response_id, _POD_ID) + ok = await heartbeat_response(response_id, attempt_number) + if not ok: + # CAS failed → attempt_number has moved past us, + # another pod owns this response now. Stop the + # heartbeat task; the handler is still running but + # its emissions to the message log will be tagged + # with this attempt and ignored on the resumed + # path's filter (which keys on the new attempt). + logger.info( + "[durable] heartbeat lost ownership response_id=%s " + "attempt=%d (another pod claimed); stopping", + response_id, + attempt_number, + ) + stop.set() + break beats += 1 # Sampled heartbeat log so the lifecycle is visible # without spamming every interval. Every 5th (~15s # at 3s interval) is a good compromise. if beats % 5 == 1: logger.info( - "[durable] heartbeat beat#%d response_id=%s pod=%s", + "[durable] heartbeat beat#%d response_id=%s attempt=%d pod=%s", beats, response_id, - _POD_ID, + attempt_number, + _POD_LOG_ID, ) except Exception: logger.warning( @@ -761,9 +728,10 @@ async def _beat(): except asyncio.CancelledError: pass logger.info( - "[durable] heartbeat stop response_id=%s pod=%s total_beats=%d", + "[durable] heartbeat stop response_id=%s attempt=%d pod=%s total_beats=%d", response_id, - _POD_ID, + attempt_number, + _POD_LOG_ID, beats, ) @@ -843,7 +811,10 @@ async def _run_background_stream( ) -> None: """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state), self._heartbeat(response_id): + async with ( + self._task_scope(response_id, state), + self._heartbeat(response_id, attempt_number), + ): await self._do_background_stream( response_id, request_data, @@ -876,7 +847,7 @@ async def _do_background_stream( "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s", response_id, attempt_number, - _POD_ID, + _POD_LOG_ID, func_name, ) all_chunks: list[dict[str, Any]] = [] @@ -941,7 +912,7 @@ async def _do_background_stream( response_id, attempt_number, seq, - _POD_ID, + _POD_LOG_ID, ) async def _run_background_invoke( @@ -954,7 +925,10 @@ async def _run_background_invoke( ) -> None: """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state), self._heartbeat(response_id): + async with ( + self._task_scope(response_id, state), + self._heartbeat(response_id, attempt_number), + ): await self._do_background_invoke( response_id, request_data, @@ -1067,11 +1041,10 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: return None logger.info( "[durable] stale heartbeat detected response_id=%s " - "heartbeat_age=%.1fs threshold=%.1fs current_owner=%s", + "heartbeat_age=%.1fs threshold=%.1fs", response_id, hb_age, self._settings.heartbeat_stale_threshold_seconds, - resp.owner_pod_id, ) if resp.original_request is None: # Nothing to replay from — the run predates durability metadata. @@ -1085,11 +1058,10 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s", response_id, resp.attempt_number, - _POD_ID, + _POD_LOG_ID, ) new_attempt = await claim_stale_response( response_id, - new_owner_pod_id=_POD_ID, stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds, ) if new_attempt is None: @@ -1152,7 +1124,7 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d", response_id, new_attempt, - _POD_ID, + _POD_LOG_ID, next_seq, ) diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 30f1ad02..8224b4d8 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -17,6 +17,13 @@ class LongRunningSettings: cleanup_timeout_seconds: float = 7.0 heartbeat_interval_seconds: float = 3.0 heartbeat_stale_threshold_seconds: float = 10.0 + # Proactive stale-scan loop: how often (on average) each pod queries the + # responses table for stale-heartbeat rows and tries to claim+resume them. + # Each pod jitters this interval so multiple pods don't all hit the DB at + # once. The loop is the proactive counterpart to the lazy-on-GET claim + # path; it ensures crashed responses get recovered even if no client polls. + stale_scan_interval_seconds: float = 30.0 + stale_scan_jitter_fraction: float = 0.5 def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: @@ -44,3 +51,7 @@ def __post_init__(self) -> None: f"strictly greater than db_statement_timeout_ms converted to seconds " f"({db_timeout_s})" ) + if self.stale_scan_interval_seconds <= 0: + raise ValueError("stale_scan_interval_seconds must be positive") + if not 0 <= self.stale_scan_jitter_fraction < 1: + raise ValueError("stale_scan_jitter_fraction must be in [0, 1)") diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index d425da44..2565a1e0 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -177,7 +177,6 @@ async def test_get_response(mock_session): row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) row.trace_id = "trace_xyz" - row.owner_pod_id = None row.heartbeat_at = None row.attempt_number = 1 row.original_request = None @@ -191,7 +190,6 @@ async def test_get_response(mock_session): "completed", datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc), "trace_xyz", - None, # owner_pod_id None, # heartbeat_at 1, # attempt_number None, # original_request @@ -298,7 +296,6 @@ async def test_creates_schema_and_tables(self, reset_db_globals): # the durability columns onto pre-existing tables. all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list) assert "CREATE SCHEMA IF NOT EXISTS" in all_sql - assert "ADD COLUMN IF NOT EXISTS owner_pod_id" in all_sql assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql @@ -366,24 +363,23 @@ async def fake_factory(): # --------------------------------------------------------------------------- -# Durability metadata: owner_pod_id, heartbeat, claim, attempt_number +# Durability metadata: heartbeat (CAS on attempt), claim, attempt_number # --------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_create_response_with_owner_and_original_request(mock_session): - """New background callers stamp pod id + serialized request on creation — +async def test_create_response_durable_stamps_heartbeat_and_original_request(mock_session): + """Durable callers stamp heartbeat_at + serialized request on creation — without these, a resumed pod can't re-invoke the handler.""" from databricks_ai_bridge.long_running.repository import create_response await create_response( "resp_abc", "in_progress", - owner_pod_id="pod-1", + durable=True, original_request={"input": [{"role": "user", "content": "hi"}]}, ) added = mock_session.add.call_args[0][0] - assert added.owner_pod_id == "pod-1" assert added.heartbeat_at is not None # original_request is JSON-encoded for Text storage. assert '"role": "user"' in added.original_request @@ -391,41 +387,40 @@ async def test_create_response_with_owner_and_original_request(mock_session): @pytest.mark.asyncio async def test_create_response_without_durability_metadata(mock_session): - """Legacy/no-durability callers should still work and write no - owner/heartbeat (so the stale sweep can't accidentally claim them).""" + """Non-durable callers (tests, legacy flows) write no heartbeat so the + stale sweep can't accidentally claim them.""" from databricks_ai_bridge.long_running.repository import create_response await create_response("resp_x", "in_progress") added = mock_session.add.call_args[0][0] - assert added.owner_pod_id is None assert added.heartbeat_at is None assert added.original_request is None @pytest.mark.asyncio -async def test_heartbeat_response_updates_timestamp(mock_session): +async def test_heartbeat_response_updates_when_attempt_matches(mock_session): from databricks_ai_bridge.long_running.repository import heartbeat_response result_mock = MagicMock() result_mock.rowcount = 1 mock_session.execute.return_value = result_mock - ok = await heartbeat_response("resp_abc", "pod-1") + ok = await heartbeat_response("resp_abc", expected_attempt_number=1) assert ok is True mock_session.commit.assert_awaited_once() @pytest.mark.asyncio -async def test_heartbeat_response_fails_when_not_owner(mock_session): - """If the CAS misses (owner changed / row deleted), heartbeat reports - failure so the caller can stop looping.""" +async def test_heartbeat_response_fails_when_attempt_changed(mock_session): + """If the CAS misses (attempt_number bumped by another pod's claim), + heartbeat reports failure so the caller can stop looping.""" from databricks_ai_bridge.long_running.repository import heartbeat_response result_mock = MagicMock() result_mock.rowcount = 0 mock_session.execute.return_value = result_mock - ok = await heartbeat_response("resp_abc", "pod-1") + ok = await heartbeat_response("resp_abc", expected_attempt_number=1) assert ok is False @@ -440,9 +435,7 @@ async def test_claim_stale_response_returns_attempt_number(mock_session): result_mock.first.return_value = row mock_session.execute.return_value = result_mock - attempt = await claim_stale_response( - "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 - ) + attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0) assert attempt == 2 @@ -454,9 +447,7 @@ async def test_claim_stale_response_returns_none_when_not_eligible(mock_session) result_mock.first.return_value = None mock_session.execute.return_value = result_mock - attempt = await claim_stale_response( - "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 - ) + attempt = await claim_stale_response("resp_abc", stale_threshold_seconds=15.0) assert attempt is None diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 0cf00c53..c222cc48 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -19,7 +19,6 @@ LongRunningAgentServer, _build_prose_recovery_message, _deferred_mark_failed, - _inject_conversation_id, _rotate_conversation_id, _sse_event, _trim_echoed_history, @@ -44,7 +43,6 @@ def _resp_info( status: str = "in_progress", created_at=None, trace_id: str | None = None, - owner_pod_id: str | None = None, heartbeat_at=None, attempt_number: int = 1, original_request: dict | None = None, @@ -61,7 +59,6 @@ def _resp_info( status=status, created_at=created_at, trace_id=trace_id, - owner_pod_id=owner_pod_id, heartbeat_at=heartbeat_at, attempt_number=attempt_number, original_request=original_request, @@ -936,9 +933,9 @@ async def test_lifespan_not_set_when_db_not_configured(self): class TestBuildProseRecoveryMessage: - """Prose recovery serializer: walk a prior attempt's events and produce a - single Responses-API user-message item narrating what happened, for the - next attempt's LLM to read as a recovery instruction.""" + """Prose recovery serializer: produce a single Responses-API user-message + item containing the prior attempt's stream events as JSON, plus a + directive that asks the LLM to figure out what's done vs interrupted.""" def _done(self, seq, attempt, item): return (seq, None, {"type": "response.output_item.done", "item": item}, attempt) @@ -950,7 +947,7 @@ def test_returns_user_message_shape(self): assert isinstance(out["content"], str) assert "[RECOVERY]" in out["content"] - def test_completed_call_with_output(self): + def test_includes_events_json(self): messages = [ self._done( 0, 1, {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"} @@ -959,24 +956,10 @@ def test_completed_call_with_output(self): ] out = _build_prose_recovery_message(messages, prior_attempt_number=1) body = out["content"] - assert "Called `f({})`" in body - assert "got result: ok" in body - - def test_call_without_output_marked_interrupted(self): - messages = [ - self._done( - 0, - 1, - { - "type": "function_call", - "call_id": "c1", - "name": "deep_research", - "arguments": "", - }, - ), - ] - out = _build_prose_recovery_message(messages, prior_attempt_number=1) - assert "interrupted before a result was returned" in out["content"] + # Body should contain the raw events JSON-serialized. + assert '"call_id": "c1"' in body + assert '"output": "ok"' in body + assert '"name": "f"' in body def test_filters_other_attempts(self): messages = [ @@ -988,32 +971,16 @@ def test_filters_other_attempts(self): ), ] out = _build_prose_recovery_message(messages, prior_attempt_number=1) - assert "f({})" in out["content"] - assert "g(" not in out["content"] - - def test_partial_text_from_deltas(self): - messages = [ - ( - 0, - None, - {"type": "response.output_item.added", "item": {"type": "message", "id": "m1"}}, - 1, - ), - ( - 1, - None, - {"type": "response.output_text.delta", "item_id": "m1", "delta": "Hello, "}, - 1, - ), - (2, None, {"type": "response.output_text.delta", "item_id": "m1", "delta": "world"}, 1), - ] - out = _build_prose_recovery_message(messages, prior_attempt_number=1) - assert "Was generating" in out["content"] - assert "Hello, world" in out["content"] + body = out["content"] + assert '"call_id": "c1"' in body + # attempt 2 events excluded + assert '"call_id": "c2"' not in body - def test_empty_attempt_falls_back_to_default_recovery(self): + def test_empty_attempt_emits_empty_events_array(self): out = _build_prose_recovery_message([], prior_attempt_number=1) - assert "interrupted before any tool calls" in out["content"] + # Body still contains the recovery directive and an empty events array. + assert "[RECOVERY]" in out["content"] + assert "Events:\n[]" in out["content"] class TestTrimEchoedHistory: @@ -1106,49 +1073,12 @@ def test_rotate_handles_missing_custom_inputs_key(self): assert out["custom_inputs"] == {} -class TestInjectConversationId: - """Anchoring an otherwise-anonymous request to a response_id guarantees a - resumed run on a new pod resolves to the same agent-SDK thread/session.""" - - def test_injects_when_nothing_set(self): - r = {"input": [], "custom_inputs": {}, "context": {}} - out = _inject_conversation_id(r, "resp_abc") - assert out["context"]["conversation_id"] == "resp_abc" - - def test_respects_existing_conversation_id(self): - r = {"input": [], "context": {"conversation_id": "user-set"}} - out = _inject_conversation_id(r, "resp_abc") - assert out["context"]["conversation_id"] == "user-set" - - def test_respects_thread_id_from_custom_inputs(self): - r = {"input": [], "custom_inputs": {"thread_id": "t-1"}, "context": {}} - out = _inject_conversation_id(r, "resp_abc") - # When the client already pinned a thread, we don't overwrite — the - # template's _get_or_create_thread_id picks up custom_inputs first. - assert "conversation_id" not in (out["context"] or {}) - - def test_respects_session_id_from_custom_inputs(self): - r = {"input": [], "custom_inputs": {"session_id": "s-1"}, "context": {}} - out = _inject_conversation_id(r, "resp_abc") - assert "conversation_id" not in (out["context"] or {}) - - def test_handles_missing_context_key(self): - r = {"input": [], "custom_inputs": {}} - out = _inject_conversation_id(r, "resp_abc") - assert out["context"]["conversation_id"] == "resp_abc" - - def test_does_not_mutate_input(self): - r = {"input": [], "custom_inputs": {}, "context": {}} - _inject_conversation_id(r, "resp_abc") - assert r["context"] == {} # original untouched - - class TestHandleBackgroundRequestPersistsDurabilityState: - """Background request entry point should now stamp the response row with - the caller's pod, the original request body, and a conversation anchor.""" + """Background request entry point should stamp the response row with the + full original_request body so resume can recover full prior-turn history.""" @pytest.mark.asyncio - async def test_persists_owner_and_original_request(self): + async def test_persists_durable_flag_and_original_request(self): with patch(f"{MODULE}.is_db_configured", return_value=True): server = LongRunningAgentServer("ResponsesAgent") _mock_validator(server) @@ -1156,11 +1086,11 @@ async def test_persists_owner_and_original_request(self): captured: dict = {} async def fake_create_response( - response_id, status, *, owner_pod_id=None, original_request=None + response_id, status, *, durable=False, original_request=None ): captured["response_id"] = response_id captured["status"] = status - captured["owner_pod_id"] = owner_pod_id + captured["durable"] = durable captured["original_request"] = original_request with ( @@ -1174,11 +1104,11 @@ async def fake_create_response( ) assert captured["status"] == "in_progress" - assert captured["owner_pod_id"] # non-empty - # original_request should include input + injected conversation_id. + assert captured["durable"] is True + # original_request preserves the input the client sent (no + # conversation_id injection — the client owns that decision). orig = captured["original_request"] assert orig["input"] == [{"role": "user", "content": "hi"}] - assert orig["context"]["conversation_id"] == captured["response_id"] # Return shape: immediate response_obj, not a stream. assert result["id"] == captured["response_id"] assert result["status"] == "in_progress" @@ -1459,7 +1389,7 @@ async def test_writes_heartbeat_periodically(self): ) with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: - async with server._heartbeat("resp_x"): + async with server._heartbeat("resp_x", attempt_number=1): await asyncio.sleep(0.2) # enough time for 2+ heartbeats # Heartbeat interval is 0.05s so we should see at least 2 writes. @@ -1477,7 +1407,7 @@ async def test_stops_cleanly_on_exit(self): ) with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: - async with server._heartbeat("resp_x"): + async with server._heartbeat("resp_x", attempt_number=1): pass # immediate exit # Give the heartbeat loop a chance to observe the stop signal. @@ -1505,7 +1435,7 @@ async def test_db_error_does_not_interrupt_body(self): new_callable=AsyncMock, side_effect=RuntimeError("db down"), ): - async with server._heartbeat("resp_x"): + async with server._heartbeat("resp_x", attempt_number=1): await asyncio.sleep(0.1) body_ran = True assert body_ran From 59dd6c52ee55cd4dfee4eb722a3f2aa54969e778 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 29 Apr 2026 23:00:02 +0000 Subject: [PATCH 4/9] LongRunningAgentServer: register /_debug/kill_task unconditionally The endpoint was gated by env-var check at route-registration time: if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") == "1": @self.app.post("/_debug/kill_task/{response_id}") This worked locally but failed on Databricks Apps where the env var appears in os.environ after the FastAPI app object is built. The endpoint was never registered, returning 404 even with the env var properly set in the deployment. Move the env check inside the handler: @self.app.post("/_debug/kill_task/{response_id}") async def _debug_kill_task(response_id: str): if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") != "1": raise HTTPException(404, "Debug kill endpoint is disabled.") ... Route is always registered; env var is checked per-request. Same security posture (404 when disabled), works regardless of when env vars become visible to the process. --- .../long_running/server.py | 57 ++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index bcf4e835..59234899 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -381,36 +381,41 @@ async def cancel_endpoint(response_id: str): # asyncio task that owns the given response_id WITHOUT running the # _task_scope cleanup, so the DB row stays in_progress with a # going-stale heartbeat — exactly the shape a real pod crash leaves. - # Opt-in via env var so it's never exposed in production. - if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") == "1": - - @self.app.post("/_debug/kill_task/{response_id}") - async def _debug_kill_task(response_id: str): - task = self._running_tasks.get(response_id) - if task is None: - logger.info( - "[durable] kill endpoint: no task response_id=%s on pod=%s", - response_id, - _POD_LOG_ID, - ) - raise HTTPException( - status_code=404, - detail=( - "No in-flight task for that response_id on this pod " - "(may already have finished or be running on another pod)." - ), - ) + # Opt-in via env var so it's never exposed in production. Env var + # is checked at request time (not registration time) because some + # platforms inject env vars after the FastAPI app object is built. + @self.app.post("/_debug/kill_task/{response_id}") + async def _debug_kill_task(response_id: str): + if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") != "1": + raise HTTPException( + status_code=404, + detail="Debug kill endpoint is disabled.", + ) + task = self._running_tasks.get(response_id) + if task is None: logger.info( - "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + "[durable] kill endpoint: no task response_id=%s on pod=%s", response_id, _POD_LOG_ID, ) - task.cancel() - return { - "response_id": response_id, - "pod_id": _POD_LOG_ID, - "status": "task_cancelled", - } + raise HTTPException( + status_code=404, + detail=( + "No in-flight task for that response_id on this pod " + "(may already have finished or be running on another pod)." + ), + ) + logger.info( + "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + response_id, + _POD_LOG_ID, + ) + task.cancel() + return { + "response_id": response_id, + "pod_id": _POD_LOG_ID, + "status": "task_cancelled", + } db_configured = is_db_configured() From 1b70d50dcebb1a8cf3f5565597c5c658f29360ee Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 30 Apr 2026 03:30:42 +0000 Subject: [PATCH 5/9] Move UI-echo dedup out of bridge; per-template handlers do it MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per design discussion with Bryan offline. Echo dedup is an agent-layer concern (the agent owns its SDK session/checkpointer, knows what's already persisted, decides how to combine that with caller input). The bridge should be minimal: heartbeat, scan loop, CAS claim, conv_id rotation, prose recovery — no manipulation of request.input shape. Removed: - _trim_echoed_history function (was ~30 LOC at module top of server.py) - Both call sites in _handle_invocations_request (non-background path) and _handle_background_request (background+DB path post-storage) - TestTrimEchoedHistory test class Templates now do dedup themselves (separate PR / templates branch): - agent-openai-advanced/utils.py:deduplicate_input checks session items - agent-langgraph-advanced/agent.py:stream_handler probes aget_state Updated AGENTS.md §3.4 to reflect the new placement; removed §3.4 mermaid diagram and the bullet referencing _trim_echoed_history in §5. Tests: 104 pass (was 110; 6 trim-specific tests removed). Co-authored-by: Isaac --- .../long_running/AGENTS.md | 50 ++++++++------ .../long_running/server.py | 69 ++----------------- .../test_long_running_server.py | 60 ---------------- 3 files changed, 34 insertions(+), 145 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md index c7331409..4559d1d3 100644 --- a/src/databricks_ai_bridge/long_running/AGENTS.md +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -24,7 +24,7 @@ Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, othe - **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`. - **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`. - **SDK-agnostic recovery.** The resumed attempt receives a flat prose narrative — no provider-specific tool-pair structure, no synthetic tool events, no per-SDK adapter code. -- **SDK-agnostic UI-echo dedup.** When the chatbot client echoes prior conversation history in `request.input`, the server detects this (presence of an `assistant`-role message) and trims to the latest user message. The SDK's session/checkpointer storage is treated as authoritative for prior turns. +- **Per-template UI-echo dedup.** The bridge does NOT trim echoed history. When the chat client echoes the full prior conversation in `request.input`, the agent handler is responsible for deduping its input against the SDK's session/checkpointer state — typically by forwarding only the latest user message when the session already has prior turns. See the templates in `app-templates/agent-{openai,langgraph}-advanced/` for the canonical 1-2 line shape. - **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility. - **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. All durability lives below the handler boundary. @@ -129,7 +129,7 @@ sequenceDiagram C->>C: lookup alias map:
chat_id → T::attempt-2 C->>S: POST /responses
{input: [echo of full UI history, new user msg],
context.conversation_id: T::attempt-2} - S->>S: _trim_echoed_history
(detects assistant in input → trim to last user) + S->>S: handler dedupes UI echo
(forwards only latest user when session has history) S->>S: handler resolves session_id = T::attempt-2 S->>SDK: load history at T::attempt-2 SDK-->>S: clean state (prose + attempt-2 emissions) @@ -142,7 +142,7 @@ sequenceDiagram Three things make this work without per-SDK repair code: 1. **Server emits the rotated conv_id in `response.resumed`.** The chatbot reads it and updates its `Map` alias. -2. **Server-side trim of UI-echoed history.** The presence of an `assistant`-role item in `request.input` is a reliable proxy for "client is echoing prior history." Trim to the latest user message; trust the SDK's session as authoritative for prior turns. +2. **Per-template UI-echo dedup in the handler.** When the SDK's session/checkpointer already has prior-turn state, the agent forwards only the latest user message (not the full echoed history). This prevents `Runner.run` from sending duplicates of session items + input items to the LLM. 3. **Always-rotate.** The rotated session was the one populated during attempt 2's run. Subsequent turns land on it. The original poisoned session is never read. ### CUJ 4: Multi-pod stale-claim contention @@ -208,8 +208,7 @@ erDiagram ```mermaid flowchart TD POST["POST /responses
background=true"] --> CR[create_response
store FULL original_request] - CR --> TRIM[deepcopy + _trim_echoed_history
for handler invocation] - TRIM --> SPAWN[spawn @stream handler] + CR --> SPAWN[spawn @stream handler
handler dedupes UI echo against SDK session] SPAWN --> HB[heartbeat loop
every 3s] SPAWN --> EMIT[append SSE events
to messages table] EMIT --> DONE{handler
exits?} @@ -253,24 +252,34 @@ Why rotation: the original SDK session may carry mid-turn state from the crashed Why the sentinel carries the rotated conv_id: cooperating chat clients capture it (via SSE) and use the rotated session for subsequent turns, so the original orphan-poisoned session is never read again. -### 3.4 SDK-agnostic UI-echo dedup +### 3.4 Per-template UI-echo dedup (NOT in the bridge) -```mermaid -flowchart LR - REQ[POST /responses
request.input from client] --> CHECK{any item
has role: assistant?} - CHECK -->|no| PASS[pass through unchanged
typical first-turn POST] - CHECK -->|yes| TRIM[client is echoing history
find last role: user index
slice items from there] - TRIM --> RESULT[result: latest user message only] - PASS --> HANDLER - RESULT --> HANDLER[handler invocation] +The bridge does **not** trim UI echo from `request.input`. Echo dedup is the agent handler's responsibility — it owns its SDK session/checkpointer and is the right layer to know what's already persisted vs what's a new turn. + +The canonical shape, per template: + +**OpenAI Agents SDK** (`agent-openai-advanced/agent_server/utils.py`): +```python +session_items = await session.get_items() +if session_items and len(messages) > 1: + return [messages[-1]] +return messages +``` + +**LangGraph** (`agent-langgraph-advanced/agent_server/agent.py`): +```python +state = await agent.aget_state(config) +if state and state.values.get("messages") and input_state["messages"]: + last_user = next( + (m for m in reversed(input_state["messages"]) if m.get("role") == "user"), + None, + ) + input_state["messages"] = [last_user] if last_user else [] ``` -The trim is the SDK-agnostic equivalent of a per-SDK "is this a continuation turn?" probe. By inspecting input shape, the server avoids: -- Querying the SDK's storage to ask "do you have prior turns?" (per-SDK call) -- Walking provider-specific tool-pair structure -- Any SDK adapter code +Both: when the SDK store already has prior turns, forward only the latest user message and let the SDK prepend its own history. Without dedup, `Runner.run` (OpenAI) or `add_messages` (LangGraph) end up combining session+input → duplicate items → malformed assistant.tool_calls block → 400. -The trim runs on every `POST /responses` and is a no-op when input has no assistant message (first-turn POSTs or the resume path's `[original_input + prose_msg]` payload, which contains only `role:user` items by construction). +Bridge's role here is just to pass `request.input` through untouched. ### 3.5 Proactive stale-scan loop @@ -334,7 +343,7 @@ Postgres row locking ensures only one of N concurrent UPDATEs matches, so at mos | Heartbeat + claim | `LongRunningAgentServer` | No | | Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No | | Prose recovery message construction | `LongRunningAgentServer._build_prose_recovery_message` | No | -| UI-echo dedup | `LongRunningAgentServer._trim_echoed_history` | No | +| UI-echo dedup | per-template handler (see §3.4) | Yes — 1-2 lines in `agent.py` / `utils.py` | | Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No | | Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) | @@ -414,7 +423,6 @@ flowchart LR - The MLflow `@invoke()` / `@stream()` handler convention. - `_build_prose_recovery_message` — prose construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. - `_rotate_conversation_id` and `_inject_conversation_id` — same reason. -- `_trim_echoed_history` — input cleanup happens at the HTTP boundary regardless of engine. - Author-visible settings (db config, heartbeat tuning, task timeout). ### What gets deleted diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 59234899..f7a90999 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -169,44 +169,6 @@ def _build_prose_recovery_message( } -def _trim_echoed_history(items: list[Any]) -> list[Any]: - """If request.input contains an assistant message, the client is echoing - prior conversation history — trust the SDK's own session/checkpointer - storage as authoritative for prior turns and forward only the latest - user message. - - SDK-agnostic equivalent of the per-SDK dedup hooks templates used to - carry (OpenAI Session.get_items() comparison; LangGraph - agent.aget_state() comparison). The presence of any ``role:assistant`` - item is a reliable proxy for "this request is a continuation echo from - the chat UI" — first-turn POSTs have a single user message, while - continuations include the prior turns' assistant replies. - - Resume-path inputs built by ``_try_claim_and_resume`` are - [original_input + prose_msg] (both ``role:user``), so the trim is a - correct no-op there. - """ - if not items: - return items - has_assistant = any( - isinstance(item, dict) and item.get("role") == "assistant" for item in items - ) - if not has_assistant: - return items - user_idxs = [ - i for i, item in enumerate(items) if isinstance(item, dict) and item.get("role") == "user" - ] - if len(user_idxs) <= 1: - return items - trimmed = items[user_idxs[-1] :] - logger.info( - "[durable] trimmed echoed history: original=%d final=%d (kept latest user turn)", - len(items), - len(trimmed), - ) - return trimmed - - def _rotate_conversation_id( request_dict: dict[str, Any], new_attempt_number: int, @@ -498,17 +460,6 @@ async def _handle_invocations_request( data = {k: v for k, v in data.items() if k not in (BACKGROUND_KEY, MLFLOW_STREAM_KEY)} return_trace_id = (get_request_headers().get(RETURN_TRACE_HEADER) or "").lower() == "true" - # For background+DB requests, trim happens INSIDE _handle_background_request - # AFTER storing the untrimmed input as original_request — so resume can - # recover the full prior-turn history. For non-background requests we - # trim here in place since there's no persistence step that needs the - # untrimmed copy. - is_background_with_db = is_background and is_db_configured() - if not is_background_with_db: - items = data.get("input") - if isinstance(items, list): - data["input"] = _trim_echoed_history(items) - try: request_data = self.validator.validate_and_convert_request(data) except ValueError as e: @@ -517,7 +468,7 @@ async def _handle_invocations_request( detail=f"Invalid parameters for {self.agent_type}: {e}", ) from None - if is_background_with_db: + if is_background and is_db_configured(): return await self._handle_background_request( request_data, is_streaming, return_trace_id ) @@ -544,26 +495,16 @@ async def _handle_background_request( dump = getattr(request_data, "model_dump", None) request_dict = dump() if callable(dump) else dict(request_data) # Store the FULL request (untrimmed) as `original_request` so resume can - # recover the entire prior-turn history. The handler invocation below - # uses a trimmed copy to avoid duplicating turns the SDK's session has - # already persisted, but on resume the rotated SDK session is empty — - # only the full conversation in `original_request.input` lets the model - # reconstruct what came before the crashed turn. + # recover the entire prior-turn history. Per-template handlers are + # responsible for deduping their own UI-echoed input against the SDK's + # session/checkpointer state — the bridge no longer trims input. await create_response( response_id, "in_progress", durable=True, original_request=request_dict, ) - - # Build a TRIMMED handler request from the same dict — drops echoed - # history that the SDK's session already has. Original_request above - # stays untrimmed for resume. - handler_dict = copy.deepcopy(request_dict) - handler_items = handler_dict.get("input") - if isinstance(handler_items, list): - handler_dict["input"] = _trim_echoed_history(handler_items) - durable_request = self.validator.validate_and_convert_request(handler_dict) + durable_request = self.validator.validate_and_convert_request(request_dict) logger.info( "Background response created response_id=%s stream=%s pod=%s", diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index c222cc48..e57c4e7e 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -21,7 +21,6 @@ _deferred_mark_failed, _rotate_conversation_id, _sse_event, - _trim_echoed_history, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings @@ -983,65 +982,6 @@ def test_empty_attempt_emits_empty_events_array(self): assert "Events:\n[]" in out["content"] -class TestTrimEchoedHistory: - """SDK-agnostic dedup: when request input contains an assistant message, - the client is echoing prior conversation history; trim to the latest - user message and trust the SDK's session/checkpointer storage as - authoritative for prior turns.""" - - def test_first_turn_passthrough(self): - items = [{"role": "user", "content": "hi"}] - assert _trim_echoed_history(items) is items - - def test_first_turn_with_no_assistant_passthrough(self): - # Multi-message input but no assistant role yet — first turn from a - # client preserving its own prior user turns. Pass through. - items = [ - {"role": "user", "content": "u1"}, - {"role": "user", "content": "u2"}, - ] - assert _trim_echoed_history(items) is items - - def test_continuation_trims_to_last_user(self): - items = [ - {"role": "user", "content": "u1"}, - {"role": "assistant", "content": "a1"}, - {"role": "user", "content": "u2"}, - ] - out = _trim_echoed_history(items) - assert len(out) == 1 - assert out[0]["role"] == "user" - assert out[0]["content"] == "u2" - - def test_continuation_with_tool_history_trims_correctly(self): - items = [ - {"role": "user", "content": "u1"}, - {"role": "assistant", "content": "a1"}, - {"role": "user", "content": "u2"}, - {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, - {"type": "function_call_output", "call_id": "c1", "output": "ok"}, - {"role": "assistant", "content": "a2"}, - {"role": "user", "content": "u3"}, - ] - out = _trim_echoed_history(items) - assert len(out) == 1 - assert out[0]["content"] == "u3" - - def test_resume_path_passthrough(self): - # Resume-built input: original input + prose user message, no - # assistant role. Trim is a no-op so the prose recovery payload is - # preserved. - items = [ - {"role": "user", "content": "original"}, - {"type": "message", "role": "user", "content": "[RECOVERY] ..."}, - ] - out = _trim_echoed_history(items) - assert out is items - - def test_empty_input(self): - assert _trim_echoed_history([]) == [] - - class TestRotateConversationId: def test_rotate_drops_thread_id_and_sets_rotated_context(self): r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} From 70698c88a3bce8231f84e144bf401f6d34283f1e Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 30 Apr 2026 04:31:15 +0000 Subject: [PATCH 6/9] AGENTS.md: drop stale references; reflect current design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Updates AGENTS.md to describe only the current design: - §1 guarantees: prose recovery is json.dumps(events) + directive (not narrative) - CUJ 2 sequence diagram: heartbeat is implicit-ownership CAS on attempt_number, no owner_pod_id mentions - CUJ 4 multi-pod contention: same - §3.3 resume input flowchart: simplified to filter+json.dumps shape - §3.7 (was duplicated 3.6): claim atomicity diagram updated to match the actual SQL (no owner_pod_id column) - §5 TaskFlow mapping: removed reference to deleted _inject_conversation_id - Quick reference: dropped _trim_echoed_history and _inject_conversation_id; added _stale_response_scanner_loop; pointed UI-echo dedup to templates No new content, just trimming stale references to functions / columns that have been removed. --- .../long_running/AGENTS.md | 58 +++++++++---------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md index 4559d1d3..6b04a76c 100644 --- a/src/databricks_ai_bridge/long_running/AGENTS.md +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -15,7 +15,7 @@ This document describes: 1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres. 2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events. -3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run and finishes the work via **prose recovery**: the new attempt receives a single user message narrating the crashed attempt's completed tool calls, results, and partial output, and continues from there on a freshly-rotated SDK session. Tool results that completed before the crash are preserved through the prose narrative. +3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run and finishes the work via **prose recovery**: the new attempt receives a single user message containing `json.dumps(events)` of the crashed attempt's stream-event log plus a directive asking the LLM to figure out what's done vs interrupted and continue. The handler runs on a freshly-rotated SDK session. Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server. @@ -77,11 +77,11 @@ sequenceDiagram participant SDK as SDK store
(rotated session) C->>A: POST /responses
{input, background:true, stream:true} - A->>DB: INSERT response_id, owner=A,
original_request (full input), attempt=1 + A->>DB: INSERT response_id, attempt=1,
heartbeat_at=now(), original_request (full input) A-->>C: 200 {id: resp_xxx, status: in_progress} activate A - Note over A: heartbeat loop (every 3s) - A->>DB: UPDATE heartbeat_at WHERE owner=A + Note over A: heartbeat loop (every 3s)
CAS-checks attempt_number = 1 + A->>DB: UPDATE heartbeat_at WHERE attempt_number=1 par Streaming A->>SDK: write to session T (original conv_id) @@ -98,9 +98,9 @@ sequenceDiagram C->>B: GET /responses/{id}?stream=true&starting_after=K B->>DB: SELECT heartbeat_at, attempt Note over B: heartbeat stale → claim - B->>DB: CAS UPDATE owner=B, attempt=2
WHERE attempt=1 (atomic) + B->>DB: CAS UPDATE attempt=2, heartbeat_at=now()
WHERE attempt=1 AND heartbeat stale (atomic) B->>DB: SELECT prior events WHERE attempt=1 - Note over B: build resume input:
· original_request.input (full prior turns)
· + prose recovery message narrating attempt 1
· rotate conv_id → ::attempt-2 + Note over B: build resume input:
· original_request.input (full prior turns)
· + prose recovery message: json.dumps(events) + directive
· rotate conv_id → ::attempt-2 B->>DB: append response.resumed sentinel
{conversation_id: rotated value} B->>B: re-invoke @stream() handler
(fresh rotated SDK session) activate B @@ -112,7 +112,7 @@ sequenceDiagram **What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the attempt boundary and carries the rotated `conversation_id` so the chatbot can use the rotated session for subsequent turns. -**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus a single prose user message describing what happened in attempt 1. The handler doesn't have to know any of this is durable resume — to the model, it just looks like a user said "the prior attempt crashed, here's what had happened, please continue." +**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus a single user message whose body is `[RECOVERY] ... Events: `. The model reads it as "the prior attempt crashed, here's the raw event log, figure out what's done and continue." ### CUJ 3: Subsequent turn after a crashed turn @@ -158,9 +158,9 @@ sequenceDiagram Note over B,C: both pods see response in_progress, heartbeat stale par - B->>DB: UPDATE responses SET owner=B, attempt=N+1
WHERE response_id=R AND attempt=N + B->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale and - C->>DB: UPDATE responses SET owner=C, attempt=N+1
WHERE response_id=R AND attempt=N + C->>DB: UPDATE responses SET attempt=N+1, heartbeat_at=now()
WHERE response_id=R AND attempt=N AND heartbeat stale end Note over DB: only one row matches; the other UPDATE returns 0 rows DB-->>B: RETURNING attempt_number=N+1 @@ -229,26 +229,21 @@ flowchart TD ### 3.3 Resume input construction -When a stale-claim CAS succeeds, the new owner builds the resume input from the prior attempt's emitted events as a flat prose narrative: +When a stale-claim CAS succeeds, the new owner builds the resume input by serializing the prior attempt's stream events as JSON in a single user message: ```mermaid flowchart LR - PRIOR[prior attempt's events
from messages table] --> WALK[_build_prose_recovery_message] - WALK --> CALLS[walk completed function_call /
function_call_output items by call_id] - WALK --> NARR[walk completed message items
extract text] - WALK --> PARTIAL[walk in-progress message deltas
concat as 'was generating: ...'] - - CALLS --> COMPOSE - NARR --> COMPOSE - PARTIAL --> COMPOSE[compose single user message:
'[RECOVERY] previous attempt crashed.
Here is what had completed:
- Called f(...) and got result: ...
- Called g(...) — interrupted before result
- Was generating: ...'] + PRIOR[prior attempt's events
from messages table] --> FILTER[filter events by
attempt_number = prior_attempt] + FILTER --> JSON[json.dumps the events list] + JSON --> COMPOSE[compose single user message:
'[RECOVERY] previous attempt crashed.
Below is the raw stream-event log...
Inspect the events, figure out what is
already done versus in-progress, and continue.

Events: <json.dumps>'] ROT[_rotate_conversation_id
::attempt-N suffix] --> SUBMIT - COMPOSE --> SUBMIT[append to original_request.input
spawn handler with rotated request
emit response.resumed sentinel] + COMPOSE --> SUBMIT[append to original_request.input
spawn handler with rotated request
emit response.resumed sentinel
with rotated conversation_id] ``` -Why prose: the LLM reads it as a recovery instruction and decides what to do (re-run the interrupted tool, skip completed ones, summarize from there). No structural carry-forward, no synthetic tool events, no per-SDK pairing rules. Trades cache hit rate (the prose is a fresh user message, not a stable structural prefix) for SDK-agnostic simplicity. +Why JSON-dumped events: the LLM reads them as the authoritative record of what attempt 1 did and decides what to do — re-run an interrupted tool, skip completed ones, summarize from there. No structural carry-forward, no synthetic tool events, no per-SDK pairing rules. The handler doesn't have to know any of this is durable resume — it just sees a recovery user message in `request.input`. -Why rotation: the original SDK session may carry mid-turn state from the crashed attempt (orphan `tool_use`, partial checkpoint) that's hard to repair from outside the SDK. Rotating to `{base}::attempt-N` opens a fresh, empty session for the resumed attempt; the prose narrative is the single source of truth for what already happened. +Why rotation: the original SDK session may carry mid-turn state from the crashed attempt (orphan `tool_use`, partial checkpoint) that's hard to repair from outside the SDK. Rotating to `{base}::attempt-N` opens a fresh, empty session for the resumed attempt; the recovery message is the single source of truth for what already happened. Why the sentinel carries the rotated conv_id: cooperating chat clients capture it (via SSE) and use the rotated session for subsequent turns, so the original orphan-poisoned session is never read again. @@ -311,7 +306,7 @@ Defaults are tuned for a single Lakebase deployment with low-latency writes. The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler. -### 3.6 Claim atomicity +### 3.7 Claim atomicity ```mermaid sequenceDiagram @@ -322,17 +317,17 @@ sequenceDiagram Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress Note over Pod: stale → attempt CAS - Pod->>DB: UPDATE responses
SET owner_pod_id=B, attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number + Pod->>DB: UPDATE responses
SET attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number alt match DB-->>Pod: attempt_number=2 - Note over Pod: claim succeeded + Note over Pod: claim succeeded — Pod B is the new de facto owner
(prior owner's heartbeat at attempt=1 will fail next tick) else no match (another pod beat us) DB-->>Pod: (empty) Note over Pod: claim lost — back off end ``` -Postgres row locking ensures only one of N concurrent UPDATEs matches, so at most one pod ends up owning a given resume. +Postgres row locking ensures only one of N concurrent UPDATEs matches the `attempt_number = N` predicate, so at most one pod ends up owning a given resume. The bumped `attempt_number` simultaneously revokes the prior owner's heartbeat: their next heartbeat write `WHERE attempt_number = N` returns rowcount=0, telling them they've lost the claim. ## 4. Author-side requirements @@ -421,9 +416,9 @@ flowchart LR - `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas). - The MLflow `@invoke()` / `@stream()` handler convention. -- `_build_prose_recovery_message` — prose construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. -- `_rotate_conversation_id` and `_inject_conversation_id` — same reason. -- Author-visible settings (db config, heartbeat tuning, task timeout). +- `_build_prose_recovery_message` — recovery-message construction is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. +- `_rotate_conversation_id` — same reason. +- Author-visible settings (db config, heartbeat tuning, task timeout, stale-scan tuning). ### What gets deleted @@ -453,8 +448,9 @@ The point of the `LongRunningAgentServer` abstraction is exactly this kind of sw - **Code:** `src/databricks_ai_bridge/long_running/` - **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py` - **Settings:** `LongRunningSettings` in `settings.py` -- **Model:** `Response` and `Message` in `models.py` +- **Models:** `Response` and `Message` in `models.py` - **HTTP routes:** registered in `LongRunningAgentServer._setup_routes` - **Prose recovery:** `_build_prose_recovery_message` in `server.py` -- **UI-echo dedup:** `_trim_echoed_history` in `server.py` -- **Conversation rotation:** `_rotate_conversation_id` / `_inject_conversation_id` in `server.py` +- **Conversation rotation:** `_rotate_conversation_id` in `server.py` +- **Stale scanner:** `LongRunningAgentServer._stale_response_scanner_loop` in `server.py` +- **UI-echo dedup:** in agent code, see `app-templates/agent-{openai,langgraph}-advanced/` From 58950623ea3a948281b307ef21c9c6d2af703c19 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 30 Apr 2026 19:12:34 +0000 Subject: [PATCH 7/9] CAS-check attempt_number on terminal-status writes from racy paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Address Bryan's nit (PR #425, comment PRRC_kwDONA8Fvc685ROc): when writing terminal status to the DB, check that the current attempt_number matches the one this pod thinks it owns — prevents a stale background task (deferred-fail timer, or post-cleanup) from clobbering an in-progress state that another pod has already claimed for resume. Targeted at the truly-racy paths only (sleep + write): - _deferred_mark_failed: takes new owning_attempt_number kwarg. Reads current attempt before writing; skips both the error event append AND the status update if ownership has moved. Status update also passes expected_attempt_number to the repository for belt-and- suspenders CAS at the SQL layer. - _task_scope: takes new attempt_number kwarg. Inline error-cleanup passes it as expected_attempt_number; the two _deferred_mark_failed fallback paths pass it as owning_attempt_number. - _run_background_stream / _run_background_invoke wire their attempt_number through to _task_scope. Repository: - update_response_status gains optional expected_attempt_number kwarg that adds `WHERE attempt_number = :expected` to the UPDATE. The success-path writes (status=completed in _do_background_stream / _do_background_invoke) and the stuck-row force-fail in _handle_retrieve_request are intentionally not CAS'd here — they happen synchronously after the handler exits while the heartbeat is still alive, so the race window is tiny. Can revisit if observed in prod. Tests: 105 pass (was 104 — added ownership-changed test for _deferred_mark_failed; updated 2 existing tests for the new kwarg). --- .../long_running/repository.py | 15 +++- .../long_running/server.py | 68 +++++++++++++++---- .../test_long_running_server.py | 40 ++++++++++- 3 files changed, 106 insertions(+), 17 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index a3621ef2..2c6a1ff0 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -39,17 +39,30 @@ async def create_response( async def update_response_status( - response_id: str, status: str, *, expected_current_status: str | None = None + response_id: str, + status: str, + *, + expected_current_status: str | None = None, + expected_attempt_number: int | None = None, ) -> bool: """Update response status. Returns True if a row was updated. If *expected_current_status* is given the update only takes effect when the row's current status matches, avoiding concurrent-update races. + + If *expected_attempt_number* is given the update only takes effect when the + row's current ``attempt_number`` matches, ensuring only the pod that owns + the current attempt can transition the row to a terminal state. This + prevents a stale background task (e.g. a deferred-fail timer that fired + after another pod claimed the row for resume) from clobbering the new + owner's in-progress state. """ async with session_scope() as session: stmt = update(Response).where(Response.response_id == response_id) if expected_current_status is not None: stmt = stmt.where(Response.status == expected_current_status) + if expected_attempt_number is not None: + stmt = stmt.where(Response.attempt_number == expected_attempt_number) stmt = stmt.values(status=status) result = await session.execute(stmt) await session.commit() diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index f7a90999..38ce456b 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -60,13 +60,23 @@ async def _deferred_mark_failed( - response_id: str, delay: float = 2.0, reason: str = "Task timed out" + response_id: str, + delay: float = 2.0, + reason: str = "Task timed out", + *, + owning_attempt_number: int | None = None, ) -> None: """Mark a response as failed after a short delay. Runs as an independent asyncio task so the caller (``_task_scope``) can - return immediately. The delay lets the connection pool stabilise after - a cancellation before we attempt new DB writes. + return immediately. The delay lets the connection pool stabilise after a + cancellation before we attempt new DB writes. + + ``owning_attempt_number`` should be the attempt this pod was running when + the failure was scheduled. The terminal status update is CAS-checked + against it: if another pod has already claimed the row for a higher + attempt by the time this fires, we skip the failed-status write so we + don't clobber the new owner's state. """ try: await asyncio.sleep(delay) @@ -77,7 +87,16 @@ async def _deferred_mark_failed( async with asyncio.timeout(delay): existing = await get_messages(response_id, after_sequence=None) next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 - attempt = await _current_attempt(response_id) + current_attempt = await _current_attempt(response_id) + if owning_attempt_number is not None and current_attempt != owning_attempt_number: + logger.info( + "Skipping deferred fail for %s: ownership changed " + "(was attempt=%d, now attempt=%d)", + response_id, + owning_attempt_number, + current_attempt, + ) + return error_event = { "type": "error", @@ -92,9 +111,13 @@ async def _deferred_mark_failed( next_seq, item=None, stream_event=error_event, - attempt_number=attempt, + attempt_number=current_attempt, + ) + await update_response_status( + response_id, + "failed", + expected_attempt_number=owning_attempt_number, ) - await update_response_status(response_id, "failed") logger.info("Marked %s as failed (reason: %s)", response_id, reason) except TimeoutError: @@ -694,9 +717,18 @@ async def _beat(): @asynccontextmanager async def _task_scope( - self, response_id: str, state: dict[str, Any] + self, + response_id: str, + state: dict[str, Any], + *, + attempt_number: int = 1, ) -> AsyncGenerator[None, None]: - """Timeout + error handling wrapper for background tasks.""" + """Timeout + error handling wrapper for background tasks. + + ``attempt_number`` is CAS-checked on terminal-status writes so a + deferred fail / cleanup that fires after another pod has claimed the + row for resume doesn't clobber the new owner's in-progress state. + """ try: async with asyncio.timeout(self._settings.task_timeout_seconds): yield @@ -707,7 +739,11 @@ async def _task_scope( self._settings.task_timeout_seconds, ) asyncio.create_task( - _deferred_mark_failed(response_id, delay=self._settings.cleanup_timeout_seconds), + _deferred_mark_failed( + response_id, + delay=self._settings.cleanup_timeout_seconds, + owning_attempt_number=attempt_number, + ), name=f"deferred-fail-{response_id}", ) except Exception as exc: @@ -717,7 +753,6 @@ async def _task_scope( async with asyncio.timeout(self._settings.cleanup_timeout_seconds): existing = await get_messages(response_id, after_sequence=None) next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 - attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -730,9 +765,13 @@ async def _task_scope( "code": "task_failed", }, }, - attempt_number=attempt, + attempt_number=attempt_number, + ) + await update_response_status( + response_id, + "failed", + expected_attempt_number=attempt_number, ) - await update_response_status(response_id, "failed") except Exception: logger.exception( "[error-cleanup] Immediate update failed for %s, deferring", @@ -743,6 +782,7 @@ async def _task_scope( response_id, delay=self._settings.cleanup_timeout_seconds, reason=str(exc), + owning_attempt_number=attempt_number, ), name=f"deferred-fail-{response_id}", ) @@ -758,7 +798,7 @@ async def _run_background_stream( """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} async with ( - self._task_scope(response_id, state), + self._task_scope(response_id, state, attempt_number=attempt_number), self._heartbeat(response_id, attempt_number), ): await self._do_background_stream( @@ -872,7 +912,7 @@ async def _run_background_invoke( """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} async with ( - self._task_scope(response_id, state), + self._task_scope(response_id, state, attempt_number=attempt_number), self._heartbeat(response_id, attempt_number), ): await self._do_background_invoke( diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index e57c4e7e..5803a1e7 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -270,7 +270,9 @@ async def test_marks_response_failed(self): stream_event = args[1]["stream_event"] assert stream_event["type"] == "error" assert stream_event["error"]["code"] == "task_timeout" - mock_update.assert_awaited_once_with("resp_123", "failed") + mock_update.assert_awaited_once_with( + "resp_123", "failed", expected_attempt_number=None + ) @pytest.mark.asyncio async def test_handles_db_error_gracefully(self): @@ -282,6 +284,38 @@ async def test_handles_db_error_gracefully(self): # Should not raise await _deferred_mark_failed("resp_123", delay=0.01) + @pytest.mark.asyncio + async def test_skips_status_write_when_attempt_changed(self): + # The pod that scheduled this fail was running attempt=1; by the + # time this fires, another pod has bumped to attempt=2. We must NOT + # write terminal status. + with ( + patch( + "databricks_ai_bridge.long_running.server.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {"type": "response.created"})], + ), + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(attempt_number=2), + ), + patch( + "databricks_ai_bridge.long_running.server.append_message", + new_callable=AsyncMock, + ) as mock_append, + patch( + "databricks_ai_bridge.long_running.server.update_response_status", + new_callable=AsyncMock, + ) as mock_update, + ): + await _deferred_mark_failed( + "resp_123", delay=0.01, owning_attempt_number=1 + ) + # Neither append nor status-write fires when we've lost ownership. + mock_append.assert_not_awaited() + mock_update.assert_not_awaited() + class TestRetrieveRequest: @pytest.mark.asyncio @@ -735,7 +769,9 @@ async def test_exception_writes_error_event_inline(self): assert evt["error"]["message"] == "something broke" assert evt["error"]["code"] == "task_failed" assert mock_append.call_args.args[1] == 2 # next_seq - mock_update.assert_awaited_once_with("resp_err", "failed") + mock_update.assert_awaited_once_with( + "resp_err", "failed", expected_attempt_number=1 + ) @pytest.mark.asyncio async def test_exception_falls_back_to_deferred_on_db_failure(self): From 1406dfe5f289b1cbe0927ec660f0184e8caecd04 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 30 Apr 2026 19:24:41 +0000 Subject: [PATCH 8/9] Fix ruff format: collapse single-line assertions CI's `ruff format --check` flagged three calls in test_long_running_server.py that fit on one line but my local formatter had wrapped onto multiple lines (different ruff config / line-length resolution between local and CI). Collapse them. --- .../databricks_ai_bridge/test_long_running_server.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 5803a1e7..304031b7 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -270,9 +270,7 @@ async def test_marks_response_failed(self): stream_event = args[1]["stream_event"] assert stream_event["type"] == "error" assert stream_event["error"]["code"] == "task_timeout" - mock_update.assert_awaited_once_with( - "resp_123", "failed", expected_attempt_number=None - ) + mock_update.assert_awaited_once_with("resp_123", "failed", expected_attempt_number=None) @pytest.mark.asyncio async def test_handles_db_error_gracefully(self): @@ -309,9 +307,7 @@ async def test_skips_status_write_when_attempt_changed(self): new_callable=AsyncMock, ) as mock_update, ): - await _deferred_mark_failed( - "resp_123", delay=0.01, owning_attempt_number=1 - ) + await _deferred_mark_failed("resp_123", delay=0.01, owning_attempt_number=1) # Neither append nor status-write fires when we've lost ownership. mock_append.assert_not_awaited() mock_update.assert_not_awaited() @@ -769,9 +765,7 @@ async def test_exception_writes_error_event_inline(self): assert evt["error"]["message"] == "something broke" assert evt["error"]["code"] == "task_failed" assert mock_append.call_args.args[1] == 2 # next_seq - mock_update.assert_awaited_once_with( - "resp_err", "failed", expected_attempt_number=1 - ) + mock_update.assert_awaited_once_with("resp_err", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_exception_falls_back_to_deferred_on_db_failure(self): From 69bd088bf478cbf1a3b5ceb3d4635011aa59f5fe Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Fri, 1 May 2026 20:20:13 +0000 Subject: [PATCH 9/9] Extend attempt_number CAS to ALL terminal-status writes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per Bryan's follow-up review (PRRC_kwDONA8Fvc685TAX, PRRC_kwDONA8Fvc69ES6Z): the success-path completed-status writes and the pre-flight no-handler-registered failed writes also need the ownership check. Without it, a pod whose heartbeat stalls (network / DB hiccup) can still finish its handler and write status="completed" after another pod has CAS-claimed the row for resume — pollers and streamers then stop early on a stale terminal status. Sites updated to pass expected_attempt_number=attempt_number: - _do_background_stream success: status="completed" (was unguarded) - _do_background_invoke success: status="completed" (was unguarded) - _do_background_stream pre-flight: status="failed" when no stream_fn - _do_background_invoke pre-flight: status="failed" when no invoke_fn The two success paths log + early-return when the CAS misses, signaling "this pod no longer owns the run" without escalating to an error. Tests: 105 pass. Updated 5 tests for the new kwarg + return_value=True on the success-path mock (so the new `if not updated` branch doesn't trip in the no-contention case). This addresses Bryan's "lgtm after fixing the last nit". --- .../long_running/server.py | 34 +++++++++++++++--- .../test_long_running_server.py | 36 +++++++++++++------ 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 38ce456b..40007ef4 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -825,7 +825,9 @@ async def _do_background_stream( """Run agent via stream_fn, persist each stream event as a message row.""" stream_fn = get_stream_function() if stream_fn is None: - await update_response_status(response_id, "failed") + await update_response_status( + response_id, "failed", expected_attempt_number=attempt_number + ) raise RuntimeError("No stream function registered; cannot run background stream") func_name = stream_fn.__name__ @@ -891,7 +893,18 @@ async def _do_background_stream( attempt_number=attempt_number, ) - await update_response_status(response_id, "completed") + updated = await update_response_status( + response_id, "completed", expected_attempt_number=attempt_number + ) + if not updated: + logger.info( + "[durable] skipped completed-status write response_id=%s attempt=%d " + "(another pod claimed the row mid-handler); pod=%s", + response_id, + attempt_number, + _POD_LOG_ID, + ) + return logger.info( "[durable] background stream completed response_id=%s attempt=%d " "total_events=%d pod=%s", @@ -935,7 +948,9 @@ async def _do_background_invoke( """Run agent via invoke_fn, persist each output item as a message row.""" invoke_fn = get_invoke_function() if invoke_fn is None: - await update_response_status(response_id, "failed") + await update_response_status( + response_id, "failed", expected_attempt_number=attempt_number + ) raise RuntimeError("No invoke function registered; cannot run background invoke") func_name = invoke_fn.__name__ @@ -975,7 +990,18 @@ async def _do_background_invoke( state["seq"] = seq + 1 if return_trace_id: await update_response_trace_id(response_id, span.trace_id) - await update_response_status(response_id, "completed") + updated = await update_response_status( + response_id, "completed", expected_attempt_number=attempt_number + ) + if not updated: + logger.info( + "[durable] skipped completed-status write response_id=%s attempt=%d " + "(another pod claimed the row mid-handler); pod=%s", + response_id, + attempt_number, + _POD_LOG_ID, + ) + return logger.debug( "Background invoke completed", extra={"response_id": response_id, "output_items": len(output)}, diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 304031b7..ae33b05c 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -507,7 +507,9 @@ async def fake_stream(request_data): patch(f"{MODULE}.get_stream_function", return_value=fake_stream), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.ResponsesAgent") as mock_ra, ): mock_mlflow.start_span.return_value = span @@ -522,7 +524,7 @@ async def fake_stream(request_data): assert seqs == [0, 1, 2] # Verify state tracks final seq assert state["seq"] == 3 - mock_update.assert_awaited_once_with("resp_1", "completed") + mock_update.assert_awaited_once_with("resp_1", "completed", expected_attempt_number=1) @pytest.mark.asyncio async def test_calls_transform_stream_event(self): @@ -574,12 +576,14 @@ async def test_no_stream_fn_marks_failed(self): with ( patch(f"{MODULE}.get_stream_function", return_value=None), - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 0} with pytest.raises(RuntimeError, match="No stream function registered"): await server._do_background_stream("resp_x", {}, False, state) - mock_update.assert_awaited_once_with("resp_x", "failed") + mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_persists_trace_id_when_requested(self): @@ -628,7 +632,9 @@ async def fake_invoke(request_data): patch(f"{MODULE}.get_invoke_function", return_value=fake_invoke), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock), ): mock_mlflow.start_span.return_value = span @@ -646,7 +652,7 @@ async def fake_invoke(request_data): assert evt["type"] == "response.output_item.done" assert "item" in evt assert state["seq"] == 2 - mock_update.assert_awaited_once_with("resp_inv", "completed") + mock_update.assert_awaited_once_with("resp_inv", "completed", expected_attempt_number=1) @pytest.mark.asyncio async def test_trace_id_persisted_when_requested(self): @@ -677,12 +683,14 @@ async def test_no_invoke_fn_marks_failed(self): with ( patch(f"{MODULE}.get_invoke_function", return_value=None), - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 0} with pytest.raises(RuntimeError, match="No invoke function registered"): await server._do_background_invoke("resp_x", {}, False, state) - mock_update.assert_awaited_once_with("resp_x", "failed") + mock_update.assert_awaited_once_with("resp_x", "failed", expected_attempt_number=1) @pytest.mark.asyncio async def test_sync_invoke_fn_supported(self): @@ -697,7 +705,9 @@ def sync_invoke(request_data): patch(f"{MODULE}.get_invoke_function", return_value=sync_invoke), patch(f"{MODULE}.mlflow") as mock_mlflow, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, patch(f"{MODULE}.update_response_trace_id", new_callable=AsyncMock), ): mock_mlflow.start_span.return_value = span @@ -706,7 +716,9 @@ def sync_invoke(request_data): await server._do_background_invoke("resp_sync", {"input": "hi"}, False, state) assert mock_append.await_count == 1 - mock_update.assert_awaited_once_with("resp_sync", "completed") + mock_update.assert_awaited_once_with( + "resp_sync", "completed", expected_attempt_number=1 + ) # --------------------------------------------------------------------------- @@ -752,7 +764,9 @@ async def test_exception_writes_error_event_inline(self): return_value=_resp_info(), ), patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, - patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, + patch( + f"{MODULE}.update_response_status", new_callable=AsyncMock, return_value=True + ) as mock_update, ): state = {"seq": 2} async with server._task_scope("resp_err", state):