diff --git a/README.md b/README.md index 16421b8..484accf 100644 --- a/README.md +++ b/README.md @@ -185,6 +185,7 @@ print(f"Replay with: agent-strace replay {meta.session_id}") | `inflation` | Token inflation across model versions | | `curve` | Personal cost-efficiency curve | | `a2a-tree` | Cross-agent trace correlation (A2A protocol) | +| `mcp` | MCP server — expose traces as queryable tools for a debugging agent | ``` agent-strace setup [--redact] [--global] Generate Claude Code hooks config @@ -1204,6 +1205,75 @@ agent-strace export --format otlp > trace.json | event_id | span ID | | parent_id | parent span ID | +## Debug with MCP + +`agent-strace mcp` starts an MCP server that exposes your session store as queryable tools. Any MCP-compatible client (Claude Code, Cursor, VS Code Copilot) can then query traces conversationally — the debugging agent reads its own execution history and surfaces what went wrong. + +```bash +agent-strace mcp +``` + +**Claude Code config** (`.claude/settings.json`): + +```json +{ + "mcpServers": { + "agent-trace": { + "command": "agent-strace", + "args": ["mcp"] + } + } +} +``` + +**Cursor config** (`.cursor/mcp.json`): + +```json +{ + "mcpServers": { + "agent-trace": { + "command": "agent-strace", + "args": ["mcp"] + } + } +} +``` + +Once connected, you can ask the debugging agent questions like: + +> "Look at the most recent session and tell me why it called bash three times in a row." +> "Which files did the agent write in session abc123 that it didn't write in def456?" +> "Find all sessions where the agent hit an error after calling npm test." + +### MCP tools + +| Tool | Description | +|---|---| +| `list_sessions` | List captured sessions with metadata (timestamp, tool calls, cost, tokens) | +| `get_session` | Full event stream for a session, with optional event type filter | +| `search_events` | Filter events by tool name, file path, exit code, or error flag across sessions | +| `get_session_summary` | Plain-English phase breakdown — what the agent did, files touched, retries | +| `diff_sessions` | Compare two sessions: tool call delta, file overlap, cost delta, error delta | + +### Example interactions + +``` +# List recent sessions +list_sessions(limit=5) + +# Get all errors from a session +search_events(session_id="abc123", has_error=true) + +# Find all sessions where the agent wrote to package-lock.json +search_events(file_path="package-lock.json") + +# Compare two sessions after changing AGENTS.md +diff_sessions(session_a="before_change", session_b="after_change") + +# Get a plain-English summary of what went wrong +get_session_summary(session_id="abc123") +``` + ## How it works ### Claude Code hooks diff --git a/examples/ci/agent-eval.yml b/examples/ci/agent-eval.yml new file mode 100644 index 0000000..32882e7 --- /dev/null +++ b/examples/ci/agent-eval.yml @@ -0,0 +1,100 @@ +# Agent eval CI workflow +# +# Runs eval scorers on every PR that touches agent config files. +# Fails the PR if any scorer drops below its threshold. +# Posts a score summary as a PR comment. +# +# Prerequisites: +# 1. Capture at least one session: agent-strace record -- +# 2. Save a baseline: agent-strace eval ci --save-baseline .agent-traces/baselines/main.json +# 3. Commit .agent-evals.yaml and .agent-traces/baselines/main.json to the repo + +name: Agent eval + +on: + pull_request: + paths: + - "AGENTS.md" + - "CLAUDE.md" + - ".claude/**" + - ".agent-evals.yaml" + - ".agent-traces/datasets/**" + +jobs: + eval: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install agent-strace + run: pip install agent-strace + + # Score the latest session in the dataset against all configured scorers. + # Exits 1 if any scorer is below threshold or regresses vs baseline. + - name: Run eval + env: + # Required only if using the llm_judge scorer. + # Remove if using heuristic scorers only (no_errors, cost_under, etc.) + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + run: | + agent-strace eval ci \ + --baseline .agent-traces/baselines/main.json \ + --tolerance 0.05 \ + --github-summary + + # Post the Markdown summary as a PR comment so reviewers see the score delta. + - name: Post eval summary + if: always() + uses: actions/github-script@v7 + with: + script: | + const fs = require('fs'); + const summaryPath = '.agent-traces/eval-summary.md'; + if (!fs.existsSync(summaryPath)) { + console.log('No eval summary found — skipping comment.'); + return; + } + const summary = fs.readFileSync(summaryPath, 'utf8'); + await github.rest.issues.createComment({ + issue_number: context.issue.number, + owner: context.repo.owner, + repo: context.repo.repo, + body: summary, + }); + + # Optional: update the baseline on every merge to main. + # Commit the updated baseline back to the repo so future PRs compare against it. + update-baseline: + if: github.event_name == 'push' && github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install agent-strace + run: pip install agent-strace + + - name: Save new baseline + run: | + mkdir -p .agent-traces/baselines + agent-strace eval ci \ + --save-baseline .agent-traces/baselines/main.json + + - name: Commit updated baseline + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + git add .agent-traces/baselines/main.json + git diff --staged --quiet || git commit -m "chore: update eval baseline [skip ci]" + git push diff --git a/src/agent_trace/cli.py b/src/agent_trace/cli.py index 29721aa..7bf7b5b 100644 --- a/src/agent_trace/cli.py +++ b/src/agent_trace/cli.py @@ -23,6 +23,7 @@ from .hooks import hook_main from .http_proxy import HTTPProxyServer from .a2a import cmd_a2a_tree +from .mcp_server import cmd_mcp from .annotate import cmd_annotate from .drift import cmd_drift from .langfuse_export import cmd_export_scores @@ -747,6 +748,18 @@ def build_parser() -> argparse.ArgumentParser: p_standup.add_argument("--no-llm", action="store_true", dest="no_llm", help="structured output only, no LLM narrative (default)") + # mcp (MCP server — expose traces as queryable tools) + p_mcp = sub.add_parser( + "mcp", + help="start an MCP server that exposes session traces as queryable tools", + ) + p_mcp.add_argument( + "--transport", + choices=["stdio"], + default="stdio", + help="transport protocol (default: stdio)", + ) + # diff --semantic and --eval-config flags (extend existing diff parser) p_diff.add_argument("--semantic", action="store_true", help="semantic outcome-level diff (files, cost, errors)") @@ -806,6 +819,7 @@ def main() -> None: "oncall": cmd_oncall, "freshness": cmd_freshness, "standup": cmd_standup, + "mcp": cmd_mcp, } handler = handlers.get(args.command) diff --git a/src/agent_trace/mcp_server.py b/src/agent_trace/mcp_server.py new file mode 100644 index 0000000..579a588 --- /dev/null +++ b/src/agent_trace/mcp_server.py @@ -0,0 +1,499 @@ +"""MCP server — expose agent-trace session store as queryable MCP tools. + +Implements the Model Context Protocol over stdio (JSON-RPC 2.0). +No external dependencies; uses only stdlib. + +Tools exposed: + list_sessions — list captured sessions with metadata + get_session — full event stream for a session + search_events — filter events by tool, file path, exit code, or time range + get_session_summary — plain-English phase breakdown (wraps explain_session) + diff_sessions — compare two sessions: what changed between runs + +Usage: + agent-strace mcp # stdio transport (default) + agent-strace mcp --trace-dir DIR # custom trace directory + +Claude Code config (.claude/settings.json): + { + "mcpServers": { + "agent-trace": { + "command": "agent-strace", + "args": ["mcp"] + } + } + } +""" + +from __future__ import annotations + +import argparse +import io +import json +import sys +from typing import Any + +from .explain import explain_session, format_explain +from .models import EventType, TraceEvent +from .store import TraceStore + + +# --------------------------------------------------------------------------- +# Tool schemas +# --------------------------------------------------------------------------- + +_TOOLS: list[dict] = [ + { + "name": "list_sessions", + "description": ( + "List captured agent sessions with metadata: session ID, start time, " + "tool call count, LLM requests, errors, total tokens, and estimated cost." + ), + "inputSchema": { + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "Maximum number of sessions to return (default: 20).", + "default": 20, + }, + "agent": { + "type": "string", + "description": "Filter by agent name (substring match).", + "default": "", + }, + }, + }, + }, + { + "name": "get_session", + "description": ( + "Return the full event stream for a session as structured JSON. " + "Includes every tool call, LLM request, file read/write, and error." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": { + "type": "string", + "description": "Session ID or unique prefix.", + }, + "event_types": { + "type": "array", + "items": {"type": "string"}, + "description": ( + "Filter to these event types only. " + "Valid values: tool_call, tool_result, llm_request, llm_response, " + "file_read, file_write, error, user_prompt, assistant_response. " + "Empty list returns all events." + ), + "default": [], + }, + }, + "required": ["session_id"], + }, + }, + { + "name": "search_events", + "description": ( + "Search events across one or all sessions. Filter by tool name, " + "file path substring, exit code, or time range." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": { + "type": "string", + "description": "Session ID or prefix. Omit to search all sessions.", + "default": "", + }, + "tool_name": { + "type": "string", + "description": "Filter tool_call events by tool name (case-insensitive substring).", + "default": "", + }, + "file_path": { + "type": "string", + "description": "Filter file_read/file_write events by path substring.", + "default": "", + }, + "exit_code": { + "type": "integer", + "description": "Filter tool_result events by exit code.", + }, + "has_error": { + "type": "boolean", + "description": "If true, return only error events.", + "default": False, + }, + "limit": { + "type": "integer", + "description": "Maximum events to return (default: 50).", + "default": 50, + }, + }, + }, + }, + { + "name": "get_session_summary", + "description": ( + "Return a plain-English summary of what the agent did in a session: " + "phases, files touched, commands run, retries, and wasted time." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_id": { + "type": "string", + "description": "Session ID or unique prefix.", + }, + }, + "required": ["session_id"], + }, + }, + { + "name": "diff_sessions", + "description": ( + "Compare two sessions and return a structured diff: " + "which tools were added/removed, file overlap, cost delta, " + "error delta, and token delta." + ), + "inputSchema": { + "type": "object", + "properties": { + "session_a": { + "type": "string", + "description": "First session ID or prefix (the 'before' session).", + }, + "session_b": { + "type": "string", + "description": "Second session ID or prefix (the 'after' session).", + }, + }, + "required": ["session_a", "session_b"], + }, + }, +] + + +# --------------------------------------------------------------------------- +# Tool implementations +# --------------------------------------------------------------------------- + +_COST_PER_1K = 0.003 # rough blended cost estimate + + +def _session_to_dict(meta) -> dict: + cost = meta.total_tokens / 1_000 * _COST_PER_1K + return { + "session_id": meta.session_id, + "started_at": meta.started_at, + "agent_name": meta.agent_name or "", + "command": meta.command or "", + "tool_calls": meta.tool_calls, + "llm_requests": meta.llm_requests, + "errors": meta.errors, + "total_tokens": meta.total_tokens, + "total_duration_ms": meta.total_duration_ms, + "estimated_cost_usd": round(cost, 4), + } + + +def _event_to_dict(ev: TraceEvent) -> dict: + return { + "event_id": ev.event_id, + "event_type": ev.event_type.value, + "timestamp": ev.timestamp, + "session_id": ev.session_id or "", + "parent_id": ev.parent_id or "", + "duration_ms": ev.duration_ms, + "data": ev.data, + } + + +def _tool_list_sessions(store: TraceStore, args: dict) -> str: + limit = int(args.get("limit") or 20) + agent_filter = str(args.get("agent") or "").lower() + sessions = store.list_sessions() + if agent_filter: + sessions = [s for s in sessions if agent_filter in (s.agent_name or "").lower()] + sessions = sessions[:limit] + result = [_session_to_dict(s) for s in sessions] + return json.dumps({"sessions": result, "count": len(result)}, indent=2) + + +def _tool_get_session(store: TraceStore, args: dict) -> str: + session_id = str(args.get("session_id") or "") + if not session_id: + return json.dumps({"error": "session_id is required"}) + full_id = store.find_session(session_id) + if not full_id: + return json.dumps({"error": f"session not found: {session_id}"}) + + meta = store.load_meta(full_id) + events = store.load_events(full_id) + + type_filter: list[str] = [t.lower() for t in (args.get("event_types") or [])] + if type_filter: + events = [e for e in events if e.event_type.value in type_filter] + + return json.dumps({ + "session": _session_to_dict(meta), + "events": [_event_to_dict(e) for e in events], + "event_count": len(events), + }, indent=2) + + +def _tool_search_events(store: TraceStore, args: dict) -> str: + session_id = str(args.get("session_id") or "") + tool_name = str(args.get("tool_name") or "").lower() + file_path = str(args.get("file_path") or "").lower() + exit_code = args.get("exit_code") + has_error = bool(args.get("has_error")) + limit = int(args.get("limit") or 50) + + if session_id: + full_id = store.find_session(session_id) + if not full_id: + return json.dumps({"error": f"session not found: {session_id}"}) + session_ids = [full_id] + else: + sessions = store.list_sessions() + session_ids = [s.session_id for s in sessions] + + matches: list[dict] = [] + for sid in session_ids: + try: + events = store.load_events(sid) + except Exception: + continue + for ev in events: + if has_error and ev.event_type != EventType.ERROR: + continue + if tool_name and ev.event_type == EventType.TOOL_CALL: + if tool_name not in str(ev.data.get("tool_name", "")).lower(): + continue + elif tool_name: + continue + if file_path: + if ev.event_type not in (EventType.FILE_READ, EventType.FILE_WRITE): + continue + path = str(ev.data.get("path", ev.data.get("file_path", ""))).lower() + if file_path not in path: + continue + if exit_code is not None: + if ev.event_type != EventType.TOOL_RESULT: + continue + if ev.data.get("exit_code") != exit_code: + continue + d = _event_to_dict(ev) + d["_session_id"] = sid + matches.append(d) + if len(matches) >= limit: + break + if len(matches) >= limit: + break + + return json.dumps({"events": matches, "count": len(matches)}, indent=2) + + +def _tool_get_session_summary(store: TraceStore, args: dict) -> str: + session_id = str(args.get("session_id") or "") + if not session_id: + return json.dumps({"error": "session_id is required"}) + full_id = store.find_session(session_id) + if not full_id: + return json.dumps({"error": f"session not found: {session_id}"}) + + result = explain_session(store, full_id) + buf = io.StringIO() + format_explain(result, out=buf) + return buf.getvalue() + + +def _tool_diff_sessions(store: TraceStore, args: dict) -> str: + sid_a = str(args.get("session_a") or "") + sid_b = str(args.get("session_b") or "") + if not sid_a or not sid_b: + return json.dumps({"error": "session_a and session_b are required"}) + + full_a = store.find_session(sid_a) + full_b = store.find_session(sid_b) + if not full_a: + return json.dumps({"error": f"session not found: {sid_a}"}) + if not full_b: + return json.dumps({"error": f"session not found: {sid_b}"}) + + meta_a = store.load_meta(full_a) + meta_b = store.load_meta(full_b) + events_a = store.load_events(full_a) + events_b = store.load_events(full_b) + + def _tools(events: list[TraceEvent]) -> dict[str, int]: + counts: dict[str, int] = {} + for e in events: + if e.event_type == EventType.TOOL_CALL: + name = e.data.get("tool_name", "unknown") + counts[name] = counts.get(name, 0) + 1 + return counts + + def _files(events: list[TraceEvent]) -> set[str]: + paths: set[str] = set() + for e in events: + if e.event_type in (EventType.FILE_READ, EventType.FILE_WRITE): + p = e.data.get("path", e.data.get("file_path", "")) + if p: + paths.add(p) + return paths + + def _errors(events: list[TraceEvent]) -> list[str]: + return [e.data.get("message", "") for e in events if e.event_type == EventType.ERROR] + + tools_a = _tools(events_a) + tools_b = _tools(events_b) + files_a = _files(events_a) + files_b = _files(events_b) + errors_a = _errors(events_a) + errors_b = _errors(events_b) + + all_tools = set(tools_a) | set(tools_b) + tool_diff = { + t: {"session_a": tools_a.get(t, 0), "session_b": tools_b.get(t, 0)} + for t in sorted(all_tools) + if tools_a.get(t, 0) != tools_b.get(t, 0) + } + + cost_a = meta_a.total_tokens / 1_000 * _COST_PER_1K + cost_b = meta_b.total_tokens / 1_000 * _COST_PER_1K + + return json.dumps({ + "session_a": full_a, + "session_b": full_b, + "tool_call_diff": tool_diff, + "files_only_in_a": sorted(files_a - files_b), + "files_only_in_b": sorted(files_b - files_a), + "files_in_both": sorted(files_a & files_b), + "token_delta": meta_b.total_tokens - meta_a.total_tokens, + "cost_delta_usd": round(cost_b - cost_a, 4), + "error_count_a": len(errors_a), + "error_count_b": len(errors_b), + "duration_delta_ms": meta_b.total_duration_ms - meta_a.total_duration_ms, + "errors_a": errors_a[:10], + "errors_b": errors_b[:10], + }, indent=2) + + +# --------------------------------------------------------------------------- +# JSON-RPC 2.0 dispatcher +# --------------------------------------------------------------------------- + +def _ok(req_id: Any, result: Any) -> dict: + return {"jsonrpc": "2.0", "id": req_id, "result": result} + + +def _err(req_id: Any, code: int, message: str) -> dict: + return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} + + +def _handle(store: TraceStore, request: dict) -> dict | None: + req_id = request.get("id") + method = request.get("method", "") + params = request.get("params") or {} + + # Notifications (no id) — no response required + if req_id is None and method not in ("initialize",): + return None + + # MCP lifecycle + if method == "initialize": + return _ok(req_id, { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "agent-trace", "version": "0.38.0"}, + }) + + if method == "notifications/initialized": + return None + + if method == "tools/list": + return _ok(req_id, {"tools": _TOOLS}) + + if method == "tools/call": + name = params.get("name", "") + tool_args = params.get("arguments") or {} + + dispatch = { + "list_sessions": _tool_list_sessions, + "get_session": _tool_get_session, + "search_events": _tool_search_events, + "get_session_summary": _tool_get_session_summary, + "diff_sessions": _tool_diff_sessions, + } + fn = dispatch.get(name) + if fn is None: + return _err(req_id, -32601, f"unknown tool: {name}") + + try: + text = fn(store, tool_args) + except Exception as exc: + return _err(req_id, -32603, str(exc)) + + return _ok(req_id, { + "content": [{"type": "text", "text": text}], + "isError": False, + }) + + if method == "ping": + return _ok(req_id, {}) + + return _err(req_id, -32601, f"method not found: {method}") + + +# --------------------------------------------------------------------------- +# Stdio transport loop +# --------------------------------------------------------------------------- + +def run_stdio(store: TraceStore) -> None: + """Read JSON-RPC requests from stdin, write responses to stdout.""" + stdin = sys.stdin + stdout = sys.stdout + + # Use binary mode for reliable line reading across platforms + if hasattr(stdin, "buffer"): + reader = io.TextIOWrapper(stdin.buffer, encoding="utf-8", newline="\n") + else: + reader = stdin + + for raw_line in reader: + raw_line = raw_line.strip() + if not raw_line: + continue + try: + request = json.loads(raw_line) + except json.JSONDecodeError as exc: + response = _err(None, -32700, f"parse error: {exc}") + stdout.write(json.dumps(response) + "\n") + stdout.flush() + continue + + response = _handle(store, request) + if response is not None: + stdout.write(json.dumps(response) + "\n") + stdout.flush() + + +# --------------------------------------------------------------------------- +# CLI handler +# --------------------------------------------------------------------------- + +def cmd_mcp(args: argparse.Namespace) -> int: + store = TraceStore(args.trace_dir) + sys.stderr.write( + f"agent-trace MCP server started (trace_dir={args.trace_dir})\n" + "Waiting for JSON-RPC requests on stdin...\n" + ) + try: + run_stdio(store) + except KeyboardInterrupt: + pass + return 0 diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py new file mode 100644 index 0000000..151e673 --- /dev/null +++ b/tests/test_mcp_server.py @@ -0,0 +1,388 @@ +"""Tests for the MCP server (agent-strace mcp).""" + +import json +import tempfile +import time +import unittest + +from agent_trace.mcp_server import ( + _handle, + _tool_diff_sessions, + _tool_get_session, + _tool_get_session_summary, + _tool_list_sessions, + _tool_search_events, +) +from agent_trace.models import EventType, SessionMeta, TraceEvent +from agent_trace.store import TraceStore + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_store() -> TraceStore: + return TraceStore(tempfile.mkdtemp()) + + +def _add_session( + store: TraceStore, + session_id: str, + events: list[TraceEvent] | None = None, + agent_name: str = "", + total_tokens: int = 1000, + total_duration_ms: float = 60_000, +) -> SessionMeta: + ts = time.time() + meta = SessionMeta( + session_id=session_id, + started_at=ts, + ended_at=ts + 60, + agent_name=agent_name, + total_tokens=total_tokens, + total_duration_ms=total_duration_ms, + tool_calls=sum(1 for e in (events or []) if e.event_type == EventType.TOOL_CALL), + errors=sum(1 for e in (events or []) if e.event_type == EventType.ERROR), + ) + store.create_session(meta) + for ev in (events or []): + store.append_event(session_id, ev) + return meta + + +def _tool_call(name: str, ts: float = 0.0) -> TraceEvent: + return TraceEvent(event_type=EventType.TOOL_CALL, timestamp=ts, + data={"tool_name": name, "arguments": {}}) + + +def _file_write(path: str, ts: float = 0.0) -> TraceEvent: + return TraceEvent(event_type=EventType.FILE_WRITE, timestamp=ts, data={"path": path}) + + +def _file_read(path: str, ts: float = 0.0) -> TraceEvent: + return TraceEvent(event_type=EventType.FILE_READ, timestamp=ts, data={"path": path}) + + +def _error(msg: str = "fail", ts: float = 0.0) -> TraceEvent: + return TraceEvent(event_type=EventType.ERROR, timestamp=ts, data={"message": msg}) + + +def _tool_result(exit_code: int = 0, ts: float = 0.0) -> TraceEvent: + return TraceEvent(event_type=EventType.TOOL_RESULT, timestamp=ts, + data={"exit_code": exit_code}) + + +def _rpc(method: str, params: dict, req_id: int = 1) -> dict: + return {"jsonrpc": "2.0", "id": req_id, "method": method, "params": params} + + +# --------------------------------------------------------------------------- +# JSON-RPC lifecycle +# --------------------------------------------------------------------------- + +class TestMcpLifecycle(unittest.TestCase): + def setUp(self): + self.store = _make_store() + + def test_initialize_returns_server_info(self): + r = _handle(self.store, _rpc("initialize", {})) + self.assertEqual(r["result"]["serverInfo"]["name"], "agent-trace") + self.assertIn("protocolVersion", r["result"]) + + def test_tools_list_returns_five_tools(self): + r = _handle(self.store, _rpc("tools/list", {})) + names = [t["name"] for t in r["result"]["tools"]] + self.assertIn("list_sessions", names) + self.assertIn("get_session", names) + self.assertIn("search_events", names) + self.assertIn("get_session_summary", names) + self.assertIn("diff_sessions", names) + self.assertEqual(len(names), 5) + + def test_unknown_method_returns_error(self): + r = _handle(self.store, _rpc("unknown/method", {})) + self.assertIn("error", r) + self.assertEqual(r["error"]["code"], -32601) + + def test_notification_returns_none(self): + req = {"jsonrpc": "2.0", "method": "notifications/initialized", "params": {}} + r = _handle(self.store, req) + self.assertIsNone(r) + + def test_ping_returns_empty_result(self): + r = _handle(self.store, _rpc("ping", {})) + self.assertEqual(r["result"], {}) + + def test_unknown_tool_returns_error(self): + r = _handle(self.store, _rpc("tools/call", {"name": "nonexistent", "arguments": {}})) + self.assertIn("error", r) + + def test_malformed_json_handled_gracefully(self): + # _handle itself doesn't parse JSON — that's run_stdio's job + # but we can verify a missing method returns an error + r = _handle(self.store, {"jsonrpc": "2.0", "id": 1}) + self.assertIn("error", r) + + +# --------------------------------------------------------------------------- +# list_sessions +# --------------------------------------------------------------------------- + +class TestListSessions(unittest.TestCase): + def setUp(self): + self.store = _make_store() + + def test_empty_store(self): + result = json.loads(_tool_list_sessions(self.store, {})) + self.assertEqual(result["count"], 0) + self.assertEqual(result["sessions"], []) + + def test_returns_sessions(self): + _add_session(self.store, "sess1", agent_name="claude") + _add_session(self.store, "sess2", agent_name="cursor") + result = json.loads(_tool_list_sessions(self.store, {})) + self.assertEqual(result["count"], 2) + + def test_limit_respected(self): + for i in range(5): + _add_session(self.store, f"sess{i}") + result = json.loads(_tool_list_sessions(self.store, {"limit": 2})) + self.assertEqual(result["count"], 2) + + def test_agent_filter(self): + _add_session(self.store, "s1", agent_name="claude") + _add_session(self.store, "s2", agent_name="cursor") + result = json.loads(_tool_list_sessions(self.store, {"agent": "claude"})) + self.assertEqual(result["count"], 1) + self.assertEqual(result["sessions"][0]["agent_name"], "claude") + + def test_session_has_cost_field(self): + _add_session(self.store, "s1", total_tokens=1000) + result = json.loads(_tool_list_sessions(self.store, {})) + self.assertIn("estimated_cost_usd", result["sessions"][0]) + + +# --------------------------------------------------------------------------- +# get_session +# --------------------------------------------------------------------------- + +class TestGetSession(unittest.TestCase): + def setUp(self): + self.store = _make_store() + _add_session(self.store, "abc123", events=[ + _tool_call("Bash"), _file_write("src/main.py"), _error("oops"), + ]) + + def test_returns_events(self): + result = json.loads(_tool_get_session(self.store, {"session_id": "abc123"})) + self.assertEqual(result["event_count"], 3) + + def test_prefix_match(self): + result = json.loads(_tool_get_session(self.store, {"session_id": "abc"})) + self.assertEqual(result["event_count"], 3) + + def test_event_type_filter(self): + result = json.loads(_tool_get_session(self.store, { + "session_id": "abc123", + "event_types": ["tool_call"], + })) + self.assertEqual(result["event_count"], 1) + self.assertEqual(result["events"][0]["event_type"], "tool_call") + + def test_not_found_returns_error(self): + result = json.loads(_tool_get_session(self.store, {"session_id": "zzz"})) + self.assertIn("error", result) + + def test_missing_session_id_returns_error(self): + result = json.loads(_tool_get_session(self.store, {})) + self.assertIn("error", result) + + def test_session_metadata_included(self): + result = json.loads(_tool_get_session(self.store, {"session_id": "abc123"})) + self.assertIn("session", result) + self.assertEqual(result["session"]["session_id"], "abc123") + + +# --------------------------------------------------------------------------- +# search_events +# --------------------------------------------------------------------------- + +class TestSearchEvents(unittest.TestCase): + def setUp(self): + self.store = _make_store() + _add_session(self.store, "s1", events=[ + _tool_call("Bash"), _tool_call("Read"), + _file_write("src/app.py"), _file_read("README.md"), + _error("something failed"), + _tool_result(exit_code=1), + ]) + _add_session(self.store, "s2", events=[ + _tool_call("Write"), _file_write("tests/test_foo.py"), + ]) + + def test_filter_by_tool_name(self): + result = json.loads(_tool_search_events(self.store, {"tool_name": "bash"})) + self.assertEqual(result["count"], 1) + self.assertEqual(result["events"][0]["data"]["tool_name"], "Bash") + + def test_filter_by_file_path(self): + result = json.loads(_tool_search_events(self.store, {"file_path": "src/"})) + self.assertEqual(result["count"], 1) + + def test_filter_has_error(self): + result = json.loads(_tool_search_events(self.store, {"has_error": True})) + self.assertEqual(result["count"], 1) + self.assertEqual(result["events"][0]["event_type"], "error") + + def test_filter_by_exit_code(self): + result = json.loads(_tool_search_events(self.store, {"exit_code": 1})) + self.assertEqual(result["count"], 1) + + def test_scoped_to_session(self): + result = json.loads(_tool_search_events(self.store, { + "session_id": "s1", "tool_name": "bash", + })) + self.assertEqual(result["count"], 1) + + def test_cross_session_search(self): + # src/app.py (file_write s1) + tests/test_foo.py (file_write s2) = 2 + result = json.loads(_tool_search_events(self.store, {"file_path": ".py"})) + self.assertEqual(result["count"], 2) + + def test_limit_respected(self): + result = json.loads(_tool_search_events(self.store, {"has_error": True, "limit": 1})) + self.assertLessEqual(result["count"], 1) + + def test_session_not_found_returns_error(self): + result = json.loads(_tool_search_events(self.store, {"session_id": "zzz"})) + self.assertIn("error", result) + + +# --------------------------------------------------------------------------- +# get_session_summary +# --------------------------------------------------------------------------- + +class TestGetSessionSummary(unittest.TestCase): + def setUp(self): + self.store = _make_store() + _add_session(self.store, "sum1", events=[ + _tool_call("Bash", ts=time.time()), + _file_write("src/main.py", ts=time.time() + 1), + ]) + + def test_returns_text_summary(self): + result = _tool_get_session_summary(self.store, {"session_id": "sum1"}) + self.assertIsInstance(result, str) + self.assertIn("sum1", result) + + def test_not_found_returns_error_json(self): + result = json.loads(_tool_get_session_summary(self.store, {"session_id": "zzz"})) + self.assertIn("error", result) + + def test_missing_session_id_returns_error(self): + result = json.loads(_tool_get_session_summary(self.store, {})) + self.assertIn("error", result) + + +# --------------------------------------------------------------------------- +# diff_sessions +# --------------------------------------------------------------------------- + +class TestDiffSessions(unittest.TestCase): + def setUp(self): + self.store = _make_store() + _add_session(self.store, "before", events=[ + _tool_call("Bash"), _tool_call("Bash"), + _file_write("src/a.py"), _file_read("README.md"), + _error("fail"), + ], total_tokens=1000) + _add_session(self.store, "after", events=[ + _tool_call("Bash"), _tool_call("Write"), + _file_write("src/b.py"), _file_read("README.md"), + ], total_tokens=800) + + def test_returns_diff_structure(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertIn("tool_call_diff", result) + self.assertIn("files_only_in_a", result) + self.assertIn("files_only_in_b", result) + self.assertIn("files_in_both", result) + self.assertIn("token_delta", result) + self.assertIn("cost_delta_usd", result) + + def test_token_delta(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertEqual(result["token_delta"], -200) + + def test_files_only_in_a(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertIn("src/a.py", result["files_only_in_a"]) + + def test_files_only_in_b(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertIn("src/b.py", result["files_only_in_b"]) + + def test_files_in_both(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertIn("README.md", result["files_in_both"]) + + def test_error_counts(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "after", + })) + self.assertEqual(result["error_count_a"], 1) + self.assertEqual(result["error_count_b"], 0) + + def test_missing_session_returns_error(self): + result = json.loads(_tool_diff_sessions(self.store, { + "session_a": "before", "session_b": "zzz", + })) + self.assertIn("error", result) + + def test_missing_args_returns_error(self): + result = json.loads(_tool_diff_sessions(self.store, {})) + self.assertIn("error", result) + + +# --------------------------------------------------------------------------- +# tools/call dispatch via _handle +# --------------------------------------------------------------------------- + +class TestHandleToolsCall(unittest.TestCase): + def setUp(self): + self.store = _make_store() + _add_session(self.store, "t1", events=[_tool_call("Bash")]) + + def test_list_sessions_via_handle(self): + r = _handle(self.store, _rpc("tools/call", { + "name": "list_sessions", "arguments": {}, + })) + result = json.loads(r["result"]["content"][0]["text"]) + self.assertEqual(result["count"], 1) + + def test_get_session_via_handle(self): + r = _handle(self.store, _rpc("tools/call", { + "name": "get_session", "arguments": {"session_id": "t1"}, + })) + result = json.loads(r["result"]["content"][0]["text"]) + self.assertEqual(result["event_count"], 1) + + def test_is_error_false_on_success(self): + r = _handle(self.store, _rpc("tools/call", { + "name": "list_sessions", "arguments": {}, + })) + self.assertFalse(r["result"]["isError"]) + + +if __name__ == "__main__": + unittest.main()