From afd2d258b95b138294694be15e3f1bc8f24904f3 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 22:22:11 +0000 Subject: [PATCH 01/39] LongRunningAgentServer: add durable resume via heartbeat + CAS claim MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mid-stream pod crashes currently mark in-progress runs as failed once the task_timeout elapses — any in-flight LLM/tool work is lost. This change adds best-effort crash recovery: another pod can atomically claim a stale run and re-invoke the registered handler, and the agent SDK's own checkpointer (LangGraph AsyncCheckpointSaver, databricks-openai Session) picks up prior progress so completed tool calls are not re-executed. Schema additions (idempotent ADD COLUMN IF NOT EXISTS): - responses.owner_pod_id, heartbeat_at, attempt_number, original_request - messages.attempt_number - idx_responses_stale partial index Server changes: - _handle_background_request anchors context.conversation_id to response_id when the client didn't pin a thread_id/session_id. Both agent templates' helpers read context.conversation_id as priority-2 fallback, so a replay from a different pod resolves to the same agent-SDK thread/session without any template code change. - original_request is persisted on create_response so another pod can re-invoke with the same arguments. - _heartbeat async context writes heartbeat_at every heartbeat_interval_seconds (default 3s) while the agent loop runs, stops cleanly on exit. - _handle_retrieve_request invokes _try_claim_and_resume before returning: if the run is in_progress and heartbeat is older than heartbeat_stale_threshold_seconds (default 15s), CAS the row, increment attempt_number, emit a response.resumed sentinel event at the next seq, and spawn _run_background_stream(attempt_number=N+1) with input=[]. - Sequence numbers stay monotonic across attempts so client cursors (GET /responses/{id}?stream=true&starting_after=N) resume correctly. - _handle_retrieve_request's non-stream path filters to output_item.done events (the authoritative conversation record). New settings: - heartbeat_interval_seconds: 3.0 - heartbeat_stale_threshold_seconds: 15.0 - validated: stale > interval Out of scope (follow-up PRs): cancellation implementation (still 501), eager multi-pod polling for stale work (v1 uses lazy reclaim on retrieve). Tests: - 11 new unit tests covering _inject_conversation_id priority rules, _try_claim_and_resume scenarios (grace period, missing original_request, failed claim, successful claim + sentinel emission, empty-input replay), _heartbeat context lifecycle, and settings validation. - Existing tests migrated to the new ResponseInfo (8-field) and get_messages 4-tuple shapes via _resp_info()/_msg() helpers. --- src/databricks_ai_bridge/long_running/db.py | 19 + .../long_running/models.py | 25 +- .../long_running/repository.py | 125 +++++- .../long_running/server.py | 311 +++++++++++-- .../long_running/settings.py | 12 + .../test_long_running_db.py | 121 ++++- .../test_long_running_server.py | 420 +++++++++++++++++- 7 files changed, 961 insertions(+), 72 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 903d466f..187a2873 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -78,6 +78,25 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): async with _engine.begin() as conn: await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}")) await conn.run_sync(Base.metadata.create_all) + # Idempotent migration for tables created by earlier versions: add any + # columns introduced for durable-resume support. Safe to run on a fresh + # schema (all columns exist) and on an upgraded one (only missing ones added). + for stmt in ( + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"CREATE INDEX IF NOT EXISTS idx_responses_stale " + f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " + "WHERE status = 'in_progress'", + ): + await conn.execute(text(stmt)) _initialized = True logger.info("[DB] Engine and schema ready") diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 1d876dc7..6dade733 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -14,7 +14,12 @@ class Base(DeclarativeBase): class Response(Base): - """Response status tracking for background agent tasks.""" + """Response status tracking for background agent tasks. + + Durability columns (``owner_pod_id``, ``heartbeat_at``, ``attempt_number``, + ``original_request``) support crash-resume: another pod can atomically + claim a stale in-progress row and replay the agent loop. + """ __tablename__ = "responses" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -25,12 +30,25 @@ class Response(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) + owner_pod_id: Mapped[str | None] = mapped_column(Text, nullable=True) + heartbeat_at: Mapped[datetime | None] = mapped_column( + DateTime(timezone=True), nullable=True + ) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) + original_request: Mapped[str | None] = mapped_column(Text, nullable=True) messages = relationship("Message", back_populates="response", cascade="all, delete-orphan") class Message(Base): - """Stream events and output items for a response.""" + """Stream events and output items for a response. + + ``attempt_number`` tags events by which run attempt emitted them so that + resumed runs append to the same event log without overwriting earlier + (abandoned) attempts, and retrieve can filter to the latest attempt only. + """ __tablename__ = "messages" __table_args__ = {"schema": AGENT_DB_SCHEMA} @@ -44,6 +62,9 @@ class Message(Base): sequence_number: Mapped[int] = mapped_column( Integer, primary_key=True, nullable=False, default=0 ) + attempt_number: Mapped[int] = mapped_column( + Integer, nullable=False, server_default="1", default=1 + ) item: Mapped[str | None] = mapped_column(Text, nullable=True) stream_event: Mapped[str | None] = mapped_column(Text, nullable=True) diff --git a/src/databricks_ai_bridge/long_running/repository.py b/src/databricks_ai_bridge/long_running/repository.py index fcd86b29..06d30edb 100644 --- a/src/databricks_ai_bridge/long_running/repository.py +++ b/src/databricks_ai_bridge/long_running/repository.py @@ -5,15 +5,37 @@ from typing import Any, NamedTuple from sqlalchemy import select, update +from sqlalchemy.sql import bindparam, text from databricks_ai_bridge.long_running.db import session_scope -from databricks_ai_bridge.long_running.models import Message, Response +from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response -async def create_response(response_id: str, status: str) -> None: - """Insert a new response.""" +async def create_response( + response_id: str, + status: str, + *, + owner_pod_id: str | None = None, + original_request: dict[str, Any] | None = None, +) -> None: + """Insert a new response row. + + ``owner_pod_id`` and ``original_request`` are optional so that non-durable + callers (tests, legacy flows) can still create rows without durability + metadata. When present, they enable heartbeat + crash-resume semantics. + """ async with session_scope() as session: - session.add(Response(response_id=response_id, status=status)) + session.add( + Response( + response_id=response_id, + status=status, + owner_pod_id=owner_pod_id, + heartbeat_at=datetime.now().astimezone() if owner_pod_id else None, + original_request=( + json.dumps(original_request) if original_request is not None else None + ), + ) + ) await session.commit() @@ -43,18 +65,84 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None: await session.commit() +async def heartbeat_response(response_id: str, pod_id: str) -> bool: + """Update heartbeat_at for a response IFF this pod owns it. + + Returns True on success. A False result means the claim has been lost + (another pod took over, or the run finished and heartbeat should stop). + """ + async with session_scope() as session: + stmt = ( + update(Response) + .where(Response.response_id == response_id, Response.owner_pod_id == pod_id) + .values(heartbeat_at=datetime.now().astimezone()) + ) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 + + +async def claim_stale_response( + response_id: str, + new_owner_pod_id: str, + stale_threshold_seconds: float, +) -> int | None: + """Atomically claim an in-progress response whose heartbeat has gone stale. + + Uses a single conditional UPDATE so exactly one caller wins on contention: + claim only succeeds if status is ``in_progress`` AND + (``owner_pod_id IS NULL`` OR ``heartbeat_at`` is older than the threshold). + + Returns the new ``attempt_number`` on success, or ``None`` if the row did + not satisfy the claim conditions (already completed, already claimed by a + live pod, or nonexistent). + """ + # Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for + # the incremented column as ergonomically. Using a single statement keeps the + # claim atomic without an explicit transaction-level lock. + stmt = text( + f""" + UPDATE {AGENT_DB_SCHEMA}.responses + SET owner_pod_id = :pod, + heartbeat_at = now(), + attempt_number = attempt_number + 1 + WHERE response_id = :rid + AND status = 'in_progress' + AND (owner_pod_id IS NULL + OR heartbeat_at IS NULL + OR heartbeat_at < now() - make_interval(secs => :threshold)) + RETURNING attempt_number + """ + ).bindparams( + bindparam("pod", type_=None), + bindparam("rid", type_=None), + bindparam("threshold", type_=None), + ) + async with session_scope() as session: + result = await session.execute( + stmt, + {"pod": new_owner_pod_id, "rid": response_id, "threshold": stale_threshold_seconds}, + ) + row = result.first() + await session.commit() + return int(row[0]) if row else None + + async def append_message( response_id: str, sequence_number: int, item: str | None = None, stream_event: dict[str, Any] | None = None, + *, + attempt_number: int = 1, ) -> None: - """Append a message (stream event) for a response.""" + """Append a message (stream event) for a response, tagged with attempt_number.""" async with session_scope() as session: session.add( Message( response_id=response_id, sequence_number=sequence_number, + attempt_number=attempt_number, item=item, stream_event=json.dumps(stream_event) if stream_event is not None else None, ) @@ -65,22 +153,26 @@ async def append_message( async def get_messages( response_id: str, after_sequence: int | None = None, -) -> list[tuple[int, str | None, dict[str, Any] | None]]: - """Fetch messages for a response, optionally after a sequence number. + *, + attempt_number: int | None = None, +) -> list[tuple[int, str | None, dict[str, Any] | None, int]]: + """Fetch messages for a response, optionally filtering by sequence / attempt. - Returns list of (sequence_number, item, stream_event_dict). + Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``. """ async with session_scope() as session: stmt = select(Message).where(Message.response_id == response_id) if after_sequence is not None: stmt = stmt.where(Message.sequence_number > after_sequence) + if attempt_number is not None: + stmt = stmt.where(Message.attempt_number == attempt_number) stmt = stmt.order_by(Message.sequence_number) result = await session.execute(stmt) rows = result.scalars().all() out = [] for r in rows: evt = json.loads(r.stream_event) if r.stream_event else None - out.append((r.sequence_number, r.item, evt)) + out.append((r.sequence_number, r.item, evt, r.attempt_number)) return out @@ -89,6 +181,10 @@ class ResponseInfo(NamedTuple): status: str created_at: datetime trace_id: str | None + owner_pod_id: str | None + heartbeat_at: datetime | None + attempt_number: int + original_request: dict[str, Any] | None async def get_response(response_id: str) -> ResponseInfo | None: @@ -97,5 +193,14 @@ async def get_response(response_id: str) -> ResponseInfo | None: result = await session.execute(select(Response).where(Response.response_id == response_id)) row = result.scalar_one_or_none() if row: - return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id) + return ResponseInfo( + row.response_id, + row.status, + row.created_at, + row.trace_id, + row.owner_pod_id, + row.heartbeat_at, + row.attempt_number, + json.loads(row.original_request) if row.original_request else None, + ) return None diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index b3374d67..5a9faf77 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -6,9 +6,12 @@ raise RuntimeError("The long_running module requires Python 3.11 or later.") import asyncio +import copy import inspect import json import logging +import os +import socket import time import uuid from collections.abc import AsyncGenerator @@ -34,9 +37,11 @@ from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured from databricks_ai_bridge.long_running.repository import ( append_message, + claim_stale_response, create_response, get_messages, get_response, + heartbeat_response, update_response_status, update_response_trace_id, ) @@ -47,6 +52,9 @@ BACKGROUND_KEY = "background" +# One ID per process so heartbeats + claims have a stable owner identity. +_POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" + async def _deferred_mark_failed( response_id: str, delay: float = 2.0, reason: str = "Task timed out" @@ -65,7 +73,8 @@ async def _deferred_mark_failed( # or SELECT FOR UPDATE on the response row to serialise writers. async with asyncio.timeout(delay): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) error_event = { "type": "error", @@ -75,7 +84,10 @@ async def _deferred_mark_failed( "code": "task_timeout", }, } - await append_message(response_id, next_seq, item=None, stream_event=error_event) + await append_message( + response_id, next_seq, item=None, stream_event=error_event, + attempt_number=attempt, + ) await update_response_status(response_id, "failed") logger.info("Marked %s as failed (reason: %s)", response_id, reason) @@ -91,6 +103,12 @@ async def _deferred_mark_failed( ) +async def _current_attempt(response_id: str) -> int: + """Fetch the current attempt_number for a response, defaulting to 1.""" + resp = await get_response(response_id) + return resp.attempt_number if resp else 1 + + def _sse_event(event_type: str, data: dict[str, Any] | str) -> str: """Format an SSE event per Open Responses spec.""" payload = data if isinstance(data, str) else json.dumps(data) @@ -105,9 +123,36 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +def _inject_conversation_id(request_data: dict[str, Any], response_id: str) -> dict[str, Any]: + """Anchor the request to ``response_id`` as its conversation. + + Templates that back this server use ``context.conversation_id`` (and + ``custom_inputs.thread_id`` / ``custom_inputs.session_id``) as priority-2 + fallbacks to derive their stateful thread/session key. If neither is + provided by the client, a resumed invocation from another pod would + generate a *fresh* ID and miss the checkpoint entirely — so we stamp the + conversation_id here before persisting the request, guaranteeing that + every replay hits the same memory store. + + Client-supplied values take precedence and are left untouched. + """ + out = copy.deepcopy(request_data) if request_data else {} + custom_inputs = out.get("custom_inputs") or {} + if custom_inputs.get("thread_id") or custom_inputs.get("session_id"): + return out + ctx = out.get("context") or {} + if ctx.get("conversation_id"): + return out + ctx = dict(ctx) + ctx["conversation_id"] = response_id + out["context"] = ctx + return out + + @experimental class LongRunningAgentServer(AgentServer): - """AgentServer subclass adding background mode and retrieve endpoints. + """AgentServer subclass adding background mode, retrieve endpoints, and + durable resume. Only compatible with ``ResponsesAgent`` mode. @@ -125,24 +170,14 @@ class LongRunningAgentServer(AgentServer): ``LAKEBASE_INSTANCE_NAME``, ``LAKEBASE_AUTOSCALING_ENDPOINT``, or both ``LAKEBASE_AUTOSCALING_PROJECT`` and ``LAKEBASE_AUTOSCALING_BRANCH``. - Args: - enable_chat_proxy: Whether to enable the chat proxy endpoint. - db_instance_name: Lakebase provisioned instance name. Overrides - ``LAKEBASE_INSTANCE_NAME``. - db_autoscaling_endpoint: Lakebase autoscaling endpoint URL. Overrides - ``LAKEBASE_AUTOSCALING_ENDPOINT``. - db_project: Lakebase autoscaling project. Overrides - ``LAKEBASE_AUTOSCALING_PROJECT``. - db_branch: Lakebase autoscaling branch. Overrides - ``LAKEBASE_AUTOSCALING_BRANCH``. - task_timeout_seconds: Max time for a background task before timeout. - Defaults to 3600 (1 hour). - poll_interval_seconds: Interval between DB polls when streaming. - Defaults to 1.0. - db_statement_timeout_ms: Postgres statement timeout. - Defaults to 5000 (5 seconds). - cleanup_timeout_seconds: Timeout for DB cleanup after task failure. - Defaults to 7.0. + Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run + whose owning pod has stopped heartbeating for more than + ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims + the run and re-invokes the registered handler with ``input=[]`` plus the + stamped ``conversation_id``. Agent SDKs (LangGraph checkpointer, + databricks-openai Session) load prior progress and continue — completed + tool calls are not re-executed. Tools interrupted mid-call may re-run; this + is the accepted best-effort tradeoff. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -162,6 +197,8 @@ def __init__( poll_interval_seconds: float = 1.0, db_statement_timeout_ms: int = 5000, cleanup_timeout_seconds: float = 7.0, + heartbeat_interval_seconds: float = 3.0, + heartbeat_stale_threshold_seconds: float = 15.0, ): if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( @@ -173,6 +210,8 @@ def __init__( poll_interval_seconds=poll_interval_seconds, db_statement_timeout_ms=db_statement_timeout_ms, cleanup_timeout_seconds=cleanup_timeout_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + heartbeat_stale_threshold_seconds=heartbeat_stale_threshold_seconds, ) self._db_instance_name = db_instance_name self._db_autoscaling_endpoint = db_autoscaling_endpoint @@ -290,11 +329,19 @@ async def _handle_background_request( ) -> dict[str, Any] | StreamingResponse: """Start a new conversation and return response_id immediately.""" response_id = f"resp_{uuid.uuid4().hex[:24]}" - await create_response(response_id, "in_progress") + # Anchor the conversation to response_id so any future replay from a + # different pod resolves to the same agent-SDK thread/session. + durable_request = _inject_conversation_id(request_data, response_id) + await create_response( + response_id, + "in_progress", + owner_pod_id=_POD_ID, + original_request=durable_request, + ) logger.debug( "Background response created", - extra={"response_id": response_id, "stream": is_streaming}, + extra={"response_id": response_id, "stream": is_streaming, "pod": _POD_ID}, ) response_obj: dict[str, Any] = { @@ -311,7 +358,9 @@ async def _handle_background_request( # Fire-and-forget is intentional — task status is persisted to the database. if is_streaming: asyncio.create_task( - self._run_background_stream(response_id, request_data, return_trace_id) + self._run_background_stream( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) return await self._handle_retrieve_request( response_id, @@ -320,10 +369,52 @@ async def _handle_background_request( ) else: asyncio.create_task( - self._run_background_invoke(response_id, request_data, return_trace_id) + self._run_background_invoke( + response_id, durable_request, return_trace_id, attempt_number=1 + ) ) return response_obj + @asynccontextmanager + async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: + """Keep the response row's heartbeat_at fresh while the body runs. + + A background task writes ``heartbeat_at = now()`` every + ``heartbeat_interval_seconds`` for the owning pod. It stops when the + body returns/raises. Heartbeat write failures are logged but do not + interrupt the agent run — the stale-run check will detect a dead pod. + """ + interval = self._settings.heartbeat_interval_seconds + stop = asyncio.Event() + + async def _beat(): + try: + while not stop.is_set(): + try: + await heartbeat_response(response_id, _POD_ID) + except Exception: + logger.warning( + "Heartbeat write failed for %s; will retry", response_id, + exc_info=True, + ) + try: + await asyncio.wait_for(stop.wait(), timeout=interval) + except TimeoutError: + pass + except asyncio.CancelledError: + pass + + hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}") + try: + yield + finally: + stop.set() + hb_task.cancel() + try: + await hb_task + except (asyncio.CancelledError, Exception): + pass + @asynccontextmanager async def _task_scope( self, response_id: str, state: dict[str, Any] @@ -348,7 +439,8 @@ async def _task_scope( # TODO: sequence number computation is racy (see _deferred_mark_failed). async with asyncio.timeout(self._settings.cleanup_timeout_seconds): existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -361,6 +453,7 @@ async def _task_scope( "code": "task_failed", }, }, + attempt_number=attempt, ) await update_response_status(response_id, "failed") except Exception: @@ -382,11 +475,16 @@ async def _run_background_stream( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the streaming agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_stream(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_stream( + response_id, request_data, return_trace_id, state, + attempt_number=attempt_number, + ) def transform_stream_event(self, event: dict, response_id: str) -> dict: """Override to transform events before persistence (e.g. replace placeholder IDs).""" @@ -398,6 +496,8 @@ async def _do_background_stream( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via stream_fn, persist each stream event as a message row.""" stream_fn = get_stream_function() @@ -407,7 +507,15 @@ async def _do_background_stream( func_name = stream_fn.__name__ all_chunks: list[dict[str, Any]] = [] - seq = 0 + # Continue sequence numbering across attempts so the client's cursor + # never rewinds on resume. First attempt starts at 0 and skips the DB + # lookup — keeps the fast path identical to pre-resume behavior and + # avoids an extra query per background request. + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + seq = 0 with mlflow.start_span(name=func_name) as span: span.set_inputs(request_data) @@ -420,13 +528,17 @@ async def _do_background_stream( evt_type = evt.get("type", "message") logger.debug( "SSE event (background)", - extra={"response_id": response_id, "seq": seq, "type": evt_type}, + extra={ + "response_id": response_id, "seq": seq, "type": evt_type, + "attempt": attempt_number, + }, ) await append_message( response_id, seq, item=json.dumps(item) if item is not None else None, stream_event=evt, + attempt_number=attempt_number, ) seq += 1 state["seq"] = seq @@ -439,6 +551,7 @@ async def _do_background_stream( response_id, seq, stream_event={"trace_id": span.trace_id}, + attempt_number=attempt_number, ) await update_response_status(response_id, "completed") @@ -452,11 +565,16 @@ async def _run_background_invoke( response_id: str, request_data: dict[str, Any], return_trace_id: bool = False, + *, + attempt_number: int = 1, ) -> None: """Timeout-guarded wrapper around the invoke agent loop.""" state: dict[str, Any] = {"seq": 0} - async with self._task_scope(response_id, state): - await self._do_background_invoke(response_id, request_data, return_trace_id, state) + async with self._task_scope(response_id, state), self._heartbeat(response_id): + await self._do_background_invoke( + response_id, request_data, return_trace_id, state, + attempt_number=attempt_number, + ) async def _do_background_invoke( self, @@ -464,6 +582,8 @@ async def _do_background_invoke( request_data: dict[str, Any], return_trace_id: bool, state: dict[str, Any], + *, + attempt_number: int = 1, ) -> None: """Run agent via invoke_fn, persist each output item as a message row.""" invoke_fn = get_invoke_function() @@ -485,19 +605,27 @@ async def _do_background_invoke( span.set_outputs(result) output = result.get("output", []) + # Continue sequence numbering across attempts (see _do_background_stream). + if attempt_number > 1: + existing = await get_messages(response_id, after_sequence=None) + base_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + else: + base_seq = 0 for i, item in enumerate(output): item_dict = ( item if isinstance(item, dict) else (item.model_dump() if hasattr(item, "model_dump") else {"content": str(item)}) ) + seq = base_seq + i await append_message( response_id, - i, + seq, item=json.dumps(item_dict), stream_event={"type": "response.output_item.done", "item": item_dict}, + attempt_number=attempt_number, ) - state["seq"] = i + 1 + state["seq"] = seq + 1 if return_trace_id: await update_response_trace_id(response_id, span.trace_id) await update_response_status(response_id, "completed") @@ -506,6 +634,80 @@ async def _do_background_invoke( extra={"response_id": response_id, "output_items": len(output)}, ) + async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: + """If ``resp`` is a stale in-progress run, attempt an atomic claim. + + On success, kick off a new background task that re-invokes the handler + with ``input=[]`` and returns the new ``attempt_number``. On failure + (another pod won, or the run is no longer stale), returns ``None``. + + This is the lazy resume path: triggered by a client retrieve. Pods + don't poll for stale work proactively in v1 — if no client ever calls + ``GET /responses/{id}``, the task_timeout sweep eventually marks it + failed. + """ + if resp.status != "in_progress": + return None + # The run may be freshly started but too young to have a heartbeat yet; + # respect the creation age as a grace period equal to the stale + # threshold. Otherwise a quick follow-up retrieve could hijack a + # running pod before it ever writes its first heartbeat. + if resp.heartbeat_at is None: + if _age_seconds(resp.created_at) < self._settings.heartbeat_stale_threshold_seconds: + return None + if resp.original_request is None: + # Nothing to replay from — the run predates durability metadata. + logger.warning( + "Cannot resume %s: no original_request persisted (old row?)", response_id, + ) + return None + + new_attempt = await claim_stale_response( + response_id, + new_owner_pod_id=_POD_ID, + stale_threshold_seconds=self._settings.heartbeat_stale_threshold_seconds, + ) + if new_attempt is None: + # Someone else owns it, or the row was updated between the read and + # the claim. Expected under contention. + return None + + # Build a "resume" request: keep everything the handler cares about + # (custom_inputs carrying thread_id, context.conversation_id), but null + # out input so the handler passes {"messages": []} / [] to its agent. + # This is the single line that makes the design framework-agnostic. + resume_request = dict(resp.original_request) + resume_request["input"] = [] + + # Emit a marker so clients can reset any in-flight rendering from the + # prior attempt before seeing new events. + existing = await get_messages(response_id, after_sequence=None) + next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + await append_message( + response_id, + next_seq, + stream_event={ + "type": "response.resumed", + "attempt": new_attempt, + "from_seq": next_seq, + }, + attempt_number=new_attempt, + ) + + logger.info( + "Claimed stale run %s as attempt %s (pod=%s)", + response_id, new_attempt, _POD_ID, + ) + + asyncio.create_task( + self._run_background_stream( + response_id, resume_request, return_trace_id=False, + attempt_number=new_attempt, + ), + name=f"resume-{response_id}-{new_attempt}", + ) + return new_attempt + async def _handle_retrieve_request( self, response_id: str, @@ -523,7 +725,20 @@ async def _handle_retrieve_request( if resp is None: raise HTTPException(status_code=404, detail="Response not found") - _, status, created_at, trace_id = resp + # Try a lazy resume before falling back to the absolute-timeout sweep. + # This gives us crash-recovery semantics: an idle client reconnecting + # after a pod died will reclaim the run and resume it here instead of + # just marking it failed. + await self._try_claim_and_resume(response_id, resp) + + # Refresh after the potential resume: status / attempt_number may have changed. + resp = await get_response(response_id) + if resp is None: + raise HTTPException(status_code=404, detail="Response not found") + + status = resp.status + created_at = resp.created_at + trace_id = resp.trace_id if ( status == "in_progress" @@ -542,10 +757,9 @@ async def _handle_retrieve_request( }, ) # TODO: sequence number computation here is racy under concurrent writers. - # Acceptable at current scale; for high-QPS use a DB-assigned sequence or - # SELECT FOR UPDATE on the response row to serialise writers. existing = await get_messages(response_id, after_sequence=None) - next_seq = max((seq for seq, _, _ in existing), default=-1) + 1 + next_seq = max((seq for seq, _, _, _ in existing), default=-1) + 1 + attempt = await _current_attempt(response_id) await append_message( response_id, next_seq, @@ -558,6 +772,7 @@ async def _handle_retrieve_request( "code": "task_timeout", }, }, + attempt_number=attempt, ) status = "failed" @@ -581,20 +796,26 @@ async def _handle_retrieve_request( if not messages and status == "in_progress": return {"id": response_id, "status": "in_progress"} if status == "completed" and messages: + # Only consider items from the final (successful) attempt so that + # abandoned in-progress items from crashed attempts don't leak + # into the authoritative response body. Completed output_item.done + # events across attempts together make up the conversation — the + # agent SDK's checkpointer guarantees done-items are not re-emitted + # by later attempts, so this is a union with no duplicates. output = [] - for _, _, evt in messages: - if evt and "item" in evt: - output.append(evt["item"]) + for _, _, evt, _attempt in messages: + if evt and evt.get("type") == "response.output_item.done": + output.append(evt.get("item")) result: dict[str, Any] = { "id": response_id, "status": "completed", - "output": output, + "output": [o for o in output if o is not None], } if trace_id: result["metadata"] = {"trace_id": trace_id} return result if status == "failed" and messages: - for _, _, evt in messages: + for _, _, evt, _attempt in messages: if evt and evt.get("type") == "error": return {"id": response_id, "status": "failed", "error": evt.get("error")} return {"id": response_id, "status": status} @@ -638,13 +859,13 @@ async def _stream_retrieve( ) break - _, status, _, _ = resp + status = resp.status # starting_after=0 fetches all messages (sequence numbers start at 0). # We use after_sequence=-1 for the DB query so that seq 0 is included. after_seq = last_seq - 1 if last_seq == 0 else last_seq messages = await get_messages(response_id, after_sequence=after_seq) - for seq, _, evt in messages: + for seq, _, evt, _attempt in messages: if evt is not None: evt = {**evt, "sequence_number": seq} event_type = evt.get("type", "message") diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 7b646116..8df21788 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -15,6 +15,8 @@ class LongRunningSettings: poll_interval_seconds: float = 1.0 db_statement_timeout_ms: int = 5000 cleanup_timeout_seconds: float = 7.0 + heartbeat_interval_seconds: float = 3.0 + heartbeat_stale_threshold_seconds: float = 15.0 def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: @@ -25,6 +27,16 @@ def __post_init__(self) -> None: raise ValueError("db_statement_timeout_ms must be positive") if self.cleanup_timeout_seconds <= 0: raise ValueError("cleanup_timeout_seconds must be positive") + if self.heartbeat_interval_seconds <= 0: + raise ValueError("heartbeat_interval_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= 0: + raise ValueError("heartbeat_stale_threshold_seconds must be positive") + if self.heartbeat_stale_threshold_seconds <= self.heartbeat_interval_seconds: + raise ValueError( + f"heartbeat_stale_threshold_seconds ({self.heartbeat_stale_threshold_seconds}) " + f"must be strictly greater than heartbeat_interval_seconds " + f"({self.heartbeat_interval_seconds}) to avoid false stale-run detection." + ) db_timeout_s = self.db_statement_timeout_ms / 1000.0 if self.cleanup_timeout_seconds <= db_timeout_s: raise ValueError( diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index a1290ba1..7e5f4f76 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -160,10 +160,12 @@ async def test_get_messages(mock_session): result_mock.scalars.return_value.all.return_value = [msg1, msg2] mock_session.execute.return_value = result_mock + msg1.attempt_number = 1 + msg2.attempt_number = 1 messages = await get_messages("resp_abc123", after_sequence=None) assert len(messages) == 2 - assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}) - assert messages[1] == (1, None, None) + assert messages[0] == (0, '{"text": "hello"}', {"type": "response.output_item.done"}, 1) + assert messages[1] == (1, None, None, 1) @pytest.mark.asyncio @@ -175,6 +177,10 @@ async def test_get_response(mock_session): row.created_at = datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc) row.trace_id = "trace_xyz" + row.owner_pod_id = None + row.heartbeat_at = None + row.attempt_number = 1 + row.original_request = None result_mock = MagicMock() result_mock.scalar_one_or_none.return_value = row mock_session.execute.return_value = result_mock @@ -185,6 +191,10 @@ async def test_get_response(mock_session): "completed", datetime(2009, 2, 13, 23, 31, 30, tzinfo=timezone.utc), "trace_xyz", + None, # owner_pod_id + None, # heartbeat_at + 1, # attempt_number + None, # original_request ) @@ -346,3 +356,110 @@ async def fake_factory(): monkeypatch.setattr(db_mod, "_session_factory", fake_factory) async with session_scope() as session: assert session is mock_session + + +# --------------------------------------------------------------------------- +# Durability metadata: owner_pod_id, heartbeat, claim, attempt_number +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_create_response_with_owner_and_original_request(mock_session): + """New background callers stamp pod id + serialized request on creation — + without these, a resumed pod can't re-invoke the handler.""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response( + "resp_abc", + "in_progress", + owner_pod_id="pod-1", + original_request={"input": [{"role": "user", "content": "hi"}]}, + ) + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id == "pod-1" + assert added.heartbeat_at is not None + # original_request is JSON-encoded for Text storage. + assert '"role": "user"' in added.original_request + + +@pytest.mark.asyncio +async def test_create_response_without_durability_metadata(mock_session): + """Legacy/no-durability callers should still work and write no + owner/heartbeat (so the stale sweep can't accidentally claim them).""" + from databricks_ai_bridge.long_running.repository import create_response + + await create_response("resp_x", "in_progress") + added = mock_session.add.call_args[0][0] + assert added.owner_pod_id is None + assert added.heartbeat_at is None + assert added.original_request is None + + +@pytest.mark.asyncio +async def test_heartbeat_response_updates_timestamp(mock_session): + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 1 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is True + mock_session.commit.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_heartbeat_response_fails_when_not_owner(mock_session): + """If the CAS misses (owner changed / row deleted), heartbeat reports + failure so the caller can stop looping.""" + from databricks_ai_bridge.long_running.repository import heartbeat_response + + result_mock = MagicMock() + result_mock.rowcount = 0 + mock_session.execute.return_value = result_mock + + ok = await heartbeat_response("resp_abc", "pod-1") + assert ok is False + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_attempt_number(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + row = MagicMock() + row.__iter__ = lambda self: iter([2]) + row.__getitem__ = lambda self, i: 2 + result_mock = MagicMock() + result_mock.first.return_value = row + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt == 2 + + +@pytest.mark.asyncio +async def test_claim_stale_response_returns_none_when_not_eligible(mock_session): + from databricks_ai_bridge.long_running.repository import claim_stale_response + + result_mock = MagicMock() + result_mock.first.return_value = None + mock_session.execute.return_value = result_mock + + attempt = await claim_stale_response( + "resp_abc", new_owner_pod_id="pod-2", stale_threshold_seconds=15.0 + ) + assert attempt is None + + +@pytest.mark.asyncio +async def test_append_message_with_attempt_number(mock_session): + """Resumed events must be tagged with the resume attempt so retrieve can + filter or the client can render the response.resumed boundary cleanly.""" + from databricks_ai_bridge.long_running.repository import append_message + + await append_message("resp_abc", 5, stream_event={"x": 1}, attempt_number=3) + added = mock_session.add.call_args[0][0] + assert added.attempt_number == 3 + assert added.sequence_number == 5 diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 27b6eaf5..b2e76abf 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,9 +14,11 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") +from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( LongRunningAgentServer, _deferred_mark_failed, + _inject_conversation_id, _sse_event, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings @@ -34,6 +36,40 @@ def _make_server(**kwargs): return LongRunningAgentServer("ResponsesAgent", **kwargs) +def _resp_info( + response_id: str = "resp_123", + status: str = "in_progress", + created_at=None, + trace_id: str | None = None, + owner_pod_id: str | None = None, + heartbeat_at=None, + attempt_number: int = 1, + original_request: dict | None = None, +) -> ResponseInfo: + """Build a ResponseInfo with sensible defaults for tests. + + Mirrors the server's repository model so test setups stay terse even as + durability columns grow over time. + """ + if created_at is None: + created_at = datetime.now(timezone.utc) + return ResponseInfo( + response_id=response_id, + status=status, + created_at=created_at, + trace_id=trace_id, + owner_pod_id=owner_pod_id, + heartbeat_at=heartbeat_at, + attempt_number=attempt_number, + original_request=original_request, + ) + + +def _msg(seq: int, item=None, evt=None, attempt: int = 1): + """Build a (seq, item, stream_event, attempt_number) tuple for get_messages mocks.""" + return (seq, item, evt, attempt) + + def _mock_span(): """Return a mock MLflow span with the attributes the server uses.""" span = MagicMock() @@ -189,7 +225,7 @@ def test_starting_after_zero_without_stream_is_allowed(self): patch( f"{MODULE}.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( f"{MODULE}.get_messages", @@ -209,8 +245,13 @@ async def test_marks_response_failed(self): patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, - return_value=[(0, None, {"type": "response.created"})], + return_value=[_msg(0, None, {"type": "response.created"})], ) as mock_get, + patch( + "databricks_ai_bridge.long_running.server.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch( "databricks_ai_bridge.long_running.server.append_message", new_callable=AsyncMock, @@ -271,13 +312,13 @@ async def test_completed_returns_output(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), "trace_abc"), + return_value=_resp_info("resp_123", "completed", trace_id="trace_abc"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - ( + _msg( 0, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -305,7 +346,7 @@ async def test_stale_run_detection(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_stale", "in_progress", old_time, None), + return_value=_resp_info("resp_stale", "in_progress", created_at=old_time), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -336,7 +377,7 @@ async def test_in_progress_returns_status(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "in_progress", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "in_progress"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", @@ -360,14 +401,14 @@ async def test_completed_stream(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "completed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "completed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created", "id": "resp_123"}), - ( + _msg(0, None, {"type": "response.created", "id": "resp_123"}), + _msg( 1, '{"text": "hi"}', {"type": "response.output_item.done", "item": {"text": "hi"}}, @@ -394,13 +435,13 @@ async def test_failed_stream_stops(self): patch( "databricks_ai_bridge.long_running.server.get_response", new_callable=AsyncMock, - return_value=("resp_123", "failed", datetime.now(timezone.utc), None), + return_value=_resp_info("resp_123", "failed"), ), patch( "databricks_ai_bridge.long_running.server.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "error", "error": {"message": "boom"}}), + _msg(0, None, {"type": "error", "error": {"message": "boom"}}), ], ), ): @@ -668,10 +709,15 @@ async def test_exception_writes_error_event_inline(self): f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[ - (0, None, {"type": "response.created"}), - (1, None, {"type": "response.output_text.delta"}), + _msg(0, None, {"type": "response.created"}), + _msg(1, None, {"type": "response.output_text.delta"}), ], ), + patch( + f"{MODULE}.get_response", + new_callable=AsyncMock, + return_value=_resp_info(), + ), patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, patch(f"{MODULE}.update_response_status", new_callable=AsyncMock) as mock_update, ): @@ -875,3 +921,351 @@ async def test_lifespan_not_set_when_db_not_configured(self): routes = [r.path for r in server.app.routes if hasattr(r, "path")] assert "/responses/{response_id}" in routes mock_init.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# Durable resume: claim/heartbeat/attempt_number/sentinel +# --------------------------------------------------------------------------- + + +class TestInjectConversationId: + """Anchoring an otherwise-anonymous request to a response_id guarantees a + resumed run on a new pod resolves to the same agent-SDK thread/session.""" + + def test_injects_when_nothing_set(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_respects_existing_conversation_id(self): + r = {"input": [], "context": {"conversation_id": "user-set"}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "user-set" + + def test_respects_thread_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"thread_id": "t-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + # When the client already pinned a thread, we don't overwrite — the + # template's _get_or_create_thread_id picks up custom_inputs first. + assert "conversation_id" not in (out["context"] or {}) + + def test_respects_session_id_from_custom_inputs(self): + r = {"input": [], "custom_inputs": {"session_id": "s-1"}, "context": {}} + out = _inject_conversation_id(r, "resp_abc") + assert "conversation_id" not in (out["context"] or {}) + + def test_handles_missing_context_key(self): + r = {"input": [], "custom_inputs": {}} + out = _inject_conversation_id(r, "resp_abc") + assert out["context"]["conversation_id"] == "resp_abc" + + def test_does_not_mutate_input(self): + r = {"input": [], "custom_inputs": {}, "context": {}} + _inject_conversation_id(r, "resp_abc") + assert r["context"] == {} # original untouched + + +class TestHandleBackgroundRequestPersistsDurabilityState: + """Background request entry point should now stamp the response row with + the caller's pod, the original request body, and a conversation anchor.""" + + @pytest.mark.asyncio + async def test_persists_owner_and_original_request(self): + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + _mock_validator(server) + + captured: dict = {} + + async def fake_create_response(response_id, status, *, owner_pod_id=None, + original_request=None): + captured["response_id"] = response_id + captured["status"] = status + captured["owner_pod_id"] = owner_pod_id + captured["original_request"] = original_request + + with ( + patch(f"{MODULE}.create_response", side_effect=fake_create_response), + patch("asyncio.create_task") as mock_create_task, + ): + result = await server._handle_background_request( + {"input": [{"role": "user", "content": "hi"}]}, + is_streaming=False, + return_trace_id=False, + ) + + assert captured["status"] == "in_progress" + assert captured["owner_pod_id"] # non-empty + # original_request should include input + injected conversation_id. + orig = captured["original_request"] + assert orig["input"] == [{"role": "user", "content": "hi"}] + assert orig["context"]["conversation_id"] == captured["response_id"] + # Return shape: immediate response_obj, not a stream. + assert result["id"] == captured["response_id"] + assert result["status"] == "in_progress" + mock_create_task.assert_called_once() + + +class TestTryClaimAndResume: + @pytest.mark.asyncio + async def test_no_op_when_completed(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + resp = _resp_info(status="completed") + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_grace_period_for_fresh_run(self): + """Just-started runs get a grace window before they're claim-eligible.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_stale_threshold_seconds=15.0 + ) + # created 2s ago, no heartbeat yet → should NOT be claimed. + from datetime import timedelta + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=2), + heartbeat_at=None, + original_request={"input": []}, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_op_without_original_request(self): + """Legacy rows created before durability metadata can't be resumed.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=None, + original_request=None, + ) + with patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock) as mock_claim: + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_not_awaited() + + @pytest.mark.asyncio + async def test_claim_fails_returns_none(self): + """Another pod won the race — we quietly step aside.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=300), + original_request={"input": [{"role": "user"}]}, + ) + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, + return_value=None) as mock_claim, + patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, + ): + result = await server._try_claim_and_resume("resp_x", resp) + assert result is None + mock_claim.assert_awaited_once() + mock_append.assert_not_awaited() + + @pytest.mark.asyncio + async def test_successful_claim_spawns_resume_and_emits_sentinel(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"user_id": "u"}, + "context": {"conversation_id": "resp_x"}, + }, + ) + captured: dict = {} + + async def fake_append(response_id, seq, *, item=None, stream_event=None, + attempt_number=1): + captured["seq"] = seq + captured["event"] = stream_event + captured["attempt_tag"] = attempt_number + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, + return_value=2), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, + return_value=[_msg(0, None, {}), _msg(1, None, {})]), + patch(f"{MODULE}.append_message", side_effect=fake_append), + patch("asyncio.create_task") as mock_create_task, + ): + attempt = await server._try_claim_and_resume("resp_x", resp) + + assert attempt == 2 + # Sentinel is written at next_seq (existing seqs were 0 and 1). + assert captured["seq"] == 2 + assert captured["event"]["type"] == "response.resumed" + assert captured["event"]["attempt"] == 2 + assert captured["attempt_tag"] == 2 + # A resume task is spawned; it was not awaited synchronously. + mock_create_task.assert_called_once() + + @pytest.mark.asyncio + async def test_resume_passes_empty_input_to_handler(self): + """The single line that makes the design framework-agnostic: input=[] so + LangGraph's checkpointer and databricks-openai Session resume cleanly + (verified by prototypes — see project memory).""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {"thread_id": "t1"}, + "context": {}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): pass + def add_done_callback(self, cb): pass + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + # We can't execute the coro (mock_run is AsyncMock so awaiting it is fine), + # but we can inspect the scheduled coroutine's frame locals or just await + # the captured coro with proper args. Simpler: check that the resume + # coroutine was built with input=[]. Drive the coroutine so mock_run + # receives the call args. + import asyncio as _a + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + assert resume_request["input"] == [] + # Other request metadata is preserved so the handler can find + # thread_id / conversation_id / user_id. + assert resume_request["custom_inputs"]["thread_id"] == "t1" + assert kwargs.get("attempt_number") == 2 + + +class TestRetrieveTriggersLazyClaim: + @pytest.mark.asyncio + async def test_retrieve_calls_try_claim(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + + resp = _resp_info("resp_x", "in_progress") + with ( + patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch.object(server, "_try_claim_and_resume", new_callable=AsyncMock, + return_value=None) as mock_claim, + ): + await server._handle_retrieve_request("resp_x", stream=False, starting_after=0) + + mock_claim.assert_awaited_once() + + +class TestHeartbeatContextManager: + @pytest.mark.asyncio + async def test_writes_heartbeat_periodically(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.2) # enough time for 2+ heartbeats + + # Heartbeat interval is 0.05s so we should see at least 2 writes. + assert mock_hb.await_count >= 2 + for call in mock_hb.await_args_list: + assert call.args[0] == "resp_x" + + @pytest.mark.asyncio + async def test_stops_cleanly_on_exit(self): + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock) as mock_hb: + async with server._heartbeat("resp_x"): + pass # immediate exit + + # Give the heartbeat loop a chance to observe the stop signal. + await asyncio.sleep(0.1) + writes_after_exit = mock_hb.await_count + + await asyncio.sleep(0.15) + # No new writes after the scope closed. + assert mock_hb.await_count == writes_after_exit + + @pytest.mark.asyncio + async def test_db_error_does_not_interrupt_body(self): + """Heartbeat failures are logged, not raised — the stale check catches + real death, so a transient write miss must not kill a live run.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer( + "ResponsesAgent", heartbeat_interval_seconds=0.05, + heartbeat_stale_threshold_seconds=1.0, + ) + + body_ran = False + with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock, + side_effect=RuntimeError("db down")): + async with server._heartbeat("resp_x"): + await asyncio.sleep(0.1) + body_ran = True + assert body_ran + + +class TestSettingsHeartbeatValidation: + def test_stale_must_exceed_interval(self): + with pytest.raises(ValueError, match="heartbeat_stale_threshold_seconds"): + LongRunningSettings( + heartbeat_interval_seconds=5.0, + heartbeat_stale_threshold_seconds=5.0, + ) + + def test_interval_must_be_positive(self): + with pytest.raises(ValueError, match="heartbeat_interval_seconds must be positive"): + LongRunningSettings(heartbeat_interval_seconds=0) + + def test_defaults_match_chat_ux(self): + # 3s interval + 15s stale gives ~5 heartbeats before a pod is considered + # dead — snug enough to recover conversations within a user's + # "reconnecting..." patience window. + s = LongRunningSettings() + assert s.heartbeat_interval_seconds == 3.0 + assert s.heartbeat_stale_threshold_seconds == 15.0 From 466a859cd57c1c37d72783e525c8c5f97acef6a3 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 22:44:26 +0000 Subject: [PATCH 02/39] Update test_creates_schema_and_tables for new ADD COLUMN migration MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit init_db now runs idempotent ADD COLUMN IF NOT EXISTS for the durability columns after the initial create_all. The existing assert_awaited_once is obsolete — assert the full SQL corpus instead. --- tests/databricks_ai_bridge/test_long_running_db.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/databricks_ai_bridge/test_long_running_db.py b/tests/databricks_ai_bridge/test_long_running_db.py index 7e5f4f76..d425da44 100644 --- a/tests/databricks_ai_bridge/test_long_running_db.py +++ b/tests/databricks_ai_bridge/test_long_running_db.py @@ -293,9 +293,16 @@ async def test_creates_schema_and_tables(self, reset_db_globals): with patch(f"{DB_MODULE}.AsyncLakebaseSQLAlchemy", mock_cls), patch(f"{DB_MODULE}.event"): await init_db(autoscaling_endpoint="ep") - mock_conn.execute.assert_awaited_once() - sql_arg = str(mock_conn.execute.call_args[0][0]) - assert "CREATE SCHEMA IF NOT EXISTS" in sql_arg + # init_db runs: CREATE SCHEMA + run_sync(create_all) + a series of + # ADD COLUMN IF NOT EXISTS / CREATE INDEX IF NOT EXISTS to migrate + # the durability columns onto pre-existing tables. + all_sql = " | ".join(str(call.args[0]) for call in mock_conn.execute.call_args_list) + assert "CREATE SCHEMA IF NOT EXISTS" in all_sql + assert "ADD COLUMN IF NOT EXISTS owner_pod_id" in all_sql + assert "ADD COLUMN IF NOT EXISTS heartbeat_at" in all_sql + assert "ADD COLUMN IF NOT EXISTS attempt_number" in all_sql + assert "ADD COLUMN IF NOT EXISTS original_request" in all_sql + assert "idx_responses_stale" in all_sql mock_conn.run_sync.assert_awaited_once() @pytest.mark.asyncio From e7cfedd338cbba28e95a51503fe912ef1f168de6 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 23:03:21 +0000 Subject: [PATCH 03/39] Fix AttributeError when injecting conversation_id into a pydantic request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _handle_invocations_request returns a ResponsesAgentRequest pydantic model from the validator, not a dict. _inject_conversation_id was calling .get() on it — runtime AttributeError on the first POST with background=true. Fix: dump to dict, inject, and round-trip back through the validator so the downstream handler still sees the declared pydantic type. Same round-trip applied to the resume path in _try_claim_and_resume. --- .../long_running/server.py | 24 +++++++++++++------ .../test_long_running_server.py | 8 +++++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 5a9faf77..d358acf6 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -123,9 +123,12 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -def _inject_conversation_id(request_data: dict[str, Any], response_id: str) -> dict[str, Any]: +def _inject_conversation_id(request_dict: dict[str, Any], response_id: str) -> dict[str, Any]: """Anchor the request to ``response_id`` as its conversation. + Operates on a plain dict — the caller is responsible for converting to/from + pydantic via ``model_dump()`` and the server's validator. + Templates that back this server use ``context.conversation_id`` (and ``custom_inputs.thread_id`` / ``custom_inputs.session_id``) as priority-2 fallbacks to derive their stateful thread/session key. If neither is @@ -136,7 +139,7 @@ def _inject_conversation_id(request_data: dict[str, Any], response_id: str) -> d Client-supplied values take precedence and are left untouched. """ - out = copy.deepcopy(request_data) if request_data else {} + out = copy.deepcopy(request_dict) if request_dict else {} custom_inputs = out.get("custom_inputs") or {} if custom_inputs.get("thread_id") or custom_inputs.get("session_id"): return out @@ -330,13 +333,19 @@ async def _handle_background_request( """Start a new conversation and return response_id immediately.""" response_id = f"resp_{uuid.uuid4().hex[:24]}" # Anchor the conversation to response_id so any future replay from a - # different pod resolves to the same agent-SDK thread/session. - durable_request = _inject_conversation_id(request_data, response_id) + # different pod resolves to the same agent-SDK thread/session. We + # round-trip through dict + validator so the handler still receives a + # pydantic ResponsesAgentRequest (its declared arg type). + request_dict = ( + request_data.model_dump() if hasattr(request_data, "model_dump") else dict(request_data) + ) + durable_dict = _inject_conversation_id(request_dict, response_id) + durable_request = self.validator.validate_and_convert_request(durable_dict) await create_response( response_id, "in_progress", owner_pod_id=_POD_ID, - original_request=durable_request, + original_request=durable_dict, ) logger.debug( @@ -676,8 +685,9 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: # (custom_inputs carrying thread_id, context.conversation_id), but null # out input so the handler passes {"messages": []} / [] to its agent. # This is the single line that makes the design framework-agnostic. - resume_request = dict(resp.original_request) - resume_request["input"] = [] + resume_dict = dict(resp.original_request) + resume_dict["input"] = [] + resume_request = self.validator.validate_and_convert_request(resume_dict) # Emit a marker so clients can reset any in-flight rendering from the # prior attempt before seeing new events. diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index b2e76abf..0644865c 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1168,10 +1168,14 @@ def add_done_callback(self, cb): pass mock_run.assert_awaited_once() args, kwargs = mock_run.call_args resume_request = args[1] if len(args) > 1 else kwargs["request_data"] - assert resume_request["input"] == [] + # resume_request is a ResponsesAgentRequest pydantic object after + # round-tripping through the validator so the handler still gets its + # declared arg type. + dumped = resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + assert dumped["input"] == [] # Other request metadata is preserved so the handler can find # thread_id / conversation_id / user_id. - assert resume_request["custom_inputs"]["thread_id"] == "t1" + assert dumped["custom_inputs"]["thread_id"] == "t1" assert kwargs.get("attempt_number") == 2 From 19df0556ab5404cff8b8329989799ddb1e1bf90c Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 23:36:53 +0000 Subject: [PATCH 04/39] Tolerate InsufficientPrivilege in ADD COLUMN migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ADD COLUMN IF NOT EXISTS is a no-op when the column already exists, but PG still enforces table-ownership before running the ALTER. Multiple pods running different SPs against the same agent_server schema — or a developer connecting as themselves to an already-migrated dogfood instance — hit 'must be owner of table responses' during init_db and fail to start. Split migrations into one-statement-per-transaction and swallow InsufficientPrivilege with an info log. If the column is missing and we really can't grant it, the next query that uses the column will fail loudly — better than refusing to boot the app on a shared DB. --- src/databricks_ai_bridge/long_running/db.py | 54 +++++++++++++-------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 187a2873..44edf16b 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -78,25 +78,41 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): async with _engine.begin() as conn: await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}")) await conn.run_sync(Base.metadata.create_all) - # Idempotent migration for tables created by earlier versions: add any - # columns introduced for durable-resume support. Safe to run on a fresh - # schema (all columns exist) and on an upgraded one (only missing ones added). - for stmt in ( - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS original_request TEXT", - f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " - "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", - f"CREATE INDEX IF NOT EXISTS idx_responses_stale " - f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " - "WHERE status = 'in_progress'", - ): - await conn.execute(text(stmt)) + + # Idempotent migration for tables created by earlier versions: add any + # columns introduced for durable-resume support. Each statement runs in + # its own transaction so an InsufficientPrivilege on one ALTER (another + # pod's SP owns the table but the schema is already migrated) doesn't + # poison the rest. A single mega-transaction would abort entirely on the + # first owner-check failure even with IF NOT EXISTS. + migration_stmts = ( + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " + "ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " + "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", + f"CREATE INDEX IF NOT EXISTS idx_responses_stale " + f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " + "WHERE status = 'in_progress'", + ) + for stmt in migration_stmts: + try: + async with _engine.begin() as conn: + await conn.execute(text(stmt)) + except Exception as exc: + msg = str(exc).lower() + if "insufficientprivilege" in msg or "must be owner" in msg: + logger.info( + "[DB] Skipping migration (not owner, presumed already applied): %s", + stmt.split("\n")[0], + ) + continue + raise _initialized = True logger.info("[DB] Engine and schema ready") From 97a5dcba827c19db9f9b802511c6c323323fb6eb Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 23:39:33 +0000 Subject: [PATCH 05/39] Include attempt_number in retrieve response for observability Clients and tests need a way to see that a response was resumed across pods (attempt_number > 1) without grepping server logs. Adds the field to every shape of the non-stream retrieve return body. --- .../long_running/server.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index d358acf6..3989a7e2 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -804,7 +804,11 @@ async def _handle_retrieve_request( messages = await get_messages(response_id, after_sequence=None) if not messages and status == "in_progress": - return {"id": response_id, "status": "in_progress"} + return { + "id": response_id, + "status": "in_progress", + "attempt_number": resp.attempt_number, + } if status == "completed" and messages: # Only consider items from the final (successful) attempt so that # abandoned in-progress items from crashed attempts don't leak @@ -820,6 +824,7 @@ async def _handle_retrieve_request( "id": response_id, "status": "completed", "output": [o for o in output if o is not None], + "attempt_number": resp.attempt_number, } if trace_id: result["metadata"] = {"trace_id": trace_id} @@ -827,8 +832,17 @@ async def _handle_retrieve_request( if status == "failed" and messages: for _, _, evt, _attempt in messages: if evt and evt.get("type") == "error": - return {"id": response_id, "status": "failed", "error": evt.get("error")} - return {"id": response_id, "status": status} + return { + "id": response_id, + "status": "failed", + "error": evt.get("error"), + "attempt_number": resp.attempt_number, + } + return { + "id": response_id, + "status": status, + "attempt_number": resp.attempt_number, + } async def _stream_retrieve( self, From d3adee7a6db510a189911a58e3e323193a87602b Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 16 Apr 2026 23:43:44 +0000 Subject: [PATCH 06/39] Add opt-in debug-kill endpoint for testing crash-resume on deployed apps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Integration tests for the durable-resume path want to simulate a pod crash without actually restarting the pod (costly, disruptive, affects other traffic). Add POST /_debug/kill_task/{response_id} that cancels the in-flight asyncio task owning that response. CancelledError propagates past _task_scope's Exception handlers (it's BaseException in py3.8+), so the DB row stays in_progress with a going-stale heartbeat — exactly the shape a real crash leaves. A client GET then triggers lazy reclaim and resume. Gated behind env var LONG_RUNNING_ENABLE_DEBUG_KILL=1 so the endpoint is never exposed in production. Tracks tasks in a per-server dict cleared via done-callback; purely an observation affordance, durability still hinges on DB state. --- .../long_running/server.py | 46 ++++++++++++- .../test_long_running_server.py | 69 ++++++++++++++++++- 2 files changed, 111 insertions(+), 4 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 3989a7e2..f760c608 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -220,6 +220,11 @@ def __init__( self._db_autoscaling_endpoint = db_autoscaling_endpoint self._db_project = db_project self._db_branch = db_branch + # Track in-flight background tasks per response_id so the debug-kill + # endpoint can simulate a pod crash without tearing the whole pod + # down. Not load-bearing for correctness — durability still relies on + # DB state, this is just a test affordance. + self._running_tasks: dict[str, asyncio.Task] = {} super().__init__(agent_type, enable_chat_proxy=enable_chat_proxy) def _setup_routes(self) -> None: @@ -237,6 +242,31 @@ async def cancel_endpoint(response_id: str): detail="Cancellation is not yet implemented.", ) + # Debug endpoint for testing durable resume: cancels the in-flight + # asyncio task that owns the given response_id WITHOUT running the + # _task_scope cleanup, so the DB row stays in_progress with a + # going-stale heartbeat — exactly the shape a real pod crash leaves. + # Opt-in via env var so it's never exposed in production. + if os.getenv("LONG_RUNNING_ENABLE_DEBUG_KILL") == "1": + + @self.app.post("/_debug/kill_task/{response_id}") + async def _debug_kill_task(response_id: str): + task = self._running_tasks.get(response_id) + if task is None: + raise HTTPException( + status_code=404, + detail=( + "No in-flight task for that response_id on this pod " + "(may already have finished or be running on another pod)." + ), + ) + task.cancel() + return { + "response_id": response_id, + "pod_id": _POD_ID, + "status": "task_cancelled", + } + db_configured = is_db_configured() @self.app.get("/responses/{response_id}") @@ -365,25 +395,34 @@ async def _handle_background_request( } # Fire-and-forget is intentional — task status is persisted to the database. + # We still track the task handle so the debug-kill endpoint can simulate + # a crash (and so we know whether a claim target lives on this pod). if is_streaming: - asyncio.create_task( + task = asyncio.create_task( self._run_background_stream( response_id, durable_request, return_trace_id, attempt_number=1 ) ) + self._track_task(response_id, task) return await self._handle_retrieve_request( response_id, stream=True, starting_after=0, ) else: - asyncio.create_task( + task = asyncio.create_task( self._run_background_invoke( response_id, durable_request, return_trace_id, attempt_number=1 ) ) + self._track_task(response_id, task) return response_obj + def _track_task(self, response_id: str, task: asyncio.Task) -> None: + """Record a background task so the debug-kill endpoint can find it.""" + self._running_tasks[response_id] = task + task.add_done_callback(lambda _t: self._running_tasks.pop(response_id, None)) + @asynccontextmanager async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: """Keep the response row's heartbeat_at fresh while the body runs. @@ -709,13 +748,14 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: response_id, new_attempt, _POD_ID, ) - asyncio.create_task( + task = asyncio.create_task( self._run_background_stream( response_id, resume_request, return_trace_id=False, attempt_number=new_attempt, ), name=f"resume-{response_id}-{new_attempt}", ) + self._track_task(response_id, task) return new_attempt async def _handle_retrieve_request( diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 0644865c..62b55353 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -388,7 +388,11 @@ async def test_in_progress_returns_status(self): result = await server._handle_retrieve_request( "resp_123", stream=False, starting_after=0 ) - assert result == {"id": "resp_123", "status": "in_progress"} + assert result == { + "id": "resp_123", + "status": "in_progress", + "attempt_number": 1, + } class TestStreamRetrieve: @@ -1273,3 +1277,66 @@ def test_defaults_match_chat_ux(self): s = LongRunningSettings() assert s.heartbeat_interval_seconds == 3.0 assert s.heartbeat_stale_threshold_seconds == 15.0 + + +class TestDebugKillTask: + """The opt-in debug-kill endpoint lets integration tests simulate a crash + against a deployed pod without restarting the whole app. Off by default + because exposing task cancellation bypasses the normal cleanup path.""" + + def test_endpoint_absent_by_default(self): + from starlette.testclient import TestClient + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + resp = client.post("/_debug/kill_task/resp_x") + assert resp.status_code == 404 # route not registered + + def test_endpoint_registered_when_env_set(self, monkeypatch): + from starlette.testclient import TestClient + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + client = TestClient(server.app, raise_server_exceptions=False) + # No in-flight task for this response_id on this pod → 404, not 405. + resp = client.post("/_debug/kill_task/resp_missing") + assert resp.status_code == 404 + assert "No in-flight task" in resp.json()["detail"] + + @pytest.mark.asyncio + async def test_cancels_tracked_task(self, monkeypatch): + """Direct-call variant: skip the TestClient (which is sync and blocks + the loop) and call the handler logic through _running_tasks directly. + Covers the important behavior: cancelling a tracked task propagates + CancelledError and the tracking dict is cleared by the done-callback. + """ + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") + with patch(f"{MODULE}.is_db_configured", return_value=True): + server = LongRunningAgentServer("ResponsesAgent") + + cancel_event = asyncio.Event() + + async def long_running(): + try: + await asyncio.sleep(60) + except asyncio.CancelledError: + cancel_event.set() + raise + + task = asyncio.create_task(long_running()) + server._track_task("resp_tracked", task) + + # Yield once so the new task can start waiting on sleep(60). + await asyncio.sleep(0) + assert "resp_tracked" in server._running_tasks + + task.cancel() + # Expect CancelledError from awaiting the task itself, and the cancel + # event set inside the except handler before the re-raise. + with pytest.raises(asyncio.CancelledError): + await task + assert cancel_event.is_set() + # done-callback (scheduled on loop) clears the registration after the + # task completes — give it one more tick. + await asyncio.sleep(0) + assert "resp_tracked" not in server._running_tasks From d7c33b7fafd5e77b176e490bf7d1be04b8c162e5 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 20:20:39 +0000 Subject: [PATCH 07/39] Apply ruff format + fix ty diagnostic on request_data dump - ruff format: line-length normalization in models.py / db.py / server.py / test_long_running_server.py (no behavior change). - ruff F401: drop unused 'import asyncio as _a' inside TestTryClaimAndResume. - ty call-non-callable: replace hasattr + attribute access with getattr + callable check. request_data is typed as dict but is a pydantic model at runtime (returned by validate_and_convert_request); narrowing via getattr(...) keeps ty happy and preserves the fallback for tests that pass dicts directly. --- src/databricks_ai_bridge/long_running/db.py | 6 +- .../long_running/models.py | 4 +- .../long_running/server.py | 43 +++++++++---- .../test_long_running_server.py | 64 +++++++++++++------ 4 files changed, 77 insertions(+), 40 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 44edf16b..44ef19fd 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -86,14 +86,12 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): # poison the rest. A single mega-transaction would abort entirely on the # first owner-check failure even with IF NOT EXISTS. migration_stmts = ( - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS owner_pod_id TEXT", f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " "ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ", f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", - f"ALTER TABLE {AGENT_DB_SCHEMA}.responses " - "ADD COLUMN IF NOT EXISTS original_request TEXT", + f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT", f"ALTER TABLE {AGENT_DB_SCHEMA}.messages " "ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1", f"CREATE INDEX IF NOT EXISTS idx_responses_stale " diff --git a/src/databricks_ai_bridge/long_running/models.py b/src/databricks_ai_bridge/long_running/models.py index 6dade733..7014a7db 100644 --- a/src/databricks_ai_bridge/long_running/models.py +++ b/src/databricks_ai_bridge/long_running/models.py @@ -31,9 +31,7 @@ class Response(Base): ) trace_id: Mapped[str | None] = mapped_column(Text, nullable=True) owner_pod_id: Mapped[str | None] = mapped_column(Text, nullable=True) - heartbeat_at: Mapped[datetime | None] = mapped_column( - DateTime(timezone=True), nullable=True - ) + heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) attempt_number: Mapped[int] = mapped_column( Integer, nullable=False, server_default="1", default=1 ) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index f760c608..f97e0622 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -85,7 +85,10 @@ async def _deferred_mark_failed( }, } await append_message( - response_id, next_seq, item=None, stream_event=error_event, + response_id, + next_seq, + item=None, + stream_event=error_event, attempt_number=attempt, ) await update_response_status(response_id, "failed") @@ -365,10 +368,12 @@ async def _handle_background_request( # Anchor the conversation to response_id so any future replay from a # different pod resolves to the same agent-SDK thread/session. We # round-trip through dict + validator so the handler still receives a - # pydantic ResponsesAgentRequest (its declared arg type). - request_dict = ( - request_data.model_dump() if hasattr(request_data, "model_dump") else dict(request_data) - ) + # pydantic ResponsesAgentRequest (its declared arg type). The + # declared param type is ``dict`` but the runtime object is a pydantic + # model from ``validate_and_convert_request``; fall back to ``dict()`` + # when tests pass a plain dict directly. + dump = getattr(request_data, "model_dump", None) + request_dict = dump() if callable(dump) else dict(request_data) durable_dict = _inject_conversation_id(request_dict, response_id) durable_request = self.validator.validate_and_convert_request(durable_dict) await create_response( @@ -442,7 +447,8 @@ async def _beat(): await heartbeat_response(response_id, _POD_ID) except Exception: logger.warning( - "Heartbeat write failed for %s; will retry", response_id, + "Heartbeat write failed for %s; will retry", + response_id, exc_info=True, ) try: @@ -530,7 +536,10 @@ async def _run_background_stream( state: dict[str, Any] = {"seq": 0} async with self._task_scope(response_id, state), self._heartbeat(response_id): await self._do_background_stream( - response_id, request_data, return_trace_id, state, + response_id, + request_data, + return_trace_id, + state, attempt_number=attempt_number, ) @@ -577,7 +586,9 @@ async def _do_background_stream( logger.debug( "SSE event (background)", extra={ - "response_id": response_id, "seq": seq, "type": evt_type, + "response_id": response_id, + "seq": seq, + "type": evt_type, "attempt": attempt_number, }, ) @@ -620,7 +631,10 @@ async def _run_background_invoke( state: dict[str, Any] = {"seq": 0} async with self._task_scope(response_id, state), self._heartbeat(response_id): await self._do_background_invoke( - response_id, request_data, return_trace_id, state, + response_id, + request_data, + return_trace_id, + state, attempt_number=attempt_number, ) @@ -706,7 +720,8 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: if resp.original_request is None: # Nothing to replay from — the run predates durability metadata. logger.warning( - "Cannot resume %s: no original_request persisted (old row?)", response_id, + "Cannot resume %s: no original_request persisted (old row?)", + response_id, ) return None @@ -745,12 +760,16 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: logger.info( "Claimed stale run %s as attempt %s (pod=%s)", - response_id, new_attempt, _POD_ID, + response_id, + new_attempt, + _POD_ID, ) task = asyncio.create_task( self._run_background_stream( - response_id, resume_request, return_trace_id=False, + response_id, + resume_request, + return_trace_id=False, attempt_number=new_attempt, ), name=f"resume-{response_id}-{new_attempt}", diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 62b55353..831703ef 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -981,8 +981,9 @@ async def test_persists_owner_and_original_request(self): captured: dict = {} - async def fake_create_response(response_id, status, *, owner_pod_id=None, - original_request=None): + async def fake_create_response( + response_id, status, *, owner_pod_id=None, original_request=None + ): captured["response_id"] = response_id captured["status"] = status captured["owner_pod_id"] = owner_pod_id @@ -1030,6 +1031,7 @@ async def test_grace_period_for_fresh_run(self): ) # created 2s ago, no heartbeat yet → should NOT be claimed. from datetime import timedelta + resp = _resp_info( status="in_progress", created_at=datetime.now(timezone.utc) - timedelta(seconds=2), @@ -1047,6 +1049,7 @@ async def test_no_op_without_original_request(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer("ResponsesAgent") from datetime import timedelta + resp = _resp_info( status="in_progress", created_at=datetime.now(timezone.utc) - timedelta(seconds=300), @@ -1064,6 +1067,7 @@ async def test_claim_fails_returns_none(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer("ResponsesAgent") from datetime import timedelta + resp = _resp_info( status="in_progress", created_at=datetime.now(timezone.utc) - timedelta(seconds=300), @@ -1071,8 +1075,9 @@ async def test_claim_fails_returns_none(self): original_request={"input": [{"role": "user"}]}, ) with ( - patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, - return_value=None) as mock_claim, + patch( + f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=None + ) as mock_claim, patch(f"{MODULE}.append_message", new_callable=AsyncMock) as mock_append, ): result = await server._try_claim_and_resume("resp_x", resp) @@ -1085,6 +1090,7 @@ async def test_successful_claim_spawns_resume_and_emits_sentinel(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer("ResponsesAgent") from datetime import timedelta + resp = _resp_info( status="in_progress", created_at=datetime.now(timezone.utc) - timedelta(seconds=300), @@ -1097,17 +1103,18 @@ async def test_successful_claim_spawns_resume_and_emits_sentinel(self): ) captured: dict = {} - async def fake_append(response_id, seq, *, item=None, stream_event=None, - attempt_number=1): + async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt_number=1): captured["seq"] = seq captured["event"] = stream_event captured["attempt_tag"] = attempt_number with ( - patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, - return_value=2), - patch(f"{MODULE}.get_messages", new_callable=AsyncMock, - return_value=[_msg(0, None, {}), _msg(1, None, {})]), + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=2), + patch( + f"{MODULE}.get_messages", + new_callable=AsyncMock, + return_value=[_msg(0, None, {}), _msg(1, None, {})], + ), patch(f"{MODULE}.append_message", side_effect=fake_append), patch("asyncio.create_task") as mock_create_task, ): @@ -1130,6 +1137,7 @@ async def test_resume_passes_empty_input_to_handler(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer("ResponsesAgent") from datetime import timedelta + resp = _resp_info( status="in_progress", created_at=datetime.now(timezone.utc) - timedelta(seconds=300), @@ -1147,8 +1155,12 @@ def capture_task(coro, *, name=None): captured_tasks.append((coro, name)) class _Fake: - def cancel(self): pass - def add_done_callback(self, cb): pass + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + return _Fake() with ( @@ -1165,7 +1177,6 @@ def add_done_callback(self, cb): pass # the captured coro with proper args. Simpler: check that the resume # coroutine was built with input=[]. Drive the coroutine so mock_run # receives the call args. - import asyncio as _a assert len(captured_tasks) == 1 coro, _name = captured_tasks[0] await coro @@ -1175,7 +1186,9 @@ def add_done_callback(self, cb): pass # resume_request is a ResponsesAgentRequest pydantic object after # round-tripping through the validator so the handler still gets its # declared arg type. - dumped = resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) assert dumped["input"] == [] # Other request metadata is preserved so the handler can find # thread_id / conversation_id / user_id. @@ -1193,8 +1206,9 @@ async def test_retrieve_calls_try_claim(self): with ( patch(f"{MODULE}.get_response", new_callable=AsyncMock, return_value=resp), patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), - patch.object(server, "_try_claim_and_resume", new_callable=AsyncMock, - return_value=None) as mock_claim, + patch.object( + server, "_try_claim_and_resume", new_callable=AsyncMock, return_value=None + ) as mock_claim, ): await server._handle_retrieve_request("resp_x", stream=False, starting_after=0) @@ -1206,7 +1220,8 @@ class TestHeartbeatContextManager: async def test_writes_heartbeat_periodically(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer( - "ResponsesAgent", heartbeat_interval_seconds=0.05, + "ResponsesAgent", + heartbeat_interval_seconds=0.05, heartbeat_stale_threshold_seconds=1.0, ) @@ -1223,7 +1238,8 @@ async def test_writes_heartbeat_periodically(self): async def test_stops_cleanly_on_exit(self): with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer( - "ResponsesAgent", heartbeat_interval_seconds=0.05, + "ResponsesAgent", + heartbeat_interval_seconds=0.05, heartbeat_stale_threshold_seconds=1.0, ) @@ -1245,13 +1261,17 @@ async def test_db_error_does_not_interrupt_body(self): real death, so a transient write miss must not kill a live run.""" with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer( - "ResponsesAgent", heartbeat_interval_seconds=0.05, + "ResponsesAgent", + heartbeat_interval_seconds=0.05, heartbeat_stale_threshold_seconds=1.0, ) body_ran = False - with patch(f"{MODULE}.heartbeat_response", new_callable=AsyncMock, - side_effect=RuntimeError("db down")): + with patch( + f"{MODULE}.heartbeat_response", + new_callable=AsyncMock, + side_effect=RuntimeError("db down"), + ): async with server._heartbeat("resp_x"): await asyncio.sleep(0.1) body_ran = True @@ -1286,6 +1306,7 @@ class TestDebugKillTask: def test_endpoint_absent_by_default(self): from starlette.testclient import TestClient + with patch(f"{MODULE}.is_db_configured", return_value=True): server = LongRunningAgentServer("ResponsesAgent") client = TestClient(server.app, raise_server_exceptions=False) @@ -1294,6 +1315,7 @@ def test_endpoint_absent_by_default(self): def test_endpoint_registered_when_env_set(self, monkeypatch): from starlette.testclient import TestClient + monkeypatch.setenv("LONG_RUNNING_ENABLE_DEBUG_KILL", "1") with patch(f"{MODULE}.is_db_configured", return_value=True): server = LongRunningAgentServer("ResponsesAgent") From f9f8a73e1ac9c6dacd51ebd0077b44bfff9e2e29 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 21:02:31 +0000 Subject: [PATCH 08/39] Log Background response created at INFO so response_id is visible in apps logs without --log-level debug --- src/databricks_ai_bridge/long_running/server.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index f97e0622..5a4f396d 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -383,9 +383,11 @@ async def _handle_background_request( original_request=durable_dict, ) - logger.debug( - "Background response created", - extra={"response_id": response_id, "stream": is_streaming, "pod": _POD_ID}, + logger.info( + "Background response created response_id=%s stream=%s pod=%s", + response_id, + is_streaming, + _POD_ID, ) response_obj: dict[str, Any] = { From d5666b2be75cc54e0b9acda7584619424fe2b310 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 21:22:17 +0000 Subject: [PATCH 09/39] Tag every SSE frame in stream retrieve with top-level response_id MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit response_id lived only inside the nested 'response' object of the response.created event, so proxies and clients had to know that specific shape to extract it. Add it as a top-level field on every frame alongside sequence_number — discoverable without schema knowledge. --- src/databricks_ai_bridge/long_running/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 5a4f396d..815da13a 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -952,7 +952,9 @@ async def _stream_retrieve( for seq, _, evt, _attempt in messages: if evt is not None: - evt = {**evt, "sequence_number": seq} + # Tag every SSE frame with the response_id so proxies / + # clients can discover it without parsing nested fields. + evt = {**evt, "sequence_number": seq, "response_id": response_id} event_type = evt.get("type", "message") logger.debug( "SSE event", From f2ffb6e219fdb38e347501573891677f019cd17c Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 21:25:58 +0000 Subject: [PATCH 10/39] Self-heal open streams: call _try_claim_and_resume from _stream_retrieve MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A stream opened BEFORE the owning pod died would poll the DB forever waiting for events that never come: _try_claim_and_resume only fires on fresh GET retrieve, not inside the existing poll loop. When the kill hit during a live stream, the UI just hung. Now every poll iteration also tries to reclaim — a no-op for fresh runs, a rescue path for dead ones. --- src/databricks_ai_bridge/long_running/server.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 815da13a..9ee3985a 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -945,6 +945,15 @@ async def _stream_retrieve( break status = resp.status + # Self-heal: if this response is still in_progress but its owning + # pod has gone silent past heartbeat_stale_threshold, try to claim + # + resume on this pod. A no-op if heartbeat is fresh or another + # pod already won. Without this, a stream opened before the crash + # would idle forever polling a dead run — since _try_claim_and_resume + # is only triggered by the outer retrieve handler on fresh GETs. + if status == "in_progress": + await self._try_claim_and_resume(response_id, resp) + # starting_after=0 fetches all messages (sequence numbers start at 0). # We use after_sequence=-1 for the DB query so that seq 0 is included. after_seq = last_seq - 1 if last_seq == 0 else last_seq From 6ee9f6c2d376e12d72c4c8bc7585a4ff97154b75 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 21:27:11 +0000 Subject: [PATCH 11/39] Tighten heartbeat_stale_threshold_seconds default from 15s to 10s MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Recovers crashed conversations faster — ~3 heartbeat intervals tolerated before reclaim instead of 5. Still plenty of margin over the 3s heartbeat interval to avoid false positives. --- src/databricks_ai_bridge/long_running/server.py | 2 +- src/databricks_ai_bridge/long_running/settings.py | 2 +- tests/databricks_ai_bridge/test_long_running_server.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 9ee3985a..b421a715 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -204,7 +204,7 @@ def __init__( db_statement_timeout_ms: int = 5000, cleanup_timeout_seconds: float = 7.0, heartbeat_interval_seconds: float = 3.0, - heartbeat_stale_threshold_seconds: float = 15.0, + heartbeat_stale_threshold_seconds: float = 10.0, ): if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 8df21788..30f1ad02 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -16,7 +16,7 @@ class LongRunningSettings: db_statement_timeout_ms: int = 5000 cleanup_timeout_seconds: float = 7.0 heartbeat_interval_seconds: float = 3.0 - heartbeat_stale_threshold_seconds: float = 15.0 + heartbeat_stale_threshold_seconds: float = 10.0 def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 831703ef..736bca7d 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1296,7 +1296,7 @@ def test_defaults_match_chat_ux(self): # "reconnecting..." patience window. s = LongRunningSettings() assert s.heartbeat_interval_seconds == 3.0 - assert s.heartbeat_stale_threshold_seconds == 15.0 + assert s.heartbeat_stale_threshold_seconds == 10.0 class TestDebugKillTask: From c2383f239124a011f3cb7a73bb3dbc8ef8f5f6bd Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 20 Apr 2026 21:32:23 +0000 Subject: [PATCH 12/39] Add [durable] INFO-level lifecycle logs across the resume path Every transition now shows up in apps logs prefixed with [durable]: - heartbeat start / beat#N sampled / stop with total - stale heartbeat detected with age + threshold - attempting claim / claim succeeded / claim lost - claim skipped with reason (grace_period, heartbeat_fresh, no_original_request) - background stream start / completed with totals - kill endpoint: cancelling task Fresh-heartbeat skips stay at DEBUG to avoid spamming every poll tick. Periodic heartbeat sampled every 5 beats (~15s at default interval). --- .../long_running/server.py | 99 +++++++++++++++++-- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index b421a715..64482731 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -256,6 +256,11 @@ async def cancel_endpoint(response_id: str): async def _debug_kill_task(response_id: str): task = self._running_tasks.get(response_id) if task is None: + logger.info( + "[durable] kill endpoint: no task response_id=%s on pod=%s", + response_id, + _POD_ID, + ) raise HTTPException( status_code=404, detail=( @@ -263,6 +268,11 @@ async def _debug_kill_task(response_id: str): "(may already have finished or be running on another pod)." ), ) + logger.info( + "[durable] kill endpoint: cancelling task response_id=%s pod=%s", + response_id, + _POD_ID, + ) task.cancel() return { "response_id": response_id, @@ -443,13 +453,31 @@ async def _heartbeat(self, response_id: str) -> AsyncGenerator[None, None]: stop = asyncio.Event() async def _beat(): + beats = 0 + logger.info( + "[durable] heartbeat start response_id=%s pod=%s interval=%.1fs", + response_id, + _POD_ID, + interval, + ) try: while not stop.is_set(): try: await heartbeat_response(response_id, _POD_ID) + beats += 1 + # Sampled heartbeat log so the lifecycle is visible + # without spamming every interval. Every 5th (~15s + # at 3s interval) is a good compromise. + if beats % 5 == 1: + logger.info( + "[durable] heartbeat beat#%d response_id=%s pod=%s", + beats, + response_id, + _POD_ID, + ) except Exception: logger.warning( - "Heartbeat write failed for %s; will retry", + "[durable] heartbeat write failed response_id=%s; will retry", response_id, exc_info=True, ) @@ -459,6 +487,12 @@ async def _beat(): pass except asyncio.CancelledError: pass + logger.info( + "[durable] heartbeat stop response_id=%s pod=%s total_beats=%d", + response_id, + _POD_ID, + beats, + ) hb_task = asyncio.create_task(_beat(), name=f"heartbeat-{response_id}") try: @@ -565,6 +599,13 @@ async def _do_background_stream( raise RuntimeError("No stream function registered; cannot run background stream") func_name = stream_fn.__name__ + logger.info( + "[durable] background stream start response_id=%s attempt=%d pod=%s handler=%s", + response_id, + attempt_number, + _POD_ID, + func_name, + ) all_chunks: list[dict[str, Any]] = [] # Continue sequence numbering across attempts so the client's cursor # never rewinds on resume. First attempt starts at 0 and skips the DB @@ -616,9 +657,13 @@ async def _do_background_stream( ) await update_response_status(response_id, "completed") - logger.debug( - "Background stream completed", - extra={"response_id": response_id, "total_events": seq}, + logger.info( + "[durable] background stream completed response_id=%s attempt=%d " + "total_events=%d pod=%s", + response_id, + attempt_number, + seq, + _POD_ID, ) async def _run_background_invoke( @@ -717,16 +762,51 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: # threshold. Otherwise a quick follow-up retrieve could hijack a # running pod before it ever writes its first heartbeat. if resp.heartbeat_at is None: - if _age_seconds(resp.created_at) < self._settings.heartbeat_stale_threshold_seconds: + age = _age_seconds(resp.created_at) + if age < self._settings.heartbeat_stale_threshold_seconds: + logger.debug( + "[durable] claim skipped response_id=%s reason=grace_period " + "age=%.1fs threshold=%.1fs", + response_id, + age, + self._settings.heartbeat_stale_threshold_seconds, + ) + return None + else: + hb_age = _age_seconds(resp.heartbeat_at) + if hb_age < self._settings.heartbeat_stale_threshold_seconds: + # Heartbeat is fresh — owner is alive. Common case, keep + # quiet at debug so we don't spam every poll iteration. + logger.debug( + "[durable] claim skipped response_id=%s reason=heartbeat_fresh " + "age=%.1fs threshold=%.1fs", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + ) return None + logger.info( + "[durable] stale heartbeat detected response_id=%s " + "heartbeat_age=%.1fs threshold=%.1fs current_owner=%s", + response_id, + hb_age, + self._settings.heartbeat_stale_threshold_seconds, + resp.owner_pod_id, + ) if resp.original_request is None: # Nothing to replay from — the run predates durability metadata. logger.warning( - "Cannot resume %s: no original_request persisted (old row?)", + "[durable] cannot resume response_id=%s reason=no_original_request", response_id, ) return None + logger.info( + "[durable] attempting claim response_id=%s current_attempt=%d new_owner=%s", + response_id, + resp.attempt_number, + _POD_ID, + ) new_attempt = await claim_stale_response( response_id, new_owner_pod_id=_POD_ID, @@ -735,6 +815,10 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: if new_attempt is None: # Someone else owns it, or the row was updated between the read and # the claim. Expected under contention. + logger.info( + "[durable] claim lost response_id=%s (another pod won or row changed)", + response_id, + ) return None # Build a "resume" request: keep everything the handler cares about @@ -761,10 +845,11 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: ) logger.info( - "Claimed stale run %s as attempt %s (pod=%s)", + "[durable] claim succeeded response_id=%s new_attempt=%d pod=%s resume_from_seq=%d", response_id, new_attempt, _POD_ID, + next_seq, ) task = asyncio.create_task( From cbd2b0b107e661cadb94648c0c734cac421626bf Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 21 Apr 2026 00:37:31 +0000 Subject: [PATCH 13/39] Add public durable-resume repair helpers for openai + langchain MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncDatabricksSession.repair(): Dedupes function_call/function_call_output items by call_id and injects synthetic function_call_outputs for orphans left behind by a mid-tool kill. Returns the count injected. No-op when clean. build_tool_resume_repair(messages) in databricks_langchain: Pure helper that returns synthetic ToolMessages for AIMessage.tool_calls in the trailing assistant turn whose paired ToolMessage never landed. Append via the add_messages reducer to satisfy Anthropic's tool_use ⇄ tool_result contract on resume. Both were previously in-template workarounds in agent-openai-advanced and agent-langgraph-advanced; moving them into the library so any user of LongRunningAgentServer + AsyncDatabricksSession or AsyncCheckpointSaver gets correct mid-tool crash-resume behavior without template-side fixes. Co-authored-by: Isaac --- .../src/databricks_langchain/__init__.py | 7 +- .../src/databricks_langchain/checkpoint.py | 82 +++++++++++++- .../src/databricks_openai/agents/session.py | 106 ++++++++++++++++++ 3 files changed, 193 insertions(+), 2 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index bfa52f8c..39db806b 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -18,7 +18,11 @@ from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit, UnityCatalogTool from databricks_langchain.chat_models import ChatDatabricks -from databricks_langchain.checkpoint import AsyncCheckpointSaver, CheckpointSaver +from databricks_langchain.checkpoint import ( + AsyncCheckpointSaver, + CheckpointSaver, + build_tool_resume_repair, +) from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent from databricks_langchain.multi_server_mcp_client import ( @@ -48,4 +52,5 @@ "DatabricksMultiServerMCPClient", "DatabricksMCPServer", "MCPServer", + "build_tool_resume_repair", ] diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index bcec679d..924bf93f 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Sequence from databricks.sdk import WorkspaceClient @@ -16,6 +16,86 @@ _checkpoint_imports_available = False +try: + from langchain_core.messages import AIMessage, ToolMessage + + _message_imports_available = True +except ImportError: + AIMessage = object # type: ignore + ToolMessage = object # type: ignore + _message_imports_available = False + + +DEFAULT_TOOL_RESUME_REPAIR_OUTPUT = ( + "Tool call was interrupted by a durable resume and did not complete. " + "Please retry if still needed." +) + + +def build_tool_resume_repair( + messages: Sequence[Any], + synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, +) -> list[Any]: + """Build synthetic ``ToolMessage`` responses for orphan tool calls. + + When a LangGraph run is killed mid-tool, the checkpointer preserves the + trailing ``AIMessage.tool_calls`` but the paired ``ToolMessage``s never + land. Replaying that state to the LLM on resume fails because the API + (Anthropic in particular) requires every ``tool_use`` to be immediately + followed by a matching ``tool_result``. + + Walks the trailing assistant turn (the last contiguous block of + ``AIMessage`` / ``ToolMessage``) and returns a synthetic ``ToolMessage`` + for each ``tool_call`` id that lacks a matching + ``ToolMessage.tool_call_id``. Appending the returned list via the + ``add_messages`` reducer restores a valid conversation. + + Example:: + + from databricks_langchain import build_tool_resume_repair + + state = await graph.aget_state(config) + repair = build_tool_resume_repair(state.values.get("messages", [])) + if repair: + await graph.aupdate_state(config, {"messages": repair}) + + Args: + messages: The current ``messages`` list from graph state. + synthetic_output: Text for each injected ``ToolMessage.content``. + + Returns: + A list of ``ToolMessage`` instances (possibly empty). Empty means + the state is already consistent — no repair needed. + """ + if not _message_imports_available or not messages: + return [] + + # Trailing assistant turn: walk backwards until we hit a non-assistant/ + # non-tool message. That block is the "pending" turn whose tool_use ↔ + # tool_result pairing we need to enforce. + trailing_start = len(messages) + for i in range(len(messages) - 1, -1, -1): + if isinstance(messages[i], (AIMessage, ToolMessage)): + trailing_start = i + else: + break + + tool_call_ids: list[str] = [] + answered: set[str] = set() + for msg in messages[trailing_start:]: + if isinstance(msg, AIMessage): + for tc in getattr(msg, "tool_calls", None) or []: + tc_id = tc.get("id") if isinstance(tc, dict) else getattr(tc, "id", None) + if tc_id and tc_id not in tool_call_ids: + tool_call_ids.append(tc_id) + elif isinstance(msg, ToolMessage): + tcid = getattr(msg, "tool_call_id", None) + if tcid: + answered.add(tcid) + + orphans = [tc_id for tc_id in tool_call_ids if tc_id not in answered] + return [ToolMessage(tool_call_id=tc_id, content=synthetic_output) for tc_id in orphans] + class CheckpointSaver(PostgresSaver): """ diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 6b12b430..96a140d9 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -35,6 +35,11 @@ async def main(): from threading import Lock from typing import Any, Optional +DEFAULT_REPAIR_SYNTHETIC_OUTPUT = ( + "Tool call was interrupted by a durable resume and did not complete. " + "Please retry if still needed." +) + try: from agents.extensions.memory import SQLAlchemySession from databricks.sdk import WorkspaceClient @@ -54,6 +59,21 @@ async def main(): logger = logging.getLogger(__name__) +def _item_get(item: Any, key: str) -> Any: + if isinstance(item, dict): + return item.get(key) + return getattr(item, key, None) + + +def _item_dict(item: Any) -> dict: + """Normalize a session item to a plain dict for re-persistence.""" + if isinstance(item, dict): + return dict(item) + if hasattr(item, "model_dump"): + return item.model_dump() + return dict(item.__dict__) if hasattr(item, "__dict__") else {} + + class AsyncDatabricksSession(SQLAlchemySession): """ Async OpenAI Agents SDK Session implementation for Databricks Lakebase. @@ -168,6 +188,92 @@ def __init__( session_id, ) + async def repair(self, synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT) -> int: + """Reconcile the session so the next LLM call sees a valid conversation. + + Handles two failure modes that durable-resume and client history echo + can introduce: + + 1. **Orphan function_calls.** A kill mid-tool leaves a ``function_call`` + item with no matching ``function_call_output``. The next LLM turn + fails with 'tool_calls must be followed by tool messages'. + + 2. **Duplicate items.** When the client re-sends prior history on a + resumed turn, the same ``call_id`` can land twice. Even with every + call_id eventually having an output, duplicates confuse the API. + + Walks items in chronological order, dedupes ``function_call`` / + ``function_call_output`` by ``call_id``, and injects a synthetic + ``function_call_output`` immediately after any orphan ``function_call``. + No-op when the session is already consistent. + + Args: + synthetic_output: Text used for the synthetic outputs inserted for + orphan tool calls. Defaults to an 'interrupted by resume' + message. + + Returns: + The number of synthetic outputs injected (0 if the session was + already clean). + """ + items = await self.get_items() + if not items: + return 0 + + call_ids_with_output: set[str] = set() + for item in items: + if _item_get(item, "type") == "function_call_output": + cid = _item_get(item, "call_id") + if cid: + call_ids_with_output.add(cid) + + sanitized: list[dict] = [] + seen_calls: set[str] = set() + seen_outputs: set[str] = set() + injected_call_ids: list[str] = [] + + for item in items: + t = _item_get(item, "type") + cid = _item_get(item, "call_id") + if t == "function_call" and cid: + if cid in seen_calls: + continue # drop duplicate + seen_calls.add(cid) + sanitized.append(_item_dict(item)) + if cid not in call_ids_with_output: + sanitized.append( + { + "type": "function_call_output", + "call_id": cid, + "output": synthetic_output, + } + ) + injected_call_ids.append(cid) + elif t == "function_call_output" and cid: + if cid in seen_outputs: + continue # drop duplicate output + seen_outputs.add(cid) + sanitized.append(_item_dict(item)) + else: + sanitized.append(_item_dict(item)) + + if len(sanitized) == len(items) and not injected_call_ids: + return 0 + + logger.info( + "AsyncDatabricksSession.repair session_id=%s original=%d sanitized=%d " + "injected=%d call_ids=%s", + self.session_id, + len(items), + len(sanitized), + len(injected_call_ids), + injected_call_ids, + ) + await self.clear_session() + if sanitized: + await self.add_items(sanitized) + return len(injected_call_ids) + @classmethod def _build_cache_key( cls, From 5d70dde9ac927c439a2b65244e9b40826bd6568a Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 21 Apr 2026 20:20:36 +0000 Subject: [PATCH 14/39] Add pre_model_hook factory; WARN on skipped durability migrations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - databricks_langchain.build_tool_resume_repair_pre_model_hook(): returns a LangGraph pre_model_hook that wraps build_tool_resume_repair. Lets templates wire durable-resume recovery via one create_agent arg instead of manual aget_state/aupdate_state surgery in the handler (which also required as_node="tools" to avoid a KeyError: 'model' in the branch re-evaluation). Fires on every model turn; no-op when state is clean. - db.py: promote "skipped migration due to insufficient privilege" from INFO to a single WARN summary at startup. If the DB was previously migrated by a different service principal this is expected, but if our SP genuinely lacks ALTER on a fresh table the claim/heartbeat path will fail later with "column does not exist" — surface the risk clearly. Co-authored-by: Isaac --- .../src/databricks_langchain/__init__.py | 2 + .../src/databricks_langchain/checkpoint.py | 47 ++++++++++++++++++- src/databricks_ai_bridge/long_running/db.py | 19 ++++++-- 3 files changed, 63 insertions(+), 5 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 39db806b..ba7948d8 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -22,6 +22,7 @@ AsyncCheckpointSaver, CheckpointSaver, build_tool_resume_repair, + build_tool_resume_repair_pre_model_hook, ) from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent @@ -53,4 +54,5 @@ "DatabricksMCPServer", "MCPServer", "build_tool_resume_repair", + "build_tool_resume_repair_pre_model_hook", ] diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 924bf93f..ebcd40e3 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Sequence +from typing import Any, Callable, Sequence from databricks.sdk import WorkspaceClient @@ -97,6 +97,51 @@ def build_tool_resume_repair( return [ToolMessage(tool_call_id=tc_id, content=synthetic_output) for tc_id in orphans] +def build_tool_resume_repair_pre_model_hook( + synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, +) -> Callable[[dict], dict]: + """Return a LangGraph ``pre_model_hook`` that repairs orphan tool calls. + + Wires ``build_tool_resume_repair`` into the graph as a pre-model hook so + durable-resume recovery happens automatically before every LLM call. Keeps + repair logic off the handler — callers only add one argument to + ``create_agent``. + + Usage:: + + from databricks_langchain import build_tool_resume_repair_pre_model_hook + + agent = create_agent( + model=model, + tools=tools, + checkpointer=checkpointer, + pre_model_hook=build_tool_resume_repair_pre_model_hook(), + ) + + The hook fires on every model turn and is a no-op when state is clean, so + the happy path is free. On a mid-tool crash-resume, it injects synthetic + ``ToolMessage``s for any ``AIMessage.tool_calls`` in the trailing turn + whose paired ``ToolMessage`` never landed. Satisfies Anthropic's + ``tool_use`` ⇄ ``tool_result`` contract without needing manual + ``aupdate_state(..., as_node="tools")`` surgery. + + Args: + synthetic_output: Text for each injected ``ToolMessage.content``. + + Returns: + A callable suitable to pass as ``pre_model_hook`` to + ``langchain.agents.create_agent`` (or ``create_react_agent``). + """ + + def _hook(state: dict) -> dict: + repair = build_tool_resume_repair( + state.get("messages", []), synthetic_output=synthetic_output + ) + return {"messages": repair} if repair else {} + + return _hook + + class CheckpointSaver(PostgresSaver): """ LangGraph PostgresSaver using a Lakebase connection pool. diff --git a/src/databricks_ai_bridge/long_running/db.py b/src/databricks_ai_bridge/long_running/db.py index 44ef19fd..aed2c903 100644 --- a/src/databricks_ai_bridge/long_running/db.py +++ b/src/databricks_ai_bridge/long_running/db.py @@ -98,6 +98,7 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) " "WHERE status = 'in_progress'", ) + skipped_migrations: list[str] = [] for stmt in migration_stmts: try: async with _engine.begin() as conn: @@ -105,14 +106,24 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy): except Exception as exc: msg = str(exc).lower() if "insufficientprivilege" in msg or "must be owner" in msg: - logger.info( - "[DB] Skipping migration (not owner, presumed already applied): %s", - stmt.split("\n")[0], - ) + skipped_migrations.append(stmt.split("\n")[0]) continue raise _initialized = True + if skipped_migrations: + # WARN-level summary: if the DB was previously migrated by another SP + # this is fine, but if it's genuinely a new table and our SP lacks + # ALTER, claim/heartbeat queries will fail later with a confusing + # "column does not exist" — surface it clearly at startup. + logger.warning( + "[DB] Skipped %d durability migration(s) due to insufficient " + "privilege — assuming table was already migrated by another " + "service principal. Crash-resume will fail with 'column does " + "not exist' if this assumption is wrong. Skipped: %s", + len(skipped_migrations), + ", ".join(skipped_migrations), + ) logger.info("[DB] Engine and schema ready") From 62df014c7454d4a8999a5ccb3f5fe1d7f7754773 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 21 Apr 2026 20:37:07 +0000 Subject: [PATCH 15/39] Rename pre_model_hook factory to middleware factory langchain.agents.create_agent in 1.x uses AgentMiddleware (middleware=[...]) instead of the older pre_model_hook arg. Rename the helper and return an AgentMiddleware instance whose before_model / abefore_model methods run build_tool_resume_repair. Zero behavior change; matches the public API of the current langchain release. Co-authored-by: Isaac --- .../src/databricks_langchain/__init__.py | 4 +- .../src/databricks_langchain/checkpoint.py | 63 ++++++++++++------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index ba7948d8..0cc35380 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -22,7 +22,7 @@ AsyncCheckpointSaver, CheckpointSaver, build_tool_resume_repair, - build_tool_resume_repair_pre_model_hook, + build_tool_resume_repair_middleware, ) from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent @@ -54,5 +54,5 @@ "DatabricksMCPServer", "MCPServer", "build_tool_resume_repair", - "build_tool_resume_repair_pre_model_hook", + "build_tool_resume_repair_middleware", ] diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index ebcd40e3..9b20029e 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -97,49 +97,68 @@ def build_tool_resume_repair( return [ToolMessage(tool_call_id=tc_id, content=synthetic_output) for tc_id in orphans] -def build_tool_resume_repair_pre_model_hook( +def build_tool_resume_repair_middleware( synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, -) -> Callable[[dict], dict]: - """Return a LangGraph ``pre_model_hook`` that repairs orphan tool calls. +) -> Any: + """Return a LangChain ``AgentMiddleware`` that repairs orphan tool calls. - Wires ``build_tool_resume_repair`` into the graph as a pre-model hook so - durable-resume recovery happens automatically before every LLM call. Keeps - repair logic off the handler — callers only add one argument to - ``create_agent``. + Wires ``build_tool_resume_repair`` into ``langchain.agents.create_agent`` + via its middleware API so durable-resume recovery happens automatically + before every LLM call. Keeps repair logic off the handler — callers only + add one argument to ``create_agent``. Usage:: - from databricks_langchain import build_tool_resume_repair_pre_model_hook + from databricks_langchain import build_tool_resume_repair_middleware agent = create_agent( model=model, tools=tools, checkpointer=checkpointer, - pre_model_hook=build_tool_resume_repair_pre_model_hook(), + middleware=[build_tool_resume_repair_middleware()], ) - The hook fires on every model turn and is a no-op when state is clean, so - the happy path is free. On a mid-tool crash-resume, it injects synthetic - ``ToolMessage``s for any ``AIMessage.tool_calls`` in the trailing turn - whose paired ``ToolMessage`` never landed. Satisfies Anthropic's - ``tool_use`` ⇄ ``tool_result`` contract without needing manual + The middleware's ``before_model`` hook fires on every model turn and is a + no-op when state is clean, so the happy path is free. On a mid-tool + crash-resume, it injects synthetic ``ToolMessage``s for any + ``AIMessage.tool_calls`` in the trailing turn whose paired + ``ToolMessage`` never landed. Satisfies Anthropic's ``tool_use`` ⇄ + ``tool_result`` contract without needing manual ``aupdate_state(..., as_node="tools")`` surgery. Args: synthetic_output: Text for each injected ``ToolMessage.content``. Returns: - A callable suitable to pass as ``pre_model_hook`` to - ``langchain.agents.create_agent`` (or ``create_react_agent``). + An ``AgentMiddleware`` instance suitable for the ``middleware=`` + argument of ``langchain.agents.create_agent``. + + Raises: + ImportError: If ``langchain.agents.middleware.AgentMiddleware`` is + unavailable (older langchain version or extra not installed). """ + try: + from langchain.agents.middleware import AgentMiddleware + except ImportError as exc: + raise ImportError( + "build_tool_resume_repair_middleware requires langchain>=1.0 with " + "the agents extra. Install via `pip install langchain[agents]` or " + "equivalent." + ) from exc + + class ToolResumeRepairMiddleware(AgentMiddleware): + """Repairs orphan tool_use AIMessages before each model invocation.""" + + def before_model(self, state, runtime): # type: ignore[override] + repair = build_tool_resume_repair( + state.get("messages", []), synthetic_output=synthetic_output + ) + return {"messages": repair} if repair else None - def _hook(state: dict) -> dict: - repair = build_tool_resume_repair( - state.get("messages", []), synthetic_output=synthetic_output - ) - return {"messages": repair} if repair else {} + async def abefore_model(self, state, runtime): # type: ignore[override] + return self.before_model(state, runtime) - return _hook + return ToolResumeRepairMiddleware() class CheckpointSaver(PostgresSaver): From 4af26f01bb33df860013d62f93c044d4da0d3e37 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 21 Apr 2026 22:33:20 +0000 Subject: [PATCH 16/39] Remove unused Callable import in checkpoint.py --- integrations/langchain/src/databricks_langchain/checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 9b20029e..05993ba3 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Sequence +from typing import Any, Sequence from databricks.sdk import WorkspaceClient From bc573fc7a812279db926fa2335bccabc3850708e Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 22:20:23 +0000 Subject: [PATCH 17/39] LongRunning: rotate conv_id on resume + full-history input sanitizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move both durability concerns out of user-space and into the server. The LangGraph middleware and OpenAI session.repair() become optional — templates no longer need to install them. Two coordinated changes in server.py: 1. Rotate context.conversation_id on every resume attempt. _try_claim_and_resume now deep-copies original_request, replays input verbatim (instead of blanking to []), drops custom_inputs.thread_id/session_id, and sets context.conversation_id = f"{base}::attempt-{N}" so the handler's SDK helpers resolve to a fresh thread/session. Base anchors on whatever the user actually pinned (thread_id/session_id/conversation_id/response_id); rotation always re-anchors on original_request — no stacking across attempts. Trade-off: attempt N>1 re-runs the LLM on the pre-crash input instead of picking up mid-turn state; strictly safer than inheriting mid-turn checkpoint state that the SDK can't fully repair (notably the LangGraph stream-event attempt-boundary orphan artifact seen in the stress test). 2. Full-history _sanitize_request_input runs before the handler sees the request (both initial POST and resume replay). Drops duplicate call_ids, drops orphan function_call_output items, injects synthetic outputs for function_calls with no output. Walks the whole list — neither the LangGraph middleware (trailing-only) nor session.repair() (session-only) cover mid-history orphans that come in via UI echo. Gated by auto_sanitize_input (default True). 14 new unit tests cover the sanitizer (paired/orphan/duplicate/mid-history/ chat-completions shape), rotation (all four fallback priorities), and the resume-replays-input-not-empty flow. Co-authored-by: Isaac --- .../long_running/server.py | 197 +++++++++++++++- .../long_running/settings.py | 4 + .../test_long_running_server.py | 221 ++++++++++++++++-- 3 files changed, 399 insertions(+), 23 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 64482731..af63130b 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -52,6 +52,14 @@ BACKGROUND_KEY = "background" +# Synthetic output injected for an orphaned function_call whose matching +# function_call_output was lost to a pod crash. "interrupted" is the most +# honest label: the LLM decides whether to retry on the next turn. +DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT = ( + "Tool call was interrupted by a server-side event " + "(e.g., pod restart). No result was produced." +) + # One ID per process so heartbeats + claims have a stable owner identity. _POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" @@ -126,6 +134,165 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +def _sanitize_request_input( + request_dict: dict[str, Any], + synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, +) -> dict[str, Any]: + """Reconcile orphaned function_call / function_call_output items in input. + + Walks ``request['input']`` end-to-end (not just the trailing turn) and: + + * drops duplicate ``function_call`` items by ``call_id``; + * drops duplicate or orphan ``function_call_output`` items (no matching + ``function_call`` anywhere in the list); + * injects a synthetic ``function_call_output`` immediately after any + ``function_call`` that has no output, so every ``tool_use`` is paired. + + Also supports chat-completions-shape assistant items + (``{role: "assistant", tool_calls: [...]}``) as declaring call_ids. + + This is a pure transform on a dict — no pydantic round-trip, no DB I/O. + Returns the same dict (mutated in place on the ``input`` key) for caller + convenience. + + Reason we walk the whole history: the LangGraph in-graph middleware can + only repair the trailing turn, but UI-echoed history can carry orphans + from prior crashed turns mid-list. See rotation-findings.md (Test E). + """ + items = request_dict.get("input") + if not isinstance(items, list) or not items: + return request_dict + + declared_call_ids: set[str] = set() + call_ids_with_output: set[str] = set() + for i in items: + if not isinstance(i, dict): + continue + t = i.get("type") + cid = i.get("call_id") + if t == "function_call" and cid: + declared_call_ids.add(cid) + if t == "function_call_output" and cid: + call_ids_with_output.add(cid) + if i.get("role") == "assistant" and isinstance(i.get("tool_calls"), list): + for tc in i["tool_calls"]: + if not isinstance(tc, dict): + continue + tc_id = tc.get("id") or (tc.get("function") or {}).get("id") + if tc_id: + declared_call_ids.add(tc_id) + + sanitized: list[dict[str, Any]] = [] + seen_calls: set[str] = set() + seen_outputs: set[str] = set() + injected = 0 + dropped_orphan_outputs = 0 + dropped_duplicates = 0 + + for item in items: + if not isinstance(item, dict): + sanitized.append(item) + continue + t = item.get("type") + cid = item.get("call_id") + + if t == "function_call" and cid: + if cid in seen_calls: + dropped_duplicates += 1 + continue + seen_calls.add(cid) + sanitized.append(item) + if cid not in call_ids_with_output: + sanitized.append( + { + "type": "function_call_output", + "call_id": cid, + "output": synthetic_output, + } + ) + injected += 1 + elif t == "function_call_output" and cid: + if cid in seen_outputs: + dropped_duplicates += 1 + continue + if cid not in declared_call_ids: + dropped_orphan_outputs += 1 + continue + seen_outputs.add(cid) + sanitized.append(item) + else: + sanitized.append(item) + + if injected or dropped_orphan_outputs or dropped_duplicates: + logger.info( + "[durable] input sanitized: injected=%d dropped_orphan_outputs=%d " + "dropped_duplicates=%d original_items=%d final_items=%d", + injected, + dropped_orphan_outputs, + dropped_duplicates, + len(items), + len(sanitized), + ) + + request_dict["input"] = sanitized + return request_dict + + +def _rotate_conversation_id( + request_dict: dict[str, Any], + new_attempt_number: int, + response_id: str, +) -> dict[str, Any]: + """Rotate the conversation anchor to a per-attempt value. + + After a crash, attempt N+1 should see a FRESH checkpointer / session so it + doesn't inherit mid-turn state that the SDK can't repair cleanly (most + notably the LangGraph stream-event attempt-boundary orphan artifact). + The handler's priority chain is: + + 1. custom_inputs.thread_id / session_id (explicit, wins) + 2. context.conversation_id (fallback) + 3. auto-generated (last resort) + + We drop (1), pick the current base anchor, and write ``{base}::attempt-N`` + into (2). The handler then resolves to a fresh key for this attempt while + still being deterministic across retries of the same attempt. + + The LLM sees full turn history via ``original_request.input``, which was + captured at the initial POST — before any attempt ran, so it's clean by + construction. + """ + custom_inputs = request_dict.get("custom_inputs") + if not isinstance(custom_inputs, dict): + custom_inputs = {} + + base_anchor = ( + custom_inputs.get("thread_id") + or custom_inputs.get("session_id") + or (request_dict.get("context") or {}).get("conversation_id") + or response_id + ) + + custom_inputs.pop("thread_id", None) + custom_inputs.pop("session_id", None) + request_dict["custom_inputs"] = custom_inputs + + ctx = request_dict.get("context") or {} + ctx = dict(ctx) + rotated = f"{base_anchor}::attempt-{new_attempt_number}" + ctx["conversation_id"] = rotated + request_dict["context"] = ctx + logger.info( + "[durable] rotated conversation_id for resume response_id=%s " + "attempt=%d base=%s rotated=%s", + response_id, + new_attempt_number, + base_anchor, + rotated, + ) + return request_dict + + def _inject_conversation_id(request_dict: dict[str, Any], response_id: str) -> dict[str, Any]: """Anchor the request to ``response_id`` as its conversation. @@ -350,6 +517,9 @@ async def _handle_invocations_request( data = {k: v for k, v in data.items() if k not in (BACKGROUND_KEY, MLFLOW_STREAM_KEY)} return_trace_id = (get_request_headers().get(RETURN_TRACE_HEADER) or "").lower() == "true" + if self._settings.auto_sanitize_input: + data = _sanitize_request_input(data) + try: request_data = self.validator.validate_and_convert_request(data) except ValueError as e: @@ -821,12 +991,27 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: ) return None - # Build a "resume" request: keep everything the handler cares about - # (custom_inputs carrying thread_id, context.conversation_id), but null - # out input so the handler passes {"messages": []} / [] to its agent. - # This is the single line that makes the design framework-agnostic. - resume_dict = dict(resp.original_request) - resume_dict["input"] = [] + # Build a "resume" request by REPLAYING the original POST's input on a + # ROTATED conversation anchor. Two coordinated moves: + # + # 1. Keep original_request["input"] intact — it was captured pre-crash + # and is protocol-clean by construction. The LLM sees full history + # via the input list, not via checkpointer state. + # + # 2. Rotate conversation_id so the handler's SDK helpers resolve to a + # FRESH thread_id / session_id for this attempt. Without this, the + # handler would reload the crashed attempt's mid-turn checkpoint, + # which on LangGraph produces a stream-event orphan artifact at the + # attempt boundary (rotation-findings.md stress test). + # + # Trade-off: attempt N+1 re-runs the LLM from scratch on the replayed + # input (one extra LLM call + any side-effectful tool re-run). This is + # strictly safer than resuming on a checkpointer that may have state + # the SDK can't fully repair. + resume_dict = copy.deepcopy(resp.original_request) + if self._settings.auto_sanitize_input: + resume_dict = _sanitize_request_input(resume_dict) + resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id) resume_request = self.validator.validate_and_convert_request(resume_dict) # Emit a marker so clients can reset any in-flight rendering from the diff --git a/src/databricks_ai_bridge/long_running/settings.py b/src/databricks_ai_bridge/long_running/settings.py index 30f1ad02..4e41b715 100644 --- a/src/databricks_ai_bridge/long_running/settings.py +++ b/src/databricks_ai_bridge/long_running/settings.py @@ -17,6 +17,10 @@ class LongRunningSettings: cleanup_timeout_seconds: float = 7.0 heartbeat_interval_seconds: float = 3.0 heartbeat_stale_threshold_seconds: float = 10.0 + # Walk request.input[] on every request and drop/repair orphaned + # function_call / function_call_output pairs before the handler runs. + # Lets handlers stay framework-idiomatic without carrying repair logic. + auto_sanitize_input: bool = True def __post_init__(self) -> None: if self.task_timeout_seconds <= 0: diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 736bca7d..b851481f 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -16,9 +16,12 @@ from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( + DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, LongRunningAgentServer, _deferred_mark_failed, _inject_conversation_id, + _rotate_conversation_id, + _sanitize_request_input, _sse_event, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings @@ -932,6 +935,133 @@ async def test_lifespan_not_set_when_db_not_configured(self): # --------------------------------------------------------------------------- +class TestSanitizeRequestInput: + """Full-history orphan walker — catches mid-history orphans that neither + the LangGraph middleware (trailing-only) nor session.repair() (session-only) + cover. See rotation-findings.md Test E.""" + + def test_empty_input_is_noop(self): + assert _sanitize_request_input({}) == {} + assert _sanitize_request_input({"input": []}) == {"input": []} + + def test_passes_through_paired_call_and_output(self): + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + {"role": "assistant", "content": "done"}, + ] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + def test_injects_synthetic_output_for_trailing_orphan_call(self): + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 3 + assert out["input"][2] == { + "type": "function_call_output", + "call_id": "c1", + "output": DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, + } + + def test_injects_synthetic_output_for_midhistory_orphan_call(self): + # The case that today's middleware misses (Test E). + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"role": "user", "content": "different question"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 4 + assert out["input"][1]["type"] == "function_call" + assert out["input"][2] == { + "type": "function_call_output", + "call_id": "c1", + "output": DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, + } + assert out["input"][3] == {"role": "user", "content": "different question"} + + def test_drops_orphan_output_with_no_matching_call(self): + # The LangGraph stream-event attempt-boundary artifact. + items = [ + {"role": "user", "content": "hi"}, + {"type": "function_call_output", "call_id": "c-ghost", "output": "x"}, + {"role": "user", "content": "follow-up"}, + ] + out = _sanitize_request_input({"input": items}) + assert out["input"] == [ + {"role": "user", "content": "hi"}, + {"role": "user", "content": "follow-up"}, + ] + + def test_dedupes_duplicate_calls_and_outputs(self): + items = [ + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + {"type": "function_call_output", "call_id": "c1", "output": "ok"}, + ] + out = _sanitize_request_input({"input": items}) + assert len(out["input"]) == 2 + assert out["input"][0]["type"] == "function_call" + assert out["input"][1]["type"] == "function_call_output" + + def test_recognizes_chat_completions_shape_as_declaring_call_id(self): + # An assistant message with tool_calls counts as "declaring" a call_id, + # so a matching function_call_output further down is NOT dropped. + items = [ + { + "role": "assistant", + "content": [], + "tool_calls": [ + {"id": "tc-1", "type": "function", "function": {"name": "f", "arguments": "{}"}} + ], + }, + {"type": "function_call_output", "call_id": "tc-1", "output": "ok"}, + ] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + def test_non_dict_items_pass_through(self): + items = [{"role": "user", "content": "hi"}, "not-a-dict", 42] + out = _sanitize_request_input({"input": list(items)}) + assert out["input"] == items + + +class TestRotateConversationId: + def test_rotate_drops_thread_id_and_sets_rotated_context(self): + r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "thread_id" not in out["custom_inputs"] + assert out["custom_inputs"]["user_id"] == "u" + assert out["context"]["conversation_id"] == "t1::attempt-2" + + def test_rotate_drops_session_id(self): + r = {"custom_inputs": {"session_id": "s1"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert "session_id" not in out["custom_inputs"] + assert out["context"]["conversation_id"] == "s1::attempt-2" + + def test_rotate_falls_back_to_context_conversation_id(self): + r = {"custom_inputs": {}, "context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-3" + + def test_rotate_falls_back_to_response_id_as_last_resort(self): + r = {"custom_inputs": {}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "resp_x::attempt-2" + + def test_rotate_handles_missing_custom_inputs_key(self): + r = {"context": {"conversation_id": "c-abc"}} + out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") + assert out["context"]["conversation_id"] == "c-abc::attempt-2" + assert out["custom_inputs"] == {} + + class TestInjectConversationId: """Anchoring an otherwise-anonymous request to a response_id guarantees a resumed run on a new pod resolves to the same agent-SDK thread/session.""" @@ -1130,10 +1260,11 @@ async def fake_append(response_id, seq, *, item=None, stream_event=None, attempt mock_create_task.assert_called_once() @pytest.mark.asyncio - async def test_resume_passes_empty_input_to_handler(self): - """The single line that makes the design framework-agnostic: input=[] so - LangGraph's checkpointer and databricks-openai Session resume cleanly - (verified by prototypes — see project memory).""" + async def test_resume_replays_input_and_rotates_conversation_id(self): + """Resume must replay original_request.input (not blank it) and rotate + the conversation anchor so the handler resolves to a fresh thread / + session for the new attempt. Prevents the LangGraph stream-event + attempt-boundary orphan artifact (rotation-findings.md).""" with patch(f"{MODULE}.is_db_configured", return_value=False): server = LongRunningAgentServer("ResponsesAgent") from datetime import timedelta @@ -1144,7 +1275,7 @@ async def test_resume_passes_empty_input_to_handler(self): heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), original_request={ "input": [{"role": "user", "content": "hi"}], - "custom_inputs": {"thread_id": "t1"}, + "custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}, }, ) @@ -1172,29 +1303,85 @@ def add_done_callback(self, cb): ): await server._try_claim_and_resume("resp_x", resp) - # We can't execute the coro (mock_run is AsyncMock so awaiting it is fine), - # but we can inspect the scheduled coroutine's frame locals or just await - # the captured coro with proper args. Simpler: check that the resume - # coroutine was built with input=[]. Drive the coroutine so mock_run - # receives the call args. assert len(captured_tasks) == 1 coro, _name = captured_tasks[0] await coro mock_run.assert_awaited_once() args, kwargs = mock_run.call_args resume_request = args[1] if len(args) > 1 else kwargs["request_data"] - # resume_request is a ResponsesAgentRequest pydantic object after - # round-tripping through the validator so the handler still gets its - # declared arg type. dumped = ( resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request ) - assert dumped["input"] == [] - # Other request metadata is preserved so the handler can find - # thread_id / conversation_id / user_id. - assert dumped["custom_inputs"]["thread_id"] == "t1" + # Input is REPLAYED (not blanked) — the LLM sees full pre-crash history. + # The MLflow validator normalizes the shape (adds "type": "message" etc.) + # so compare essentials. + assert len(dumped["input"]) == 1 + assert dumped["input"][0]["role"] == "user" + assert dumped["input"][0]["content"] == "hi" + # thread_id was dropped so the handler's priority-2 fallback wins. + assert "thread_id" not in (dumped["custom_inputs"] or {}) + # Other custom_inputs keys are preserved. + assert dumped["custom_inputs"]["user_id"] == "u" + # conversation_id is rotated to a per-attempt value anchored on t1. + assert dumped["context"]["conversation_id"] == "t1::attempt-2" assert kwargs.get("attempt_number") == 2 + @pytest.mark.asyncio + async def test_resume_rotation_anchors_on_context_conversation_id(self): + """When the client didn't pin a thread_id/session_id, rotation uses + the injected context.conversation_id as the base anchor.""" + with patch(f"{MODULE}.is_db_configured", return_value=False): + server = LongRunningAgentServer("ResponsesAgent") + from datetime import timedelta + + resp = _resp_info( + status="in_progress", + created_at=datetime.now(timezone.utc) - timedelta(seconds=300), + heartbeat_at=datetime.now(timezone.utc) - timedelta(seconds=100), + original_request={ + "input": [{"role": "user", "content": "hi"}], + "custom_inputs": {}, + "context": {"conversation_id": "resp_x"}, + }, + ) + + captured_tasks = [] + + def capture_task(coro, *, name=None): + captured_tasks.append((coro, name)) + + class _Fake: + def cancel(self): + pass + + def add_done_callback(self, cb): + pass + + return _Fake() + + with ( + patch(f"{MODULE}.claim_stale_response", new_callable=AsyncMock, return_value=3), + patch(f"{MODULE}.get_messages", new_callable=AsyncMock, return_value=[]), + patch(f"{MODULE}.append_message", new_callable=AsyncMock), + patch("asyncio.create_task", side_effect=capture_task), + patch.object(server, "_run_background_stream", new_callable=AsyncMock) as mock_run, + ): + await server._try_claim_and_resume("resp_x", resp) + + assert len(captured_tasks) == 1 + coro, _name = captured_tasks[0] + await coro + mock_run.assert_awaited_once() + args, kwargs = mock_run.call_args + resume_request = args[1] if len(args) > 1 else kwargs["request_data"] + dumped = ( + resume_request.model_dump() if hasattr(resume_request, "model_dump") else resume_request + ) + # Rotation anchors on the stored context.conversation_id (priority 2). + # Note: re-rotating in a subsequent attempt would re-anchor on the + # ORIGINAL stored value, not the previous rotation — no stacking. + assert dumped["context"]["conversation_id"] == "resp_x::attempt-3" + class TestRetrieveTriggersLazyClaim: @pytest.mark.asyncio From 91360a9e2845f5e1354a1c77df340cb0aa89c2c5 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 22:31:43 +0000 Subject: [PATCH 18/39] Checkpoint saver: read-time repair so middleware stays optional MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncCheckpointSaver.aget_tuple / CheckpointSaver.get_tuple now run build_tool_resume_repair on returned state before handing it to the graph. With the LongRunningAgentServer rotate-on-resume + input-sanitizer landing in the same PR, this closes the last gap: a stable-thread_id client that relies on the checkpointer as cross-turn truth still survives a mid-tool crash, without the caller installing build_tool_resume_repair_middleware. Scenario the unit tests capture: turn 2 crashes mid-tool → checkpoint retains an orphan AIMessage.tool_calls on the stable thread → turn 3 loads that state and would fail the Anthropic tool_use ⇄ tool_result pairing check. Read-time repair injects a synthetic ToolMessage on every load, self-healing via the next checkpoint write at node boundary. - build_tool_resume_repair is idempotent and O(trailing-turn) — no perf regression on the happy path. - Never mutates caller-supplied lists; returns a new CheckpointTuple with a copied channel_values["messages"]. - Gracefully no-ops when langchain_core is not importable (non-agent checkpointer use cases). Middleware and session.repair() remain publicly exported for callers who want explicit control or are building on custom graphs outside create_agent's shape. Co-authored-by: Isaac --- .../src/databricks_langchain/checkpoint.py | 59 ++++++++++++++ .../tests/unit_tests/test_checkpoint.py | 79 +++++++++++++++++++ 2 files changed, 138 insertions(+) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 05993ba3..e0a8ba42 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -1,9 +1,13 @@ from __future__ import annotations +import copy +import logging from typing import Any, Sequence from databricks.sdk import WorkspaceClient +logger = logging.getLogger(__name__) + try: from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool from langgraph.checkpoint.postgres import PostgresSaver @@ -161,6 +165,53 @@ async def abefore_model(self, state, runtime): # type: ignore[override] return ToolResumeRepairMiddleware() +def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: + """Return a copy of ``tup`` with orphan tool_calls in its ``messages`` + channel closed by synthetic ``ToolMessage`` s. + + Called on every ``(a)get_tuple`` to make the served checkpoint + protocol-valid (every ``tool_use`` paired with a ``tool_result``) without + requiring callers to install ``build_tool_resume_repair_middleware``. + A kill between the ``model`` and ``tools`` nodes leaves the trailing + ``AIMessage.tool_calls`` unpaired; on the NEXT turn that state would + otherwise leak into the LLM and be rejected by the provider's pairing + check. + + Idempotent — ``build_tool_resume_repair`` is a no-op when state is + already clean. Cheap — the walk is O(trailing-turn), same as the + in-graph middleware. + + Side effect: the synthetic ``ToolMessage`` s added here become part of + the state LangGraph writes on the NEXT node boundary, so the repair + self-heals the DB row over time rather than re-computing on every read. + """ + if tup is None or not _message_imports_available: + return tup + + checkpoint = getattr(tup, "checkpoint", None) + if not isinstance(checkpoint, dict): + return tup + channel_values = checkpoint.get("channel_values") + if not isinstance(channel_values, dict): + return tup + messages = channel_values.get("messages") + if not isinstance(messages, list) or not messages: + return tup + + repair = build_tool_resume_repair(messages) + if not repair: + return tup + + logger.info( + "[durable] checkpoint read-time repair: injected %d synthetic ToolMessage(s)", + len(repair), + ) + new_checkpoint = copy.copy(checkpoint) + new_checkpoint["channel_values"] = dict(channel_values) + new_checkpoint["channel_values"]["messages"] = list(messages) + list(repair) + return tup._replace(checkpoint=new_checkpoint) + + class CheckpointSaver(PostgresSaver): """ LangGraph PostgresSaver using a Lakebase connection pool. @@ -205,6 +256,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._lakebase.close() return False + def get_tuple(self, config): + """Return the checkpoint tuple, with trailing orphan tool_calls paired.""" + return _repair_loaded_checkpoint_tuple(super().get_tuple(config)) + class AsyncCheckpointSaver(AsyncPostgresSaver): """ @@ -252,3 +307,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): """Exit async context manager and close the connection pool.""" await self._lakebase.close() return False + + async def aget_tuple(self, config): + """Return the checkpoint tuple, with trailing orphan tool_calls paired.""" + return _repair_loaded_checkpoint_tuple(await super().aget_tuple(config)) diff --git a/integrations/langchain/tests/unit_tests/test_checkpoint.py b/integrations/langchain/tests/unit_tests/test_checkpoint.py index 09b62a3f..4752e961 100644 --- a/integrations/langchain/tests/unit_tests/test_checkpoint.py +++ b/integrations/langchain/tests/unit_tests/test_checkpoint.py @@ -377,3 +377,82 @@ async def test_async_checkpoint_saver_branch_resource_path(monkeypatch): assert "host=auto-db-host" in test_pool.conninfo assert saver._lakebase._is_autoscaling is True + + +class TestReadTimeCheckpointRepair: + """Read-time repair: aget_tuple / get_tuple returns a state where every + trailing ``AIMessage.tool_calls`` is paired with a ``ToolMessage``. Keeps + user-space free of middleware when the app is built on our savers.""" + + def _make_tuple(self, messages): + from collections import namedtuple + + FakeTuple = namedtuple( + "CheckpointTuple", + ["config", "checkpoint", "metadata", "parent_config", "pending_writes"], + ) + return FakeTuple( + config={}, + checkpoint={ + "v": 1, + "id": "ckpt", + "channel_values": {"messages": list(messages)}, + }, + metadata={}, + parent_config=None, + pending_writes=None, + ) + + def test_repairs_trailing_orphan_tool_call(self): + from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + tup = self._make_tuple( + [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ] + ) + repaired = _repair_loaded_checkpoint_tuple(tup) + msgs = repaired.checkpoint["channel_values"]["messages"] + assert len(msgs) == 3 + assert isinstance(msgs[-1], ToolMessage) + assert msgs[-1].tool_call_id == "c1" + + def test_noop_when_state_is_clean(self): + from langchain_core.messages import AIMessage, HumanMessage, ToolMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + tup = self._make_tuple( + [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ToolMessage(tool_call_id="c1", content="ok"), + AIMessage(content="done"), + ] + ) + repaired = _repair_loaded_checkpoint_tuple(tup) + # No repair added → tuple unchanged. + assert repaired is tup + + def test_none_tuple_passes_through(self): + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + assert _repair_loaded_checkpoint_tuple(None) is None + + def test_does_not_mutate_original_messages_list(self): + from langchain_core.messages import AIMessage, HumanMessage + + from databricks_langchain.checkpoint import _repair_loaded_checkpoint_tuple + + original_messages = [ + HumanMessage("hi"), + AIMessage(content="", tool_calls=[{"id": "c1", "name": "f", "args": {}}]), + ] + tup = self._make_tuple(original_messages) + original_len = len(original_messages) + _repair_loaded_checkpoint_tuple(tup) + # Calling repair must NOT mutate the caller's original list. + assert len(original_messages) == original_len From 5da7dbdd8d791a2aa0e5286d990d4b46823c8ac8 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 22:59:17 +0000 Subject: [PATCH 19/39] Session: auto-repair on get_items so middleware-free templates are safe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AsyncDatabricksSession.get_items() now applies the orphan/duplicate walker in memory before returning. Parallels the LangGraph checkpointer read-time repair landed earlier in this branch. Closes the last gap for middleware-free templates: after PR 195 strips session.repair() from the template, a base-session (stable session_id) crash leaves orphan function_calls in session state that OpenAI's Runner would happily echo to the LLM on the next turn — 400 on tool_calls pairing. - Factor the walk out of repair() into _sanitize_items(items, synthetic). Returns the caller's list unchanged on the happy path so callers can cheaply skip re-persistence. - Add auto_repair / auto_repair_synthetic_output constructor kwargs (default True + the same interrupted-by-resume text). - Override get_items() to apply the filter when auto_repair is on. Runner.get_items() now always sees protocol-valid items. - repair() keeps working (bypasses the override to read raw items) and continues to rewrite the DB for callers that want persistent cleanup. 7 new tests cover the sanitizer walker (trailing orphan, mid-history orphan, multi-parallel-orphan, duplicate dedup, orphan-output drop) and the get_items auto-repair path (on/off). Co-authored-by: Isaac --- .../src/databricks_openai/agents/session.py | 188 ++++++++++++------ .../openai/tests/unit_tests/test_session.py | 122 ++++++++++++ 2 files changed, 247 insertions(+), 63 deletions(-) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 96a140d9..e36df9f5 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -74,6 +74,86 @@ def _item_dict(item: Any) -> dict: return dict(item.__dict__) if hasattr(item, "__dict__") else {} +def _sanitize_items( + items: list[Any], + synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, +) -> list[Any]: + """Return a protocol-valid view of session items. + + Walks items in chronological order, drops duplicate + ``function_call`` / ``function_call_output`` by ``call_id``, drops orphan + ``function_call_output`` items whose originating call is not present, + and injects a synthetic output immediately after any ``function_call`` + whose matching output never landed. + + Shared by ``repair()`` (destructive: rewrites the DB) and + ``get_items()`` (non-destructive: in-memory filter on read). Returning + the original ``items`` untouched when nothing needs repair lets callers + skip writes cheaply on the happy path. + """ + if not items: + return items + + call_ids_with_output: set[str] = set() + declared_call_ids: set[str] = set() + for item in items: + t = _item_get(item, "type") + cid = _item_get(item, "call_id") + if t == "function_call" and cid: + declared_call_ids.add(cid) + if t == "function_call_output" and cid: + call_ids_with_output.add(cid) + + sanitized: list[Any] = [] + seen_calls: set[str] = set() + seen_outputs: set[str] = set() + injected_call_ids: list[str] = [] + dropped_orphan_outputs = 0 + + for item in items: + t = _item_get(item, "type") + cid = _item_get(item, "call_id") + if t == "function_call" and cid: + if cid in seen_calls: + continue + seen_calls.add(cid) + sanitized.append(item) + if cid not in call_ids_with_output: + sanitized.append( + { + "type": "function_call_output", + "call_id": cid, + "output": synthetic_output, + } + ) + injected_call_ids.append(cid) + elif t == "function_call_output" and cid: + if cid in seen_outputs: + continue + if cid not in declared_call_ids: + dropped_orphan_outputs += 1 + continue + seen_outputs.add(cid) + sanitized.append(item) + else: + sanitized.append(item) + + if len(sanitized) == len(items) and not injected_call_ids and not dropped_orphan_outputs: + # Happy path — return the caller's list reference so they can + # cheaply skip any re-persistence. + return items + + logger.info( + "[durable] session items sanitized: injected=%d dropped_orphan_outputs=%d " + "original=%d final=%d", + len(injected_call_ids), + dropped_orphan_outputs, + len(items), + len(sanitized), + ) + return sanitized + + class AsyncDatabricksSession(SQLAlchemySession): """ Async OpenAI Agents SDK Session implementation for Databricks Lakebase. @@ -128,6 +208,8 @@ def __init__( sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", use_cached_engine: bool = True, + auto_repair: bool = True, + auto_repair_synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, **engine_kwargs, ) -> None: """ @@ -162,6 +244,9 @@ def __init__( "Please install with: pip install databricks-openai[memory]" ) + self._auto_repair = auto_repair + self._auto_repair_synthetic_output = auto_repair_synthetic_output + self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, autoscaling_endpoint=autoscaling_endpoint, @@ -188,24 +273,32 @@ def __init__( session_id, ) - async def repair(self, synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT) -> int: - """Reconcile the session so the next LLM call sees a valid conversation. + async def get_items(self, limit: Optional[int] = None) -> list[Any]: + """Return session items, repaired for protocol validity when enabled. - Handles two failure modes that durable-resume and client history echo - can introduce: + When ``auto_repair=True`` (default), the returned list has every + ``function_call`` paired with a ``function_call_output`` — orphans + from a durable-resume crash get a synthetic output appended, and + duplicates get deduped. The underlying DB rows are not modified; + this is a pure in-memory filter, cheap to re-run on every call. - 1. **Orphan function_calls.** A kill mid-tool leaves a ``function_call`` - item with no matching ``function_call_output``. The next LLM turn - fails with 'tool_calls must be followed by tool messages'. + Callers that want the raw persisted items can construct the session + with ``auto_repair=False``, or call ``repair()`` which writes the + sanitized state back to the DB. + """ + items = await super().get_items(limit=limit) + if not self._auto_repair: + return items + return _sanitize_items(items, synthetic_output=self._auto_repair_synthetic_output) - 2. **Duplicate items.** When the client re-sends prior history on a - resumed turn, the same ``call_id`` can land twice. Even with every - call_id eventually having an output, duplicates confuse the API. + async def repair(self, synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT) -> int: + """Reconcile the session DB so persisted rows are protocol-valid. - Walks items in chronological order, dedupes ``function_call`` / - ``function_call_output`` by ``call_id``, and injects a synthetic - ``function_call_output`` immediately after any orphan ``function_call``. - No-op when the session is already consistent. + Destructive — rewrites the ``agent_messages`` rows via + ``clear_session()`` + ``add_items(sanitized)``. Callers who only need + a clean view for the next LLM call should rely on ``get_items()``'s + auto-repair instead; ``repair()`` is for one-shot maintenance jobs + or tests that want to assert the DB itself is clean. Args: synthetic_output: Text used for the synthetic outputs inserted for @@ -213,65 +306,34 @@ async def repair(self, synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT) message. Returns: - The number of synthetic outputs injected (0 if the session was - already clean). + The number of synthetic outputs injected (0 if already clean). """ - items = await self.get_items() + # Bypass our auto-repair override so we see the raw items and can + # tell whether the DB is already clean. + items = await super().get_items() if not items: return 0 - - call_ids_with_output: set[str] = set() - for item in items: - if _item_get(item, "type") == "function_call_output": - cid = _item_get(item, "call_id") - if cid: - call_ids_with_output.add(cid) - - sanitized: list[dict] = [] - seen_calls: set[str] = set() - seen_outputs: set[str] = set() - injected_call_ids: list[str] = [] - - for item in items: - t = _item_get(item, "type") - cid = _item_get(item, "call_id") - if t == "function_call" and cid: - if cid in seen_calls: - continue # drop duplicate - seen_calls.add(cid) - sanitized.append(_item_dict(item)) - if cid not in call_ids_with_output: - sanitized.append( - { - "type": "function_call_output", - "call_id": cid, - "output": synthetic_output, - } - ) - injected_call_ids.append(cid) - elif t == "function_call_output" and cid: - if cid in seen_outputs: - continue # drop duplicate output - seen_outputs.add(cid) - sanitized.append(_item_dict(item)) - else: - sanitized.append(_item_dict(item)) - - if len(sanitized) == len(items) and not injected_call_ids: + sanitized = _sanitize_items(items, synthetic_output=synthetic_output) + # When _sanitize_items has nothing to do it returns ``items`` itself. + if sanitized is items: return 0 - + injected_call_ids = [ + _item_get(s, "call_id") + for s in sanitized + if _item_get(s, "type") == "function_call_output" + and _item_get(s, "call_id") not in {_item_get(i, "call_id") for i in items} + ] + sanitized_dicts = [_item_dict(i) for i in sanitized] logger.info( - "AsyncDatabricksSession.repair session_id=%s original=%d sanitized=%d " - "injected=%d call_ids=%s", + "AsyncDatabricksSession.repair session_id=%s original=%d sanitized=%d injected=%d", self.session_id, len(items), - len(sanitized), + len(sanitized_dicts), len(injected_call_ids), - injected_call_ids, ) await self.clear_session() - if sanitized: - await self.add_items(sanitized) + if sanitized_dicts: + await self.add_items(sanitized_dicts) return len(injected_call_ids) @classmethod diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index e3a348ba..1eec0f31 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1287,3 +1287,125 @@ def test_init_branch_resource_path_resolves_host( mock_autoscaling_workspace_client.postgres.list_endpoints.assert_called_once_with( parent="projects/my-project/branches/my-branch" ) + + +class TestSanitizeItems: + """Pure walker that reconciles orphan function_call / function_call_output + items. Shared by both the destructive ``repair()`` path and the read-time + ``get_items()`` filter.""" + + def _items_for(self, *types_and_ids): + # Helper: build items from (type, call_id) tuples. + items = [] + for spec in types_and_ids: + if isinstance(spec, str): + items.append({"role": "user", "content": spec}) + else: + t, cid = spec + items.append( + {"type": t, "call_id": cid, "name": "f", "arguments": "{}"} + if t == "function_call" + else {"type": t, "call_id": cid, "output": "ok"} + ) + return items + + def test_noop_when_clean_returns_same_list(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + "hi", + ("function_call", "c1"), + ("function_call_output", "c1"), + "done", + ) + out = _sanitize_items(items) + assert out is items # caller can skip re-persistence + + def test_injects_synthetic_output_for_orphan_call(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for("hi", ("function_call", "c1")) + out = _sanitize_items(items) + assert len(out) == 3 + assert out[-1]["type"] == "function_call_output" + assert out[-1]["call_id"] == "c1" + + def test_injects_for_multiple_orphan_calls(self): + # Scenario the user hit: multiple parallel tool_calls, all orphaned. + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + "hi", + ("function_call", "c1"), + ("function_call", "c2"), + ("function_call", "c3"), + ) + out = _sanitize_items(items) + calls = [i for i in out if i.get("type") == "function_call"] + outputs = [i for i in out if i.get("type") == "function_call_output"] + assert len(calls) == 3 + assert len(outputs) == 3 + assert {o["call_id"] for o in outputs} == {"c1", "c2", "c3"} + + def test_drops_orphan_output_with_no_matching_call(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for("hi", ("function_call_output", "ghost")) + out = _sanitize_items(items) + assert all(i.get("type") != "function_call_output" for i in out) + + def test_dedupes_duplicate_calls_and_outputs(self): + from databricks_openai.agents.session import _sanitize_items + + items = self._items_for( + ("function_call", "c1"), + ("function_call", "c1"), + ("function_call_output", "c1"), + ("function_call_output", "c1"), + ) + out = _sanitize_items(items) + assert len(out) == 2 + + +class TestAsyncGetItemsAutoRepair: + """get_items() applies read-time repair when auto_repair=True. Uses a + minimal subclass that bypasses parent SQLAlchemySession init so we can + exercise the override without a DB.""" + + def _fake_session(self, items, auto_repair=True): + from databricks_openai.agents.session import AsyncDatabricksSession, _sanitize_items + + class _FakeSession(AsyncDatabricksSession): + def __init__(self, stored, auto): + # Bypass parent init — only need the auto-repair flags. + self._auto_repair = auto + self._auto_repair_synthetic_output = "INTERRUPTED" + self._stored = stored + + async def get_items(self, limit=None): + items = list(self._stored) + if not self._auto_repair: + return items + return _sanitize_items(items, synthetic_output=self._auto_repair_synthetic_output) + + return _FakeSession(items, auto_repair) + + @pytest.mark.asyncio + async def test_auto_repair_injects_synthetic_outputs(self): + sess = self._fake_session( + [ + {"role": "user", "content": "hi"}, + {"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}, + {"type": "function_call", "call_id": "c2", "name": "f", "arguments": "{}"}, + ] + ) + items = await sess.get_items() + synth = [i for i in items if i.get("output") == "INTERRUPTED"] + assert len(synth) == 2 + + @pytest.mark.asyncio + async def test_auto_repair_off_returns_raw_items(self): + raw = [{"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}] + sess = self._fake_session(list(raw), auto_repair=False) + items = await sess.get_items() + assert items == raw From 51f0a4c721ba2630663930c45c88e361795a333b Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:09:15 +0000 Subject: [PATCH 20/39] Stamp custom_inputs.attempt_number on resume so handlers can see retries MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rotation gives each resume attempt a fresh conversation anchor, which means the LLM on attempt 2 sees no signal that it's a retry — it'll re-plan from scratch and can re-run read-only tools it already executed before the crash. Callers asked for a way to detect retries without reading the whole state machine. Zero-opinion breadcrumb: _rotate_conversation_id now stamps custom_inputs["attempt_number"] = N on every resume. Absent on normal first-attempt requests. Templates that want retry-aware behavior (e.g., appending "you are resuming a retry" to the system prompt, or gating side-effectful tool execution) can opt in by reading this field; others ignore it and get the current rotation-default behavior. Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/server.py | 5 +++++ .../databricks_ai_bridge/test_long_running_server.py | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index af63130b..3aad2c08 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -275,6 +275,11 @@ def _rotate_conversation_id( custom_inputs.pop("thread_id", None) custom_inputs.pop("session_id", None) + # Leave a breadcrumb so handlers that care about retry awareness (e.g., + # injecting a "you are resuming a retry" system prompt, or opting-out of + # retry-unsafe tools) can branch on it. Absent from normal first-attempt + # requests — handlers should default to "1" if missing. + custom_inputs["attempt_number"] = new_attempt_number request_dict["custom_inputs"] = custom_inputs ctx = request_dict.get("context") or {} diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index b851481f..a272544d 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1059,7 +1059,16 @@ def test_rotate_handles_missing_custom_inputs_key(self): r = {"context": {"conversation_id": "c-abc"}} out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") assert out["context"]["conversation_id"] == "c-abc::attempt-2" - assert out["custom_inputs"] == {} + # attempt_number breadcrumb is stamped even when custom_inputs was absent. + assert out["custom_inputs"] == {"attempt_number": 2} + + def test_rotate_stamps_attempt_number_breadcrumb(self): + r = {"custom_inputs": {"user_id": "u"}, "context": {}} + out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") + # Handlers can branch on this to inject retry-aware behavior. + assert out["custom_inputs"]["attempt_number"] == 3 + # Unrelated custom_inputs keys are preserved. + assert out["custom_inputs"]["user_id"] == "u" class TestInjectConversationId: From f3f8eb3015406e9ecb58ed9138d6fcff8cdd5d91 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:18:14 +0000 Subject: [PATCH 21/39] Synthetic-output text: informative, scoped, nudges against re-running peers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three sites mint synthetic outputs for orphaned tool calls (server input sanitizer, OpenAI session auto-repair, LangGraph checkpointer read-time repair). The old text ("Tool call was interrupted... retry if still needed") was too generic — dogfood UI testing showed models reading it and rationally concluding "let me re-run the whole sequence from the top to get a clean re-execution", re-invoking peers that had already completed. New text is scoped to the single interrupted call: - notes THIS tool call failed and no result is available - asserts other tool calls' results in history are still valid - suggests re-invoking only this specific tool if still needed Informative rather than directive — leaves room for the model's judgment while removing the ambiguity that drove full-sequence re-runs. All three constants carry the same text so LangGraph and OpenAI stay in sync regardless of which layer (server, session, checkpointer) minted the synthetic output. Co-authored-by: Isaac --- .../langchain/src/databricks_langchain/checkpoint.py | 7 +++++-- .../openai/src/databricks_openai/agents/session.py | 7 +++++-- src/databricks_ai_bridge/long_running/server.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index e0a8ba42..95acd1bc 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -31,8 +31,11 @@ DEFAULT_TOOL_RESUME_REPAIR_OUTPUT = ( - "Tool call was interrupted by a durable resume and did not complete. " - "Please retry if still needed." + "[INTERRUPTED] This tool call did not complete due to a server " + "interruption, so no result is available. Other tool calls in the " + "conversation history completed normally and their results remain valid. " + "If the information is still needed, re-invoking only this specific tool " + "is usually sufficient." ) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index e36df9f5..0b7576a7 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -36,8 +36,11 @@ async def main(): from typing import Any, Optional DEFAULT_REPAIR_SYNTHETIC_OUTPUT = ( - "Tool call was interrupted by a durable resume and did not complete. " - "Please retry if still needed." + "[INTERRUPTED] This tool call did not complete due to a server " + "interruption, so no result is available. Other tool calls in the " + "conversation history completed normally and their results remain valid. " + "If the information is still needed, re-invoking only this specific tool " + "is usually sufficient." ) try: diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 3aad2c08..634534f5 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -53,11 +53,15 @@ BACKGROUND_KEY = "background" # Synthetic output injected for an orphaned function_call whose matching -# function_call_output was lost to a pod crash. "interrupted" is the most -# honest label: the LLM decides whether to retry on the next turn. +# function_call_output was lost to a pod crash. The text is prescriptive: +# without the "do NOT re-invoke tools that already returned" guidance, models +# tend to restart the whole tool sequence on resume. DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT = ( - "Tool call was interrupted by a server-side event " - "(e.g., pod restart). No result was produced." + "[INTERRUPTED] This tool call did not complete due to a server " + "interruption, so no result is available. Other tool calls in the " + "conversation history completed normally and their results remain valid. " + "If the information is still needed, re-invoking only this specific tool " + "is usually sufficient." ) # One ID per process so heartbeats + claims have a stable owner identity. From 40d7e09ef9e579aaffb1e434e14d1cbed672f319 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:27:41 +0000 Subject: [PATCH 22/39] Resume inherits prior attempt's completed tool outputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rotation gave each resume attempt a fresh thread but ALSO blanked attempt N+1's visibility into what attempt N already accomplished — the LLM re-planned from just the user's latest message and re-emitted tool calls that had already completed. Dogfood UI testing showed this clearly: a "call get_time then deep_research" turn that crashed during deep_research would, on resume, re-run get_time too even though its result had already streamed to the client. _try_claim_and_resume now collects the prior attempt's emitted function_call / function_call_output items from agent_server.messages and appends them to the replayed input[]. The sanitizer's full-history walker then closes the interrupted call with a synthetic "[INTERRUPTED]" output. Net effect on attempt N+1's LLM: user: call get_time then deep_research AI: tool_call c1 (get_time) tool: AI: tool_call c2 (deep_research) tool: [INTERRUPTED] This tool call did not complete... The LLM sees attempt N's get_time result as valid history and the informative synthetic output on the interrupted deep_research call, guiding it to re-invoke only the interrupted tool instead of the whole chain. Mechanics: - Read events from the same get_messages() call already used to compute the response.resumed sentinel's sequence_number — no extra DB round-trip. - Filter to attempt_number == new_attempt - 1 + item.type in {function_call, function_call_output}. - Append to deep-copied input[] before the existing sanitizer + rotation. Composes with the attempt_number breadcrumb (handlers that want to can still branch on custom_inputs.attempt_number > 1 for retry-aware behavior). Co-authored-by: Isaac --- .../long_running/server.py | 77 +++++++++++++++---- .../test_long_running_server.py | 53 +++++++++++++ 2 files changed, 115 insertions(+), 15 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 634534f5..40e611ed 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -138,6 +138,39 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +def _collect_prior_attempt_tool_events( + messages: list[tuple], prior_attempt_number: int +) -> list[dict]: + """Return ``function_call`` / ``function_call_output`` items that the + given prior attempt already emitted as ``response.output_item.done``. + + Lets the next resume attempt inherit already-completed tool results + instead of starting blank. Without this, the new attempt's LLM re-plans + from just the user's latest message and will re-emit tool calls that + already ran successfully — exactly what ``get_time then deep_research`` + UI testing surfaced when only ``deep_research`` was interrupted. + + ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, + attempt_number)``. Only ``response.output_item.done`` events are + considered; other event types (text deltas, etc.) don't carry the + canonical item shape. + """ + out: list[dict] = [] + for _seq, _item_json, evt, attempt_tag in messages: + if attempt_tag != prior_attempt_number: + continue + if not isinstance(evt, dict): + continue + if evt.get("type") != "response.output_item.done": + continue + item = evt.get("item") + if not isinstance(item, dict): + continue + if item.get("type") in ("function_call", "function_call_output"): + out.append(item) + return out + + def _sanitize_request_input( request_dict: dict[str, Any], synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, @@ -1001,32 +1034,46 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: return None # Build a "resume" request by REPLAYING the original POST's input on a - # ROTATED conversation anchor. Two coordinated moves: + # ROTATED conversation anchor, enriched with the prior attempt's + # already-emitted tool events: # - # 1. Keep original_request["input"] intact — it was captured pre-crash - # and is protocol-clean by construction. The LLM sees full history - # via the input list, not via checkpointer state. + # 1. Carry forward the prior attempt's function_call / function_call_output + # items so the LLM sees what's already been done. Without this, + # attempt N+1 re-plans from just the user's latest message and + # re-emits tool calls that previously completed (e.g. it re-runs + # get_time even though only deep_research was interrupted). The + # interrupted tool's orphan function_call gets a synthetic + # "interrupted" output via the sanitizer below. # # 2. Rotate conversation_id so the handler's SDK helpers resolve to a # FRESH thread_id / session_id for this attempt. Without this, the # handler would reload the crashed attempt's mid-turn checkpoint, # which on LangGraph produces a stream-event orphan artifact at the - # attempt boundary (rotation-findings.md stress test). - # - # Trade-off: attempt N+1 re-runs the LLM from scratch on the replayed - # input (one extra LLM call + any side-effectful tool re-run). This is - # strictly safer than resuming on a checkpointer that may have state - # the SDK can't fully repair. + # attempt boundary (rotation-findings.md stress test). Attempt N+1 + # runs on a clean checkpointer; the prior-attempt tool events in + # input[] are the single source of truth for what already ran. + existing = await get_messages(response_id, after_sequence=None) + next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 + prior_tool_events = _collect_prior_attempt_tool_events( + existing, prior_attempt_number=new_attempt - 1 + ) + resume_dict = copy.deepcopy(resp.original_request) + if prior_tool_events: + resume_input = list(resume_dict.get("input") or []) + resume_input.extend(prior_tool_events) + resume_dict["input"] = resume_input + logger.info( + "[durable] resume inherited %d tool-event item(s) from attempt %d " + "response_id=%s", + len(prior_tool_events), + new_attempt - 1, + response_id, + ) if self._settings.auto_sanitize_input: resume_dict = _sanitize_request_input(resume_dict) resume_dict = _rotate_conversation_id(resume_dict, new_attempt, response_id) resume_request = self.validator.validate_and_convert_request(resume_dict) - - # Emit a marker so clients can reset any in-flight rendering from the - # prior attempt before seeing new events. - existing = await get_messages(response_id, after_sequence=None) - next_seq = max((s for s, _, _, _ in existing), default=-1) + 1 await append_message( response_id, next_seq, diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index a272544d..9cd31a29 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -18,6 +18,7 @@ from databricks_ai_bridge.long_running.server import ( DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, LongRunningAgentServer, + _collect_prior_attempt_tool_events, _deferred_mark_failed, _inject_conversation_id, _rotate_conversation_id, @@ -1031,6 +1032,58 @@ def test_non_dict_items_pass_through(self): assert out["input"] == items +class TestCollectPriorAttemptToolEvents: + """Gather function_call / function_call_output items emitted during a + prior attempt so the next attempt can inherit already-completed tool + results instead of re-running them from scratch.""" + + def _event(self, seq, attempt, item_type, call_id, output=None): + item = {"type": item_type, "call_id": call_id, "name": "f", "arguments": "{}"} + if output is not None: + item = {"type": item_type, "call_id": call_id, "output": output} + evt = {"type": "response.output_item.done", "item": item} + return (seq, None, evt, attempt) + + def test_filters_to_requested_prior_attempt(self): + messages = [ + self._event(0, 1, "function_call", "c1"), + self._event(1, 1, "function_call_output", "c1", output="ok"), + # attempt 2's events should not be returned when asking for attempt 1. + self._event(2, 2, "function_call", "c2"), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert [i["call_id"] for i in out] == ["c1", "c1"] + assert [i["type"] for i in out] == ["function_call", "function_call_output"] + + def test_only_output_item_done_events_count(self): + noise = ( + 0, + None, + {"type": "response.output_text.delta", "delta": "hi"}, + 1, + ) + messages = [noise, self._event(1, 1, "function_call", "c1")] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["call_id"] == "c1" + + def test_returns_empty_when_prior_attempt_emitted_no_tool_events(self): + # Attempt 1 was just text, no tool calls. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.done", + "item": {"type": "message", "role": "assistant"}, + }, + 1, + ) + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert out == [] + + class TestRotateConversationId: def test_rotate_drops_thread_id_and_sets_rotated_context(self): r = {"custom_inputs": {"thread_id": "t1", "user_id": "u"}, "context": {}} From 77cd8a878460e94f6dc545b18c570883220030fe Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:41:18 +0000 Subject: [PATCH 23/39] Resume inheritance: include completed assistant message items MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend _collect_prior_attempt_tool_events to carry forward response.output_item.done items of type "message" (assistant narrative text) in addition to function_call / function_call_output. Lets the next attempt's LLM see its own prior narration as history so it continues where it left off instead of re-narrating from scratch. Only fully-completed items are recoverable — mid-stream partial text (output_text.delta frames that never reached output_item.done) can't be reassembled from the event log and is lost on crash. Limits the allow-list to the three known conversational types (function_call, function_call_output, message) so future item kinds don't auto-flow through without review. Co-authored-by: Isaac --- .../long_running/server.py | 29 +++++++++++------- .../test_long_running_server.py | 30 +++++++++++++++++-- 2 files changed, 45 insertions(+), 14 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 40e611ed..f14eb518 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -138,22 +138,29 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() +_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") + + def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: - """Return ``function_call`` / ``function_call_output`` items that the - given prior attempt already emitted as ``response.output_item.done``. + """Return completed conversational items from the given prior attempt. + + Collects ``response.output_item.done`` events whose item is a + ``function_call``, ``function_call_output``, or assistant ``message``. + Letting the next attempt inherit these lets its LLM see the prior + attempt's completed work — tool results AND narrative text — so it can + continue from where things left off instead of re-planning from just + the user's latest message (which caused observed "re-run the whole + chain" behavior and regenerated narration in UI testing). - Lets the next resume attempt inherit already-completed tool results - instead of starting blank. Without this, the new attempt's LLM re-plans - from just the user's latest message and will re-emit tool calls that - already ran successfully — exactly what ``get_time then deep_research`` - UI testing surfaced when only ``deep_research`` was interrupted. + Note: only *fully completed* items are recoverable here. Mid-stream + partial text (``response.output_text.delta`` frames that never reached + ``output_item.done``) is lost by design — the event log doesn't carry + reassemblable partial items. ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, - attempt_number)``. Only ``response.output_item.done`` events are - considered; other event types (text deltas, etc.) don't carry the - canonical item shape. + attempt_number)``. """ out: list[dict] = [] for _seq, _item_json, evt, attempt_tag in messages: @@ -166,7 +173,7 @@ def _collect_prior_attempt_tool_events( item = evt.get("item") if not isinstance(item, dict): continue - if item.get("type") in ("function_call", "function_call_output"): + if item.get("type") in _INHERITABLE_ITEM_TYPES: out.append(item) return out diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 9cd31a29..0370beb8 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1067,15 +1067,39 @@ def test_only_output_item_done_events_count(self): assert len(out) == 1 assert out[0]["call_id"] == "c1" - def test_returns_empty_when_prior_attempt_emitted_no_tool_events(self): - # Attempt 1 was just text, no tool calls. + def test_inherits_assistant_message_items(self): + # Completed assistant text messages inherit too, so the next attempt's + # LLM sees its prior narration and doesn't re-emit it from scratch. messages = [ ( 0, None, { "type": "response.output_item.done", - "item": {"type": "message", "role": "assistant"}, + "item": {"type": "message", "role": "assistant", "content": "Let me check"}, + }, + 1, + ), + self._event(1, 1, "function_call", "c1"), + self._event(2, 1, "function_call_output", "c1", output="ok"), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + # 1 message + 1 function_call + 1 function_call_output + assert len(out) == 3 + assert out[0]["type"] == "message" + assert out[1]["type"] == "function_call" + assert out[2]["type"] == "function_call_output" + + def test_skips_unknown_item_types(self): + # Item types outside the allow-list (e.g., future event kinds) are + # dropped — safer than forwarding them to the handler. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.done", + "item": {"type": "reasoning", "content": "think"}, }, 1, ) From 7fecacddc8f30bfa2d9dd27edabd83bbb698b99e Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:47:15 +0000 Subject: [PATCH 24/39] Resume inheritance: reassemble mid-stream partial text from deltas MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A text-only crash — a model generating a long response that gets killed mid-token-stream — never emits response.output_item.done for the in-flight assistant message. My prior "only completed items" filter found nothing to inherit and attempt N+1 regenerated from the top. Observed directly in UI testing: 324 output_text.delta events on attempt 1, zero output_item.done, attempt 2 restarted from the beginning. Fix: walk the same event log twice in one pass. Besides the existing output_item.done collection, track in-flight message items by id via output_item.added, accumulate their output_text.delta frames, and emit any never-completed items as synthetic assistant message items at the end. The assembled text gives attempt N+1's LLM the partial narration as prior assistant context, letting it continue where the crash cut off instead of restarting. If a .done eventually lands for a tracked id, the in-progress tracker is cleared — no duplicate emission alongside the completed item. Co-authored-by: Isaac --- .../long_running/server.py | 85 +++++++++++++++---- .../test_long_running_server.py | 73 ++++++++++++++++ 2 files changed, 140 insertions(+), 18 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index f14eb518..c7458150 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -144,37 +144,86 @@ def _age_seconds(created_at: datetime) -> float: def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: - """Return completed conversational items from the given prior attempt. + """Return conversational items the given prior attempt already emitted, + including partial-message reassembly. - Collects ``response.output_item.done`` events whose item is a - ``function_call``, ``function_call_output``, or assistant ``message``. - Letting the next attempt inherit these lets its LLM see the prior - attempt's completed work — tool results AND narrative text — so it can - continue from where things left off instead of re-planning from just - the user's latest message (which caused observed "re-run the whole - chain" behavior and regenerated narration in UI testing). + Two passes combined: - Note: only *fully completed* items are recoverable here. Mid-stream - partial text (``response.output_text.delta`` frames that never reached - ``output_item.done``) is lost by design — the event log doesn't carry - reassemblable partial items. + 1. **Completed items**: ``response.output_item.done`` events for + ``function_call`` / ``function_call_output`` / assistant ``message``. + These are the reliable, self-contained pieces the prior attempt + finished. + + 2. **Partial in-flight messages**: if the crash happened mid-stream on a + text response, the message has ``response.output_item.added`` + + ``response.output_text.delta`` events but no ``done``. We reassemble + the accumulated deltas into a synthetic ``message`` item so the next + attempt's LLM sees where the narration trailed off and can continue. + Without this, a text-only crash leaves the LLM with just the user's + message and it restarts from the top. + + Events are walked in sequence_number order; partial items emit in the + position the attempt started them. ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, attempt_number)``. """ out: list[dict] = [] + # Track in-progress message items by id so we can reassemble their + # deltas. When a matching .done lands, clear the tracker — the completed + # item was already captured by the "done" branch below. + in_progress_text: dict[str, list[str]] = {} + in_progress_order: list[str] = [] + for _seq, _item_json, evt, attempt_tag in messages: if attempt_tag != prior_attempt_number: continue if not isinstance(evt, dict): continue - if evt.get("type") != "response.output_item.done": - continue - item = evt.get("item") - if not isinstance(item, dict): + t = evt.get("type") + + if t == "response.output_item.done": + item = evt.get("item") + if isinstance(item, dict) and item.get("type") in _INHERITABLE_ITEM_TYPES: + out.append(item) + if item.get("type") == "message": + iid = item.get("id") + if iid in in_progress_text: + del in_progress_text[iid] + in_progress_order.remove(iid) + elif t == "response.output_item.added": + item = evt.get("item") + if isinstance(item, dict) and item.get("type") == "message": + iid = item.get("id") + if iid: + in_progress_text.setdefault(iid, []) + if iid not in in_progress_order: + in_progress_order.append(iid) + elif t == "response.output_text.delta": + iid = evt.get("item_id") + delta = evt.get("delta") + if iid and isinstance(delta, str) and iid in in_progress_text: + in_progress_text[iid].append(delta) + + # Emit synthetic message items for any never-completed in-progress text. + for iid in in_progress_order: + chunks = in_progress_text.get(iid) or [] + if not chunks: continue - if item.get("type") in _INHERITABLE_ITEM_TYPES: - out.append(item) + partial_text = "".join(chunks) + out.append( + { + "type": "message", + "role": "assistant", + "content": [ + { + "type": "output_text", + "text": partial_text, + "annotations": [], + } + ], + } + ) return out diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 0370beb8..cda99525 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1090,6 +1090,79 @@ def test_inherits_assistant_message_items(self): assert out[1]["type"] == "function_call" assert out[2]["type"] == "function_call_output" + def test_reassembles_partial_text_from_delta_events(self): + # Attempt crashed mid-stream: item.added + deltas but no item.done. + # The collector should synthesize a message item from accumulated deltas + # so the next attempt sees where narration trailed off. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": {"type": "message", "id": "msg_1"}, + }, + 1, + ), + ( + 1, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "Hello, "}, + 1, + ), + ( + 2, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "world"}, + 1, + ), + # No item.done — crash. + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["type"] == "message" + assert out[0]["role"] == "assistant" + assert out[0]["content"][0]["text"] == "Hello, world" + + def test_ignores_partial_text_if_item_eventually_completed(self): + # Deltas streamed, then item.done landed — the completed item is what + # we inherit; the deltas are just intermediate frames. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": {"type": "message", "id": "msg_1"}, + }, + 1, + ), + ( + 1, + None, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "Hello"}, + 1, + ), + ( + 2, + None, + { + "type": "response.output_item.done", + "item": { + "type": "message", + "id": "msg_1", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hello, world"}], + }, + }, + 1, + ), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + # Only the completed item — NOT a duplicate from the partial deltas. + assert len(out) == 1 + assert out[0]["content"][0]["text"] == "Hello, world" + def test_skips_unknown_item_types(self): # Item types outside the allow-list (e.g., future event kinds) are # dropped — safer than forwarding them to the handler. From 6ef968f905f2456732ecb2ce683559812168956e Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 22 Apr 2026 23:56:19 +0000 Subject: [PATCH 25/39] Resume inheritance: also reassemble reasoning + function_call arg streams MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the partial-reassembly path beyond message text. Generalises into a dispatch table so adding a new item kind is two entries: a reassembler callable + its expected delta event type(s). - reasoning items: added to the inheritable allow-list. Partial reasoning (added + delta frames, no .done) is reassembled from response.reasoning.delta / response.reasoning_summary_text.delta frames into a synthetic reasoning item with a summary_text block. Relevant for thinking-mode models (Claude extended thinking, o1-style) where the reasoning token stream can crash mid-flight. - function_call argument streams: added + function_call_arguments.delta frames reassemble into a function_call item with the accumulated arguments text. If the partial JSON doesn't parse, we fall back to '{}' — the item is still protocol-valid input and the input sanitizer will pair it with a synthetic [INTERRUPTED] output, so the LLM sees "this tool call started but didn't finish" rather than losing the attempt entirely. File_search / code_interpreter / computer_call items are intentionally NOT auto-inherited — they're app-specific and their shapes vary; add to the allow-list if an app needs them. Co-authored-by: Isaac --- .../long_running/server.py | 170 +++++++++++++----- .../test_long_running_server.py | 125 ++++++++++++- 2 files changed, 245 insertions(+), 50 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index c7458150..d984bb9f 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -138,41 +138,110 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") +_INHERITABLE_ITEM_TYPES = ( + "function_call", + "function_call_output", + "message", + "reasoning", +) + +# Event types whose ``delta`` contributes to a specific in-progress item kind. +# Single dict so a new partial-reassembly type is just two entries: a +# reassembler below + an entry here. +_DELTA_EVENTS_BY_ITEM_TYPE = { + "message": {"response.output_text.delta"}, + "reasoning": { + "response.reasoning.delta", + "response.reasoning_summary_text.delta", + }, + "function_call": {"response.function_call_arguments.delta"}, +} + + +def _reassemble_message(text: str, template: dict) -> dict: + return { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": text, "annotations": []}, + ], + } + + +def _reassemble_reasoning(text: str, template: dict) -> dict: + return { + "type": "reasoning", + "summary": [{"type": "summary_text", "text": text}], + } + + +def _reassemble_function_call(arguments_text: str, template: dict) -> dict | None: + """Rebuild a function_call from its added-event template + accumulated + argument-delta text. + + Falls back to ``{}`` if the partial arguments don't parse as JSON, so the + item is still protocol-valid input (the sanitizer will pair it with a + synthetic "[INTERRUPTED]" output — the LLM reads that as "this call was + started but didn't finish" and can re-invoke cleanly). + """ + call_id = template.get("call_id") or template.get("id") + if not call_id: + return None + try: + json.loads(arguments_text) + args = arguments_text + except Exception: + args = "{}" + return { + "type": "function_call", + "call_id": call_id, + "name": template.get("name"), + "arguments": args, + } + + +_PARTIAL_REASSEMBLERS = { + "message": _reassemble_message, + "reasoning": _reassemble_reasoning, + "function_call": _reassemble_function_call, +} def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: """Return conversational items the given prior attempt already emitted, - including partial-message reassembly. + including partial reassembly of messages, reasoning, and tool-call + argument streams. Two passes combined: 1. **Completed items**: ``response.output_item.done`` events for - ``function_call`` / ``function_call_output`` / assistant ``message``. - These are the reliable, self-contained pieces the prior attempt + ``function_call`` / ``function_call_output`` / assistant ``message`` + / ``reasoning``. Reliable self-contained pieces the prior attempt finished. - 2. **Partial in-flight messages**: if the crash happened mid-stream on a - text response, the message has ``response.output_item.added`` + - ``response.output_text.delta`` events but no ``done``. We reassemble - the accumulated deltas into a synthetic ``message`` item so the next - attempt's LLM sees where the narration trailed off and can continue. + 2. **Partial in-flight items**: if the crash happened mid-stream, an + item has ``response.output_item.added`` and a sequence of + type-specific delta frames (``output_text.delta``, + ``reasoning.delta`` / ``reasoning_summary_text.delta``, + ``function_call_arguments.delta``) but never reached ``done``. + We reassemble the deltas into a synthetic item so the next attempt's + LLM sees where narration / reasoning / tool-call args trailed off. Without this, a text-only crash leaves the LLM with just the user's message and it restarts from the top. - Events are walked in sequence_number order; partial items emit in the - position the attempt started them. + Events are walked in sequence_number order; partial items are emitted + last (after all completed items from the same attempt). ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, attempt_number)``. """ out: list[dict] = [] - # Track in-progress message items by id so we can reassemble their - # deltas. When a matching .done lands, clear the tracker — the completed - # item was already captured by the "done" branch below. - in_progress_text: dict[str, list[str]] = {} + # Track in-progress items by their server-assigned id: type + chunks + + # the original added-event item shell (needed for function_call's + # name / call_id, which don't appear on delta frames). + in_progress: dict[str, dict] = {} in_progress_order: list[str] = [] for _seq, _item_json, evt, attempt_tag in messages: @@ -184,46 +253,53 @@ def _collect_prior_attempt_tool_events( if t == "response.output_item.done": item = evt.get("item") - if isinstance(item, dict) and item.get("type") in _INHERITABLE_ITEM_TYPES: + if not isinstance(item, dict): + continue + if item.get("type") in _INHERITABLE_ITEM_TYPES: out.append(item) - if item.get("type") == "message": - iid = item.get("id") - if iid in in_progress_text: - del in_progress_text[iid] - in_progress_order.remove(iid) + # Drop matching in-progress tracker — the completed item is + # authoritative; no duplicate emission from deltas. + iid = item.get("id") + if iid in in_progress: + del in_progress[iid] + in_progress_order.remove(iid) elif t == "response.output_item.added": item = evt.get("item") - if isinstance(item, dict) and item.get("type") == "message": - iid = item.get("id") - if iid: - in_progress_text.setdefault(iid, []) - if iid not in in_progress_order: - in_progress_order.append(iid) - elif t == "response.output_text.delta": + if not isinstance(item, dict): + continue + item_type = item.get("type") + iid = item.get("id") + if iid and item_type in _PARTIAL_REASSEMBLERS: + in_progress[iid] = { + "item_type": item_type, + "chunks": [], + "template": dict(item), + } + if iid not in in_progress_order: + in_progress_order.append(iid) + else: iid = evt.get("item_id") delta = evt.get("delta") - if iid and isinstance(delta, str) and iid in in_progress_text: - in_progress_text[iid].append(delta) + if not (iid and isinstance(delta, str) and iid in in_progress): + continue + expected_events = _DELTA_EVENTS_BY_ITEM_TYPE.get( + in_progress[iid]["item_type"], set() + ) + if t in expected_events: + in_progress[iid]["chunks"].append(delta) - # Emit synthetic message items for any never-completed in-progress text. for iid in in_progress_order: - chunks = in_progress_text.get(iid) or [] - if not chunks: + tracked = in_progress.get(iid) + if not tracked or not tracked["chunks"]: continue - partial_text = "".join(chunks) - out.append( - { - "type": "message", - "role": "assistant", - "content": [ - { - "type": "output_text", - "text": partial_text, - "annotations": [], - } - ], - } - ) + reassembler = _PARTIAL_REASSEMBLERS.get(tracked["item_type"]) + if reassembler is None: + continue + joined = "".join(tracked["chunks"]) + reassembled = reassembler(joined, tracked["template"]) + if reassembled is not None: + out.append(reassembled) + return out diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index cda99525..f8b41001 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1163,16 +1163,135 @@ def test_ignores_partial_text_if_item_eventually_completed(self): assert len(out) == 1 assert out[0]["content"][0]["text"] == "Hello, world" + def test_reassembles_partial_reasoning(self): + # Mid-stream reasoning: added + reasoning.delta frames, no .done. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": {"type": "reasoning", "id": "r_1"}, + }, + 1, + ), + ( + 1, + None, + { + "type": "response.reasoning.delta", + "item_id": "r_1", + "delta": "Let me think about ", + }, + 1, + ), + ( + 2, + None, + { + "type": "response.reasoning_summary_text.delta", + "item_id": "r_1", + "delta": "this carefully.", + }, + 1, + ), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["type"] == "reasoning" + assert out[0]["summary"][0]["text"] == "Let me think about this carefully." + + def test_reassembles_partial_function_call_arguments(self): + # Tool call arg stream started but crash happened before .done. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "id": "fc_1", + "call_id": "call_abc", + "name": "get_weather", + }, + }, + 1, + ), + ( + 1, + None, + { + "type": "response.function_call_arguments.delta", + "item_id": "fc_1", + "delta": '{"city":', + }, + 1, + ), + ( + 2, + None, + { + "type": "response.function_call_arguments.delta", + "item_id": "fc_1", + "delta": ' "Paris"}', + }, + 1, + ), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["type"] == "function_call" + assert out[0]["call_id"] == "call_abc" + assert out[0]["name"] == "get_weather" + assert out[0]["arguments"] == '{"city": "Paris"}' + + def test_partial_function_call_malformed_args_falls_back_to_empty_object(self): + # Args stream cut mid-JSON — not parseable. Fall back to {} so the + # function_call item stays protocol-valid; the sanitizer still pairs + # it with an INTERRUPTED output. + messages = [ + ( + 0, + None, + { + "type": "response.output_item.added", + "item": { + "type": "function_call", + "id": "fc_2", + "call_id": "call_xyz", + "name": "get_weather", + }, + }, + 1, + ), + ( + 1, + None, + { + "type": "response.function_call_arguments.delta", + "item_id": "fc_2", + "delta": '{"city":', + }, + 1, + ), + ] + out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) + assert len(out) == 1 + assert out[0]["arguments"] == "{}" + assert out[0]["call_id"] == "call_xyz" + def test_skips_unknown_item_types(self): - # Item types outside the allow-list (e.g., future event kinds) are - # dropped — safer than forwarding them to the handler. + # Item types outside the allow-list (e.g., future event kinds like + # file_search_call / code_interpreter_call) are dropped — safer than + # forwarding them to the handler without review. messages = [ ( 0, None, { "type": "response.output_item.done", - "item": {"type": "reasoning", "content": "think"}, + "item": {"type": "file_search_call", "results": []}, }, 1, ) From 2f26e26f489473792fa82b2803aa710ec24914bd Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 00:09:48 +0000 Subject: [PATCH 26/39] docs: update server.py docstrings for rotate+replay+inherit resume behavior MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two docstrings still described the older "input=[]" resume contract (before the rotation + replay + prior-attempt-inheritance work landed): - LongRunningAgentServer class docstring: said the handler is re-invoked with input=[] and SDKs load prior progress. Now it's re-invoked with rotated conversation_id + original input + inherited prior-attempt items + sanitizer-paired synthetic [INTERRUPTED] outputs. - _try_claim_and_resume method docstring: same correction. No behavior change — docs-only. Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/server.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index d984bb9f..8597ba97 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -520,11 +520,12 @@ class LongRunningAgentServer(AgentServer): Durable resume: when ``GET /responses/{id}`` sees an ``in_progress`` run whose owning pod has stopped heartbeating for more than ``heartbeat_stale_threshold_seconds``, the retrieving pod atomically claims - the run and re-invokes the registered handler with ``input=[]`` plus the - stamped ``conversation_id``. Agent SDKs (LangGraph checkpointer, - databricks-openai Session) load prior progress and continue — completed - tool calls are not re-executed. Tools interrupted mid-call may re-run; this - is the accepted best-effort tradeoff. + the run and re-invokes the registered handler with a rotated + ``conversation_id`` (so the agent SDK resolves to a fresh thread/session), + the original request's ``input`` enriched with the prior attempt's already + emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic + output paired with any tool call that didn't finish. Completed work is + preserved; only the interrupted step re-runs. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -1091,8 +1092,10 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: """If ``resp`` is a stale in-progress run, attempt an atomic claim. On success, kick off a new background task that re-invokes the handler - with ``input=[]`` and returns the new ``attempt_number``. On failure - (another pod won, or the run is no longer stale), returns ``None``. + on a rotated conversation anchor with the replayed input enriched by + the prior attempt's emitted items, and returns the new + ``attempt_number``. On failure (another pod won, or the run is no + longer stale), returns ``None``. This is the lazy resume path: triggered by a client retrieve. Pods don't poll for stale work proactively in v1 — if no client ever calls From 68ce27663160873f8a30f06d6ebec2c163acba28 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 00:22:39 +0000 Subject: [PATCH 27/39] Strip PR to bare minimum essentials for final durable-resume contract Remove the exploratory surface that the final design doesn't actually need. The remaining code is the smallest set that makes rotate+replay + inheritance + read-time repair work end-to-end. Dropped: - **build_tool_resume_repair_middleware** (checkpoint.py, __init__.py export). The AsyncCheckpointSaver.aget_tuple read-time repair covers the same case without the middleware API surface. The underlying pure walker (build_tool_resume_repair) stays since both the read-time repair and custom-graph users can import it directly. - **AsyncDatabricksSession.repair()** (session.py). Destructive DB rewriter. Redundant now that get_items()'s auto_repair filter returns protocol-valid items on every read. The _sanitize_items helper stays, and _item_dict was removed as unused. - **Reasoning-item inheritance + partial reasoning reassembly** (server.py). Dropped from _INHERITABLE_ITEM_TYPES and the reassemblers dispatch. Templates don't exercise reasoning mode; re-add to the allow-list when someone ships a thinking-mode app. - **function_call argument-stream reassembly** (server.py). Edge case: a crash after output_item.added for a function_call but before args finish streaming. Partial JSON is risky to feed back anyway; we drop it and let the next attempt's LLM re-decide. - **custom_inputs.attempt_number breadcrumb** (server.py). Unused by any template; retry awareness comes from the synthetic [INTERRUPTED] output text on the inherited item, not from this breadcrumb. Kept essentials: - rotate + replay of original_request.input on resume - full-history input sanitizer - AsyncCheckpointSaver.aget_tuple read-time repair - AsyncDatabricksSession.get_items auto-repair - prior-attempt completed-item inheritance (function_call, function_call_output, message) - partial assistant-message text reassembly from output_text.delta frames - [durable] lifecycle logs - debug /_debug/kill_task endpoint (env-gated) Tests drop 4 (reasoning, function_call args parse+fallback, attempt_number breadcrumb). Net: 344 fewer lines; 152 passing tests. Co-authored-by: Isaac --- .../src/databricks_langchain/__init__.py | 2 - .../src/databricks_langchain/checkpoint.py | 78 +------- .../src/databricks_openai/agents/session.py | 58 ------ .../long_running/server.py | 175 +++++------------- .../test_long_running_server.py | 129 +------------ 5 files changed, 49 insertions(+), 393 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 0cc35380..39db806b 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -22,7 +22,6 @@ AsyncCheckpointSaver, CheckpointSaver, build_tool_resume_repair, - build_tool_resume_repair_middleware, ) from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent @@ -54,5 +53,4 @@ "DatabricksMCPServer", "MCPServer", "build_tool_resume_repair", - "build_tool_resume_repair_middleware", ] diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 95acd1bc..8e0d31ca 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -104,85 +104,19 @@ def build_tool_resume_repair( return [ToolMessage(tool_call_id=tc_id, content=synthetic_output) for tc_id in orphans] -def build_tool_resume_repair_middleware( - synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, -) -> Any: - """Return a LangChain ``AgentMiddleware`` that repairs orphan tool calls. - - Wires ``build_tool_resume_repair`` into ``langchain.agents.create_agent`` - via its middleware API so durable-resume recovery happens automatically - before every LLM call. Keeps repair logic off the handler — callers only - add one argument to ``create_agent``. - - Usage:: - - from databricks_langchain import build_tool_resume_repair_middleware - - agent = create_agent( - model=model, - tools=tools, - checkpointer=checkpointer, - middleware=[build_tool_resume_repair_middleware()], - ) - - The middleware's ``before_model`` hook fires on every model turn and is a - no-op when state is clean, so the happy path is free. On a mid-tool - crash-resume, it injects synthetic ``ToolMessage``s for any - ``AIMessage.tool_calls`` in the trailing turn whose paired - ``ToolMessage`` never landed. Satisfies Anthropic's ``tool_use`` ⇄ - ``tool_result`` contract without needing manual - ``aupdate_state(..., as_node="tools")`` surgery. - - Args: - synthetic_output: Text for each injected ``ToolMessage.content``. - - Returns: - An ``AgentMiddleware`` instance suitable for the ``middleware=`` - argument of ``langchain.agents.create_agent``. - - Raises: - ImportError: If ``langchain.agents.middleware.AgentMiddleware`` is - unavailable (older langchain version or extra not installed). - """ - try: - from langchain.agents.middleware import AgentMiddleware - except ImportError as exc: - raise ImportError( - "build_tool_resume_repair_middleware requires langchain>=1.0 with " - "the agents extra. Install via `pip install langchain[agents]` or " - "equivalent." - ) from exc - - class ToolResumeRepairMiddleware(AgentMiddleware): - """Repairs orphan tool_use AIMessages before each model invocation.""" - - def before_model(self, state, runtime): # type: ignore[override] - repair = build_tool_resume_repair( - state.get("messages", []), synthetic_output=synthetic_output - ) - return {"messages": repair} if repair else None - - async def abefore_model(self, state, runtime): # type: ignore[override] - return self.before_model(state, runtime) - - return ToolResumeRepairMiddleware() - - def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: """Return a copy of ``tup`` with orphan tool_calls in its ``messages`` channel closed by synthetic ``ToolMessage`` s. Called on every ``(a)get_tuple`` to make the served checkpoint - protocol-valid (every ``tool_use`` paired with a ``tool_result``) without - requiring callers to install ``build_tool_resume_repair_middleware``. - A kill between the ``model`` and ``tools`` nodes leaves the trailing - ``AIMessage.tool_calls`` unpaired; on the NEXT turn that state would - otherwise leak into the LLM and be rejected by the provider's pairing - check. + protocol-valid (every ``tool_use`` paired with a ``tool_result``) + transparently. A kill between the ``model`` and ``tools`` nodes leaves + the trailing ``AIMessage.tool_calls`` unpaired; on the NEXT turn that + state would otherwise leak into the LLM and be rejected by the + provider's pairing check. Idempotent — ``build_tool_resume_repair`` is a no-op when state is - already clean. Cheap — the walk is O(trailing-turn), same as the - in-graph middleware. + already clean. Cheap — the walk is O(trailing-turn). Side effect: the synthetic ``ToolMessage`` s added here become part of the state LangGraph writes on the NEXT node boundary, so the repair diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 0b7576a7..164b2f86 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -68,15 +68,6 @@ def _item_get(item: Any, key: str) -> Any: return getattr(item, key, None) -def _item_dict(item: Any) -> dict: - """Normalize a session item to a plain dict for re-persistence.""" - if isinstance(item, dict): - return dict(item) - if hasattr(item, "model_dump"): - return item.model_dump() - return dict(item.__dict__) if hasattr(item, "__dict__") else {} - - def _sanitize_items( items: list[Any], synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, @@ -284,61 +275,12 @@ async def get_items(self, limit: Optional[int] = None) -> list[Any]: from a durable-resume crash get a synthetic output appended, and duplicates get deduped. The underlying DB rows are not modified; this is a pure in-memory filter, cheap to re-run on every call. - - Callers that want the raw persisted items can construct the session - with ``auto_repair=False``, or call ``repair()`` which writes the - sanitized state back to the DB. """ items = await super().get_items(limit=limit) if not self._auto_repair: return items return _sanitize_items(items, synthetic_output=self._auto_repair_synthetic_output) - async def repair(self, synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT) -> int: - """Reconcile the session DB so persisted rows are protocol-valid. - - Destructive — rewrites the ``agent_messages`` rows via - ``clear_session()`` + ``add_items(sanitized)``. Callers who only need - a clean view for the next LLM call should rely on ``get_items()``'s - auto-repair instead; ``repair()`` is for one-shot maintenance jobs - or tests that want to assert the DB itself is clean. - - Args: - synthetic_output: Text used for the synthetic outputs inserted for - orphan tool calls. Defaults to an 'interrupted by resume' - message. - - Returns: - The number of synthetic outputs injected (0 if already clean). - """ - # Bypass our auto-repair override so we see the raw items and can - # tell whether the DB is already clean. - items = await super().get_items() - if not items: - return 0 - sanitized = _sanitize_items(items, synthetic_output=synthetic_output) - # When _sanitize_items has nothing to do it returns ``items`` itself. - if sanitized is items: - return 0 - injected_call_ids = [ - _item_get(s, "call_id") - for s in sanitized - if _item_get(s, "type") == "function_call_output" - and _item_get(s, "call_id") not in {_item_get(i, "call_id") for i in items} - ] - sanitized_dicts = [_item_dict(i) for i in sanitized] - logger.info( - "AsyncDatabricksSession.repair session_id=%s original=%d sanitized=%d injected=%d", - self.session_id, - len(items), - len(sanitized_dicts), - len(injected_call_ids), - ) - await self.clear_session() - if sanitized_dicts: - await self.add_items(sanitized_dicts) - return len(injected_call_ids) - @classmethod def _build_cache_key( cls, diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 8597ba97..ea972eba 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -138,110 +138,38 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -_INHERITABLE_ITEM_TYPES = ( - "function_call", - "function_call_output", - "message", - "reasoning", -) - -# Event types whose ``delta`` contributes to a specific in-progress item kind. -# Single dict so a new partial-reassembly type is just two entries: a -# reassembler below + an entry here. -_DELTA_EVENTS_BY_ITEM_TYPE = { - "message": {"response.output_text.delta"}, - "reasoning": { - "response.reasoning.delta", - "response.reasoning_summary_text.delta", - }, - "function_call": {"response.function_call_arguments.delta"}, -} - - -def _reassemble_message(text: str, template: dict) -> dict: - return { - "type": "message", - "role": "assistant", - "content": [ - {"type": "output_text", "text": text, "annotations": []}, - ], - } - - -def _reassemble_reasoning(text: str, template: dict) -> dict: - return { - "type": "reasoning", - "summary": [{"type": "summary_text", "text": text}], - } - - -def _reassemble_function_call(arguments_text: str, template: dict) -> dict | None: - """Rebuild a function_call from its added-event template + accumulated - argument-delta text. - - Falls back to ``{}`` if the partial arguments don't parse as JSON, so the - item is still protocol-valid input (the sanitizer will pair it with a - synthetic "[INTERRUPTED]" output — the LLM reads that as "this call was - started but didn't finish" and can re-invoke cleanly). - """ - call_id = template.get("call_id") or template.get("id") - if not call_id: - return None - try: - json.loads(arguments_text) - args = arguments_text - except Exception: - args = "{}" - return { - "type": "function_call", - "call_id": call_id, - "name": template.get("name"), - "arguments": args, - } - - -_PARTIAL_REASSEMBLERS = { - "message": _reassemble_message, - "reasoning": _reassemble_reasoning, - "function_call": _reassemble_function_call, -} +_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: - """Return conversational items the given prior attempt already emitted, - including partial reassembly of messages, reasoning, and tool-call - argument streams. + """Return conversational items the given prior attempt already emitted. - Two passes combined: + Collects: 1. **Completed items**: ``response.output_item.done`` events for - ``function_call`` / ``function_call_output`` / assistant ``message`` - / ``reasoning``. Reliable self-contained pieces the prior attempt - finished. - - 2. **Partial in-flight items**: if the crash happened mid-stream, an - item has ``response.output_item.added`` and a sequence of - type-specific delta frames (``output_text.delta``, - ``reasoning.delta`` / ``reasoning_summary_text.delta``, - ``function_call_arguments.delta``) but never reached ``done``. - We reassemble the deltas into a synthetic item so the next attempt's - LLM sees where narration / reasoning / tool-call args trailed off. - Without this, a text-only crash leaves the LLM with just the user's - message and it restarts from the top. - - Events are walked in sequence_number order; partial items are emitted - last (after all completed items from the same attempt). + ``function_call`` / ``function_call_output`` / assistant ``message``. + Reliable self-contained pieces the prior attempt finished. + + 2. **Partial assistant text**: if the crash happened mid-stream on a + text response, the message has ``response.output_item.added`` + + ``response.output_text.delta`` frames but no ``done``. We reassemble + the deltas into a synthetic ``message`` item so the next attempt's + LLM sees where narration trailed off and can continue. Without this, + a text-only crash leaves the LLM with just the user's message and it + restarts from the top. + + Events are walked in sequence_number order; partial items emit last + (after all completed items from the same attempt). ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, attempt_number)``. """ out: list[dict] = [] - # Track in-progress items by their server-assigned id: type + chunks + - # the original added-event item shell (needed for function_call's - # name / call_id, which don't appear on delta frames). - in_progress: dict[str, dict] = {} + # Track in-progress message items by id: accumulated text chunks. Reset + # when a matching .done arrives — that completed item is authoritative. + in_progress_text: dict[str, list[str]] = {} in_progress_order: list[str] = [] for _seq, _item_json, evt, attempt_tag in messages: @@ -253,53 +181,39 @@ def _collect_prior_attempt_tool_events( if t == "response.output_item.done": item = evt.get("item") - if not isinstance(item, dict): - continue - if item.get("type") in _INHERITABLE_ITEM_TYPES: + if isinstance(item, dict) and item.get("type") in _INHERITABLE_ITEM_TYPES: out.append(item) - # Drop matching in-progress tracker — the completed item is - # authoritative; no duplicate emission from deltas. iid = item.get("id") - if iid in in_progress: - del in_progress[iid] + if iid in in_progress_text: + del in_progress_text[iid] in_progress_order.remove(iid) elif t == "response.output_item.added": item = evt.get("item") - if not isinstance(item, dict): - continue - item_type = item.get("type") - iid = item.get("id") - if iid and item_type in _PARTIAL_REASSEMBLERS: - in_progress[iid] = { - "item_type": item_type, - "chunks": [], - "template": dict(item), - } - if iid not in in_progress_order: - in_progress_order.append(iid) - else: + if isinstance(item, dict) and item.get("type") == "message": + iid = item.get("id") + if iid: + in_progress_text.setdefault(iid, []) + if iid not in in_progress_order: + in_progress_order.append(iid) + elif t == "response.output_text.delta": iid = evt.get("item_id") delta = evt.get("delta") - if not (iid and isinstance(delta, str) and iid in in_progress): - continue - expected_events = _DELTA_EVENTS_BY_ITEM_TYPE.get( - in_progress[iid]["item_type"], set() - ) - if t in expected_events: - in_progress[iid]["chunks"].append(delta) + if iid and isinstance(delta, str) and iid in in_progress_text: + in_progress_text[iid].append(delta) for iid in in_progress_order: - tracked = in_progress.get(iid) - if not tracked or not tracked["chunks"]: - continue - reassembler = _PARTIAL_REASSEMBLERS.get(tracked["item_type"]) - if reassembler is None: + chunks = in_progress_text.get(iid) or [] + if not chunks: continue - joined = "".join(tracked["chunks"]) - reassembled = reassembler(joined, tracked["template"]) - if reassembled is not None: - out.append(reassembled) - + out.append( + { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "".join(chunks), "annotations": []}, + ], + } + ) return out @@ -444,11 +358,6 @@ def _rotate_conversation_id( custom_inputs.pop("thread_id", None) custom_inputs.pop("session_id", None) - # Leave a breadcrumb so handlers that care about retry awareness (e.g., - # injecting a "you are resuming a retry" system prompt, or opting-out of - # retry-unsafe tools) can branch on it. Absent from normal first-attempt - # requests — handlers should default to "1" if missing. - custom_inputs["attempt_number"] = new_attempt_number request_dict["custom_inputs"] = custom_inputs ctx = request_dict.get("context") or {} diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index f8b41001..b9518427 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1163,124 +1163,6 @@ def test_ignores_partial_text_if_item_eventually_completed(self): assert len(out) == 1 assert out[0]["content"][0]["text"] == "Hello, world" - def test_reassembles_partial_reasoning(self): - # Mid-stream reasoning: added + reasoning.delta frames, no .done. - messages = [ - ( - 0, - None, - { - "type": "response.output_item.added", - "item": {"type": "reasoning", "id": "r_1"}, - }, - 1, - ), - ( - 1, - None, - { - "type": "response.reasoning.delta", - "item_id": "r_1", - "delta": "Let me think about ", - }, - 1, - ), - ( - 2, - None, - { - "type": "response.reasoning_summary_text.delta", - "item_id": "r_1", - "delta": "this carefully.", - }, - 1, - ), - ] - out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - assert len(out) == 1 - assert out[0]["type"] == "reasoning" - assert out[0]["summary"][0]["text"] == "Let me think about this carefully." - - def test_reassembles_partial_function_call_arguments(self): - # Tool call arg stream started but crash happened before .done. - messages = [ - ( - 0, - None, - { - "type": "response.output_item.added", - "item": { - "type": "function_call", - "id": "fc_1", - "call_id": "call_abc", - "name": "get_weather", - }, - }, - 1, - ), - ( - 1, - None, - { - "type": "response.function_call_arguments.delta", - "item_id": "fc_1", - "delta": '{"city":', - }, - 1, - ), - ( - 2, - None, - { - "type": "response.function_call_arguments.delta", - "item_id": "fc_1", - "delta": ' "Paris"}', - }, - 1, - ), - ] - out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - assert len(out) == 1 - assert out[0]["type"] == "function_call" - assert out[0]["call_id"] == "call_abc" - assert out[0]["name"] == "get_weather" - assert out[0]["arguments"] == '{"city": "Paris"}' - - def test_partial_function_call_malformed_args_falls_back_to_empty_object(self): - # Args stream cut mid-JSON — not parseable. Fall back to {} so the - # function_call item stays protocol-valid; the sanitizer still pairs - # it with an INTERRUPTED output. - messages = [ - ( - 0, - None, - { - "type": "response.output_item.added", - "item": { - "type": "function_call", - "id": "fc_2", - "call_id": "call_xyz", - "name": "get_weather", - }, - }, - 1, - ), - ( - 1, - None, - { - "type": "response.function_call_arguments.delta", - "item_id": "fc_2", - "delta": '{"city":', - }, - 1, - ), - ] - out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - assert len(out) == 1 - assert out[0]["arguments"] == "{}" - assert out[0]["call_id"] == "call_xyz" - def test_skips_unknown_item_types(self): # Item types outside the allow-list (e.g., future event kinds like # file_search_call / code_interpreter_call) are dropped — safer than @@ -1328,16 +1210,7 @@ def test_rotate_handles_missing_custom_inputs_key(self): r = {"context": {"conversation_id": "c-abc"}} out = _rotate_conversation_id(r, new_attempt_number=2, response_id="resp_x") assert out["context"]["conversation_id"] == "c-abc::attempt-2" - # attempt_number breadcrumb is stamped even when custom_inputs was absent. - assert out["custom_inputs"] == {"attempt_number": 2} - - def test_rotate_stamps_attempt_number_breadcrumb(self): - r = {"custom_inputs": {"user_id": "u"}, "context": {}} - out = _rotate_conversation_id(r, new_attempt_number=3, response_id="resp_x") - # Handlers can branch on this to inject retry-aware behavior. - assert out["custom_inputs"]["attempt_number"] == 3 - # Unrelated custom_inputs keys are preserved. - assert out["custom_inputs"]["user_id"] == "u" + assert out["custom_inputs"] == {} class TestInjectConversationId: From c23d9b62db2e336781f7da56bc5e37e075108592 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 00:27:23 +0000 Subject: [PATCH 28/39] Make build_tool_resume_repair internal (rename to _build_tool_resume_repair) The public export existed for users who wanted to install the middleware in custom graphs. Since the middleware wrapper is gone (read-time repair in AsyncCheckpointSaver.aget_tuple covers the case transparently for users on our saver), only one internal caller remains: _repair_loaded_checkpoint_tuple. Rename to underscore-prefixed private, drop the databricks_langchain __init__ export and __all__ entry. No behavior change. Co-authored-by: Isaac --- .../src/databricks_langchain/__init__.py | 7 +--- .../src/databricks_langchain/checkpoint.py | 42 ++++++------------- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/__init__.py b/integrations/langchain/src/databricks_langchain/__init__.py index 39db806b..bfa52f8c 100644 --- a/integrations/langchain/src/databricks_langchain/__init__.py +++ b/integrations/langchain/src/databricks_langchain/__init__.py @@ -18,11 +18,7 @@ from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit, UnityCatalogTool from databricks_langchain.chat_models import ChatDatabricks -from databricks_langchain.checkpoint import ( - AsyncCheckpointSaver, - CheckpointSaver, - build_tool_resume_repair, -) +from databricks_langchain.checkpoint import AsyncCheckpointSaver, CheckpointSaver from databricks_langchain.embeddings import DatabricksEmbeddings from databricks_langchain.genie import GenieAgent from databricks_langchain.multi_server_mcp_client import ( @@ -52,5 +48,4 @@ "DatabricksMultiServerMCPClient", "DatabricksMCPServer", "MCPServer", - "build_tool_resume_repair", ] diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 8e0d31ca..fa5b2a72 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -39,40 +39,24 @@ ) -def build_tool_resume_repair( +def _build_tool_resume_repair( messages: Sequence[Any], synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, ) -> list[Any]: """Build synthetic ``ToolMessage`` responses for orphan tool calls. - When a LangGraph run is killed mid-tool, the checkpointer preserves the - trailing ``AIMessage.tool_calls`` but the paired ``ToolMessage``s never - land. Replaying that state to the LLM on resume fails because the API - (Anthropic in particular) requires every ``tool_use`` to be immediately - followed by a matching ``tool_result``. + Internal helper used by ``_repair_loaded_checkpoint_tuple``. When a + LangGraph run is killed mid-tool, the checkpointer preserves the + trailing ``AIMessage.tool_calls`` but the paired ``ToolMessage``s + never land. Replaying that state to the LLM fails because the API + (Anthropic in particular) requires every ``tool_use`` to be + immediately followed by a matching ``tool_result``. Walks the trailing assistant turn (the last contiguous block of - ``AIMessage`` / ``ToolMessage``) and returns a synthetic ``ToolMessage`` - for each ``tool_call`` id that lacks a matching - ``ToolMessage.tool_call_id``. Appending the returned list via the - ``add_messages`` reducer restores a valid conversation. - - Example:: - - from databricks_langchain import build_tool_resume_repair - - state = await graph.aget_state(config) - repair = build_tool_resume_repair(state.values.get("messages", [])) - if repair: - await graph.aupdate_state(config, {"messages": repair}) - - Args: - messages: The current ``messages`` list from graph state. - synthetic_output: Text for each injected ``ToolMessage.content``. - - Returns: - A list of ``ToolMessage`` instances (possibly empty). Empty means - the state is already consistent — no repair needed. + ``AIMessage`` / ``ToolMessage``) and returns a synthetic + ``ToolMessage`` for each ``tool_call`` id that lacks a matching + ``ToolMessage.tool_call_id``. The caller appends these to the + ``messages`` channel before the next model call. """ if not _message_imports_available or not messages: return [] @@ -115,7 +99,7 @@ def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: state would otherwise leak into the LLM and be rejected by the provider's pairing check. - Idempotent — ``build_tool_resume_repair`` is a no-op when state is + Idempotent — ``_build_tool_resume_repair`` is a no-op when state is already clean. Cheap — the walk is O(trailing-turn). Side effect: the synthetic ``ToolMessage`` s added here become part of @@ -135,7 +119,7 @@ def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: if not isinstance(messages, list) or not messages: return tup - repair = build_tool_resume_repair(messages) + repair = _build_tool_resume_repair(messages) if not repair: return tup From 8bd0718bf1fcce6fc74e2071728c920d28d0132d Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 00:53:26 +0000 Subject: [PATCH 29/39] server: add asyncio.sleep(0) yield point in stream loop OpenAI Agents Runner's stream_events() awaits a queue that drains fast; without an explicit yield point per event, task.cancel() can sit for tens of seconds during a text-heavy stream before propagating. The deep_research tool path doesn't hit this because asyncio.sleep(15) is cancellable by design. Reproduced: 2000-word essay streaming, /_debug/kill_task issued at 12s, task continued running for another 48s after cancel before exiting naturally. After this fix, cancel should propagate within a single event's worth of stream delay. Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/server.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index ea972eba..480c4f42 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -898,6 +898,11 @@ async def _do_background_stream( ) seq += 1 state["seq"] = seq + # Explicit yield so task.cancel() propagates promptly on + # tight event streams. The OpenAI Agents Runner's + # stream_events() awaits a queue that empties fast enough + # that cancellation can sit for tens of seconds without this. + await asyncio.sleep(0) span.set_attribute(SpanAttributeKey.MESSAGE_FORMAT, "openai") span.set_outputs(ResponsesAgent.responses_agent_output_reducer(all_chunks)) From 4d0756e2b7d9cb622a4ab5f4f998dfa73122446e Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 02:10:14 +0000 Subject: [PATCH 30/39] Drop event: line from durable SSE frames MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Match the non-durable stream format (data-only frames, type carried inside the payload). The previous event: prefix was benign for GPT-5 on openai-advanced because GPT-5 runs a single response.created/completed pair per turn — the AI SDK's Databricks provider parses the data, sees the type, and completes cleanly. For Claude via the OpenAI-compatible endpoint, each tool iteration in a multi-tool turn emits its own response.created/response.completed pair; combined with the event: prefix, the provider's state machine fails to emit a clean finish UIMessageChunk and the UI retries forever. The non-durable path never had this because its format is data-only; matching it here is the minimal fix. Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/server.py | 9 +++++++-- tests/databricks_ai_bridge/test_long_running_server.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 480c4f42..8cd23e99 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -125,9 +125,14 @@ async def _current_attempt(response_id: str) -> int: def _sse_event(event_type: str, data: dict[str, Any] | str) -> str: - """Format an SSE event per Open Responses spec.""" + """Emit ``data:``-only SSE frames. Match the non-durable stream format + so downstream SSE parsers dispatch on the payload's ``type`` field + rather than a leading ``event:`` name line. Claude's multi-response + stream (one response.created/completed pair per tool iteration) plus + the event-name prefix confuses the AI SDK's Databricks provider into + a retry loop.""" payload = data if isinstance(data, str) else json.dumps(data) - return f"event: {event_type}\ndata: {payload}\n\n" + return f"data: {payload}\n\n" def _age_seconds(created_at: datetime) -> float: diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index b9518427..91f6f0b3 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -95,8 +95,8 @@ def _mock_validator(server): class TestSSEEvent: def test_dict_data(self): result = _sse_event("response.created", {"id": "resp_123", "status": "in_progress"}) - assert result.startswith("event: response.created\n") - assert "data: " in result + assert result.startswith("data: ") + assert "event:" not in result assert result.endswith("\n\n") data_line = result.split("data: ")[1].strip() parsed = json.loads(data_line) @@ -104,8 +104,8 @@ def test_dict_data(self): def test_string_data(self): result = _sse_event("error", "something went wrong") - assert "event: error\n" in result - assert "data: something went wrong\n\n" in result + assert "event:" not in result + assert result == "data: something went wrong\n\n" class TestLongRunningSettings: From 07ded9dbe1fce0549dbb497ff823a858f6f2839b Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 05:27:39 +0000 Subject: [PATCH 31/39] Inheritance: drop completed message items, preserve only tool pairs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Root cause of stuck durable resume on Claude multi-tool turns: the openai-agents SDK's Anthropic adapter rejects the replayed input with HTTP 400 — "tool_use ids were found without tool_result blocks immediately after". Claude's attempt 1 event stream interleaves function_call → narrative message → function_call_output, and my inheritance preserved that order. When the adapter converts each item to an Anthropic message, the narrative message lands between the assistant's tool_use and the user's tool_result, violating Anthropic's "tool_use immediately followed by tool_result" contract. Attempt 2 therefore fails on its first LLM call and the agent never continues. Fix: remove completed `message` items from the inheritable allow-list so only function_call + function_call_output pairs flow through the replay — and those are naturally paired in the event log, so the Anthropic adapter produces a valid message sequence. Preserved: - Partial mid-stream text reassembly (different code path) — a text-only crash still synthesizes the partial assistant message from output_text.delta frames and inherits it, so Claude's prefill continuation still works. - .done cleanup of the partial tracker fires regardless of whether the completed item is on the inherit list, so a fully-completed message no longer gets falsely "reassembled" as a partial at the end of the walk. Co-authored-by: Isaac --- .../long_running/server.py | 36 +++++++++++++------ .../test_long_running_server.py | 30 +++++++++------- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 8cd23e99..a9cdfa67 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -143,30 +143,41 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") +# Only tool pairs are inherited as completed items. Completed `message` +# items are intentionally excluded — Claude (via openai-agents → Anthropic +# adapter) emits a narrative `message` between each function_call and its +# function_call_output, and preserving that order in the replay produces +# an Anthropic message sequence where `tool_use` is not immediately +# followed by `tool_result`, which the provider rejects (HTTP 400). +# Dropping the narrative messages leaves tool_call/tool_result pairs +# adjacent in the replay. Partial mid-stream text (text-only crash) is a +# separate path below and still reassembled. +_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output") def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: - """Return conversational items the given prior attempt already emitted. + """Return tool-related items the given prior attempt already emitted. Collects: - 1. **Completed items**: ``response.output_item.done`` events for - ``function_call`` / ``function_call_output`` / assistant ``message``. - Reliable self-contained pieces the prior attempt finished. + 1. **Completed tool pairs**: ``response.output_item.done`` events for + ``function_call`` / ``function_call_output``. Completed narrative + ``message`` items are skipped on purpose — see + ``_INHERITABLE_ITEM_TYPES`` for the Anthropic-adapter ordering + reason. 2. **Partial assistant text**: if the crash happened mid-stream on a text response, the message has ``response.output_item.added`` + ``response.output_text.delta`` frames but no ``done``. We reassemble the deltas into a synthetic ``message`` item so the next attempt's LLM sees where narration trailed off and can continue. Without this, - a text-only crash leaves the LLM with just the user's message and it + a text-only crash leaves the LLM with just the user's message and restarts from the top. - Events are walked in sequence_number order; partial items emit last - (after all completed items from the same attempt). + Events are walked in sequence_number order; the partial message, if + any, emits last. ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, attempt_number)``. @@ -186,8 +197,13 @@ def _collect_prior_attempt_tool_events( if t == "response.output_item.done": item = evt.get("item") - if isinstance(item, dict) and item.get("type") in _INHERITABLE_ITEM_TYPES: - out.append(item) + if isinstance(item, dict): + if item.get("type") in _INHERITABLE_ITEM_TYPES: + out.append(item) + # Always clear the partial-text tracker when a matching .done + # arrives, even if the completed item isn't on the inherit + # list. Otherwise a message that fully streamed would still + # get reassembled as a "partial" at the end. iid = item.get("id") if iid in in_progress_text: del in_progress_text[iid] diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 91f6f0b3..3cdea1a5 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1067,9 +1067,11 @@ def test_only_output_item_done_events_count(self): assert len(out) == 1 assert out[0]["call_id"] == "c1" - def test_inherits_assistant_message_items(self): - # Completed assistant text messages inherit too, so the next attempt's - # LLM sees its prior narration and doesn't re-emit it from scratch. + def test_completed_assistant_message_items_are_skipped(self): + # Narrative `message` items Claude interleaves between tool calls + # would break Anthropic's "tool_use immediately followed by + # tool_result" rule if inherited. Only function_call/output pairs + # flow through; completed messages are dropped. messages = [ ( 0, @@ -1084,11 +1086,11 @@ def test_inherits_assistant_message_items(self): self._event(2, 1, "function_call_output", "c1", output="ok"), ] out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - # 1 message + 1 function_call + 1 function_call_output - assert len(out) == 3 - assert out[0]["type"] == "message" - assert out[1]["type"] == "function_call" - assert out[2]["type"] == "function_call_output" + # Only the function_call + function_call_output pair — the message + # is intentionally skipped. + assert len(out) == 2 + assert out[0]["type"] == "function_call" + assert out[1]["type"] == "function_call_output" def test_reassembles_partial_text_from_delta_events(self): # Attempt crashed mid-stream: item.added + deltas but no item.done. @@ -1125,8 +1127,10 @@ def test_reassembles_partial_text_from_delta_events(self): assert out[0]["content"][0]["text"] == "Hello, world" def test_ignores_partial_text_if_item_eventually_completed(self): - # Deltas streamed, then item.done landed — the completed item is what - # we inherit; the deltas are just intermediate frames. + # Deltas streamed, then item.done landed — since completed message + # items are no longer inherited at all, and the partial reassembly + # only fires when .done is missing, this case produces an empty + # inherited list. messages = [ ( 0, @@ -1159,9 +1163,9 @@ def test_ignores_partial_text_if_item_eventually_completed(self): ), ] out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - # Only the completed item — NOT a duplicate from the partial deltas. - assert len(out) == 1 - assert out[0]["content"][0]["text"] == "Hello, world" + # Completed messages are intentionally NOT inherited, and partial + # reassembly cleared its tracker when .done arrived → empty list. + assert out == [] def test_skips_unknown_item_types(self): # Item types outside the allow-list (e.g., future event kinds like From 7bcb1f379a100297a0d8e287ff3dd4624f33bae1 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 05:30:42 +0000 Subject: [PATCH 32/39] Inheritance: hoist narrative messages after tool pairs instead of dropping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prior fix dropped completed `message` items to avoid Anthropic's "tool_use must be immediately followed by tool_result" 400 on replay. Better alternative: keep the narrative but move it past the tool pairs so each function_call stays adjacent to its function_call_output. Collector now accumulates into two buckets while walking events: - tool_items: function_call / function_call_output in event-log order - narrative_items: completed message items (and the partial reassembled one if text crashed mid-stream) in event-log order Returns tool_items + narrative_items. The resulting Anthropic sequence is user → [assistant(tool_use_A) → user(tool_result_A) → ...] → assistant(narrative text). Valid alternation, and the LLM still gets the prior narrative as trailing context for continuation decisions. Co-authored-by: Isaac --- .../long_running/server.py | 69 +++++++++---------- .../test_long_running_server.py | 55 ++++++++++----- 2 files changed, 72 insertions(+), 52 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index a9cdfa67..0c96bd71 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -143,46 +143,42 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -# Only tool pairs are inherited as completed items. Completed `message` -# items are intentionally excluded — Claude (via openai-agents → Anthropic -# adapter) emits a narrative `message` between each function_call and its -# function_call_output, and preserving that order in the replay produces -# an Anthropic message sequence where `tool_use` is not immediately -# followed by `tool_result`, which the provider rejects (HTTP 400). -# Dropping the narrative messages leaves tool_call/tool_result pairs -# adjacent in the replay. Partial mid-stream text (text-only crash) is a -# separate path below and still reassembled. -_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output") +_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") def _collect_prior_attempt_tool_events( messages: list[tuple], prior_attempt_number: int ) -> list[dict]: - """Return tool-related items the given prior attempt already emitted. + """Return conversational items the given prior attempt already emitted, + reordered so the replay is a valid provider message sequence. - Collects: + Collected pieces: 1. **Completed tool pairs**: ``response.output_item.done`` events for - ``function_call`` / ``function_call_output``. Completed narrative - ``message`` items are skipped on purpose — see - ``_INHERITABLE_ITEM_TYPES`` for the Anthropic-adapter ordering - reason. - - 2. **Partial assistant text**: if the crash happened mid-stream on a - text response, the message has ``response.output_item.added`` + - ``response.output_text.delta`` frames but no ``done``. We reassemble - the deltas into a synthetic ``message`` item so the next attempt's - LLM sees where narration trailed off and can continue. Without this, - a text-only crash leaves the LLM with just the user's message and - restarts from the top. - - Events are walked in sequence_number order; the partial message, if - any, emits last. + ``function_call`` + ``function_call_output``. + 2. **Completed narrative messages**: ``response.output_item.done`` + events for assistant ``message`` items. + 3. **Partial assistant text**: if the crash happened mid-stream on a + text response, the message has ``output_item.added`` + + ``output_text.delta`` frames but no ``done``. We reassemble the + deltas into a synthetic ``message`` item so the next attempt's LLM + sees where narration trailed off and can continue. + + The emitted order is (tool pairs in event order) → (narrative messages + in event order) → (partial reassembled message). Claude's raw stream + interleaves narrative `message` items between function_call and its + function_call_output, which in Anthropic format would look like + ``assistant(tool_use)`` → ``assistant(text)`` → ``user(tool_result)`` + and trip the provider's "tool_use must be immediately followed by + tool_result" rule (HTTP 400). Hoisting narrative messages to the end + keeps each function_call adjacent to its output and lets the narrative + flow as a trailing assistant block. ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, attempt_number)``. """ - out: list[dict] = [] + tool_items: list[dict] = [] + narrative_items: list[dict] = [] # Track in-progress message items by id: accumulated text chunks. Reset # when a matching .done arrives — that completed item is authoritative. in_progress_text: dict[str, list[str]] = {} @@ -198,12 +194,11 @@ def _collect_prior_attempt_tool_events( if t == "response.output_item.done": item = evt.get("item") if isinstance(item, dict): - if item.get("type") in _INHERITABLE_ITEM_TYPES: - out.append(item) - # Always clear the partial-text tracker when a matching .done - # arrives, even if the completed item isn't on the inherit - # list. Otherwise a message that fully streamed would still - # get reassembled as a "partial" at the end. + itype = item.get("type") + if itype == "message": + narrative_items.append(item) + elif itype in _INHERITABLE_ITEM_TYPES: + tool_items.append(item) iid = item.get("id") if iid in in_progress_text: del in_progress_text[iid] @@ -226,7 +221,7 @@ def _collect_prior_attempt_tool_events( chunks = in_progress_text.get(iid) or [] if not chunks: continue - out.append( + narrative_items.append( { "type": "message", "role": "assistant", @@ -235,7 +230,9 @@ def _collect_prior_attempt_tool_events( ], } ) - return out + # Tool pairs first (each function_call immediately followed by its + # function_call_output in the event-log order), narrative after. + return tool_items + narrative_items def _sanitize_request_input( diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 3cdea1a5..2fc04621 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -1067,30 +1067,50 @@ def test_only_output_item_done_events_count(self): assert len(out) == 1 assert out[0]["call_id"] == "c1" - def test_completed_assistant_message_items_are_skipped(self): - # Narrative `message` items Claude interleaves between tool calls - # would break Anthropic's "tool_use immediately followed by - # tool_result" rule if inherited. Only function_call/output pairs - # flow through; completed messages are dropped. + def test_messages_hoisted_after_tool_pairs(self): + # Claude interleaves narrative `message` items between function_call + # and function_call_output in its event stream. Preserving that + # ordering in the replay would violate Anthropic's "tool_use + # immediately followed by tool_result" rule. Collector hoists all + # narrative messages to the end so tool pairs stay adjacent. messages = [ + self._event(0, 1, "function_call", "c1"), ( - 0, + 1, None, { "type": "response.output_item.done", - "item": {"type": "message", "role": "assistant", "content": "Let me check"}, + "item": {"type": "message", "role": "assistant", "content": "step one"}, }, 1, ), - self._event(1, 1, "function_call", "c1"), self._event(2, 1, "function_call_output", "c1", output="ok"), + self._event(3, 1, "function_call", "c2"), + ( + 4, + None, + { + "type": "response.output_item.done", + "item": {"type": "message", "role": "assistant", "content": "step two"}, + }, + 1, + ), + self._event(5, 1, "function_call_output", "c2", output="ok"), ] out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - # Only the function_call + function_call_output pair — the message - # is intentionally skipped. - assert len(out) == 2 - assert out[0]["type"] == "function_call" - assert out[1]["type"] == "function_call_output" + # 2 pairs + 2 messages. + assert len(out) == 6 + assert [i["type"] for i in out] == [ + "function_call", + "function_call_output", + "function_call", + "function_call_output", + "message", + "message", + ] + # call_ids paired up (c1,c1,c2,c2) and narrative in event order. + assert out[0]["call_id"] == "c1" and out[1]["call_id"] == "c1" + assert out[2]["call_id"] == "c2" and out[3]["call_id"] == "c2" def test_reassembles_partial_text_from_delta_events(self): # Attempt crashed mid-stream: item.added + deltas but no item.done. @@ -1163,9 +1183,12 @@ def test_ignores_partial_text_if_item_eventually_completed(self): ), ] out = _collect_prior_attempt_tool_events(messages, prior_attempt_number=1) - # Completed messages are intentionally NOT inherited, and partial - # reassembly cleared its tracker when .done arrived → empty list. - assert out == [] + # Completed message inherits via the narrative bucket; partial + # reassembly cleared its tracker when .done arrived so it does NOT + # also synthesize a duplicate from the deltas. + assert len(out) == 1 + assert out[0]["type"] == "message" + assert out[0]["content"][0]["text"] == "Hello, world" def test_skips_unknown_item_types(self): # Item types outside the allow-list (e.g., future event kinds like From 88815639009b7bd7b852b8879eb2cae527ea0128 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 05:43:00 +0000 Subject: [PATCH 33/39] =?UTF-8?q?Stable=20state=20=E2=80=94=20durable=20ex?= =?UTF-8?q?ecution=20verified=20end-to-end?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Marker commit: paired with app-templates PR #195 final state, this branch is now a correctly working stable baseline for durable execution across both LangGraph and OpenAI advanced templates. UI-validated scenarios that pass cleanly: - LangGraph + Claude, single-tool interrupt + resume (checkpointer read-time repair closes orphan tool_uses on stable thread). - LangGraph + Claude, multi-tool interrupt + resume, agent continues. - LangGraph + Claude, text-only mid-stream crash + resume with partial-text continuation (prefill). - OpenAI + GPT-5, single-tool + multi-tool interrupt + resume. - OpenAI + Claude, multi-tool interrupt + resume (previously hit the Anthropic "tool_use must be followed by tool_result" 400; fixed by hoisting inherited narrative messages past the tool pairs). - OpenAI + Claude, text-only mid-stream crash + resume with prefill continuation. - Cross-turn recall after crash-and-resume on both templates. Key pieces in this PR: - LongRunningAgentServer: rotate conversation_id + replay original_request.input + prior-attempt tool-pair inheritance + narrative hoist + partial-text reassembly + full-history input sanitizer. Data-only SSE frames (no event: prefix). - AsyncCheckpointSaver.aget_tuple: read-time orphan repair via build_tool_resume_repair for stable-thread baselines. - AsyncDatabricksSession.get_items: auto_repair sanitizer returns protocol-valid items without touching the DB. - asyncio.sleep(0) yield in stream loop so task.cancel() propagates promptly during tight text streams. Co-authored-by: Isaac From ef8f86ad0182be2b29eaa4aae2541d986c63f74f Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 06:03:02 +0000 Subject: [PATCH 34/39] Refactor: extract shared sanitize_tool_items helper Split the orphan tool-call sanitizer out of server.py / session.py into long_running/repair.py so the server input sanitizer and the OpenAI AsyncDatabricksSession.get_items auto-repair use the same walker (server was ~100 LOC, session was ~78 LOC, now one shared helper). Also: - Rename _INHERITABLE_ITEM_TYPES -> _TOOL_PAIR_TYPES to reflect that message items are hoisted separately, not inherited. - Split _collect_prior_attempt_tool_events into three focused helpers: _iter_attempt_events, _extract_completed_items, _reassemble_partial_message. - Restore the Args docstring on LongRunningAgentServer.__init__. No behavior change; all 179 unit tests pass (long_running + langchain checkpoint + openai session). Co-authored-by: Isaac --- .../src/databricks_openai/agents/session.py | 88 +----- .../long_running/repair.py | 137 +++++++++ .../long_running/server.py | 285 ++++++++---------- 3 files changed, 278 insertions(+), 232 deletions(-) create mode 100644 src/databricks_ai_bridge/long_running/repair.py diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 164b2f86..e991e75a 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -51,101 +51,37 @@ async def main(): DEFAULT_TOKEN_CACHE_DURATION_SECONDS, AsyncLakebaseSQLAlchemy, ) + from databricks_ai_bridge.long_running.repair import sanitize_tool_items _session_imports_available = True except ImportError: SQLAlchemySession = object # type: ignore DEFAULT_TOKEN_CACHE_DURATION_SECONDS = None # type: ignore DEFAULT_POOL_RECYCLE_SECONDS = None # type: ignore + sanitize_tool_items = None # type: ignore _session_imports_available = False logger = logging.getLogger(__name__) -def _item_get(item: Any, key: str) -> Any: - if isinstance(item, dict): - return item.get(key) - return getattr(item, key, None) - - def _sanitize_items( items: list[Any], synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, ) -> list[Any]: - """Return a protocol-valid view of session items. - - Walks items in chronological order, drops duplicate - ``function_call`` / ``function_call_output`` by ``call_id``, drops orphan - ``function_call_output`` items whose originating call is not present, - and injects a synthetic output immediately after any ``function_call`` - whose matching output never landed. - - Shared by ``repair()`` (destructive: rewrites the DB) and - ``get_items()`` (non-destructive: in-memory filter on read). Returning - the original ``items`` untouched when nothing needs repair lets callers - skip writes cheaply on the happy path. + """Return a protocol-valid view of session items (thin wrapper around + the shared :func:`sanitize_tool_items` walker). + + Kept as a private alias so existing ``self._sanitize_items`` call sites + in this module stay stable. Behaviour: drop duplicate / orphan tool + items, inject synthetic outputs for unpaired function_calls, return + the caller's list reference unchanged on the happy path so + ``repair()`` can short-circuit re-persistence. """ if not items: return items - - call_ids_with_output: set[str] = set() - declared_call_ids: set[str] = set() - for item in items: - t = _item_get(item, "type") - cid = _item_get(item, "call_id") - if t == "function_call" and cid: - declared_call_ids.add(cid) - if t == "function_call_output" and cid: - call_ids_with_output.add(cid) - - sanitized: list[Any] = [] - seen_calls: set[str] = set() - seen_outputs: set[str] = set() - injected_call_ids: list[str] = [] - dropped_orphan_outputs = 0 - - for item in items: - t = _item_get(item, "type") - cid = _item_get(item, "call_id") - if t == "function_call" and cid: - if cid in seen_calls: - continue - seen_calls.add(cid) - sanitized.append(item) - if cid not in call_ids_with_output: - sanitized.append( - { - "type": "function_call_output", - "call_id": cid, - "output": synthetic_output, - } - ) - injected_call_ids.append(cid) - elif t == "function_call_output" and cid: - if cid in seen_outputs: - continue - if cid not in declared_call_ids: - dropped_orphan_outputs += 1 - continue - seen_outputs.add(cid) - sanitized.append(item) - else: - sanitized.append(item) - - if len(sanitized) == len(items) and not injected_call_ids and not dropped_orphan_outputs: - # Happy path — return the caller's list reference so they can - # cheaply skip any re-persistence. - return items - - logger.info( - "[durable] session items sanitized: injected=%d dropped_orphan_outputs=%d " - "original=%d final=%d", - len(injected_call_ids), - dropped_orphan_outputs, - len(items), - len(sanitized), + return sanitize_tool_items( + items, synthetic_output, log_prefix="[durable] session items sanitized" ) - return sanitized class AsyncDatabricksSession(SQLAlchemySession): diff --git a/src/databricks_ai_bridge/long_running/repair.py b/src/databricks_ai_bridge/long_running/repair.py new file mode 100644 index 00000000..de16ca39 --- /dev/null +++ b/src/databricks_ai_bridge/long_running/repair.py @@ -0,0 +1,137 @@ +"""Shared orphan-tool-call repair logic. + +``sanitize_tool_items`` walks a list of Responses-API-style items and +reconciles orphan / duplicate ``function_call`` / ``function_call_output`` +items. Used by: + +* the server-side input sanitizer in :mod:`...long_running.server`, which + runs on every request before the handler is invoked; and +* the OpenAI :class:`AsyncDatabricksSession` ``get_items`` auto-repair, + which returns protocol-valid items without touching the underlying DB. + +The LangChain checkpointer has its own repair path +(``_build_tool_resume_repair``) that operates on ``AIMessage`` / +``ToolMessage`` shapes rather than the dict items here. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +def _default_item_get(item: Any, key: str) -> Any: + if isinstance(item, dict): + return item.get(key) + return getattr(item, key, None) + + +def sanitize_tool_items( + items: list[Any], + synthetic_output: str, + *, + item_get: Callable[[Any, str], Any] = _default_item_get, + log_prefix: str = "[durable] items sanitized", +) -> list[Any]: + """Return a protocol-valid view of ``items``. + + In order: + + * drops duplicate ``function_call`` items by ``call_id``; + * drops duplicate or orphan ``function_call_output`` items (no matching + ``function_call`` anywhere in the list); + * injects a synthetic ``function_call_output`` immediately after any + ``function_call`` that has no output in the list. + + Also recognises chat-completions-shape ``{role: assistant, tool_calls: + [...]}`` items as declaring call_ids, so mixed-shape histories don't + trip the orphan check. + + Returns the caller's ``items`` reference unchanged on the happy path so + downstream can skip any re-persistence cheaply. + + The ``synthetic_output`` text is passed in by the caller — each caller + owns its own copy of the string so product decisions about wording + stay scoped to the durable-resume path they belong to. + + ``item_get`` lets session-style objects (ORM rows with attribute + access) reuse this walker; defaults to plain dict ``.get``. + """ + if not items: + return items + + declared_call_ids: set[str] = set() + call_ids_with_output: set[str] = set() + for item in items: + t = item_get(item, "type") + cid = item_get(item, "call_id") + if t == "function_call" and cid: + declared_call_ids.add(cid) + if t == "function_call_output" and cid: + call_ids_with_output.add(cid) + # Chat-completions shape: assistant message with tool_calls. + if item_get(item, "role") == "assistant": + tool_calls = item_get(item, "tool_calls") + if isinstance(tool_calls, list): + for tc in tool_calls: + if not isinstance(tc, dict): + continue + tc_id = tc.get("id") or (tc.get("function") or {}).get("id") + if tc_id: + declared_call_ids.add(tc_id) + + sanitized: list[Any] = [] + seen_calls: set[str] = set() + seen_outputs: set[str] = set() + injected = 0 + dropped_orphan_outputs = 0 + dropped_duplicates = 0 + + for item in items: + t = item_get(item, "type") + cid = item_get(item, "call_id") + if t == "function_call" and cid: + if cid in seen_calls: + dropped_duplicates += 1 + continue + seen_calls.add(cid) + sanitized.append(item) + if cid not in call_ids_with_output: + sanitized.append( + { + "type": "function_call_output", + "call_id": cid, + "output": synthetic_output, + } + ) + injected += 1 + elif t == "function_call_output" and cid: + if cid in seen_outputs: + dropped_duplicates += 1 + continue + if cid not in declared_call_ids: + dropped_orphan_outputs += 1 + continue + seen_outputs.add(cid) + sanitized.append(item) + else: + sanitized.append(item) + + if not (injected or dropped_orphan_outputs or dropped_duplicates): + # Happy path: hand back the original list so callers can skip + # re-persistence by identity comparison (``sanitized is items``). + return items + + logger.info( + "%s: injected=%d dropped_orphan_outputs=%d dropped_duplicates=%d " + "original=%d final=%d", + log_prefix, + injected, + dropped_orphan_outputs, + dropped_duplicates, + len(items), + len(sanitized), + ) + return sanitized diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 0c96bd71..d0933eb1 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -35,6 +35,7 @@ from mlflow.tracing.constant import SpanAttributeKey from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured +from databricks_ai_bridge.long_running.repair import sanitize_tool_items from databricks_ai_bridge.long_running.repository import ( append_message, claim_stale_response, @@ -143,67 +144,62 @@ def _age_seconds(created_at: datetime) -> float: return (now - created_at).total_seconds() -_INHERITABLE_ITEM_TYPES = ("function_call", "function_call_output", "message") +# Tool-pair items are collected into the main inheritance bucket and +# preserved in event-log order. Narrative ``message`` items are routed +# separately in the collector (see ``_collect_prior_attempt_tool_events``) +# so they can be hoisted past the tool pairs for Anthropic adapter +# compatibility — adding ``"message"`` here would break that hoist. +_TOOL_PAIR_TYPES = ("function_call", "function_call_output") -def _collect_prior_attempt_tool_events( - messages: list[tuple], prior_attempt_number: int -) -> list[dict]: - """Return conversational items the given prior attempt already emitted, - reordered so the replay is a valid provider message sequence. - - Collected pieces: - - 1. **Completed tool pairs**: ``response.output_item.done`` events for - ``function_call`` + ``function_call_output``. - 2. **Completed narrative messages**: ``response.output_item.done`` - events for assistant ``message`` items. - 3. **Partial assistant text**: if the crash happened mid-stream on a - text response, the message has ``output_item.added`` + - ``output_text.delta`` frames but no ``done``. We reassemble the - deltas into a synthetic ``message`` item so the next attempt's LLM - sees where narration trailed off and can continue. - - The emitted order is (tool pairs in event order) → (narrative messages - in event order) → (partial reassembled message). Claude's raw stream - interleaves narrative `message` items between function_call and its - function_call_output, which in Anthropic format would look like - ``assistant(tool_use)`` → ``assistant(text)`` → ``user(tool_result)`` - and trip the provider's "tool_use must be immediately followed by - tool_result" rule (HTTP 400). Hoisting narrative messages to the end - keeps each function_call adjacent to its output and lets the narrative - flow as a trailing assistant block. - - ``messages`` is the repository's tuples: ``(seq, item_json, stream_event, - attempt_number)``. - """ - tool_items: list[dict] = [] - narrative_items: list[dict] = [] - # Track in-progress message items by id: accumulated text chunks. Reset - # when a matching .done arrives — that completed item is authoritative. - in_progress_text: dict[str, list[str]] = {} - in_progress_order: list[str] = [] +def _iter_attempt_events(messages: list[tuple], attempt: int): + """Yield ``(event_type, event_dict)`` pairs for the given attempt. + Skips rows from other attempts and non-dict event payloads so callers + can write single-concern walkers without repeating the same filter. + """ for _seq, _item_json, evt, attempt_tag in messages: - if attempt_tag != prior_attempt_number: + if attempt_tag != attempt: continue if not isinstance(evt, dict): continue - t = evt.get("type") + yield evt.get("type"), evt - if t == "response.output_item.done": - item = evt.get("item") - if isinstance(item, dict): - itype = item.get("type") - if itype == "message": - narrative_items.append(item) - elif itype in _INHERITABLE_ITEM_TYPES: - tool_items.append(item) - iid = item.get("id") - if iid in in_progress_text: - del in_progress_text[iid] - in_progress_order.remove(iid) - elif t == "response.output_item.added": + +def _extract_completed_items( + messages: list[tuple], attempt: int +) -> tuple[list[dict], list[dict]]: + """Scan ``.done`` events and partition into (tool pairs, narrative).""" + tool_items: list[dict] = [] + narrative_items: list[dict] = [] + for t, evt in _iter_attempt_events(messages, attempt): + if t != "response.output_item.done": + continue + item = evt.get("item") + if not isinstance(item, dict): + continue + itype = item.get("type") + if itype == "message": + narrative_items.append(item) + elif itype in _TOOL_PAIR_TYPES: + tool_items.append(item) + return tool_items, narrative_items + + +def _reassemble_partial_message(messages: list[tuple], attempt: int) -> dict | None: + """Return a synthetic assistant message if the attempt ended with a + never-completed in-flight text item, else ``None``. + + Tracks ``output_item.added`` for message items, accumulates their + ``output_text.delta`` frames, and drops the tracker when a matching + ``.done`` arrives (that item is authoritative). Anything left at the + end is an unfinished message whose deltas we stitch into a synthetic + item so the next attempt's LLM can continue the prior narration. + """ + in_progress_text: dict[str, list[str]] = {} + in_progress_order: list[str] = [] + for t, evt in _iter_attempt_events(messages, attempt): + if t == "response.output_item.added": item = evt.get("item") if isinstance(item, dict) and item.get("type") == "message": iid = item.get("id") @@ -211,6 +207,13 @@ def _collect_prior_attempt_tool_events( in_progress_text.setdefault(iid, []) if iid not in in_progress_order: in_progress_order.append(iid) + elif t == "response.output_item.done": + item = evt.get("item") + if isinstance(item, dict): + iid = item.get("id") + if iid in in_progress_text: + del in_progress_text[iid] + in_progress_order.remove(iid) elif t == "response.output_text.delta": iid = evt.get("item_id") delta = evt.get("delta") @@ -221,17 +224,40 @@ def _collect_prior_attempt_tool_events( chunks = in_progress_text.get(iid) or [] if not chunks: continue - narrative_items.append( - { - "type": "message", - "role": "assistant", - "content": [ - {"type": "output_text", "text": "".join(chunks), "annotations": []}, - ], - } - ) - # Tool pairs first (each function_call immediately followed by its - # function_call_output in the event-log order), narrative after. + return { + "type": "message", + "role": "assistant", + "content": [ + {"type": "output_text", "text": "".join(chunks), "annotations": []}, + ], + } + return None + + +def _collect_prior_attempt_tool_events( + messages: list[tuple], prior_attempt_number: int +) -> list[dict]: + """Return items the given prior attempt emitted, reordered to be a + valid provider message sequence on replay. + + Composition: (completed tool pairs in event order) → (completed + narrative messages in event order) → (partial reassembled message, if + any). Claude's raw stream interleaves narrative ``message`` items + between each ``function_call`` and its ``function_call_output``, which + in Anthropic format would look like ``assistant(tool_use)`` → + ``assistant(text)`` → ``user(tool_result)`` and trip the provider's + "tool_use must be immediately followed by tool_result" rule (HTTP + 400). Hoisting narrative past the tool pairs keeps each function_call + adjacent to its output and lets the narrative flow as a trailing + assistant block. + + ``messages`` is the repository's tuples ``(seq, item_json, + stream_event, attempt_number)``. + """ + tool_items, narrative_items = _extract_completed_items(messages, prior_attempt_number) + partial = _reassemble_partial_message(messages, prior_attempt_number) + if partial is not None: + narrative_items.append(partial) return tool_items + narrative_items @@ -239,103 +265,20 @@ def _sanitize_request_input( request_dict: dict[str, Any], synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, ) -> dict[str, Any]: - """Reconcile orphaned function_call / function_call_output items in input. - - Walks ``request['input']`` end-to-end (not just the trailing turn) and: - - * drops duplicate ``function_call`` items by ``call_id``; - * drops duplicate or orphan ``function_call_output`` items (no matching - ``function_call`` anywhere in the list); - * injects a synthetic ``function_call_output`` immediately after any - ``function_call`` that has no output, so every ``tool_use`` is paired. + """Reconcile orphaned function_call / function_call_output items in + ``request['input']`` via the shared :func:`sanitize_tool_items` walker. - Also supports chat-completions-shape assistant items - (``{role: "assistant", tool_calls: [...]}``) as declaring call_ids. - - This is a pure transform on a dict — no pydantic round-trip, no DB I/O. - Returns the same dict (mutated in place on the ``input`` key) for caller - convenience. - - Reason we walk the whole history: the LangGraph in-graph middleware can - only repair the trailing turn, but UI-echoed history can carry orphans - from prior crashed turns mid-list. See rotation-findings.md (Test E). + Walks the whole history (not just the trailing turn) because UI-echoed + history can carry orphans from prior crashed turns mid-list. + Mutates ``request_dict['input']`` in place and returns the dict for + caller convenience. """ items = request_dict.get("input") if not isinstance(items, list) or not items: return request_dict - - declared_call_ids: set[str] = set() - call_ids_with_output: set[str] = set() - for i in items: - if not isinstance(i, dict): - continue - t = i.get("type") - cid = i.get("call_id") - if t == "function_call" and cid: - declared_call_ids.add(cid) - if t == "function_call_output" and cid: - call_ids_with_output.add(cid) - if i.get("role") == "assistant" and isinstance(i.get("tool_calls"), list): - for tc in i["tool_calls"]: - if not isinstance(tc, dict): - continue - tc_id = tc.get("id") or (tc.get("function") or {}).get("id") - if tc_id: - declared_call_ids.add(tc_id) - - sanitized: list[dict[str, Any]] = [] - seen_calls: set[str] = set() - seen_outputs: set[str] = set() - injected = 0 - dropped_orphan_outputs = 0 - dropped_duplicates = 0 - - for item in items: - if not isinstance(item, dict): - sanitized.append(item) - continue - t = item.get("type") - cid = item.get("call_id") - - if t == "function_call" and cid: - if cid in seen_calls: - dropped_duplicates += 1 - continue - seen_calls.add(cid) - sanitized.append(item) - if cid not in call_ids_with_output: - sanitized.append( - { - "type": "function_call_output", - "call_id": cid, - "output": synthetic_output, - } - ) - injected += 1 - elif t == "function_call_output" and cid: - if cid in seen_outputs: - dropped_duplicates += 1 - continue - if cid not in declared_call_ids: - dropped_orphan_outputs += 1 - continue - seen_outputs.add(cid) - sanitized.append(item) - else: - sanitized.append(item) - - if injected or dropped_orphan_outputs or dropped_duplicates: - logger.info( - "[durable] input sanitized: injected=%d dropped_orphan_outputs=%d " - "dropped_duplicates=%d original_items=%d final_items=%d", - injected, - dropped_orphan_outputs, - dropped_duplicates, - len(items), - len(sanitized), - ) - - request_dict["input"] = sanitized + request_dict["input"] = sanitize_tool_items( + items, synthetic_output, log_prefix="[durable] input sanitized" + ) return request_dict @@ -475,6 +418,36 @@ def __init__( heartbeat_interval_seconds: float = 3.0, heartbeat_stale_threshold_seconds: float = 10.0, ): + """Create a durable-resume-enabled agent server. + + Args: + agent_type: Must be ``"ResponsesAgent"``; this class is + Responses-API-shaped only. + enable_chat_proxy: Forwarded to the parent ``AgentServer``. + db_instance_name: Provisioned Lakebase instance name. Mutually + exclusive with the autoscaling options. + db_autoscaling_endpoint: Lakebase autoscaling endpoint URL. + db_project: Lakebase autoscaling project name (requires + ``db_branch``). + db_branch: Lakebase autoscaling branch name (requires + ``db_project``). + task_timeout_seconds: Max wall-clock time for a background run + before ``_task_scope`` fires a timeout and marks the + response ``failed``. + poll_interval_seconds: Interval between DB polls when streaming + cached events to a retrieve-endpoint client. + db_statement_timeout_ms: Postgres ``statement_timeout`` applied + to every DB statement from this process. + cleanup_timeout_seconds: Deadline for the post-crash error + write path before giving up and letting the stale-run + sweep mark the row failed. + heartbeat_interval_seconds: How often the owning pod writes + ``heartbeat_at`` while a run is in flight. + heartbeat_stale_threshold_seconds: Age at which a heartbeat is + considered stale and another pod may claim the run. + Also used as the grace window for a freshly-created run + that hasn't written its first heartbeat yet. + """ if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( f"LongRunningAgentServer only supports '{self._SUPPORTED_AGENT_TYPE}', " From 933333c92f2e10c8c32161364da28ad703029315 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 06:42:16 +0000 Subject: [PATCH 35/39] Consolidate [INTERRUPTED] synthetic output + simplify session API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cleanup pass on the durable-resume surface following PR review: 1. Single source of truth for the synthetic tool-output string. Previously three copies (server.py, session.py, checkpoint.py), each under a different constant name, all identical. Follow the existing pattern from lakebase.py (defined once, imported into integrations) and lift it to long_running/repair.py as ``DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT``. It becomes the default for ``sanitize_tool_items(synthetic_output=...)`` so callers that don't customise don't pass anything. 2. Drop the ``auto_repair`` / ``auto_repair_synthetic_output`` parameters from ``AsyncDatabricksSession.__init__``. No realistic caller wants auto-repair off (it yields protocol-invalid histories that break the next call), and no one needs to customise the synthetic output text. ``get_items()`` is now unconditionally repaired and ``_sanitize_items`` is a one-liner that only sets the session-scoped log prefix. 3. Remove the pass-through ``synthetic_output`` parameter from ``_sanitize_request_input`` (server.py) and ``_build_tool_resume_repair`` (checkpoint.py). Both were internal helpers nothing else overrode — the caller trail simply forwarded the default to ``sanitize_tool_items``. 4. Restore the class-level Args docstring on ``LongRunningAgentServer`` matching the origin/main shape. Existing param entries keep their original wording verbatim; the only additions are ``heartbeat_interval_seconds`` and ``heartbeat_stale_threshold_seconds``. The redundant ``__init__`` docstring is removed so there's one source of param docs. Net: -19 LOC in session.py, -27 LOC in server.py, -16 LOC in checkpoint.py; the shared constant exists once; all 151 tests (long_running + langchain checkpoint + openai session) pass. Co-authored-by: Isaac --- .../src/databricks_langchain/checkpoint.py | 20 ++--- .../src/databricks_openai/agents/session.py | 51 +++---------- .../openai/tests/unit_tests/test_session.py | 30 +++----- .../long_running/repair.py | 15 +++- .../long_running/server.py | 74 +++++++------------ .../test_long_running_server.py | 2 +- 6 files changed, 69 insertions(+), 123 deletions(-) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index fa5b2a72..068f92dc 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -5,6 +5,7 @@ from typing import Any, Sequence from databricks.sdk import WorkspaceClient +from databricks_ai_bridge.long_running.repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT logger = logging.getLogger(__name__) @@ -30,19 +31,7 @@ _message_imports_available = False -DEFAULT_TOOL_RESUME_REPAIR_OUTPUT = ( - "[INTERRUPTED] This tool call did not complete due to a server " - "interruption, so no result is available. Other tool calls in the " - "conversation history completed normally and their results remain valid. " - "If the information is still needed, re-invoking only this specific tool " - "is usually sufficient." -) - - -def _build_tool_resume_repair( - messages: Sequence[Any], - synthetic_output: str = DEFAULT_TOOL_RESUME_REPAIR_OUTPUT, -) -> list[Any]: +def _build_tool_resume_repair(messages: Sequence[Any]) -> list[Any]: """Build synthetic ``ToolMessage`` responses for orphan tool calls. Internal helper used by ``_repair_loaded_checkpoint_tuple``. When a @@ -85,7 +74,10 @@ def _build_tool_resume_repair( answered.add(tcid) orphans = [tc_id for tc_id in tool_call_ids if tc_id not in answered] - return [ToolMessage(tool_call_id=tc_id, content=synthetic_output) for tc_id in orphans] + return [ + ToolMessage(tool_call_id=tc_id, content=DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT) + for tc_id in orphans + ] def _repair_loaded_checkpoint_tuple(tup: Any) -> Any: diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index e991e75a..87821560 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -35,14 +35,6 @@ async def main(): from threading import Lock from typing import Any, Optional -DEFAULT_REPAIR_SYNTHETIC_OUTPUT = ( - "[INTERRUPTED] This tool call did not complete due to a server " - "interruption, so no result is available. Other tool calls in the " - "conversation history completed normally and their results remain valid. " - "If the information is still needed, re-invoking only this specific tool " - "is usually sufficient." -) - try: from agents.extensions.memory import SQLAlchemySession from databricks.sdk import WorkspaceClient @@ -64,24 +56,12 @@ async def main(): logger = logging.getLogger(__name__) -def _sanitize_items( - items: list[Any], - synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, -) -> list[Any]: - """Return a protocol-valid view of session items (thin wrapper around - the shared :func:`sanitize_tool_items` walker). - - Kept as a private alias so existing ``self._sanitize_items`` call sites - in this module stay stable. Behaviour: drop duplicate / orphan tool - items, inject synthetic outputs for unpaired function_calls, return - the caller's list reference unchanged on the happy path so - ``repair()`` can short-circuit re-persistence. +def _sanitize_items(items: list[Any]) -> list[Any]: + """Session-scoped wrapper around :func:`sanitize_tool_items` that only + sets the log prefix. Kept as a one-liner so existing + ``self._sanitize_items`` call sites stay stable. """ - if not items: - return items - return sanitize_tool_items( - items, synthetic_output, log_prefix="[durable] session items sanitized" - ) + return sanitize_tool_items(items, log_prefix="[durable] session items sanitized") class AsyncDatabricksSession(SQLAlchemySession): @@ -138,8 +118,6 @@ def __init__( sessions_table: str = "agent_sessions", messages_table: str = "agent_messages", use_cached_engine: bool = True, - auto_repair: bool = True, - auto_repair_synthetic_output: str = DEFAULT_REPAIR_SYNTHETIC_OUTPUT, **engine_kwargs, ) -> None: """ @@ -174,9 +152,6 @@ def __init__( "Please install with: pip install databricks-openai[memory]" ) - self._auto_repair = auto_repair - self._auto_repair_synthetic_output = auto_repair_synthetic_output - self._lakebase = self._get_or_create_lakebase( instance_name=instance_name, autoscaling_endpoint=autoscaling_endpoint, @@ -204,18 +179,16 @@ def __init__( ) async def get_items(self, limit: Optional[int] = None) -> list[Any]: - """Return session items, repaired for protocol validity when enabled. + """Return session items, always repaired for protocol validity. - When ``auto_repair=True`` (default), the returned list has every - ``function_call`` paired with a ``function_call_output`` — orphans - from a durable-resume crash get a synthetic output appended, and - duplicates get deduped. The underlying DB rows are not modified; - this is a pure in-memory filter, cheap to re-run on every call. + The returned list has every ``function_call`` paired with a + ``function_call_output`` — orphans from a durable-resume crash get + a synthetic output appended, and duplicates get deduped. The + underlying DB rows are not modified; this is a pure in-memory + filter, cheap to re-run on every call. """ items = await super().get_items(limit=limit) - if not self._auto_repair: - return items - return _sanitize_items(items, synthetic_output=self._auto_repair_synthetic_output) + return _sanitize_items(items) @classmethod def _build_cache_key( diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py index 1eec0f31..11cf7c20 100644 --- a/integrations/openai/tests/unit_tests/test_session.py +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -1368,27 +1368,22 @@ def test_dedupes_duplicate_calls_and_outputs(self): class TestAsyncGetItemsAutoRepair: - """get_items() applies read-time repair when auto_repair=True. Uses a - minimal subclass that bypasses parent SQLAlchemySession init so we can - exercise the override without a DB.""" + """get_items() always applies read-time repair. Uses a minimal subclass + that bypasses parent SQLAlchemySession init so we can exercise the + override without a DB.""" - def _fake_session(self, items, auto_repair=True): + def _fake_session(self, items): from databricks_openai.agents.session import AsyncDatabricksSession, _sanitize_items class _FakeSession(AsyncDatabricksSession): - def __init__(self, stored, auto): - # Bypass parent init — only need the auto-repair flags. - self._auto_repair = auto - self._auto_repair_synthetic_output = "INTERRUPTED" + def __init__(self, stored): + # Bypass parent init — only need the stored items. self._stored = stored async def get_items(self, limit=None): - items = list(self._stored) - if not self._auto_repair: - return items - return _sanitize_items(items, synthetic_output=self._auto_repair_synthetic_output) + return _sanitize_items(list(self._stored)) - return _FakeSession(items, auto_repair) + return _FakeSession(items) @pytest.mark.asyncio async def test_auto_repair_injects_synthetic_outputs(self): @@ -1400,12 +1395,5 @@ async def test_auto_repair_injects_synthetic_outputs(self): ] ) items = await sess.get_items() - synth = [i for i in items if i.get("output") == "INTERRUPTED"] + synth = [i for i in items if i.get("type") == "function_call_output"] assert len(synth) == 2 - - @pytest.mark.asyncio - async def test_auto_repair_off_returns_raw_items(self): - raw = [{"type": "function_call", "call_id": "c1", "name": "f", "arguments": "{}"}] - sess = self._fake_session(list(raw), auto_repair=False) - items = await sess.get_items() - assert items == raw diff --git a/src/databricks_ai_bridge/long_running/repair.py b/src/databricks_ai_bridge/long_running/repair.py index de16ca39..a666288b 100644 --- a/src/databricks_ai_bridge/long_running/repair.py +++ b/src/databricks_ai_bridge/long_running/repair.py @@ -21,6 +21,19 @@ logger = logging.getLogger(__name__) +#: Default body for the synthetic ``function_call_output`` injected when a +#: prior attempt's tool call has no matching output (e.g. the pod was killed +#: between emitting the call and its result). Shared between the server-side +#: input sanitizer and integration-side read-time repair paths so the user- +#: visible text stays consistent across the durable-resume contract. +DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT = ( + "[INTERRUPTED] This tool call did not complete due to a server " + "interruption, so no result is available. Other tool calls in the " + "conversation history completed normally and their results remain valid. " + "If the information is still needed, re-invoking only this specific tool " + "is usually sufficient." +) + def _default_item_get(item: Any, key: str) -> Any: if isinstance(item, dict): @@ -30,7 +43,7 @@ def _default_item_get(item: Any, key: str) -> Any: def sanitize_tool_items( items: list[Any], - synthetic_output: str, + synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, *, item_get: Callable[[Any, str], Any] = _default_item_get, log_prefix: str = "[durable] items sanitized", diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index d0933eb1..9d1bf636 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -53,18 +53,6 @@ BACKGROUND_KEY = "background" -# Synthetic output injected for an orphaned function_call whose matching -# function_call_output was lost to a pod crash. The text is prescriptive: -# without the "do NOT re-invoke tools that already returned" guidance, models -# tend to restart the whole tool sequence on resume. -DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT = ( - "[INTERRUPTED] This tool call did not complete due to a server " - "interruption, so no result is available. Other tool calls in the " - "conversation history completed normally and their results remain valid. " - "If the information is still needed, re-invoking only this specific tool " - "is usually sufficient." -) - # One ID per process so heartbeats + claims have a stable owner identity. _POD_ID = f"{socket.gethostname()}-{os.getpid()}-{uuid.uuid4().hex[:8]}" @@ -261,10 +249,7 @@ def _collect_prior_attempt_tool_events( return tool_items + narrative_items -def _sanitize_request_input( - request_dict: dict[str, Any], - synthetic_output: str = DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, -) -> dict[str, Any]: +def _sanitize_request_input(request_dict: dict[str, Any]) -> dict[str, Any]: """Reconcile orphaned function_call / function_call_output items in ``request['input']`` via the shared :func:`sanitize_tool_items` walker. @@ -277,7 +262,7 @@ def _sanitize_request_input( if not isinstance(items, list) or not items: return request_dict request_dict["input"] = sanitize_tool_items( - items, synthetic_output, log_prefix="[durable] input sanitized" + items, log_prefix="[durable] input sanitized" ) return request_dict @@ -396,6 +381,31 @@ class LongRunningAgentServer(AgentServer): emitted tool calls / outputs / narrative, and an ``[INTERRUPTED]`` synthetic output paired with any tool call that didn't finish. Completed work is preserved; only the interrupted step re-runs. + + Args: + enable_chat_proxy: Whether to enable the chat proxy endpoint. + db_instance_name: Lakebase provisioned instance name. Overrides + ``LAKEBASE_INSTANCE_NAME``. + db_autoscaling_endpoint: Lakebase autoscaling endpoint URL. Overrides + ``LAKEBASE_AUTOSCALING_ENDPOINT``. + db_project: Lakebase autoscaling project. Overrides + ``LAKEBASE_AUTOSCALING_PROJECT``. + db_branch: Lakebase autoscaling branch. Overrides + ``LAKEBASE_AUTOSCALING_BRANCH``. + task_timeout_seconds: Max time for a background task before timeout. + Defaults to 3600 (1 hour). + poll_interval_seconds: Interval between DB polls when streaming. + Defaults to 1.0. + db_statement_timeout_ms: Postgres statement timeout. + Defaults to 5000 (5 seconds). + cleanup_timeout_seconds: Timeout for DB cleanup after task failure. + Defaults to 7.0. + heartbeat_interval_seconds: How often the owning pod writes + ``heartbeat_at`` while a run is in flight. Defaults to 3.0. + heartbeat_stale_threshold_seconds: Age at which a heartbeat is + considered stale and another pod may claim the run. Also used + as the grace window for a freshly-created run that hasn't + written its first heartbeat yet. Defaults to 10.0. """ _SUPPORTED_AGENT_TYPE = "ResponsesAgent" @@ -418,36 +428,6 @@ def __init__( heartbeat_interval_seconds: float = 3.0, heartbeat_stale_threshold_seconds: float = 10.0, ): - """Create a durable-resume-enabled agent server. - - Args: - agent_type: Must be ``"ResponsesAgent"``; this class is - Responses-API-shaped only. - enable_chat_proxy: Forwarded to the parent ``AgentServer``. - db_instance_name: Provisioned Lakebase instance name. Mutually - exclusive with the autoscaling options. - db_autoscaling_endpoint: Lakebase autoscaling endpoint URL. - db_project: Lakebase autoscaling project name (requires - ``db_branch``). - db_branch: Lakebase autoscaling branch name (requires - ``db_project``). - task_timeout_seconds: Max wall-clock time for a background run - before ``_task_scope`` fires a timeout and marks the - response ``failed``. - poll_interval_seconds: Interval between DB polls when streaming - cached events to a retrieve-endpoint client. - db_statement_timeout_ms: Postgres ``statement_timeout`` applied - to every DB statement from this process. - cleanup_timeout_seconds: Deadline for the post-crash error - write path before giving up and letting the stale-run - sweep mark the row failed. - heartbeat_interval_seconds: How often the owning pod writes - ``heartbeat_at`` while a run is in flight. - heartbeat_stale_threshold_seconds: Age at which a heartbeat is - considered stale and another pod may claim the run. - Also used as the grace window for a freshly-created run - that hasn't written its first heartbeat yet. - """ if agent_type != self._SUPPORTED_AGENT_TYPE: raise ValueError( f"LongRunningAgentServer only supports '{self._SUPPORTED_AGENT_TYPE}', " diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 2fc04621..226d6d1a 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,9 +14,9 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") +from databricks_ai_bridge.long_running.repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( - DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT, LongRunningAgentServer, _collect_prior_attempt_tool_events, _deferred_mark_failed, From d42ceb2bfcdcb91e6c1248f3f061e59c973203cd Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 23:27:41 +0000 Subject: [PATCH 36/39] Apply ruff format + drop unused import Fix the two CI lint failures exposed after the main merge: - repair.py: drop unused ``Optional`` import (F401) - repair.py + server.py: ruff format reformatting Co-authored-by: Isaac --- src/databricks_ai_bridge/long_running/repair.py | 5 ++--- src/databricks_ai_bridge/long_running/server.py | 14 ++++---------- 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/src/databricks_ai_bridge/long_running/repair.py b/src/databricks_ai_bridge/long_running/repair.py index a666288b..99a7e43c 100644 --- a/src/databricks_ai_bridge/long_running/repair.py +++ b/src/databricks_ai_bridge/long_running/repair.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Optional +from typing import Any, Callable logger = logging.getLogger(__name__) @@ -138,8 +138,7 @@ def sanitize_tool_items( return items logger.info( - "%s: injected=%d dropped_orphan_outputs=%d dropped_duplicates=%d " - "original=%d final=%d", + "%s: injected=%d dropped_orphan_outputs=%d dropped_duplicates=%d original=%d final=%d", log_prefix, injected, dropped_orphan_outputs, diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index 9d1bf636..e23a30e8 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -154,9 +154,7 @@ def _iter_attempt_events(messages: list[tuple], attempt: int): yield evt.get("type"), evt -def _extract_completed_items( - messages: list[tuple], attempt: int -) -> tuple[list[dict], list[dict]]: +def _extract_completed_items(messages: list[tuple], attempt: int) -> tuple[list[dict], list[dict]]: """Scan ``.done`` events and partition into (tool pairs, narrative).""" tool_items: list[dict] = [] narrative_items: list[dict] = [] @@ -261,9 +259,7 @@ def _sanitize_request_input(request_dict: dict[str, Any]) -> dict[str, Any]: items = request_dict.get("input") if not isinstance(items, list) or not items: return request_dict - request_dict["input"] = sanitize_tool_items( - items, log_prefix="[durable] input sanitized" - ) + request_dict["input"] = sanitize_tool_items(items, log_prefix="[durable] input sanitized") return request_dict @@ -312,8 +308,7 @@ def _rotate_conversation_id( ctx["conversation_id"] = rotated request_dict["context"] = ctx logger.info( - "[durable] rotated conversation_id for resume response_id=%s " - "attempt=%d base=%s rotated=%s", + "[durable] rotated conversation_id for resume response_id=%s attempt=%d base=%s rotated=%s", response_id, new_attempt_number, base_anchor, @@ -1084,8 +1079,7 @@ async def _try_claim_and_resume(self, response_id: str, resp) -> int | None: resume_input.extend(prior_tool_events) resume_dict["input"] = resume_input logger.info( - "[durable] resume inherited %d tool-event item(s) from attempt %d " - "response_id=%s", + "[durable] resume inherited %d tool-event item(s) from attempt %d response_id=%s", len(prior_tool_events), new_attempt - 1, response_id, From 3c1ca101f4e8a71e006ce12665b7a1fae047603a Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Thu, 23 Apr 2026 23:33:24 +0000 Subject: [PATCH 37/39] Move tool_repair.py out of long_running/ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``databricks_langchain.checkpoint`` is Python 3.10+ and re-exported eagerly from ``databricks_langchain/__init__.py``. Placing ``DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT`` inside ``long_running/repair.py`` meant any ``import databricks_langchain`` triggered ``long_running/__init__.py``, which: 1. ``from .server import ...`` — ``server.py`` raises RuntimeError on Python < 3.11, breaking langchain tests on 3.10. 2. ``from .db import ...`` — pulls ``psycopg``, breaking any langchain test that doesn't have ``[memory]`` extras installed. Both failed in CI. Move the file to top-level ``databricks_ai_bridge.tool_repair`` so it can be imported without crossing the ``long_running/`` package boundary. Three import sites updated (server, session, checkpoint) plus the test. No API break — ``sanitize_tool_items`` and ``DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT`` were still pre-release. Co-authored-by: Isaac --- integrations/langchain/src/databricks_langchain/checkpoint.py | 2 +- integrations/openai/src/databricks_openai/agents/session.py | 2 +- src/databricks_ai_bridge/long_running/server.py | 2 +- .../{long_running/repair.py => tool_repair.py} | 0 tests/databricks_ai_bridge/test_long_running_server.py | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename src/databricks_ai_bridge/{long_running/repair.py => tool_repair.py} (100%) diff --git a/integrations/langchain/src/databricks_langchain/checkpoint.py b/integrations/langchain/src/databricks_langchain/checkpoint.py index 2ec87f7c..361a7040 100644 --- a/integrations/langchain/src/databricks_langchain/checkpoint.py +++ b/integrations/langchain/src/databricks_langchain/checkpoint.py @@ -5,7 +5,7 @@ from typing import Any, Sequence from databricks.sdk import WorkspaceClient -from databricks_ai_bridge.long_running.repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT +from databricks_ai_bridge.tool_repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT logger = logging.getLogger(__name__) diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py index 4e6c680c..5850892d 100644 --- a/integrations/openai/src/databricks_openai/agents/session.py +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -43,7 +43,7 @@ async def main(): DEFAULT_TOKEN_CACHE_DURATION_SECONDS, AsyncLakebaseSQLAlchemy, ) - from databricks_ai_bridge.long_running.repair import sanitize_tool_items + from databricks_ai_bridge.tool_repair import sanitize_tool_items _session_imports_available = True except ImportError: diff --git a/src/databricks_ai_bridge/long_running/server.py b/src/databricks_ai_bridge/long_running/server.py index e23a30e8..5d10dfdd 100644 --- a/src/databricks_ai_bridge/long_running/server.py +++ b/src/databricks_ai_bridge/long_running/server.py @@ -35,7 +35,6 @@ from mlflow.tracing.constant import SpanAttributeKey from databricks_ai_bridge.long_running.db import dispose_db, init_db, is_db_configured -from databricks_ai_bridge.long_running.repair import sanitize_tool_items from databricks_ai_bridge.long_running.repository import ( append_message, claim_stale_response, @@ -47,6 +46,7 @@ update_response_trace_id, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings +from databricks_ai_bridge.tool_repair import sanitize_tool_items from databricks_ai_bridge.utils.annotations import experimental logger = logging.getLogger(__name__) diff --git a/src/databricks_ai_bridge/long_running/repair.py b/src/databricks_ai_bridge/tool_repair.py similarity index 100% rename from src/databricks_ai_bridge/long_running/repair.py rename to src/databricks_ai_bridge/tool_repair.py diff --git a/tests/databricks_ai_bridge/test_long_running_server.py b/tests/databricks_ai_bridge/test_long_running_server.py index 226d6d1a..a8461379 100644 --- a/tests/databricks_ai_bridge/test_long_running_server.py +++ b/tests/databricks_ai_bridge/test_long_running_server.py @@ -14,7 +14,6 @@ pytest.importorskip("fastapi") pytest.importorskip("psycopg") -from databricks_ai_bridge.long_running.repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT from databricks_ai_bridge.long_running.repository import ResponseInfo from databricks_ai_bridge.long_running.server import ( LongRunningAgentServer, @@ -26,6 +25,7 @@ _sse_event, ) from databricks_ai_bridge.long_running.settings import LongRunningSettings +from databricks_ai_bridge.tool_repair import DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT # --------------------------------------------------------------------------- # Shared helpers From e3db5891eb154ab0bf54e02d9844d6eff04f4919 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 28 Apr 2026 05:07:11 +0000 Subject: [PATCH 38/39] Add AGENTS.md for LongRunningAgentServer design MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Addresses Bryan's review feedback on PR #416: 1. AGENTS.md describing all desired behaviors 2. CUJs explaining author requirements + today's interface vs TaskFlow 3. Mermaid diagrams for human reviewers (sequence, ER, flowcharts) Sections: - §1 Module purpose, guarantees, non-goals - §2 Four customer journeys with sequence diagrams (author writes agent, pod crashes mid-tool, subsequent turn after crash, multi-pod contention) - §3 Architecture: storage layout (ER diagram), four key flows (flowchart), resume input construction (flowchart), heartbeat tuning, CAS atomicity (sequence diagram) - §4 Author-side requirements: what's invisible vs what authors tune - §5 TaskFlow migration mapping (today → TaskFlow primitives, what stays, what gets deleted, sequencing) Co-authored-by: Isaac --- .../long_running/AGENTS.md | 385 ++++++++++++++++++ 1 file changed, 385 insertions(+) create mode 100644 src/databricks_ai_bridge/long_running/AGENTS.md diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md new file mode 100644 index 00000000..6956c473 --- /dev/null +++ b/src/databricks_ai_bridge/long_running/AGENTS.md @@ -0,0 +1,385 @@ +# LongRunningAgentServer + +Durable, crash-resumable agent execution for MLflow `ResponsesAgent` handlers. + +This document describes: +1. What `LongRunningAgentServer` does and the guarantees it gives callers ([§1](#1-what-this-module-does)). +2. The four customer journeys it covers, with sequence diagrams ([§2](#2-customer-journeys)). +3. The architecture: storage layout, claim mechanism, recovery, and stream resume ([§3](#3-architecture)). +4. Author-side requirements: what changes (and doesn't) when a handler opts into durable mode ([§4](#4-author-side-requirements)). +5. The interface today and how it's expected to evolve when [TaskFlow](https://github.com/databricks-eng/universe/tree/master/experimental/taskflow) lands ([§5](#5-future-direction-taskflow)). + +## 1. What this module does + +`LongRunningAgentServer` extends MLflow's `AgentServer` for `ResponsesAgent` handlers with three capabilities: + +1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres. +2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events. +3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run, replays prior tool/output state, and finishes the work. Tool results that completed before the crash are preserved. + +Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server. + +### Guarantees + +- **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`. +- **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`. +- **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility. +- **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. Repair logic lives below the handler boundary. + +### Non-goals + +- Cross-region failover. Pods are assumed to share one Lakebase. +- Tool-level checkpointing / exactly-once tool execution. +- A workflow DSL. Handlers are ordinary async generators / coroutines. + +## 2. Customer journeys + +### CUJ 1: Author writes a long-running agent + +The author subclasses `LongRunningAgentServer` and registers `@invoke()` / `@stream()` handlers like a regular MLflow agent server. **No durability code in `agent.py`.** + +```python +from databricks_ai_bridge.long_running import LongRunningAgentServer +from mlflow.genai.agent_server import invoke, stream + +agent_server = LongRunningAgentServer( + "ResponsesAgent", + db_instance_name="my-lakebase-instance", +) + +@stream() +async def stream_handler(request): + # ordinary agent code: build messages, call SDK, yield events + ... + +@invoke() +async def invoke_handler(request): + ... + +app = agent_server.app +``` + +The agent author writes their handler exactly the same way they would for the non-durable `AgentServer`. `LongRunningAgentServer` adds the durable wiring transparently. + +### CUJ 2: Pod crashes mid-tool, client polls + +A client posts a long-running request, the owning pod dies mid-tool, another pod takes over, and the client gets the final output without restarting. + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant A as Pod A (owner) + participant DB as Lakebase
(agent_server.*) + participant B as Pod B + participant SDK as SDK store
(checkpointer / session) + + C->>A: POST /responses
{input, background:true, stream:true} + A->>DB: INSERT response_id, owner=A,
original_request, attempt=1 + A-->>C: 200 {id: resp_xxx, status: in_progress} + activate A + Note over A: heartbeat loop (every 3s) + A->>DB: UPDATE heartbeat_at WHERE owner=A + + par Streaming + A->>SDK: write tool_use, tool_result + A->>DB: append events seq=0..N (attempt=1) + and Polling + C->>A: GET /responses/{id}?stream=true&starting_after=N + A-->>C: SSE events (response.output_text.delta, ...) + end + + Note over A: 💥 pod crashes mid-tool
(after tool_use, before tool_result) + deactivate A + Note over DB: heartbeat_at goes stale (>10s old) + + C->>B: GET /responses/{id}?stream=true&starting_after=K + B->>DB: SELECT heartbeat_at, attempt + Note over B: heartbeat stale → claim + B->>DB: CAS UPDATE owner=B, attempt=2
WHERE attempt=1 (atomic) + B->>DB: SELECT prior events WHERE attempt=1 + Note over B: build resume input:
· original_request.input
· + completed tool pairs
· + synthetic [INTERRUPTED] for orphan
· rotate conv_id → ::attempt-2 + B->>DB: append response.resumed sentinel + B->>B: re-invoke @stream() handler
(fresh SDK session via rotation) + activate B + B->>SDK: write to {original_id}::attempt-2 store + B->>DB: append events seq=K+1..M (attempt=2) + B-->>C: SSE events (response.resumed, then attempt-2 events from K+1) + deactivate B +``` + +**What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the boundary between attempts so the UI can render appropriately (e.g., "reconnecting…" + a fresh attempt bubble). + +**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus the carried-forward tool events plus a synthetic `[INTERRUPTED]` for any tool whose result didn't land. The handler doesn't have to know any of this. + +### CUJ 3: Subsequent turn after a crashed turn + +After a crash + resume, the next turn from the client lands on a fresh `POST /responses`. The SDK's storage (LangGraph checkpointer or OpenAI session) still contains the orphan `tool_use` from the crashed attempt because the SDK persisted it before the result landed. Without intervention, the next LLM call would be rejected by the provider's `tool_use → tool_result` pairing rule. + +```mermaid +sequenceDiagram + autonumber + participant C as Client + participant S as Server (any pod) + participant SDK as SDK store + + Note over SDK: state from prior turn:
[..., tool_use_X (orphan)] + + C->>S: POST /responses
{input: [..., new user msg]} + S->>S: handler runs + S->>SDK: load history (e.g. aget_tuple / get_items) + Note over S,SDK: read-time repair wraps the SDK call:
injects synthetic tool_result for any
orphan tool_use in the trailing turn + SDK-->>S: clean state (orphan paired) + S->>S: model receives protocol-valid history + Note over S: turn succeeds normally +``` + +The repair wrappers are in `databricks-langchain.AsyncCheckpointSaver.aget_tuple` and `databricks-openai.AsyncDatabricksSession.get_items`. They're idempotent and no-op on clean state. + +### CUJ 4: Multi-pod stale-claim contention + +Two pods each see the same response in `in_progress` with a stale heartbeat. Only one wins the CAS. + +```mermaid +sequenceDiagram + autonumber + participant B as Pod B + participant DB as Lakebase + participant C as Pod C + + Note over B,C: both pods see response in_progress, heartbeat stale + par + B->>DB: UPDATE responses SET owner=B, attempt=N+1
WHERE response_id=R AND attempt=N + and + C->>DB: UPDATE responses SET owner=C, attempt=N+1
WHERE response_id=R AND attempt=N + end + Note over DB: only one row matches; the other UPDATE returns 0 rows + DB-->>B: RETURNING attempt_number=N+1 + DB-->>C: RETURNING (no row) + Note over B: B wins, spawns resume handler + Note over C: C aborts cleanly, returns to its retrieve loop +``` + +The `claim_stale_response` function (`repository.py`) executes a single `UPDATE … WHERE attempt_number = :current AND ((heartbeat_at IS NULL) OR (heartbeat_at < now() - interval))` with `RETURNING`. Postgres serializes the writes; only the pod whose `current` value was unmodified at commit time gets the `RETURNING` row. + +## 3. Architecture + +### 3.1 Storage layout + +Two tables in the `agent_server` schema: + +```mermaid +erDiagram + responses { + text response_id PK + text status "in_progress / completed / failed" + timestamptz created_at + text owner_pod_id + timestamptz heartbeat_at + int attempt_number + text original_request "JSON of initial POST" + text trace_id + } + messages { + text response_id FK + int sequence_number + int attempt_number + text item "JSON of output item" + text stream_event "JSON of SSE frame" + } + responses ||--o{ messages : "has" +``` + +- `responses.attempt_number` is the CAS guard for claim atomicity. +- `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body). +- Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime. + +### 3.2 The four key flows + +```mermaid +flowchart TD + POST["POST /responses
background=true"] --> CR[create_response
store original_request] + CR --> SPAWN[spawn @stream handler] + SPAWN --> HB[heartbeat loop
every 3s] + SPAWN --> EMIT[append SSE events
to messages table] + EMIT --> DONE{handler
exits?} + DONE -- yes --> COMPLETE[update status=completed] + DONE -- crash --> STALE[heartbeat stops
row goes stale] + + GET["GET /responses/{id}
?stream=true&starting_after=N"] --> CHECK[check heartbeat age] + CHECK -->|fresh or terminal| READ[read events from messages
where seq > N] + CHECK -->|stale| CLAIM[CAS claim
attempt += 1] + CLAIM -->|won| BUILD[collect prior events
build resume input
rotate conv_id] + CLAIM -->|lost| READ + BUILD --> SPAWN + READ --> STREAM[SSE stream to client] + + KILL["POST /_debug/kill_task/{id}
(test-only)"] --> CANCEL[cancel asyncio task
without status update] + CANCEL --> STALE +``` + +### 3.3 Resume input construction + +When a stale-claim CAS succeeds, the new owner builds the resume input from the prior attempt's emitted events: + +```mermaid +flowchart LR + PRIOR[prior attempt's events
from messages table] --> WALK[_collect_prior_attempt_tool_events] + WALK --> POOL[completed tool pairs
function_call + function_call_output] + WALK --> NARR[completed narrative
output_item.done for messages] + WALK --> PARTIAL[partial assistant text
reassembled from deltas] + + POOL --> COMPOSE[compose: original_request.input
+ tool pairs
+ synthetic [INTERRUPTED] for orphan
+ narrative
+ partial] + NARR --> COMPOSE + PARTIAL --> COMPOSE + + ROT[_rotate_conversation_id
::attempt-N suffix] --> SANITIZE + COMPOSE --> SANITIZE[_sanitize_request_input
pair orphan function_call ids] + SANITIZE --> INVOKE[re-invoke @stream handler
with rotated request] +``` + +Why the rotation: the SDK's storage may carry mid-turn state from the crashed attempt that's hard to repair without SDK-internals knowledge. Rotating to `{base}::attempt-N` opens a fresh thread/session for the resumed attempt; the structured input carries the prior work forward in a shape both LangGraph and OpenAI handle natively. + +Why the sanitizer: the trailing assistant turn always has at least one orphan `function_call` (the one whose tool was interrupted). The sanitizer pairs it with a synthetic `[INTERRUPTED]` `function_call_output` so the next LLM call is protocol-valid. + +### 3.4 Heartbeat and stale threshold + +Defaults are tuned for a single Lakebase deployment with low-latency writes. + +| Setting | Default | Rationale | +|---|---|---| +| `heartbeat_interval_seconds` | 3.0 | Frequent enough that short pauses (GC, tokio task waits) don't trip stale detection | +| `heartbeat_stale_threshold_seconds` | 10.0 | Three missed heartbeats = unambiguously dead. Validated stale > interval at startup. | +| `task_timeout_seconds` | 3600 | Hard ceiling. After this, a stuck `in_progress` row is force-failed regardless of heartbeat. | +| `poll_interval_seconds` | 1.0 | Stream-retrieve polls the messages table at this rate while waiting for new events. | + +The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler. + +### 3.5 Claim atomicity + +```mermaid +sequenceDiagram + autonumber + participant Pod as Pod B (claimer) + participant DB as Lakebase + + Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R + DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress + Note over Pod: stale → attempt CAS + Pod->>DB: UPDATE responses
SET owner_pod_id=B, attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number + alt match + DB-->>Pod: attempt_number=2 + Note over Pod: claim succeeded + else no match (another pod beat us) + DB-->>Pod: (empty) + Note over Pod: claim lost — back off + end +``` + +Postgres row locking ensures only one of N concurrent UPDATEs matches, so at most one pod ends up owning a given resume. + +## 4. Author-side requirements + +### 4.1 Today + +| Concern | Where it lives | Author-visible? | +|---|---|---| +| Heartbeat + claim | `LongRunningAgentServer` | No | +| Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No | +| Resume input construction | `LongRunningAgentServer._collect_prior_attempt_tool_events` + sanitizer | No | +| Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No | +| Read-time repair on subsequent turns | `databricks-langchain.AsyncCheckpointSaver.aget_tuple` (wrap) / `databricks-openai.AsyncDatabricksSession.get_items` (wrap) | No — invisible inside the SDK adapters | +| Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) | + +The author's `agent.py` is unchanged from a non-durable agent. They construct an `AsyncCheckpointSaver` (LangGraph) or `AsyncDatabricksSession` (OpenAI) and use it normally. Durability fires below the SDK boundary. + +### 4.2 Settings worth exposing to authors + +- `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config. +- `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load. +- `task_timeout_seconds` — per-attempt ceiling. + +Everything else is internal. + +### 4.3 Settings authors should NOT need to override + +- `auto_sanitize_input` is true by default and should stay that way for chat UIs. +- The synthetic `[INTERRUPTED]` text (`DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT` in `tool_repair.py`) is part of the durable contract; changing it is a product decision, not a per-template knob. + +## 5. Future direction: TaskFlow + +[TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery. + +When TaskFlow is production-ready, `LongRunningAgentServer` is expected to keep its **HTTP surface and author-visible API unchanged**, swapping only the engine internals. + +### Mapping today → TaskFlow + +```mermaid +flowchart LR + subgraph TODAY[LongRunningAgentServer today] + T1[create_response + asyncio.create_task] + T2[_heartbeat async CM] + T3[_try_claim_and_resume CAS] + T4[_collect_prior_attempt_tool_events] + T5[/responses/{id}?stream=true] + T6[/_debug/kill_task] + end + + subgraph TF[TaskFlow] + F1[Taskflow.start name input user_id] + F2[built-in executor heartbeat] + F3[built-in recovery worker + claim_for_recovery] + F4[TaskHandler.recover ctx previous_events] + F5[Taskflow.subscribe key last_seq] + F6[Taskflow.simulate_crash key] + end + + T1 --> F1 + T2 --> F2 + T3 --> F3 + T4 --> F4 + T5 --> F5 + T6 --> F6 +``` + +### What stays in `LongRunningAgentServer` after the swap + +- `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas). +- The MLflow `@invoke()` / `@stream()` handler convention. +- `_rotate_conversation_id` — rotation is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. +- `_inject_conversation_id` at POST time. +- Author-visible settings (db config, heartbeat tuning, task timeout). + +### What gets deleted + +- The hand-rolled heartbeat task and CAS claim CTE — replaced by TaskFlow's executor heartbeat + `claim_for_recovery`. +- The `_try_claim_and_resume` lazy-claim path — TaskFlow's recovery worker handles this proactively and across pods, fixing the documented "claim only fires on GET" limitation. +- The `agent_server.responses` + `agent_server.messages` schema — TaskFlow owns its storage layer. +- The stream-cursor logic in `_stream_retrieve` — `Taskflow.subscribe(key, last_seq)` is the cursor-based stream resume. + +### What requires a small TaskFlow API addition + +TaskFlow derives idempotency keys as `SHA256(name + canonical_input + user_id)`. The HTTP surface uses a server-generated `resp_{uuid}`. We've requested an `idempotency_key: Option` parameter on `Taskflow.submit()` so we can keep the existing HTTP contract while submitting to TaskFlow. See [`engine.rs:317`](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/experimental/taskflow/engine/src/engine.rs?L317) for the current `generate_key` call site; the override would slot in there. + +### Migration sequencing + +1. Add `LongRunningAgentServer(backend="taskflow"|"builtin")` knob, default `"builtin"`. HTTP surface unchanged on either backend. +2. Port `agent-non-conversational` (the simplest template) to `backend="taskflow"`. Run the full crash-resume matrix. +3. Port the advanced templates. **Zero changes to `agent.py`** — the swap is the constructor argument. +4. Flip default to `"taskflow"`. Deprecate `"builtin"`. +5. Delete the heartbeat / claim / repository code. Big delete PR. + +The point of the `LongRunningAgentServer` abstraction is exactly this kind of swap: callers should never have to care which engine is underneath. + +--- + +## Quick reference + +- **Code:** `src/databricks_ai_bridge/long_running/` +- **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py` +- **Settings:** `LongRunningSettings` in `settings.py` +- **Model:** `Response` and `Message` in `models.py` +- **HTTP routes:** registered in `LongRunningAgentServer._setup_routes` +- **Read-time repair (LangGraph):** `databricks-langchain/checkpoint.py::_repair_loaded_checkpoint_tuple` +- **Read-time repair (OpenAI):** `databricks-openai/agents/session.py::AsyncDatabricksSession.get_items` +- **Shared sanitizer:** `tool_repair.py::sanitize_tool_items` From 7b9ae32d7a5edf237a8ad4bc67cd5f1ecac54030 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Tue, 28 Apr 2026 17:43:22 +0000 Subject: [PATCH 39/39] Revert "Add AGENTS.md for LongRunningAgentServer design" This reverts commit e3db5891eb154ab0bf54e02d9844d6eff04f4919. --- .../long_running/AGENTS.md | 385 ------------------ 1 file changed, 385 deletions(-) delete mode 100644 src/databricks_ai_bridge/long_running/AGENTS.md diff --git a/src/databricks_ai_bridge/long_running/AGENTS.md b/src/databricks_ai_bridge/long_running/AGENTS.md deleted file mode 100644 index 6956c473..00000000 --- a/src/databricks_ai_bridge/long_running/AGENTS.md +++ /dev/null @@ -1,385 +0,0 @@ -# LongRunningAgentServer - -Durable, crash-resumable agent execution for MLflow `ResponsesAgent` handlers. - -This document describes: -1. What `LongRunningAgentServer` does and the guarantees it gives callers ([§1](#1-what-this-module-does)). -2. The four customer journeys it covers, with sequence diagrams ([§2](#2-customer-journeys)). -3. The architecture: storage layout, claim mechanism, recovery, and stream resume ([§3](#3-architecture)). -4. Author-side requirements: what changes (and doesn't) when a handler opts into durable mode ([§4](#4-author-side-requirements)). -5. The interface today and how it's expected to evolve when [TaskFlow](https://github.com/databricks-eng/universe/tree/master/experimental/taskflow) lands ([§5](#5-future-direction-taskflow)). - -## 1. What this module does - -`LongRunningAgentServer` extends MLflow's `AgentServer` for `ResponsesAgent` handlers with three capabilities: - -1. **Background execution.** A `POST /responses` request with `background: true` returns a `response_id` immediately; the agent loop runs detached from the HTTP connection. State persists to Lakebase Postgres. -2. **Streaming retrieval.** `GET /responses/{response_id}?stream=true&starting_after=N` replays events past sequence `N` and tails new ones until the run finishes. Reconnects without losing events. -3. **Crash-resumable execution.** If the pod running an agent loop dies, another pod atomically claims the run, replays prior tool/output state, and finishes the work. Tool results that completed before the crash are preserved. - -Callers see one HTTP surface; the underlying SDK (LangGraph, OpenAI Agents, others) is opaque to the server. - -### Guarantees - -- **At-most-once durable claim.** Only one pod runs a given response at a time. The handoff uses an atomic CAS on `attempt_number`. -- **Append-only event log.** Every SSE frame is persisted to `agent_server.messages` keyed by `(response_id, attempt_number, sequence_number)`. Clients cursor-resume from `starting_after`. -- **Best-effort tool execution.** A tool call interrupted mid-flight may re-run on the resumed attempt. Idempotency is the tool author's responsibility. -- **No agent code changes required.** Templates that subclass `LongRunningAgentServer` keep using `@invoke()` / `@stream()` decorators. Repair logic lives below the handler boundary. - -### Non-goals - -- Cross-region failover. Pods are assumed to share one Lakebase. -- Tool-level checkpointing / exactly-once tool execution. -- A workflow DSL. Handlers are ordinary async generators / coroutines. - -## 2. Customer journeys - -### CUJ 1: Author writes a long-running agent - -The author subclasses `LongRunningAgentServer` and registers `@invoke()` / `@stream()` handlers like a regular MLflow agent server. **No durability code in `agent.py`.** - -```python -from databricks_ai_bridge.long_running import LongRunningAgentServer -from mlflow.genai.agent_server import invoke, stream - -agent_server = LongRunningAgentServer( - "ResponsesAgent", - db_instance_name="my-lakebase-instance", -) - -@stream() -async def stream_handler(request): - # ordinary agent code: build messages, call SDK, yield events - ... - -@invoke() -async def invoke_handler(request): - ... - -app = agent_server.app -``` - -The agent author writes their handler exactly the same way they would for the non-durable `AgentServer`. `LongRunningAgentServer` adds the durable wiring transparently. - -### CUJ 2: Pod crashes mid-tool, client polls - -A client posts a long-running request, the owning pod dies mid-tool, another pod takes over, and the client gets the final output without restarting. - -```mermaid -sequenceDiagram - autonumber - participant C as Client - participant A as Pod A (owner) - participant DB as Lakebase
(agent_server.*) - participant B as Pod B - participant SDK as SDK store
(checkpointer / session) - - C->>A: POST /responses
{input, background:true, stream:true} - A->>DB: INSERT response_id, owner=A,
original_request, attempt=1 - A-->>C: 200 {id: resp_xxx, status: in_progress} - activate A - Note over A: heartbeat loop (every 3s) - A->>DB: UPDATE heartbeat_at WHERE owner=A - - par Streaming - A->>SDK: write tool_use, tool_result - A->>DB: append events seq=0..N (attempt=1) - and Polling - C->>A: GET /responses/{id}?stream=true&starting_after=N - A-->>C: SSE events (response.output_text.delta, ...) - end - - Note over A: 💥 pod crashes mid-tool
(after tool_use, before tool_result) - deactivate A - Note over DB: heartbeat_at goes stale (>10s old) - - C->>B: GET /responses/{id}?stream=true&starting_after=K - B->>DB: SELECT heartbeat_at, attempt - Note over B: heartbeat stale → claim - B->>DB: CAS UPDATE owner=B, attempt=2
WHERE attempt=1 (atomic) - B->>DB: SELECT prior events WHERE attempt=1 - Note over B: build resume input:
· original_request.input
· + completed tool pairs
· + synthetic [INTERRUPTED] for orphan
· rotate conv_id → ::attempt-2 - B->>DB: append response.resumed sentinel - B->>B: re-invoke @stream() handler
(fresh SDK session via rotation) - activate B - B->>SDK: write to {original_id}::attempt-2 store - B->>DB: append events seq=K+1..M (attempt=2) - B-->>C: SSE events (response.resumed, then attempt-2 events from K+1) - deactivate B -``` - -**What the client observes:** a single SSE stream that may pause briefly during the heartbeat-stale window (~10s by default), then resumes. The `response.resumed` sentinel marks the boundary between attempts so the UI can render appropriately (e.g., "reconnecting…" + a fresh attempt bubble). - -**What the agent author observes:** their handler is invoked once for the original POST; a second time on resume. The second invocation's `request.input` contains the original input plus the carried-forward tool events plus a synthetic `[INTERRUPTED]` for any tool whose result didn't land. The handler doesn't have to know any of this. - -### CUJ 3: Subsequent turn after a crashed turn - -After a crash + resume, the next turn from the client lands on a fresh `POST /responses`. The SDK's storage (LangGraph checkpointer or OpenAI session) still contains the orphan `tool_use` from the crashed attempt because the SDK persisted it before the result landed. Without intervention, the next LLM call would be rejected by the provider's `tool_use → tool_result` pairing rule. - -```mermaid -sequenceDiagram - autonumber - participant C as Client - participant S as Server (any pod) - participant SDK as SDK store - - Note over SDK: state from prior turn:
[..., tool_use_X (orphan)] - - C->>S: POST /responses
{input: [..., new user msg]} - S->>S: handler runs - S->>SDK: load history (e.g. aget_tuple / get_items) - Note over S,SDK: read-time repair wraps the SDK call:
injects synthetic tool_result for any
orphan tool_use in the trailing turn - SDK-->>S: clean state (orphan paired) - S->>S: model receives protocol-valid history - Note over S: turn succeeds normally -``` - -The repair wrappers are in `databricks-langchain.AsyncCheckpointSaver.aget_tuple` and `databricks-openai.AsyncDatabricksSession.get_items`. They're idempotent and no-op on clean state. - -### CUJ 4: Multi-pod stale-claim contention - -Two pods each see the same response in `in_progress` with a stale heartbeat. Only one wins the CAS. - -```mermaid -sequenceDiagram - autonumber - participant B as Pod B - participant DB as Lakebase - participant C as Pod C - - Note over B,C: both pods see response in_progress, heartbeat stale - par - B->>DB: UPDATE responses SET owner=B, attempt=N+1
WHERE response_id=R AND attempt=N - and - C->>DB: UPDATE responses SET owner=C, attempt=N+1
WHERE response_id=R AND attempt=N - end - Note over DB: only one row matches; the other UPDATE returns 0 rows - DB-->>B: RETURNING attempt_number=N+1 - DB-->>C: RETURNING (no row) - Note over B: B wins, spawns resume handler - Note over C: C aborts cleanly, returns to its retrieve loop -``` - -The `claim_stale_response` function (`repository.py`) executes a single `UPDATE … WHERE attempt_number = :current AND ((heartbeat_at IS NULL) OR (heartbeat_at < now() - interval))` with `RETURNING`. Postgres serializes the writes; only the pod whose `current` value was unmodified at commit time gets the `RETURNING` row. - -## 3. Architecture - -### 3.1 Storage layout - -Two tables in the `agent_server` schema: - -```mermaid -erDiagram - responses { - text response_id PK - text status "in_progress / completed / failed" - timestamptz created_at - text owner_pod_id - timestamptz heartbeat_at - int attempt_number - text original_request "JSON of initial POST" - text trace_id - } - messages { - text response_id FK - int sequence_number - int attempt_number - text item "JSON of output item" - text stream_event "JSON of SSE frame" - } - responses ||--o{ messages : "has" -``` - -- `responses.attempt_number` is the CAS guard for claim atomicity. -- `messages.attempt_number` tags every event so retrieval can filter to the latest attempt's output (avoiding partial output from a crashed attempt leaking into the final response body). -- Schema migrations are idempotent (`ADD COLUMN IF NOT EXISTS`) so an existing deployment upgrades without downtime. - -### 3.2 The four key flows - -```mermaid -flowchart TD - POST["POST /responses
background=true"] --> CR[create_response
store original_request] - CR --> SPAWN[spawn @stream handler] - SPAWN --> HB[heartbeat loop
every 3s] - SPAWN --> EMIT[append SSE events
to messages table] - EMIT --> DONE{handler
exits?} - DONE -- yes --> COMPLETE[update status=completed] - DONE -- crash --> STALE[heartbeat stops
row goes stale] - - GET["GET /responses/{id}
?stream=true&starting_after=N"] --> CHECK[check heartbeat age] - CHECK -->|fresh or terminal| READ[read events from messages
where seq > N] - CHECK -->|stale| CLAIM[CAS claim
attempt += 1] - CLAIM -->|won| BUILD[collect prior events
build resume input
rotate conv_id] - CLAIM -->|lost| READ - BUILD --> SPAWN - READ --> STREAM[SSE stream to client] - - KILL["POST /_debug/kill_task/{id}
(test-only)"] --> CANCEL[cancel asyncio task
without status update] - CANCEL --> STALE -``` - -### 3.3 Resume input construction - -When a stale-claim CAS succeeds, the new owner builds the resume input from the prior attempt's emitted events: - -```mermaid -flowchart LR - PRIOR[prior attempt's events
from messages table] --> WALK[_collect_prior_attempt_tool_events] - WALK --> POOL[completed tool pairs
function_call + function_call_output] - WALK --> NARR[completed narrative
output_item.done for messages] - WALK --> PARTIAL[partial assistant text
reassembled from deltas] - - POOL --> COMPOSE[compose: original_request.input
+ tool pairs
+ synthetic [INTERRUPTED] for orphan
+ narrative
+ partial] - NARR --> COMPOSE - PARTIAL --> COMPOSE - - ROT[_rotate_conversation_id
::attempt-N suffix] --> SANITIZE - COMPOSE --> SANITIZE[_sanitize_request_input
pair orphan function_call ids] - SANITIZE --> INVOKE[re-invoke @stream handler
with rotated request] -``` - -Why the rotation: the SDK's storage may carry mid-turn state from the crashed attempt that's hard to repair without SDK-internals knowledge. Rotating to `{base}::attempt-N` opens a fresh thread/session for the resumed attempt; the structured input carries the prior work forward in a shape both LangGraph and OpenAI handle natively. - -Why the sanitizer: the trailing assistant turn always has at least one orphan `function_call` (the one whose tool was interrupted). The sanitizer pairs it with a synthetic `[INTERRUPTED]` `function_call_output` so the next LLM call is protocol-valid. - -### 3.4 Heartbeat and stale threshold - -Defaults are tuned for a single Lakebase deployment with low-latency writes. - -| Setting | Default | Rationale | -|---|---|---| -| `heartbeat_interval_seconds` | 3.0 | Frequent enough that short pauses (GC, tokio task waits) don't trip stale detection | -| `heartbeat_stale_threshold_seconds` | 10.0 | Three missed heartbeats = unambiguously dead. Validated stale > interval at startup. | -| `task_timeout_seconds` | 3600 | Hard ceiling. After this, a stuck `in_progress` row is force-failed regardless of heartbeat. | -| `poll_interval_seconds` | 1.0 | Stream-retrieve polls the messages table at this rate while waiting for new events. | - -The stale threshold also applies as a grace period for newly-created responses that haven't written their first heartbeat yet — protects against an over-eager retrieve hijacking a still-starting handler. - -### 3.5 Claim atomicity - -```mermaid -sequenceDiagram - autonumber - participant Pod as Pod B (claimer) - participant DB as Lakebase - - Pod->>DB: SELECT attempt_number, heartbeat_at, status
FROM responses WHERE response_id=R - DB-->>Pod: attempt=1, heartbeat_at=12s ago, status=in_progress - Note over Pod: stale → attempt CAS - Pod->>DB: UPDATE responses
SET owner_pod_id=B, attempt_number=2, heartbeat_at=now()
WHERE response_id=R AND attempt_number=1
AND (heartbeat_at IS NULL OR heartbeat_at < now() - interval '10s')
RETURNING attempt_number - alt match - DB-->>Pod: attempt_number=2 - Note over Pod: claim succeeded - else no match (another pod beat us) - DB-->>Pod: (empty) - Note over Pod: claim lost — back off - end -``` - -Postgres row locking ensures only one of N concurrent UPDATEs matches, so at most one pod ends up owning a given resume. - -## 4. Author-side requirements - -### 4.1 Today - -| Concern | Where it lives | Author-visible? | -|---|---|---| -| Heartbeat + claim | `LongRunningAgentServer` | No | -| Conversation_id rotation | `LongRunningAgentServer._rotate_conversation_id` | No | -| Resume input construction | `LongRunningAgentServer._collect_prior_attempt_tool_events` + sanitizer | No | -| Stream resume cursor | `LongRunningAgentServer._stream_retrieve` | No | -| Read-time repair on subsequent turns | `databricks-langchain.AsyncCheckpointSaver.aget_tuple` (wrap) / `databricks-openai.AsyncDatabricksSession.get_items` (wrap) | No — invisible inside the SDK adapters | -| Tool/SDK selection | `agent.py` | Yes (this is the author's actual code) | - -The author's `agent.py` is unchanged from a non-durable agent. They construct an `AsyncCheckpointSaver` (LangGraph) or `AsyncDatabricksSession` (OpenAI) and use it normally. Durability fires below the SDK boundary. - -### 4.2 Settings worth exposing to authors - -- `db_instance_name` / `db_autoscaling_endpoint` / `db_project` + `db_branch` — Lakebase connection config. -- `heartbeat_interval_seconds` / `heartbeat_stale_threshold_seconds` — for tuning under heavy load. -- `task_timeout_seconds` — per-attempt ceiling. - -Everything else is internal. - -### 4.3 Settings authors should NOT need to override - -- `auto_sanitize_input` is true by default and should stay that way for chat UIs. -- The synthetic `[INTERRUPTED]` text (`DEFAULT_SYNTHETIC_INTERRUPTED_OUTPUT` in `tool_repair.py`) is part of the durable contract; changing it is a product decision, not a per-template knob. - -## 5. Future direction: TaskFlow - -[TaskFlow](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/tree/experimental/taskflow) is a Rust-core durable-task engine being built in `experimental/taskflow`. It provides exactly the primitives `LongRunningAgentServer` hand-rolls today (heartbeat, CAS claim, recovery worker, event log with stream resume) — but as a library with WAL-first durability and proactive (not lazy-on-GET) recovery. - -When TaskFlow is production-ready, `LongRunningAgentServer` is expected to keep its **HTTP surface and author-visible API unchanged**, swapping only the engine internals. - -### Mapping today → TaskFlow - -```mermaid -flowchart LR - subgraph TODAY[LongRunningAgentServer today] - T1[create_response + asyncio.create_task] - T2[_heartbeat async CM] - T3[_try_claim_and_resume CAS] - T4[_collect_prior_attempt_tool_events] - T5[/responses/{id}?stream=true] - T6[/_debug/kill_task] - end - - subgraph TF[TaskFlow] - F1[Taskflow.start name input user_id] - F2[built-in executor heartbeat] - F3[built-in recovery worker + claim_for_recovery] - F4[TaskHandler.recover ctx previous_events] - F5[Taskflow.subscribe key last_seq] - F6[Taskflow.simulate_crash key] - end - - T1 --> F1 - T2 --> F2 - T3 --> F3 - T4 --> F4 - T5 --> F5 - T6 --> F6 -``` - -### What stays in `LongRunningAgentServer` after the swap - -- `POST /responses` / `GET /responses/{id}` HTTP routes (and their schemas). -- The MLflow `@invoke()` / `@stream()` handler convention. -- `_rotate_conversation_id` — rotation is handler-policy, not engine-internals; lives in the adapter that bridges MLflow handlers to TaskFlow's `recover()`. -- `_inject_conversation_id` at POST time. -- Author-visible settings (db config, heartbeat tuning, task timeout). - -### What gets deleted - -- The hand-rolled heartbeat task and CAS claim CTE — replaced by TaskFlow's executor heartbeat + `claim_for_recovery`. -- The `_try_claim_and_resume` lazy-claim path — TaskFlow's recovery worker handles this proactively and across pods, fixing the documented "claim only fires on GET" limitation. -- The `agent_server.responses` + `agent_server.messages` schema — TaskFlow owns its storage layer. -- The stream-cursor logic in `_stream_retrieve` — `Taskflow.subscribe(key, last_seq)` is the cursor-based stream resume. - -### What requires a small TaskFlow API addition - -TaskFlow derives idempotency keys as `SHA256(name + canonical_input + user_id)`. The HTTP surface uses a server-generated `resp_{uuid}`. We've requested an `idempotency_key: Option` parameter on `Taskflow.submit()` so we can keep the existing HTTP contract while submitting to TaskFlow. See [`engine.rs:317`](https://sourcegraph.prod.databricks-corp.com/databricks-eng/universe/-/blob/experimental/taskflow/engine/src/engine.rs?L317) for the current `generate_key` call site; the override would slot in there. - -### Migration sequencing - -1. Add `LongRunningAgentServer(backend="taskflow"|"builtin")` knob, default `"builtin"`. HTTP surface unchanged on either backend. -2. Port `agent-non-conversational` (the simplest template) to `backend="taskflow"`. Run the full crash-resume matrix. -3. Port the advanced templates. **Zero changes to `agent.py`** — the swap is the constructor argument. -4. Flip default to `"taskflow"`. Deprecate `"builtin"`. -5. Delete the heartbeat / claim / repository code. Big delete PR. - -The point of the `LongRunningAgentServer` abstraction is exactly this kind of swap: callers should never have to care which engine is underneath. - ---- - -## Quick reference - -- **Code:** `src/databricks_ai_bridge/long_running/` -- **Tests:** `tests/databricks_ai_bridge/test_long_running_server.py`, `test_long_running_db.py` -- **Settings:** `LongRunningSettings` in `settings.py` -- **Model:** `Response` and `Message` in `models.py` -- **HTTP routes:** registered in `LongRunningAgentServer._setup_routes` -- **Read-time repair (LangGraph):** `databricks-langchain/checkpoint.py::_repair_loaded_checkpoint_tuple` -- **Read-time repair (OpenAI):** `databricks-openai/agents/session.py::AsyncDatabricksSession.get_items` -- **Shared sanitizer:** `tool_repair.py::sanitize_tool_items`