Skip to content
Merged
456 changes: 456 additions & 0 deletions src/databricks_ai_bridge/long_running/AGENTS.md

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/databricks_ai_bridge/long_running/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,50 @@ def _set_statement_timeout(dbapi_conn, connection_record, connection_proxy):
await conn.execute(text(f"CREATE SCHEMA IF NOT EXISTS {AGENT_DB_SCHEMA}"))
await conn.run_sync(Base.metadata.create_all)

# Idempotent migration for tables created by earlier versions: add any
# columns introduced for durable-resume support. Each statement runs in
# its own transaction so an InsufficientPrivilege on one ALTER (another
# pod's SP owns the table but the schema is already migrated) doesn't
# poison the rest. A single mega-transaction would abort entirely on the
# first owner-check failure even with IF NOT EXISTS.
migration_stmts = (
f"ALTER TABLE {AGENT_DB_SCHEMA}.responses "
"ADD COLUMN IF NOT EXISTS heartbeat_at TIMESTAMPTZ",
f"ALTER TABLE {AGENT_DB_SCHEMA}.responses "
"ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1",
f"ALTER TABLE {AGENT_DB_SCHEMA}.responses ADD COLUMN IF NOT EXISTS original_request TEXT",
f"ALTER TABLE {AGENT_DB_SCHEMA}.messages "
"ADD COLUMN IF NOT EXISTS attempt_number INTEGER NOT NULL DEFAULT 1",
f"CREATE INDEX IF NOT EXISTS idx_responses_stale "
f"ON {AGENT_DB_SCHEMA}.responses (status, heartbeat_at) "
"WHERE status = 'in_progress'",
)
skipped_migrations: list[str] = []
for stmt in migration_stmts:
try:
async with _engine.begin() as conn:
await conn.execute(text(stmt))
except Exception as exc:
msg = str(exc).lower()
if "insufficientprivilege" in msg or "must be owner" in msg:
skipped_migrations.append(stmt.split("\n")[0])
continue
raise

_initialized = True
if skipped_migrations:
# WARN-level summary: if the DB was previously migrated by another SP
# this is fine, but if it's genuinely a new table and our SP lacks
# ALTER, claim/heartbeat queries will fail later with a confusing
# "column does not exist" — surface it clearly at startup.
logger.warning(
"[DB] Skipped %d durability migration(s) due to insufficient "
"privilege — assuming table was already migrated by another "
"service principal. Crash-resume will fail with 'column does "
"not exist' if this assumption is wrong. Skipped: %s",
len(skipped_migrations),
", ".join(skipped_migrations),
)
logger.info("[DB] Engine and schema ready")


Expand Down
24 changes: 22 additions & 2 deletions src/databricks_ai_bridge/long_running/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ class Base(DeclarativeBase):


class Response(Base):
"""Response status tracking for background agent tasks."""
"""Response status tracking for background agent tasks.

Durability columns (``heartbeat_at``, ``attempt_number``,
``original_request``) support crash-resume: another pod atomically
claims a stale in-progress row by CAS-ing on ``attempt_number`` and
replays the agent loop. The owning pod is implicit — it's whatever
pod last successfully heartbeat at the current attempt_number.
"""

__tablename__ = "responses"
__table_args__ = {"schema": AGENT_DB_SCHEMA}
Expand All @@ -25,12 +32,22 @@ class Response(Base):
DateTime(timezone=True), nullable=False, server_default=func.now()
)
trace_id: Mapped[str | None] = mapped_column(Text, nullable=True)
heartbeat_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
attempt_number: Mapped[int] = mapped_column(
Integer, nullable=False, server_default="1", default=1
)
original_request: Mapped[str | None] = mapped_column(Text, nullable=True)

messages = relationship("Message", back_populates="response", cascade="all, delete-orphan")


class Message(Base):
"""Stream events and output items for a response."""
"""Stream events and output items for a response.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for myself to come back later


``attempt_number`` tags events by which run attempt emitted them so that
resumed runs append to the same event log without overwriting earlier
(abandoned) attempts, and retrieve can filter to the latest attempt only.
"""

