diff --git a/TDD_IMPLEMENTATION_PLAN.md b/TDD_IMPLEMENTATION_PLAN.md index 38e7e28ab8..5300ec9e00 100644 --- a/TDD_IMPLEMENTATION_PLAN.md +++ b/TDD_IMPLEMENTATION_PLAN.md @@ -17,7 +17,7 @@ | ADB raw tap/swipe/keyevent | `adb_vision/server.py` | ✅ **DONE** | Via `adb_tap`, `adb_swipe`, `adb_keyevent` tools | | ALAS state machine | `alas_wrapped/module/ui/page.py` | Reference only | 43 pages, 98 transitions — extract knowledge, not code | | MEmu config | `docs/dev/memu_playbook.md` | Documented | Admin-at-startup solved via memuc.exe | -| ALAS MCP server | `agent_orchestrator/alas_mcp_server.py` | ⚠️ **DEPRECATED** | Do not extend — use `adb_vision/server.py` | +| ALAS MCP server | `agent_orchestrator/alas_mcp_server.py` | ⚠️ **DEPRECATED** | Do not extend — generic internal tool bridge is disabled by default | ### What Must Be Built (Greenfield) diff --git a/agent_orchestrator/alas_mcp_server.py b/agent_orchestrator/alas_mcp_server.py index f69b58a465..774668201d 100644 --- a/agent_orchestrator/alas_mcp_server.py +++ b/agent_orchestrator/alas_mcp_server.py @@ -2,52 +2,24 @@ import asyncio import base64 import io -import json import logging import os import sys import inspect import time -from datetime import datetime, timezone from pathlib import Path from typing import Optional, List, Dict, Any +from mcp_audit import record_command, record_event -# --------------------------------------------------------------------------- -# Always-on action log — every MCP tool call appended as a JSONL line. -# Screenshots are also saved as PNG files alongside the log. -# --------------------------------------------------------------------------- -_ACTION_LOG_PATH = Path(__file__).parent / "mcp_actions.jsonl" -_SCREENSHOT_DIR = Path(__file__).parent / "mcp_screenshots" -_action_seq = 0 # monotonic call counter within this server process - - -def _action_log(tool: str, args: dict, result_summary: str, error: str = "", duration_ms: int = 0): - """Append one JSONL record to the action log (never raises).""" - global _action_seq - _action_seq += 1 - record = { - "seq": _action_seq, - "ts": datetime.now(timezone.utc).isoformat(), - "tool": tool, - "args": args, - "result": result_summary, - "error": error, - "duration_ms": duration_ms, - } - try: - _ACTION_LOG_PATH.parent.mkdir(parents=True, exist_ok=True) - with open(_ACTION_LOG_PATH, "a", encoding="utf-8") as fh: - fh.write(json.dumps(record, ensure_ascii=True) + "\n") - except Exception: - pass # never disrupt a tool call due to logging failure +_SCREENSHOT_DIR = Path(__file__).parent / "mcp_screenshots" -def _save_screenshot_png(data_b64: str, seq: int) -> str: +def _save_screenshot_png(data_b64: str) -> str: """Save a base64 PNG to mcp_screenshots/_.png; return the path.""" try: _SCREENSHOT_DIR.mkdir(parents=True, exist_ok=True) - ts = datetime.now().strftime("%Y%m%dT%H%M%S") - fname = _SCREENSHOT_DIR / f"{seq:05d}_{ts}.png" + stamp = time.time_ns() + fname = _SCREENSHOT_DIR / f"{stamp}_{os.getpid()}.png" fname.write_bytes(base64.b64decode(data_b64)) return str(fname) except Exception: @@ -84,6 +56,22 @@ def _find_adb() -> str: # ADB serial — set from config in main(); used by async ADB CLI tools. # --------------------------------------------------------------------------- ADB_SERIAL: str = "127.0.0.1:21513" +ALLOW_UNSAFE_STATE_MACHINE_CALLS: bool = ( + os.environ.get("ALAS_ALLOW_UNSAFE_STATE_MACHINE_CALLS", "").strip().lower() + in {"1", "true", "yes"} +) +REMOTE_ALLOWED_ALAS_TOOLS: set[str] = { + "main.collect_mail", + "dorm.collect_rewards", + "dorm.feed_ships", + "dorm.buy_furniture", + "dorm.get_ship_count", + "commission.run", + "research.run", + "shop.run", + "guild.collect_lobby_rewards", + "workflow.daily_base_sweep", +} async def _adb_run(*args: str, timeout: float) -> bytes: @@ -93,8 +81,10 @@ async def _adb_run(*args: str, timeout: float) -> bytes: Raises TimeoutError if the command exceeds *timeout* seconds. Raises RuntimeError if adb exits non-zero. """ + argv = [ADB_EXECUTABLE, "-s", ADB_SERIAL, *args] + started = time.perf_counter() proc = await asyncio.create_subprocess_exec( - ADB_EXECUTABLE, "-s", ADB_SERIAL, *args, + *argv, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) @@ -103,12 +93,36 @@ async def _adb_run(*args: str, timeout: float) -> bytes: except asyncio.TimeoutError: proc.kill() await proc.wait() + record_command( + command_name="adb.exec", + argv=argv, + duration_ms=(time.perf_counter() - started) * 1000, + status="error", + error=f"Timeout after {timeout}s", + ) raise TimeoutError(f"adb timed out after {timeout}s: adb -s {ADB_SERIAL} {' '.join(args)}") if proc.returncode != 0: + stderr_text = stderr.decode(errors="replace").strip() + record_command( + command_name="adb.exec", + argv=argv, + duration_ms=(time.perf_counter() - started) * 1000, + status="error", + error=f"exit {proc.returncode}", + stderr=stderr_text, + ) raise RuntimeError( f"adb {' '.join(args)} failed (exit {proc.returncode}): " - f"{stderr.decode(errors='replace').strip()}" + f"{stderr_text}" ) + record_command( + command_name="adb.exec", + argv=argv, + duration_ms=(time.perf_counter() - started) * 1000, + status="success", + stdout=stdout, + stderr=stderr, + ) return stdout try: @@ -234,15 +248,27 @@ def _take_screenshot() -> str: timeout=25.0, ) except asyncio.TimeoutError: - _action_log("adb_screenshot", {"serial": ADB_SERIAL}, "TIMEOUT", - "screencap timed out after 25s", 25000) + record_event( + tool_name="adb_screenshot.capture", + arguments={"serial": ADB_SERIAL}, + status="error", + result_summary="capture timeout", + duration_ms=25000, + error="screencap timed out after 25s", + event_type="capture", + ) raise RuntimeError("adb_screenshot timed out after 25s") - ms = int((time.monotonic() - t0) * 1000) png_bytes = base64.b64decode(data) - saved = _save_screenshot_png(data, _action_seq + 1) - _action_log("adb_screenshot", {"serial": ADB_SERIAL}, - f"png_bytes={len(png_bytes)} saved={saved}", "", ms) + saved = _save_screenshot_png(data) + record_event( + tool_name="adb_screenshot.capture", + arguments={"serial": ADB_SERIAL}, + status="success", + result_summary=f"png_bytes={len(png_bytes)} saved={saved}", + duration_ms=(time.monotonic() - t0) * 1000, + event_type="capture", + ) return { "content": [ {"type": "image", "mimeType": "image/png", "data": data} @@ -260,10 +286,7 @@ async def adb_tap(x: int, y: int) -> str: x: X coordinate (integer) y: Y coordinate (integer) """ - t0 = time.monotonic() await _adb_run("shell", "input", "tap", str(x), str(y), timeout=5.0) - ms = int((time.monotonic() - t0) * 1000) - _action_log("adb_tap", {"x": x, "y": y, "serial": ADB_SERIAL}, f"tapped {x},{y}", "", ms) return f"tapped {x},{y}" @mcp.tool() @@ -276,7 +299,6 @@ async def adb_launch_game() -> str: Timeout: 10 seconds. """ - t0 = time.monotonic() await _adb_run( "shell", "am", "start", "-a", "android.intent.action.MAIN", @@ -284,8 +306,6 @@ async def adb_launch_game() -> str: "-n", "com.YoStarEN.AzurLane/com.manjuu.azurlane.PrePermissionActivity", timeout=10.0, ) - ms = int((time.monotonic() - t0) * 1000) - _action_log("adb_launch_game", {"serial": ADB_SERIAL}, "launch intent sent", "", ms) return "Azur Lane launch intent sent" @@ -303,12 +323,10 @@ async def adb_get_focus() -> Dict[str, Any]: "activity": "com.manjuu.azurlane.MainActivity" # or null } """ - t0 = time.monotonic() stdout = await _adb_run( "shell", "dumpsys", "window", "windows", timeout=8.0, ) - ms = int((time.monotonic() - t0) * 1000) raw_text = stdout.decode(errors="replace") focus_line = "" for line in raw_text.splitlines(): @@ -326,7 +344,6 @@ async def adb_get_focus() -> Dict[str, Any]: activity = m.group(2) result = {"raw": focus_line, "package": package, "activity": activity} - _action_log("adb_get_focus", {"serial": ADB_SERIAL}, f"{package}/{activity}", "", ms) return result @@ -344,16 +361,11 @@ async def adb_swipe(x1: int, y1: int, x2: int, y2: int, duration_ms: int = 300) y2: Ending Y coordinate duration_ms: Duration in milliseconds (default: 300) """ - t0 = time.monotonic() await _adb_run( "shell", "input", "swipe", str(x1), str(y1), str(x2), str(y2), str(duration_ms), timeout=5.0 + duration_ms / 1000.0, ) - ms = int((time.monotonic() - t0) * 1000) - _action_log("adb_swipe", {"x1": x1, "y1": y1, "x2": x2, "y2": y2, - "duration_ms": duration_ms, "serial": ADB_SERIAL}, - f"swiped {x1},{y1}->{x2},{y2}", "", ms) return f"swiped {x1},{y1}->{x2},{y2}" @mcp.tool() @@ -365,10 +377,7 @@ def alas_get_current_state() -> str: """ if ctx is None: raise RuntimeError("ALAS context not initialized") - t0 = time.monotonic() page = ctx._state_machine.get_current_state() - ms = int((time.monotonic() - t0) * 1000) - _action_log("alas_get_current_state", {}, str(page), "", ms) return str(page) @mcp.tool() @@ -383,17 +392,19 @@ def alas_goto(page: str) -> str: """ if ctx is None: raise RuntimeError("ALAS context not initialized") - t0 = time.monotonic() - err = "" from module.ui.page import Page destination = Page.all_pages.get(page) if destination is None: - err = f"unknown page: {page}" - _action_log("alas_goto", {"page": page}, "FAILED", err, 0) - raise ValueError(err) + raise ValueError(f"unknown page: {page}") ctx._state_machine.transition(destination) - ms = int((time.monotonic() - t0) * 1000) - _action_log("alas_goto", {"page": page}, f"navigated to {page}", "", ms) + record_event( + tool_name="state_machine.transition", + arguments={"page": page}, + status="success", + result_summary=f"navigated to {page}", + duration_ms=0, + event_type="delegate", + ) return f"navigated to {page}" @mcp.tool() @@ -412,8 +423,8 @@ def alas_list_tools() -> List[Dict[str, Any]]: "parameters": t.parameters } for t in ctx._state_machine.get_all_tools() + if ALLOW_UNSAFE_STATE_MACHINE_CALLS or t.name in REMOTE_ALLOWED_ALAS_TOOLS ] - _action_log("alas_list_tools", {}, f"{len(tools)} tools", "", 0) return tools @mcp.tool() @@ -427,17 +438,43 @@ def alas_call_tool(name: str, arguments: Optional[Dict[str, Any]] = None) -> Any if ctx is None: raise RuntimeError("ALAS context not initialized") t0 = time.monotonic() - err = "" args = arguments or {} + if not ALLOW_UNSAFE_STATE_MACHINE_CALLS and name not in REMOTE_ALLOWED_ALAS_TOOLS: + err = ( + f"state machine tool '{name}' is not exposed through MCP. " + "Add it to REMOTE_ALLOWED_ALAS_TOOLS for an intentional surface expansion." + ) + record_event( + tool_name="state_machine.call_tool", + arguments={"name": name, "args": args}, + status="blocked", + result_summary="blocked by MCP allowlist", + duration_ms=(time.monotonic() - t0) * 1000, + error=err, + event_type="delegate", + ) + raise PermissionError(err) try: result = ctx._state_machine.call_tool(name, **args) - except Exception as e: - err = str(e) - _action_log("alas_call_tool", {"name": name, "args": args}, "FAILED", err, int((time.monotonic()-t0)*1000)) + except Exception as exc: + record_event( + tool_name="state_machine.call_tool", + arguments={"name": name, "args": args}, + status="error", + result_summary="call failed", + duration_ms=(time.monotonic() - t0) * 1000, + error=f"{type(exc).__name__}: {exc}", + event_type="delegate", + ) raise - ms = int((time.monotonic() - t0) * 1000) - result_str = str(result)[:200] if result is not None else "None" - _action_log("alas_call_tool", {"name": name, "args": args}, result_str, "", ms) + record_event( + tool_name="state_machine.call_tool", + arguments={"name": name, "args": args}, + status="success", + result_summary=str(result)[:200] if result is not None else "None", + duration_ms=(time.monotonic() - t0) * 1000, + event_type="delegate", + ) return result @@ -462,7 +499,6 @@ def alas_login_ensure_main( from alas_wrapped.tools.login import ensure_main_with_config_device t0 = time.monotonic() - err = "" try: result = ensure_main_with_config_device( ctx.script.config, @@ -472,12 +508,25 @@ def alas_login_ensure_main( dismiss_popups=dismiss_popups, get_ship=get_ship, ) - except Exception as e: - err = str(e) - _action_log("alas_login_ensure_main", {"max_wait_s": max_wait_s}, "FAILED", err, int((time.monotonic()-t0)*1000)) + except Exception as exc: + record_event( + tool_name="login.ensure_main", + arguments={"max_wait_s": max_wait_s}, + status="error", + result_summary="ensure_main failed", + duration_ms=(time.monotonic() - t0) * 1000, + error=f"{type(exc).__name__}: {exc}", + event_type="delegate", + ) raise - ms = int((time.monotonic() - t0) * 1000) - _action_log("alas_login_ensure_main", {"max_wait_s": max_wait_s}, str(result.get("observed_state","?") if isinstance(result, dict) else result)[:200], "", ms) + record_event( + tool_name="login.ensure_main", + arguments={"max_wait_s": max_wait_s}, + status="success", + result_summary=str(result.get("observed_state", "?") if isinstance(result, dict) else result)[:200], + duration_ms=(time.monotonic() - t0) * 1000, + event_type="delegate", + ) return result def main(): @@ -504,4 +553,4 @@ def main(): mcp.run(transport="stdio") if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/agent_orchestrator/mcp_audit.py b/agent_orchestrator/mcp_audit.py index 11c4c87a5b..205cb0dc02 100644 --- a/agent_orchestrator/mcp_audit.py +++ b/agent_orchestrator/mcp_audit.py @@ -1,17 +1,12 @@ -"""MCP audit logging -- append-only JSONL with rotation. - -Every MCP tool call is logged with timestamp, tool name, arguments, -caller identity, duration, result summary, and error info. - -Two integration paths: -- FastMCP Middleware (AuditMiddleware) for the MCP stdio transport -- audit_cli_call() wrapper for the --cli subprocess path -""" +"""Canonical MCP audit logging with nested child-event support.""" from __future__ import annotations +import contextvars +import itertools import json import logging import os +import string import sys import time import traceback @@ -21,13 +16,9 @@ _logger = logging.getLogger("mcp_audit") -# --------------------------------------------------------------------------- -# Reusable JSONL append helper -# --------------------------------------------------------------------------- try: from module.base.jsonl import append_jsonl except ImportError: - # Fallback when ALAS is not importable (unit tests, standalone use). def append_jsonl(path, payload, rotate_bytes=None, error_callback=None): try: folder = os.path.dirname(path) @@ -41,47 +32,67 @@ def append_jsonl(path, payload, rotate_bytes=None, error_callback=None): root, ext = os.path.splitext(path) ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ") os.replace(path, f"{root}.{ts}{ext or '.jsonl'}") - with open(path, "a", encoding="utf-8") as f: - f.write(json.dumps(payload, ensure_ascii=True) + "\n") + with open(path, "a", encoding="utf-8") as handle: + handle.write(json.dumps(payload, ensure_ascii=True) + "\n") return True - except Exception as e: + except Exception as exc: if error_callback: try: - error_callback(e) + error_callback(exc) except Exception: pass return False -# --------------------------------------------------------------------------- -# Configuration -# --------------------------------------------------------------------------- _AUDIT_FILE = str(Path(__file__).parent / "mcp_audit.jsonl") -_ROTATE_BYTES = 20 * 1024 * 1024 # 20 MB +_ROTATE_BYTES = 20 * 1024 * 1024 _debug_mode: bool = False +_event_counter = itertools.count(1) +_current_context: contextvars.ContextVar[dict[str, Any] | None] = contextvars.ContextVar( + "mcp_audit_current_context", + default=None, +) def configure(*, debug: bool = False, audit_path: Optional[str] = None): - """Called once at startup to set module-level configuration.""" global _debug_mode, _AUDIT_FILE _debug_mode = debug if audit_path: _AUDIT_FILE = audit_path -# --------------------------------------------------------------------------- -# Result summarization -# --------------------------------------------------------------------------- +def _truncate(value: str, limit: int = 200) -> str: + if len(value) <= limit: + return value + return f"{value[:limit]}...<{len(value)} chars>" + + +def _looks_binary(data: bytes) -> bool: + if not data: + return False + if b"\x00" in data: + return True + sample = data[:256] + printable = sum(chr(byte) in string.printable for byte in sample) + return printable / len(sample) < 0.85 + + +def _summarize_blob(blob: str | bytes, limit: int = 200) -> str: + if isinstance(blob, bytes): + if _looks_binary(blob): + return f"" + blob = blob.decode("utf-8", errors="replace") + return _truncate(blob.strip(), limit) + + def _summarize_result(result: Any) -> str: - """Short human-readable summary. Never includes full base64 data.""" - # FastMCP ToolResult (has .content list of ContentBlock) if hasattr(result, "content") and isinstance(result.content, list): parts = [] for block in result.content: block_type = getattr(block, "type", None) if block_type == "text": text = getattr(block, "text", "") - parts.append(f"text: {text[:120]}" if len(text) > 120 else f"text: {text}") + parts.append(f"text: {_truncate(text, 120)}") elif block_type == "image": mime = getattr(block, "mimeType", "image/unknown") data = getattr(block, "data", "") @@ -91,59 +102,50 @@ def _summarize_result(result: Any) -> str: parts.append(f"{block_type or 'unknown'}: ...") return "; ".join(parts) if parts else "" - # Raw dict (CLI path or direct return) if isinstance(result, dict): if "content" in result and isinstance(result["content"], list): parts = [] for item in result["content"]: - if isinstance(item, dict): - if item.get("type") == "image": - mime = item.get("mimeType", "image/unknown") - data = item.get("data", "") - size_kb = (len(data) * 3 // 4) / 1024 if data else 0 - parts.append(f"image: {mime}, ~{size_kb:.0f}KB") - elif item.get("type") == "text": - text = item.get("text", "") - parts.append(f"text: {text[:120]}") - return "; ".join(parts) if parts else str(result)[:200] + if not isinstance(item, dict): + continue + if item.get("type") == "image": + mime = item.get("mimeType", "image/unknown") + data = item.get("data", "") + size_kb = (len(data) * 3 // 4) / 1024 if data else 0 + parts.append(f"image: {mime}, ~{size_kb:.0f}KB") + elif item.get("type") == "text": + parts.append(f"text: {_truncate(item.get('text', ''), 120)}") + return "; ".join(parts) if parts else _truncate(str(result)) if "success" in result: state = result.get("observed_state") or result.get("expected_state") status = "ok" if result.get("success") else "fail" return f"{status}, state={state}" - return str(result)[:200] + return _truncate(str(result)) if isinstance(result, list): return f"list[{len(result)} items]" - if isinstance(result, str): - return result[:200] + return _truncate(result) + return _truncate(str(result)) - return str(result)[:200] - -# --------------------------------------------------------------------------- -# Argument sanitization -# --------------------------------------------------------------------------- def _sanitize_arguments(arguments: Optional[Dict[str, Any]]) -> Dict[str, Any]: - """Strip large binary data and _caller from logged args.""" if not arguments: return {} - sanitized = {} - for k, v in arguments.items(): - if k == "_caller": + sanitized: Dict[str, Any] = {} + for key, value in arguments.items(): + if key == "_caller": continue - if isinstance(v, str) and len(v) > 500: - sanitized[k] = f"<{len(v)} chars>" + if isinstance(value, str): + sanitized[key] = _truncate(value, 500) + elif isinstance(value, (list, tuple)) and len(value) > 50: + sanitized[key] = f"<{len(value)} items>" else: - sanitized[k] = v + sanitized[key] = value return sanitized -# --------------------------------------------------------------------------- -# Caller identification -# --------------------------------------------------------------------------- def _detect_caller(arguments: Optional[Dict[str, Any]]) -> str: - """Extract caller identity from arguments, env, or heuristics.""" if arguments and "_caller" in arguments: return str(arguments["_caller"]) env_caller = os.environ.get("MCP_CALLER") @@ -155,9 +157,10 @@ def _detect_caller(arguments: Optional[Dict[str, Any]]) -> str: return "unknown" -# --------------------------------------------------------------------------- -# Record construction and writing -# --------------------------------------------------------------------------- +def _next_event_id() -> str: + return f"evt-{os.getpid()}-{next(_event_counter)}" + + def _build_audit_record( *, tool_name: str, @@ -168,11 +171,12 @@ def _build_audit_record( error: Optional[str], result_summary: str, mode: str, + event_type: str, + event_id: str, + parent_event_id: Optional[str], ) -> Dict[str, Any]: return { - "ts": datetime.now(timezone.utc) - .isoformat(timespec="milliseconds") - .replace("+00:00", "Z"), + "ts": datetime.now(timezone.utc).isoformat(timespec="milliseconds").replace("+00:00", "Z"), "tool": tool_name, "arguments": _sanitize_arguments(arguments), "caller": caller, @@ -182,47 +186,101 @@ def _build_audit_record( "result_summary": result_summary, "pid": os.getpid(), "mode": mode, + "event_type": event_type, + "event_id": event_id, + "parent_event_id": parent_event_id, } def _write_audit(record: Dict[str, Any]) -> None: - """Write one audit record to JSONL and optionally to stderr.""" append_jsonl( _AUDIT_FILE, record, rotate_bytes=_ROTATE_BYTES, - error_callback=lambda e: _logger.warning(f"audit write failed: {e}"), + error_callback=lambda exc: _logger.warning(f"audit write failed: {exc}"), ) if _debug_mode: - ts = record["ts"] - tool = record["tool"] - dur = record["duration_ms"] - caller = record["caller"] - status = record["status"] - summary = record["result_summary"] - err = record.get("error") - line = f"[AUDIT] {ts} {tool} caller={caller} {dur:.1f}ms {status}" - if err: - line += f" ERROR: {err}" - else: - line += f" -> {summary}" - print(line, file=sys.stderr) + line = json.dumps(record, ensure_ascii=True) + print(f"[AUDIT] {line}", file=sys.stderr) -# --------------------------------------------------------------------------- -# FastMCP Middleware -# --------------------------------------------------------------------------- -AuditMiddleware = None # Will be set below if FastMCP is available +def _current_parent_event_id() -> Optional[str]: + context = _current_context.get() + if not context: + return None + return context.get("event_id") -try: - from fastmcp.server.middleware import Middleware, MiddlewareContext, CallNext - from fastmcp.tools.tool import ToolResult as _ToolResult +def _current_caller(default_arguments: Optional[Dict[str, Any]] = None) -> str: + context = _current_context.get() + if context and context.get("caller"): + return str(context["caller"]) + return _detect_caller(default_arguments) + + +def record_event( + *, + tool_name: str, + arguments: Optional[Dict[str, Any]], + status: str, + result_summary: str, + duration_ms: float, + error: Optional[str] = None, + mode: str = "child", + event_type: str = "child", + caller: Optional[str] = None, +) -> None: + record = _build_audit_record( + tool_name=tool_name, + arguments=arguments, + caller=caller or _current_caller(arguments), + duration_ms=duration_ms, + status=status, + error=error, + result_summary=result_summary, + mode=mode, + event_type=event_type, + event_id=_next_event_id(), + parent_event_id=_current_parent_event_id(), + ) + _write_audit(record) + + +def record_command( + *, + command_name: str, + argv: list[str], + duration_ms: float, + status: str, + error: Optional[str] = None, + stdout: str | bytes = "", + stderr: str | bytes = "", +) -> None: + summary_parts = [f"argv={' '.join(argv)}"] + if stdout: + summary_parts.append(f"stdout={_summarize_blob(stdout, 120)}") + if stderr: + summary_parts.append(f"stderr={_summarize_blob(stderr, 120)}") + record_event( + tool_name=command_name, + arguments={"argv": argv}, + status=status, + result_summary=" | ".join(summary_parts), + duration_ms=duration_ms, + error=error, + mode="subprocess", + event_type="subprocess", + ) + + +AuditMiddleware = None + +try: import mcp.types as _mt + from fastmcp.server.middleware import CallNext, Middleware, MiddlewareContext + from fastmcp.tools.tool import ToolResult as _ToolResult class _AuditMiddleware(Middleware): - """Audit logging middleware for all MCP tool calls.""" - async def on_call_tool( self, context: MiddlewareContext[_mt.CallToolRequestParams], @@ -232,91 +290,86 @@ async def on_call_tool( arguments = dict(context.message.arguments or {}) caller = _detect_caller(arguments) - # Strip _caller before forwarding to real tool if "_caller" in arguments: clean_args = {k: v for k, v in arguments.items() if k != "_caller"} context = context.copy( - message=_mt.CallToolRequestParams( - name=tool_name, - arguments=clean_args, - ) + message=_mt.CallToolRequestParams(name=tool_name, arguments=clean_args) ) start = time.perf_counter() - error_msg = None status = "success" + error_msg = None result_summary = "" - + event_id = _next_event_id() + token = _current_context.set( + {"event_id": event_id, "tool_name": tool_name, "caller": caller} + ) try: result = await call_next(context) result_summary = _summarize_result(result) return result - except Exception as e: + except Exception as exc: status = "error" - error_msg = f"{type(e).__name__}: {e}" + error_msg = f"{type(exc).__name__}: {exc}" if _debug_mode: error_msg += "\n" + traceback.format_exc(limit=6) raise finally: - duration_ms = (time.perf_counter() - start) * 1000 + _current_context.reset(token) record = _build_audit_record( tool_name=tool_name, arguments=arguments, caller=caller, - duration_ms=duration_ms, + duration_ms=(time.perf_counter() - start) * 1000, status=status, error=error_msg, result_summary=result_summary, mode="mcp", + event_type="tool_call", + event_id=event_id, + parent_event_id=None, ) _write_audit(record) AuditMiddleware = _AuditMiddleware - except ImportError: - pass # FastMCP not installed; CLI path still works - + pass -# --------------------------------------------------------------------------- -# CLI path wrapper -# --------------------------------------------------------------------------- -def audit_cli_call( - tool_name: str, - arguments: Dict[str, Any], - func: Callable, -) -> Any: - """Wrap a CLI tool invocation with audit logging. - Returns the raw result; raises on error (after logging). - """ +def audit_cli_call(tool_name: str, arguments: Dict[str, Any], func: Callable) -> Any: caller = _detect_caller(arguments) clean_args = {k: v for k, v in arguments.items() if k != "_caller"} - start = time.perf_counter() - error_msg = None status = "success" + error_msg = None result_summary = "" - + event_id = _next_event_id() + token = _current_context.set( + {"event_id": event_id, "tool_name": tool_name, "caller": caller} + ) try: result = func(**clean_args) result_summary = _summarize_result(result) return result - except Exception as e: + except Exception as exc: status = "error" - error_msg = f"{type(e).__name__}: {e}" + error_msg = f"{type(exc).__name__}: {exc}" if _debug_mode: error_msg += "\n" + traceback.format_exc(limit=6) raise finally: - duration_ms = (time.perf_counter() - start) * 1000 + _current_context.reset(token) record = _build_audit_record( tool_name=tool_name, arguments=arguments, caller=caller, - duration_ms=duration_ms, + duration_ms=(time.perf_counter() - start) * 1000, status=status, error=error_msg, result_summary=result_summary, mode="cli", + event_type="tool_call", + event_id=event_id, + parent_event_id=None, ) _write_audit(record) diff --git a/agent_orchestrator/test_alas_mcp.py b/agent_orchestrator/test_alas_mcp.py index 4975fd5eee..ddcb25de6a 100644 --- a/agent_orchestrator/test_alas_mcp.py +++ b/agent_orchestrator/test_alas_mcp.py @@ -8,6 +8,7 @@ @pytest.fixture def mock_ctx(): ctx = mock.Mock() + previous_flag = alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS # Mock state machine ctx._state_machine = mock.Mock() ctx._state_machine.get_current_state.return_value = "page_main" @@ -20,7 +21,8 @@ def mock_ctx(): ctx._state_machine.get_all_tools.return_value = [mock_tool] alas_mcp_server.ctx = ctx - return ctx + yield ctx + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = previous_flag def _make_dummy_png() -> bytes: @@ -141,17 +143,45 @@ def test_alas_goto_invalid(mock_ctx, monkeypatch): alas_mcp_server.alas_goto("invalid_page") def test_alas_list_tools(mock_ctx): + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = False + result = alas_mcp_server.alas_list_tools() + assert result == [] + + +def test_alas_list_tools_includes_allowlisted_tools(mock_ctx): + allowlisted_tool = mock.Mock() + allowlisted_tool.name = "commission.run" + allowlisted_tool.description = "collect commissions" + allowlisted_tool.parameters = {} + mock_ctx._state_machine.get_all_tools.return_value = [allowlisted_tool] + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = False result = alas_mcp_server.alas_list_tools() assert len(result) == 1 - assert result[0]["name"] == "test_tool" + assert result[0]["name"] == "commission.run" + +def test_alas_call_tool_blocked_by_default(mock_ctx): + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = False + with pytest.raises(PermissionError, match="not exposed through MCP"): + alas_mcp_server.alas_call_tool("test_tool", {"arg": 1}) + mock_ctx._state_machine.call_tool.assert_not_called() + -def test_alas_call_tool(mock_ctx): +def test_alas_call_tool_allowed_when_explicit(mock_ctx): + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = True mock_ctx._state_machine.call_tool.return_value = {"success": True} result = alas_mcp_server.alas_call_tool("test_tool", {"arg": 1}) assert result == {"success": True} mock_ctx._state_machine.call_tool.assert_called_with("test_tool", arg=1) +def test_alas_call_tool_allowlisted_name(mock_ctx): + alas_mcp_server.ALLOW_UNSAFE_STATE_MACHINE_CALLS = False + mock_ctx._state_machine.call_tool.return_value = {"success": True} + result = alas_mcp_server.alas_call_tool("commission.run", {"arg": 1}) + assert result == {"success": True} + mock_ctx._state_machine.call_tool.assert_called_with("commission.run", arg=1) + + # --------------------------------------------------------------------------- # adb_launch_game — async, uses _adb_run CLI; no ctx needed # --------------------------------------------------------------------------- @@ -206,4 +236,3 @@ async def test_adb_get_focus_launcher(): assert result["package"] == "com.microvirt.launcher2" assert result["activity"] == "com.microvirt.launcher2.MainActivity" - diff --git a/agent_orchestrator/test_integration_mcp.py b/agent_orchestrator/test_integration_mcp.py index aa39a6b5d8..53c6dddaf4 100644 --- a/agent_orchestrator/test_integration_mcp.py +++ b/agent_orchestrator/test_integration_mcp.py @@ -28,28 +28,24 @@ async def test_server_startup_and_list_tools(): """ Test the actual MCP server startup and tool listing. """ - # Create a real mock hierarchy with config and click_methods for adb_tap dispatch mock_ctx = mock.Mock() - mock_ctx.script = mock.Mock() - mock_ctx.script.config.Emulator_ControlMethod = 'MaaTouch' - mock_ctx.script.device = mock.Mock() - mock_ctx.script.device.click_methods = { - 'MaaTouch': mock_ctx.script.device.click_maatouch, - } - mock_ctx.encode_screenshot_png_base64.return_value = "fake_base64" server.ctx = mock_ctx - - # FastMCP call_tool is async and returns a ToolResult - result = await _call_tool("adb_tap", {"x": 10, "y": 20}) - # For FastMCP 3.0, call_tool might return a ToolResult object - # We check its content + + async def _mock_adb_run(*args, timeout): + assert args == ("shell", "input", "tap", "10", "20") + assert timeout == 5.0 + return b"" + + with mock.patch.object(server, "_adb_run", side_effect=_mock_adb_run) as mock_adb: + result = await _call_tool("adb_tap", {"x": 10, "y": 20}) + if hasattr(result, "content"): text = result.content[0].text else: text = str(result) - + assert "tapped 10,20" in text - mock_ctx.script.device.click_maatouch.assert_called_with(10, 20) + assert mock_adb.await_count == 1 @pytest.mark.asyncio async def test_alas_goto_integration(monkeypatch): diff --git a/agent_orchestrator/test_mcp_audit.py b/agent_orchestrator/test_mcp_audit.py new file mode 100644 index 0000000000..5d2109e122 --- /dev/null +++ b/agent_orchestrator/test_mcp_audit.py @@ -0,0 +1,70 @@ +import json +from pathlib import Path + +import mcp_audit + + +def _read_records(path: Path) -> list[dict]: + return [json.loads(line) for line in path.read_text(encoding="utf-8").splitlines()] + + +def test_record_command_writes_subprocess_event(tmp_path): + audit_path = tmp_path / "audit.jsonl" + mcp_audit.configure(audit_path=str(audit_path)) + + mcp_audit.record_command( + command_name="adb.exec", + argv=["adb", "-s", "serial", "get-state"], + duration_ms=12.5, + status="success", + stdout="device", + ) + + records = _read_records(audit_path) + assert len(records) == 1 + assert records[0]["tool"] == "adb.exec" + assert records[0]["mode"] == "subprocess" + assert records[0]["event_type"] == "subprocess" + assert records[0]["parent_event_id"] is None + + +def test_audit_cli_call_sets_parent_for_child_events(tmp_path): + audit_path = tmp_path / "audit.jsonl" + mcp_audit.configure(audit_path=str(audit_path)) + + def _run(): + mcp_audit.record_event( + tool_name="state_machine.call_tool", + arguments={"name": "commission.run"}, + status="success", + result_summary="ok", + duration_ms=3.0, + event_type="delegate", + ) + return "ok" + + result = mcp_audit.audit_cli_call("outer.tool", {}, _run) + assert result == "ok" + + records = _read_records(audit_path) + assert len(records) == 2 + child, parent = records[0], records[1] + assert child["parent_event_id"] == parent["event_id"] + assert child["event_type"] == "delegate" + assert parent["event_type"] == "tool_call" + + +def test_record_command_summarizes_binary_streams(tmp_path): + audit_path = tmp_path / "audit.jsonl" + mcp_audit.configure(audit_path=str(audit_path)) + + mcp_audit.record_command( + command_name="adb.exec", + argv=["adb", "exec-out", "screencap", "-p"], + duration_ms=10.0, + status="success", + stdout=b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR", + ) + + records = _read_records(audit_path) + assert " close announcements modal -> lobby via `adb_vision/pilot.py` - ⏳ Piloting mode (vision + manual recovery) not started diff --git a/docs/dev/logging.md b/docs/dev/logging.md index 89d81ba146..d89f4e040a 100644 --- a/docs/dev/logging.md +++ b/docs/dev/logging.md @@ -42,6 +42,7 @@ When LLM recovery is triggered: ## Service Surface Guardrails - Keep the MCP surface fixed and explicit; do not add generic shell/script execution tools. +- Deprecated `agent_orchestrator/alas_mcp_server.py` keeps its generic `alas_call_tool(...)` bridge disabled by default. - If a tool invocation is blocked by policy, log the blocked attempt with the same audit schema instead of failing silently. ## In-Flight Enhancements