Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/tools/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,23 @@

logger = logging.getLogger(__name__)
AGENT_HUMAN_ONLY_PLACEHOLDER = "[human-only content hidden]"

# Per-thread asyncio.Event registry for event-driven msg_wait wake-ups.
# When msg_post succeeds, the corresponding event is set so that all waiters
# on that thread wake up immediately instead of waiting for the 1s poll tick.
# Keys are thread_id strings. Access is safe within a single asyncio event loop
# (SSE mode — single uvicorn process). The 1s asyncio.sleep fallback in _poll()
# ensures correctness even if an event is missed (e.g. during hot-module reload).
_thread_events: dict[str, asyncio.Event] = {}


def _get_thread_event(thread_id: str) -> asyncio.Event:
"""Return (or create) the asyncio.Event for a given thread_id."""
if thread_id not in _thread_events:
_thread_events[thread_id] = asyncio.Event()
return _thread_events[thread_id]


AGENT_HUMAN_ONLY_METADATA_KEYS = {
"visibility",
"audience",
Expand Down Expand Up @@ -794,6 +811,12 @@ async def handle_msg_post(db, arguments: dict[str, Any]) -> list[types.TextConte
result["handoff_target"] = meta["handoff_target"]
if ENABLE_STOP_REASON and meta.get("stop_reason"):
result["stop_reason"] = meta["stop_reason"]

# Notify any msg_wait callers on this thread that a new message is available.
# This allows event-driven wake-up instead of waiting for the 1s poll tick.
if thread_id in _thread_events:
_thread_events[thread_id].set()

return [types.TextContent(type="text", text=json.dumps(result))]

def _filter_metadata_fields(meta_str: str | None) -> str | None:
Expand Down Expand Up @@ -1162,6 +1185,7 @@ async def _refresh_heartbeat() -> None:
async def _poll():
last_heartbeat = asyncio.get_event_loop().time()
local_after_seq = after_seq
event = _get_thread_event(thread_id)
while True:
raw_msgs = await crud.msg_list(db, thread_id, after_seq=local_after_seq, include_system_prompt=False)
msgs = _project_messages_for_agent(raw_msgs)
Expand Down Expand Up @@ -1208,7 +1232,15 @@ async def _poll():
await _refresh_heartbeat()
last_heartbeat = now

await asyncio.sleep(1.0)
# Event-driven wake-up: wait for msg_post to signal this thread's event.
# Falls back to a 1s timeout so correctness is preserved even if the
# event fires before we start waiting (spurious-wakeup-safe: the outer
# while-True loop re-checks crud.msg_list after every wake-up).
event.clear()
try:
await asyncio.wait_for(event.wait(), timeout=1.0)
except asyncio.TimeoutError:
pass

try:
msgs = await asyncio.wait_for(_poll(), timeout=timeout_s)
Expand Down
23 changes: 9 additions & 14 deletions tests/test_timeout_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def custom_showwarning(self, message, category, filename, lineno, file=None, lin
@pytest.mark.asyncio
async def test_api_threads_timeout_on_get_db():
"""Test that API returns 503 when get_db() times out."""
with patch("asyncio.wait_for") as mock_wait_for:
with patch("src.main.asyncio.wait_for") as mock_wait_for:
# First call to wait_for (get_db) times out
mock_wait_for.side_effect = asyncio.TimeoutError()

Expand All @@ -98,7 +98,7 @@ async def mock_wait_for_impl(coro, timeout):
else:
raise asyncio.TimeoutError()

with patch("asyncio.wait_for", side_effect=mock_wait_for_impl):
with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl):
try:
await api_threads()
pytest.fail("Expected HTTPException with 503")
Expand All @@ -113,7 +113,7 @@ async def test_api_agents_timeout():
async def mock_wait_for_impl(coro, timeout):
raise asyncio.TimeoutError()

with patch("asyncio.wait_for", side_effect=mock_wait_for_impl):
with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl):
try:
await api_agents()
pytest.fail("Expected HTTPException with 503")
Expand All @@ -129,7 +129,6 @@ async def mock_wait_for_impl(coro, timeout):
@pytest.mark.asyncio
async def test_api_threads_success():
"""Test successful thread listing with no timeout."""
mock_db = AsyncMock()
import datetime
now = datetime.datetime.now()

Expand All @@ -145,16 +144,12 @@ async def test_api_threads_success():
)
]

async def mock_wait_for_get_db(coro, timeout):
return mock_db

async def mock_gather(*coros):
return (mock_threads, len(mock_threads))
mock_db = AsyncMock()

with patch("asyncio.wait_for", side_effect=mock_wait_for_get_db), \
patch("asyncio.gather", side_effect=mock_gather):
# Since api_threads is an async function that returns an envelope dict,
# we need to test the actual return value
with patch("src.main.get_db", return_value=mock_db), \
patch("src.main.crud.thread_list", new=AsyncMock(return_value=mock_threads)), \
patch("src.main.crud.thread_count", new=AsyncMock(return_value=len(mock_threads))), \
patch("src.main.crud.threads_agents_map", new=AsyncMock(return_value={})):
result = await api_threads()

# Verify result is an envelope dict with expected structure (UP-20)
Expand Down Expand Up @@ -195,7 +190,7 @@ async def mock_wait_for_impl(coro, timeout):
# Return mock_agents for agent_list calls
return mock_agents

with patch("asyncio.wait_for", side_effect=mock_wait_for_impl):
with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl):
result = await api_agents()

assert isinstance(result, list)
Expand Down