diff --git a/src/acp/__init__.py b/src/acp/__init__.py index 71a693f..2d09a7d 100644 --- a/src/acp/__init__.py +++ b/src/acp/__init__.py @@ -4,7 +4,6 @@ Agent, Client, RequestError, - TerminalHandle, connect_to_agent, run_agent, ) @@ -133,7 +132,6 @@ "RequestError", "Agent", "Client", - "TerminalHandle", # stdio helper "stdio_streams", "spawn_stdio_connection", diff --git a/src/acp/agent/connection.py b/src/acp/agent/connection.py index 9fc55f2..30b1092 100644 --- a/src/acp/agent/connection.py +++ b/src/acp/agent/connection.py @@ -39,7 +39,6 @@ WriteTextFileRequest, WriteTextFileResponse, ) -from ..terminal import TerminalHandle from ..utils import compatible_class, notify_model, param_model, request_model, request_optional_model from .router import build_agent_router @@ -50,7 +49,9 @@ @final @compatible_class class AgentSideConnection: - """Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation.""" + """Agent-side connection wrapper that dispatches JSON-RPC messages to a Client implementation. + The agent can use this connection to communicate with the Client so it behaves like a Client. + """ def __init__( self, @@ -62,7 +63,7 @@ def __init__( use_unstable_protocol: bool = False, **connection_kwargs: Any, ) -> None: - agent = to_agent(cast(Client, self)) if callable(to_agent) else to_agent + agent = to_agent(self) if callable(to_agent) else to_agent if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_AGENT_CONNECTION_ERROR) handler = build_agent_router(cast(Agent, agent), use_unstable_protocol=use_unstable_protocol) @@ -141,8 +142,8 @@ async def create_terminal( env: list[EnvVariable] | None = None, output_byte_limit: int | None = None, **kwargs: Any, - ) -> TerminalHandle: - create_response = await request_model( + ) -> CreateTerminalResponse: + return await request_model( self._conn, CLIENT_METHODS["terminal_create"], CreateTerminalRequest( @@ -156,7 +157,6 @@ async def create_terminal( ), CreateTerminalResponse, ) - return TerminalHandle(create_response.terminal_id, session_id, self._conn) @param_model(TerminalOutputRequest) async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: @@ -214,3 +214,7 @@ async def __aenter__(self) -> AgentSideConnection: async def __aexit__(self, exc_type, exc, tb) -> None: await self.close() + + def on_connect(self, conn: Agent) -> None: + # A dummy method to match the Client protocol + pass diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index 7a5cdcb..c71da96 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -50,7 +50,9 @@ @final @compatible_class class ClientSideConnection: - """Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation.""" + """Client-side connection wrapper that dispatches JSON-RPC messages to an Agent implementation. + The client can use this connection to communicate with the Agent so it behaves like an Agent. + """ def __init__( self, @@ -63,7 +65,7 @@ def __init__( ) -> None: if not isinstance(input_stream, asyncio.StreamWriter) or not isinstance(output_stream, asyncio.StreamReader): raise TypeError(_CLIENT_CONNECTION_ERROR) - client = to_client(cast(Agent, self)) if callable(to_client) else to_client + client = to_client(self) if callable(to_client) else to_client handler = build_client_router(cast(Client, client), use_unstable_protocol=use_unstable_protocol) self._conn = Connection(handler, input_stream, output_stream, **connection_kwargs) if on_connect := getattr(client, "on_connect", None): @@ -221,3 +223,7 @@ async def __aenter__(self) -> ClientSideConnection: async def __aexit__(self, exc_type, exc, tb) -> None: await self.close() + + def on_connect(self, conn: Client) -> None: + # A dummy method to match the Agent protocol + pass diff --git a/src/acp/core.py b/src/acp/core.py index 1d440de..42632f6 100644 --- a/src/acp/core.py +++ b/src/acp/core.py @@ -14,7 +14,6 @@ from .connection import Connection, JsonValue, MethodHandler from .exceptions import RequestError from .interfaces import Agent, Client -from .terminal import TerminalHandle __all__ = [ "Agent", @@ -25,7 +24,6 @@ "JsonValue", "MethodHandler", "RequestError", - "TerminalHandle", "connect_to_agent", "run_agent", ] diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index 6d143fe..5cc45ff 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -65,7 +65,6 @@ WriteTextFileRequest, WriteTextFileResponse, ) -from .terminal import TerminalHandle from .utils import param_model __all__ = ["Agent", "Client"] @@ -114,7 +113,7 @@ async def create_terminal( env: list[EnvVariable] | None = None, output_byte_limit: int | None = None, **kwargs: Any, - ) -> CreateTerminalResponse | TerminalHandle: ... + ) -> CreateTerminalResponse: ... @param_model(TerminalOutputRequest) async def terminal_output(self, session_id: str, terminal_id: str, **kwargs: Any) -> TerminalOutputResponse: ... diff --git a/src/acp/terminal.py b/src/acp/terminal.py deleted file mode 100644 index fdc4777..0000000 --- a/src/acp/terminal.py +++ /dev/null @@ -1,69 +0,0 @@ -from __future__ import annotations - -from contextlib import suppress - -from .connection import Connection -from .meta import CLIENT_METHODS -from .schema import ( - KillTerminalCommandResponse, - ReleaseTerminalResponse, - TerminalOutputResponse, - WaitForTerminalExitResponse, -) - -__all__ = ["TerminalHandle"] - - -class TerminalHandle: - def __init__(self, terminal_id: str, session_id: str, conn: Connection) -> None: - self.id = terminal_id - self._session_id = session_id - self._conn = conn - - @property - def terminal_id(self) -> str: - return self.id - - async def current_output(self) -> TerminalOutputResponse: - response = await self._conn.send_request( - CLIENT_METHODS["terminal_output"], - {"sessionId": self._session_id, "terminalId": self.id}, - ) - return TerminalOutputResponse.model_validate(response) - - async def wait_for_exit(self) -> WaitForTerminalExitResponse: - response = await self._conn.send_request( - CLIENT_METHODS["terminal_wait_for_exit"], - {"sessionId": self._session_id, "terminalId": self.id}, - ) - return WaitForTerminalExitResponse.model_validate(response) - - async def kill(self) -> KillTerminalCommandResponse: - response = await self._conn.send_request( - CLIENT_METHODS["terminal_kill"], - {"sessionId": self._session_id, "terminalId": self.id}, - ) - payload = response if isinstance(response, dict) else {} - return KillTerminalCommandResponse.model_validate(payload) - - async def release(self) -> ReleaseTerminalResponse: - response = await self._conn.send_request( - CLIENT_METHODS["terminal_release"], - {"sessionId": self._session_id, "terminalId": self.id}, - ) - payload = response if isinstance(response, dict) else {} - return ReleaseTerminalResponse.model_validate(payload) - - async def aclose(self) -> None: - """Release the terminal, ignoring errors that occur during shutdown.""" - with suppress(Exception): - await self.release() - - async def close(self) -> None: - await self.aclose() - - async def __aenter__(self) -> TerminalHandle: - return self - - async def __aexit__(self, exc_type, exc, tb) -> None: - await self.aclose()