diff --git a/src/tools/dispatch.py b/src/tools/dispatch.py index 68e8cd7..eb70751 100644 --- a/src/tools/dispatch.py +++ b/src/tools/dispatch.py @@ -48,6 +48,23 @@ logger = logging.getLogger(__name__) AGENT_HUMAN_ONLY_PLACEHOLDER = "[human-only content hidden]" + +# Per-thread asyncio.Event registry for event-driven msg_wait wake-ups. +# When msg_post succeeds, the corresponding event is set so that all waiters +# on that thread wake up immediately instead of waiting for the 1s poll tick. +# Keys are thread_id strings. Access is safe within a single asyncio event loop +# (SSE mode — single uvicorn process). The 1s asyncio.sleep fallback in _poll() +# ensures correctness even if an event is missed (e.g. during hot-module reload). +_thread_events: dict[str, asyncio.Event] = {} + + +def _get_thread_event(thread_id: str) -> asyncio.Event: + """Return (or create) the asyncio.Event for a given thread_id.""" + if thread_id not in _thread_events: + _thread_events[thread_id] = asyncio.Event() + return _thread_events[thread_id] + + AGENT_HUMAN_ONLY_METADATA_KEYS = { "visibility", "audience", @@ -794,6 +811,12 @@ async def handle_msg_post(db, arguments: dict[str, Any]) -> list[types.TextConte result["handoff_target"] = meta["handoff_target"] if ENABLE_STOP_REASON and meta.get("stop_reason"): result["stop_reason"] = meta["stop_reason"] + + # Notify any msg_wait callers on this thread that a new message is available. + # This allows event-driven wake-up instead of waiting for the 1s poll tick. + if thread_id in _thread_events: + _thread_events[thread_id].set() + return [types.TextContent(type="text", text=json.dumps(result))] def _filter_metadata_fields(meta_str: str | None) -> str | None: @@ -1162,6 +1185,7 @@ async def _refresh_heartbeat() -> None: async def _poll(): last_heartbeat = asyncio.get_event_loop().time() local_after_seq = after_seq + event = _get_thread_event(thread_id) while True: raw_msgs = await crud.msg_list(db, thread_id, after_seq=local_after_seq, include_system_prompt=False) msgs = _project_messages_for_agent(raw_msgs) @@ -1208,7 +1232,15 @@ async def _poll(): await _refresh_heartbeat() last_heartbeat = now - await asyncio.sleep(1.0) + # Event-driven wake-up: wait for msg_post to signal this thread's event. + # Falls back to a 1s timeout so correctness is preserved even if the + # event fires before we start waiting (spurious-wakeup-safe: the outer + # while-True loop re-checks crud.msg_list after every wake-up). + event.clear() + try: + await asyncio.wait_for(event.wait(), timeout=1.0) + except asyncio.TimeoutError: + pass try: msgs = await asyncio.wait_for(_poll(), timeout=timeout_s) diff --git a/tests/test_timeout_handling.py b/tests/test_timeout_handling.py index 56b6dc3..60d666b 100644 --- a/tests/test_timeout_handling.py +++ b/tests/test_timeout_handling.py @@ -73,7 +73,7 @@ def custom_showwarning(self, message, category, filename, lineno, file=None, lin @pytest.mark.asyncio async def test_api_threads_timeout_on_get_db(): """Test that API returns 503 when get_db() times out.""" - with patch("asyncio.wait_for") as mock_wait_for: + with patch("src.main.asyncio.wait_for") as mock_wait_for: # First call to wait_for (get_db) times out mock_wait_for.side_effect = asyncio.TimeoutError() @@ -98,7 +98,7 @@ async def mock_wait_for_impl(coro, timeout): else: raise asyncio.TimeoutError() - with patch("asyncio.wait_for", side_effect=mock_wait_for_impl): + with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl): try: await api_threads() pytest.fail("Expected HTTPException with 503") @@ -113,7 +113,7 @@ async def test_api_agents_timeout(): async def mock_wait_for_impl(coro, timeout): raise asyncio.TimeoutError() - with patch("asyncio.wait_for", side_effect=mock_wait_for_impl): + with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl): try: await api_agents() pytest.fail("Expected HTTPException with 503") @@ -129,7 +129,6 @@ async def mock_wait_for_impl(coro, timeout): @pytest.mark.asyncio async def test_api_threads_success(): """Test successful thread listing with no timeout.""" - mock_db = AsyncMock() import datetime now = datetime.datetime.now() @@ -145,16 +144,12 @@ async def test_api_threads_success(): ) ] - async def mock_wait_for_get_db(coro, timeout): - return mock_db - - async def mock_gather(*coros): - return (mock_threads, len(mock_threads)) + mock_db = AsyncMock() - with patch("asyncio.wait_for", side_effect=mock_wait_for_get_db), \ - patch("asyncio.gather", side_effect=mock_gather): - # Since api_threads is an async function that returns an envelope dict, - # we need to test the actual return value + with patch("src.main.get_db", return_value=mock_db), \ + patch("src.main.crud.thread_list", new=AsyncMock(return_value=mock_threads)), \ + patch("src.main.crud.thread_count", new=AsyncMock(return_value=len(mock_threads))), \ + patch("src.main.crud.threads_agents_map", new=AsyncMock(return_value={})): result = await api_threads() # Verify result is an envelope dict with expected structure (UP-20) @@ -195,7 +190,7 @@ async def mock_wait_for_impl(coro, timeout): # Return mock_agents for agent_list calls return mock_agents - with patch("asyncio.wait_for", side_effect=mock_wait_for_impl): + with patch("src.main.asyncio.wait_for", side_effect=mock_wait_for_impl): result = await api_agents() assert isinstance(result, list)