__tablename__ = "messages"
__table_args__ = {"schema": AGENT_DB_SCHEMA}
Expand All @@ -44,6 +61,9 @@ class Message(Base):
sequence_number: Mapped[int] = mapped_column(
Integer, primary_key=True, nullable=False, default=0
)
attempt_number: Mapped[int] = mapped_column(
Integer, nullable=False, server_default="1", default=1
)
item: Mapped[str | None] = mapped_column(Text, nullable=True)
stream_event: Mapped[str | None] = mapped_column(Text, nullable=True)

Expand Down
173 changes: 162 additions & 11 deletions src/databricks_ai_bridge/long_running/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,64 @@
from typing import Any, NamedTuple

from sqlalchemy import select, update
from sqlalchemy.sql import bindparam, text

from databricks_ai_bridge.long_running.db import session_scope
from databricks_ai_bridge.long_running.models import Message, Response
from databricks_ai_bridge.long_running.models import AGENT_DB_SCHEMA, Message, Response


async def create_response(response_id: str, status: str) -> None:
"""Insert a new response."""
async def create_response(
response_id: str,
status: str,
*,
durable: bool = False,
original_request: dict[str, Any] | None = None,
) -> None:
"""Insert a new response row.

When ``durable=True``, ``heartbeat_at`` is initialized to ``now()`` so
the row doesn't immediately look stale. Non-durable callers (tests,
legacy flows) skip the heartbeat init.
"""
async with session_scope() as session:
session.add(Response(response_id=response_id, status=status))
session.add(
Response(
response_id=response_id,
status=status,
heartbeat_at=datetime.now().astimezone() if durable else None,
original_request=(
json.dumps(original_request) if original_request is not None else None
),
)
)
await session.commit()


async def update_response_status(
response_id: str, status: str, *, expected_current_status: str | None = None
response_id: str,
status: str,
*,
expected_current_status: str | None = None,
expected_attempt_number: int | None = None,
) -> bool:
"""Update response status. Returns True if a row was updated.

If *expected_current_status* is given the update only takes effect when the
row's current status matches, avoiding concurrent-update races.

If *expected_attempt_number* is given the update only takes effect when the
row's current ``attempt_number`` matches, ensuring only the pod that owns
the current attempt can transition the row to a terminal state. This
prevents a stale background task (e.g. a deferred-fail timer that fired
after another pod claimed the row for resume) from clobbering the new
owner's in-progress state.
"""
async with session_scope() as session:
stmt = update(Response).where(Response.response_id == response_id)
if expected_current_status is not None:
stmt = stmt.where(Response.status == expected_current_status)
if expected_attempt_number is not None:
stmt = stmt.where(Response.attempt_number == expected_attempt_number)
stmt = stmt.values(status=status)
result = await session.execute(stmt)
await session.commit()
Expand All @@ -43,18 +77,120 @@ async def update_response_trace_id(response_id: str, trace_id: str) -> None:
await session.commit()


async def heartbeat_response(response_id: str, expected_attempt_number: int) -> bool:
"""Update heartbeat_at for a response IFF the attempt is still ours.

Returns True on success. A False result means the claim has been lost —
another pod CAS-bumped ``attempt_number``, so this pod is no longer the
owner and the heartbeat task should stop. Implicit-ownership model:
whichever pod last successfully heartbeats at the current
``attempt_number`` is the de facto owner.
"""
async with session_scope() as session:
stmt = (
update(Response)
.where(
Response.response_id == response_id,
Response.attempt_number == expected_attempt_number,
)
.values(heartbeat_at=datetime.now().astimezone())
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0


async def claim_stale_response(
response_id: str,
stale_threshold_seconds: float,
) -> int | None:
"""Atomically claim an in-progress response whose heartbeat has gone stale.

Uses a single conditional UPDATE so exactly one caller wins on contention:
claim only succeeds if status is ``in_progress`` AND
(``heartbeat_at IS NULL`` OR ``heartbeat_at`` is older than the threshold).
The new attempt_number is the previous + 1; the prior attempt's heartbeat
task will detect this on its next heartbeat (rowcount=0) and stop.

Returns the new ``attempt_number`` on success, or ``None`` if the row did
not satisfy the claim conditions (already completed, heartbeat still fresh,
or nonexistent).
"""
# Raw SQL because SQLAlchemy's ORM-level update doesn't expose RETURNING for
# the incremented column as ergonomically. Using a single statement keeps the
# claim atomic without an explicit transaction-level lock.
stmt = text(
f"""
UPDATE {AGENT_DB_SCHEMA}.responses
SET heartbeat_at = now(),
attempt_number = attempt_number + 1
WHERE response_id = :rid
AND status = 'in_progress'
AND (heartbeat_at IS NULL
OR heartbeat_at < now() - make_interval(secs => :threshold))
RETURNING attempt_number
"""
).bindparams(
bindparam("rid", type_=None),
bindparam("threshold", type_=None),
)
async with session_scope() as session:
result = await session.execute(
stmt,
{"rid": response_id, "threshold": stale_threshold_seconds},
)
row = result.first()
await session.commit()
return int(row[0]) if row else None


async def find_stale_response_ids(
stale_threshold_seconds: float,
limit: int = 50,
) -> list[str]:
"""Return ids of in_progress responses whose heartbeat is older than the
threshold. Used by the proactive scanner to find candidates for resume
without waiting for a client GET.

Limited to ``limit`` rows per scan to bound DB load. Ordered by
``heartbeat_at`` ascending so the oldest staleness is handled first.
"""
stmt = text(
f"""
SELECT response_id FROM {AGENT_DB_SCHEMA}.responses
WHERE status = 'in_progress'
AND heartbeat_at IS NOT NULL
AND heartbeat_at < now() - make_interval(secs => :threshold)
ORDER BY heartbeat_at ASC
LIMIT :limit
"""
).bindparams(
bindparam("threshold", type_=None),
bindparam("limit", type_=None),
)
async with session_scope() as session:
result = await session.execute(
stmt,
{"threshold": stale_threshold_seconds, "limit": limit},
)
return [row[0] for row in result.all()]


async def append_message(
response_id: str,
sequence_number: int,
item: str | None = None,
stream_event: dict[str, Any] | None = None,
*,
attempt_number: int = 1,
) -> None:
"""Append a message (stream event) for a response."""
"""Append a message (stream event) for a response, tagged with attempt_number."""
async with session_scope() as session:
session.add(
Message(
response_id=response_id,
sequence_number=sequence_number,
attempt_number=attempt_number,
item=item,
stream_event=json.dumps(stream_event) if stream_event is not None else None,
)
Expand All @@ -65,22 +201,26 @@ async def append_message(
async def get_messages(
response_id: str,
after_sequence: int | None = None,
) -> list[tuple[int, str | None, dict[str, Any] | None]]:
"""Fetch messages for a response, optionally after a sequence number.
*,
attempt_number: int | None = None,
) -> list[tuple[int, str | None, dict[str, Any] | None, int]]:
"""Fetch messages for a response, optionally filtering by sequence / attempt.

Returns list of (sequence_number, item, stream_event_dict).
Returns list of ``(sequence_number, item, stream_event_dict, attempt_number)``.
"""
async with session_scope() as session:
stmt = select(Message).where(Message.response_id == response_id)
if after_sequence is not None:
stmt = stmt.where(Message.sequence_number > after_sequence)
if attempt_number is not None:
stmt = stmt.where(Message.attempt_number == attempt_number)
stmt = stmt.order_by(Message.sequence_number)
result = await session.execute(stmt)
rows = result.scalars().all()
out = []
for r in rows:
evt = json.loads(r.stream_event) if r.stream_event else None
out.append((r.sequence_number, r.item, evt))
out.append((r.sequence_number, r.item, evt, r.attempt_number))
return out


Expand All @@ -89,6 +229,9 @@ class ResponseInfo(NamedTuple):
status: str
created_at: datetime
trace_id: str | None
heartbeat_at: datetime | None
attempt_number: int
original_request: dict[str, Any] | None


async def get_response(response_id: str) -> ResponseInfo | None:
Expand All @@ -97,5 +240,13 @@ async def get_response(response_id: str) -> ResponseInfo | None:
result = await session.execute(select(Response).where(Response.response_id == response_id))
row = result.scalar_one_or_none()
if row:
return ResponseInfo(row.response_id, row.status, row.created_at, row.trace_id)
return ResponseInfo(
row.response_id,
row.status,
row.created_at,
row.trace_id,
row.heartbeat_at,
row.attempt_number,
json.loads(row.original_request) if row.original_request else None,
)
return None
Loading
Loading