diff --git a/src/predict_rlm/backends/base.py b/src/predict_rlm/backends/base.py index 8c95be0e..00d52d77 100644 --- a/src/predict_rlm/backends/base.py +++ b/src/predict_rlm/backends/base.py @@ -78,6 +78,21 @@ def async_tool_callback(self): finally: _TOOL_CALLBACK_GATES.reset(token) + def is_running(self) -> bool: + """Whether a top-level execution currently holds the gate.""" + with self._condition: + return self._running + + def wait_until_idle(self, timeout: float | None = None) -> bool: + """Block until no top-level execution holds the gate. + + Returns ``True`` once the gate is idle, or ``False`` if ``timeout`` + elapsed first. Callers must be on a different thread than the one + holding the gate (the executing worker), otherwise this deadlocks. + """ + with self._condition: + return self._condition.wait_for(lambda: not self._running, timeout) + def _acquire(self) -> None: with self._condition: while self._running: diff --git a/src/predict_rlm/backends/sbx/backend.py b/src/predict_rlm/backends/sbx/backend.py index ee58f8c5..2945abce 100644 --- a/src/predict_rlm/backends/sbx/backend.py +++ b/src/predict_rlm/backends/sbx/backend.py @@ -154,11 +154,16 @@ def __init__( self.on_runtime_hook_event = on_runtime_hook_event self._host_workspace = Path.cwd() self._owns_staging_root = _staging_root is None - self._staging_root = ( - Path(_staging_root) - if _staging_root - else (self._host_workspace / ".predict_rlm_sbx" / uuid.uuid4().hex) - ) + if _staging_root is not None: + self._staging_root = Path(_staging_root) + elif self.config.reuse and self.config.name: + self._staging_root = ( + self._host_workspace / ".predict_rlm_sbx" / self.config.name + ) + else: + self._staging_root = ( + self._host_workspace / ".predict_rlm_sbx" / uuid.uuid4().hex + ) self._staging_root.mkdir(parents=True, exist_ok=True) if self._owns_staging_root and not self.config.persist: _owned_staging_roots_pending_cleanup.add(str(self._staging_root)) @@ -168,10 +173,12 @@ def __init__( self._ws: ClientConnection | None = None self._pending_tool_calls: dict[concurrent.futures.Future[dict[str, Any]], int] = {} self._active_execute_timeout_deadline: float | None = None + self.cancellation_interrupt_timeout: float = 10.0 self._execution_gate = BackendExecutionGate("SBX backend") self._sandbox_name: str | None = None self._prepared_supervisor_path: Path | None = None self._published_websocket_url: str | None = None + self._active_websocket_port: int | None = None self._shutdown = False self._post_execute_hooks: list[Callable[[Any], Any]] = [] self._owned_direct_aliases: list[Path] = [] @@ -256,7 +263,76 @@ async def aexecute( *, timeout: float | None = None, ) -> Any: - return await asyncio.to_thread(self.execute, code, variables, timeout=timeout) + try: + return await asyncio.to_thread(self.execute, code, variables, timeout=timeout) + except asyncio.CancelledError: + await self._abort_execution_after_cancellation() + raise + + async def _abort_execution_after_cancellation(self) -> None: + try: + await asyncio.to_thread(self._interrupt_after_cancellation) + except asyncio.CancelledError: + raise + except Exception: + # Preserve the original cancellation even if cleanup has to be best-effort. + pass + + def _interrupt_after_cancellation(self) -> None: + if not self._uses_websocket_transport(): + return + try: + self.interrupt(timeout=self.cancellation_interrupt_timeout) + except Exception: + self._hard_abort_websocket_after_failed_interrupt() + + def interrupt(self, *, timeout: float | None = 10.0) -> bool: + """Abort the currently-running cell while keeping the warm sandbox. + + Sends an out-of-band ``interrupt`` frame over the websocket. ``ws.send`` + is thread-safe and we never call ``ws.recv`` here: the execute loop + blocked in ``recv`` (possibly on another thread) delivers the resulting + interrupted result through its normal path. Returns whether a cell was + running when the interrupt was issued. + """ + if not self._uses_websocket_transport(): + return False + was_running = self._execution_gate.is_running() + ws = self._ws + if ws is None: + return False + payload = { + "jsonrpc": "2.0", + "method": "interrupt", + "params": {}, + "id": self._next_request_id(), + } + try: + ws.send(self._serialize_supervisor_message(payload), text=True) + except TypeError: + ws.send(self._serialize_supervisor_message(payload)) + except Exception as exc: + raise SandboxFatalError( + f"Failed to send interrupt to Sbx WebSocket supervisor: {exc}" + ) from exc + if was_running and not self._execution_gate.wait_until_idle(timeout): + raise SandboxFatalError( + "Interrupt frame sent but the running cell did not release the " + f"execution gate within {timeout}s." + ) + return was_running + + async def ainterrupt(self, *, timeout: float | None = 10.0) -> bool: + return await asyncio.to_thread(self.interrupt, timeout=timeout) + + def _hard_abort_websocket_after_failed_interrupt(self) -> None: + """Tear down the websocket + supervisor when a graceful interrupt fails. + + This sacrifices the warm sandbox instead of reusing a connection with an + execute still draining in another thread. + """ + with contextlib.suppress(Exception): + self._discard_supervisor_process() def mount_file_at(self, host_path: str, virtual_path: str) -> None: source = Path(host_path) @@ -408,7 +484,8 @@ def _relocate_owned_staging_root_if_nested_in_direct_workspace(self) -> None: continue old_staging_root = self._staging_root _owned_staging_roots_pending_cleanup.discard(str(old_staging_root)) - self._staging_root = Path(tempfile.mkdtemp(prefix="predict-rlm-sbx-")) + self._staging_root = self._relocated_staging_root() + self._staging_root.mkdir(parents=True, exist_ok=True) if not self.config.persist: _owned_staging_roots_pending_cleanup.add(str(self._staging_root)) shutil.rmtree(old_staging_root, ignore_errors=True) @@ -418,6 +495,11 @@ def _relocate_owned_staging_root_if_nested_in_direct_workspace(self) -> None: pass return + def _relocated_staging_root(self) -> Path: + if self.config.reuse and self.config.name: + return Path(tempfile.gettempdir()) / f"predict-rlm-sbx-{self.config.name}" + return Path(tempfile.mkdtemp(prefix="predict-rlm-sbx-")) + def _same_direct_workspace_mounts(self, mounts: list[DirectWorkspaceMount]) -> bool: return self._direct_workspace_mount_keys(mounts) == self._direct_workspace_mount_keys( self._direct_workspace_mounts @@ -542,6 +624,19 @@ def shutdown(self) -> None: text=True, ) self._log_lifecycle("sbx.shutdown.rm") + elif ( + self._supervisor_command is None + and self._sandbox_name + and self.config.reuse + and self.config.stop_on_shutdown + ): + subprocess.run( + ["sbx", "stop", self._sandbox_name], + check=False, + capture_output=True, + text=True, + ) + self._log_lifecycle("sbx.shutdown.stop") self._cleanup_direct_workspace_aliases_host_side() self._cleanup_staging_root() self._log_lifecycle("sbx.shutdown.complete") @@ -556,6 +651,31 @@ def _cleanup_staging_root(self) -> None: except OSError: pass + def destroy(self) -> None: + """Force-remove the sandbox and delete its staging root.""" + self._log_lifecycle("sbx.destroy.start") + if not self._shutdown: + with contextlib.suppress(Exception): + self.shutdown() + if self._sandbox_name: + self.remove(self._sandbox_name) + _owned_staging_roots_pending_cleanup.discard(str(self._staging_root)) + shutil.rmtree(self._staging_root, ignore_errors=True) + with contextlib.suppress(OSError): + self._staging_root.parent.rmdir() + self._sandbox_name = None + self._log_lifecycle("sbx.destroy.complete") + + @classmethod + def remove(cls, name: str) -> None: + """Force-remove a persisted sandbox by name (no staging-root cleanup).""" + subprocess.run( + ["sbx", "rm", "--force", name], + check=False, + capture_output=True, + text=True, + ) + def _ensure_process(self) -> None: if self._uses_websocket_transport(): self._ensure_websocket_supervisor() @@ -586,6 +706,7 @@ def _ensure_process(self) -> None: text=True, env=env, bufsize=1, + start_new_session=True, ) self._start_stdout_reader() if self.output_fields: @@ -661,6 +782,7 @@ def _ensure_websocket_supervisor(self) -> None: error_type=type(exc).__name__, duration_ms=round((time.perf_counter() - start) * 1000), ) + self._teardown_failed_websocket_supervisor() raise self._log_lifecycle( "sbx.runner.started", @@ -688,11 +810,14 @@ def _start_local_websocket_supervisor(self) -> None: text=True, env=env, bufsize=1, + start_new_session=True, ) def _start_sbx_websocket_supervisor(self) -> None: supervisor_path = self._start_sbx_and_prepare_supervisor() assert self._sandbox_name is not None + websocket_port = self._resolve_websocket_port() + self._active_websocket_port = websocket_port runner_root = self._staging_root runner_root.mkdir(parents=True, exist_ok=True) command = [ @@ -709,7 +834,7 @@ def _start_sbx_websocket_supervisor(self) -> None: "--websocket-host", "0.0.0.0", "--websocket-port", - str(self.config.websocket_port), + str(websocket_port), "--websocket-path", self._websocket_path, "--websocket-max-message-bytes", @@ -727,13 +852,29 @@ def _start_sbx_websocket_supervisor(self) -> None: stderr=subprocess.PIPE, text=True, bufsize=1, + start_new_session=True, ) - self._websocket_url = self._publish_websocket_port() + self._websocket_url = self._publish_websocket_port(websocket_port) + + def _resolve_websocket_port(self) -> int: + if self.config.websocket_port: + return self.config.websocket_port + return self._choose_dynamic_websocket_port() + + def _choose_dynamic_websocket_port(self) -> int: + return 20_000 + secrets.randbelow(40_000) def _connect_websocket_supervisor(self, url: str) -> None: deadline = time.monotonic() + self.config.websocket_startup_timeout last_error: BaseException | None = None while True: + if self._proc is not None and self._proc.poll() is not None: + stderr = self._read_stderr_for_process(self._proc) + diagnostic = stderr.strip() or str(last_error or "process exited") + raise SandboxFatalError( + "Sbx WebSocket supervisor exited before accepting connections at " + f"{url}: {diagnostic}" + ) try: self._ws = websocket_connect( url, @@ -753,15 +894,18 @@ def _connect_websocket_supervisor(self, url: str) -> None: ) from last_error time.sleep(0.1) - def _publish_websocket_port(self) -> str: + def _publish_websocket_port(self, port: int | None = None) -> str: assert self._sandbox_name is not None + port = port or self._active_websocket_port or self.config.websocket_port + if not port: + raise SandboxFatalError("Cannot publish sbx WebSocket supervisor without a port") result = subprocess.run( [ "sbx", "ports", self._sandbox_name, "--publish", - str(self.config.websocket_port), + str(port), ], check=False, capture_output=True, @@ -771,7 +915,7 @@ def _publish_websocket_port(self) -> str: if result.returncode != 0: raise SandboxFatalError( "Failed to publish sbx WebSocket supervisor port " - f"{self.config.websocket_port}: exit code {result.returncode}; " + f"{port}: exit code {result.returncode}; " f"stdout: {result.stdout.strip()}; stderr: {result.stderr.strip()}" ) endpoint = self._parse_published_websocket_endpoint(result.stdout) @@ -818,71 +962,158 @@ def _start_sbx_and_prepare_supervisor(self) -> Path: "Install it with `brew install docker/tap/sbx` and run `sbx login`." ) - if self._sandbox_name is None: - supervisor_path = self._prepare_supervisor_script() + if self._sandbox_name is not None: + return self._prepared_supervisor_path or self._prepare_supervisor_script() - primary_workspace = str(self._staging_root) - if self.config.workspace_read_only: - primary_workspace = f"{primary_workspace}:ro" - direct_workspaces = self._direct_workspace_args() - sandbox_name = self.config.name or f"predict-rlm-{uuid.uuid4().hex[:12]}" - create_cmd = [ - "sbx", - "create", - "shell", - primary_workspace, - *self.config.extra_workspaces, - *direct_workspaces, - "--name", - sandbox_name, - ] - for flag, value in ( - ("--cpus", self.config.cpus), - ("--memory", self.config.memory), - ("--template", self.config.template), - ("--kit", self.config.kit), - ("--branch", self.config.branch), - ): - if value is not None: - create_cmd.extend([flag, str(value)]) - create_start = time.perf_counter() + supervisor_path = self._prepare_supervisor_script() + + if self.config.reuse and self._try_reattach_named_sandbox(): + return supervisor_path + + self._create_and_bootstrap_sandbox() + return supervisor_path + + def _create_and_bootstrap_sandbox(self) -> None: + primary_workspace = str(self._staging_root) + if self.config.workspace_read_only: + primary_workspace = f"{primary_workspace}:ro" + direct_workspaces = self._direct_workspace_args() + sandbox_name = self.config.name or f"predict-rlm-{uuid.uuid4().hex[:12]}" + create_cmd = [ + "sbx", + "create", + "shell", + primary_workspace, + *self.config.extra_workspaces, + *direct_workspaces, + "--name", + sandbox_name, + ] + for flag, value in ( + ("--cpus", self.config.cpus), + ("--memory", self.config.memory), + ("--template", self.config.template), + ("--kit", self.config.kit), + ("--branch", self.config.branch), + ): + if value is not None: + create_cmd.extend([flag, str(value)]) + create_start = time.perf_counter() + self._log_lifecycle( + "sbx.create.start", + create_timeout=self.config.create_timeout, + workspace_read_only=self.config.workspace_read_only, + extra_workspaces=len(self.config.extra_workspaces), + ) + try: + created = subprocess.run( + create_cmd, + check=True, + capture_output=True, + text=True, + timeout=self.config.create_timeout, + ) + except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: self._log_lifecycle( - "sbx.create.start", - create_timeout=self.config.create_timeout, - workspace_read_only=self.config.workspace_read_only, - extra_workspaces=len(self.config.extra_workspaces), + "sbx.create.error", + duration_ms=ms_since(create_start), + error_type=type(exc).__name__, + status="error", ) - try: - created = subprocess.run( - create_cmd, - check=True, - capture_output=True, - text=True, - timeout=self.config.create_timeout, - ) - except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as exc: + raise SandboxFatalError(f"Failed to create sbx sandbox: {exc}") from exc + + self._sandbox_name = sandbox_name + self._log_lifecycle( + "sbx.create.ok", + duration_ms=ms_since(create_start), + stdout_chars=len(created.stdout or ""), + stderr_chars=len(created.stderr or ""), + ) + self._apply_network_policy() + self._bootstrap_packages() + self._setup_direct_workspace_aliases_in_sandbox() + + def _probe_sandbox_state(self, name: str) -> str: + """Resolve a named sandbox to ``running`` / ``stopped`` / ``missing``.""" + result = subprocess.run( + ["sbx", "ls"], + check=False, + capture_output=True, + text=True, + timeout=self.config.exec_timeout, + ) + if result.returncode != 0: + return "missing" + for line in result.stdout.splitlines(): + fields = line.split() + if not fields or fields[0] != name: + continue + rest = " ".join(fields[1:]).lower() + if "stop" in rest or "exit" in rest: + return "stopped" + return "running" + return "missing" + + def _sbx_sandbox_healthy(self, name: str) -> bool: + """Cheap liveness probe: a trivial in-container command must succeed.""" + result = subprocess.run( + ["sbx", "exec", name, "true"], + check=False, + capture_output=True, + text=True, + timeout=self.config.exec_timeout, + ) + return result.returncode == 0 + + def _try_reattach_named_sandbox(self) -> bool: + """Return True when an existing named sandbox is ready to reuse.""" + name = self.config.name + assert name is not None + self._log_lifecycle("sbx.reattach.start", sandbox_name=name) + state = self._probe_sandbox_state(name) + + if state == "missing": + self._log_lifecycle("sbx.reattach.miss", sandbox_name=name) + return False + + if state == "stopped": + start_result = subprocess.run( + ["sbx", "start", name], + check=False, + capture_output=True, + text=True, + timeout=self.config.create_timeout, + ) + if start_result.returncode != 0: self._log_lifecycle( - "sbx.create.error", - duration_ms=ms_since(create_start), - error_type=type(exc).__name__, - status="error", + "sbx.reattach.unhealthy.recreate", + sandbox_name=name, + reason="start_failed", ) - raise SandboxFatalError(f"Failed to create sbx sandbox: {exc}") from exc + self._force_remove_sandbox(name) + return False - self._sandbox_name = sandbox_name + if not self._sbx_sandbox_healthy(name): self._log_lifecycle( - "sbx.create.ok", - duration_ms=ms_since(create_start), - stdout_chars=len(created.stdout or ""), - stderr_chars=len(created.stderr or ""), + "sbx.reattach.unhealthy.recreate", + sandbox_name=name, + reason="health_check_failed", ) - self._apply_network_policy() - self._bootstrap_packages() - self._setup_direct_workspace_aliases_in_sandbox() - else: - supervisor_path = self._prepared_supervisor_path or self._prepare_supervisor_script() + self._force_remove_sandbox(name) + return False - return supervisor_path + self._sandbox_name = name + self._setup_direct_workspace_aliases_in_sandbox() + self._log_lifecycle("sbx.reattach.ok", sandbox_name=name) + return True + + def _force_remove_sandbox(self, name: str) -> None: + subprocess.run( + ["sbx", "rm", "--force", name], + check=False, + capture_output=True, + text=True, + ) def _direct_workspace_args(self) -> list[str]: seen = {str(self._staging_root)} @@ -1766,6 +1997,23 @@ def _unwrap_execute_response(self, response: dict) -> Any: def _ensure_process_for_request(self, method: str) -> None: self._ensure_process_for_method(method) + def _teardown_failed_websocket_supervisor(self) -> None: + # A connect/handshake failure otherwise leaves a half-started supervisor + # alive with _websocket_url still set, which short-circuits relaunch so + # the next prewarm/execute reconnects to the dead endpoint. Kill it and + # reset transport state so the next attempt rebuilds from scratch. + if self._proc is not None and self._proc.poll() is None: + with contextlib.suppress(Exception): + self._proc.kill() + self._proc.wait(timeout=5) + self._discard_supervisor_process() + if self._websocket_supervisor_command is None: + # The sbx path's URL came from `sbx ports --publish`; drop it so the + # retry republishes instead of reusing the dead forward. The local + # runner's externally supplied URL stays put. + self._websocket_url = None + self._published_websocket_url = None + def _discard_supervisor_process(self) -> None: if self._ws is not None: with contextlib.suppress(Exception): diff --git a/src/predict_rlm/backends/sbx/config.py b/src/predict_rlm/backends/sbx/config.py index 2cced500..e1e3ac59 100644 --- a/src/predict_rlm/backends/sbx/config.py +++ b/src/predict_rlm/backends/sbx/config.py @@ -2,7 +2,7 @@ from __future__ import annotations -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator DEFAULT_SBX_TEMPLATE = "docker.io/docker/sandbox-templates:shell" @@ -18,10 +18,21 @@ class SbxConfig(BaseModel): branch: str | None = None persist: bool = False remove_on_shutdown: bool = True + reuse: bool = False + stop_on_shutdown: bool = False extra_workspaces: list[str] = Field(default_factory=list) workspace_read_only: bool = False create_timeout: float = 120.0 exec_timeout: float = 300.0 - websocket_port: int = 8765 + websocket_port: int = 0 websocket_startup_timeout: float = 30.0 websocket_max_message_bytes: int = 32 * 1024 * 1024 + + @model_validator(mode="after") + def _apply_reuse_semantics(self) -> "SbxConfig": + if self.reuse: + if not self.name: + raise ValueError("SbxConfig.reuse=True requires a non-empty `name`.") + self.persist = True + self.remove_on_shutdown = False + return self diff --git a/src/predict_rlm/backends/supervisor/_payload.py b/src/predict_rlm/backends/supervisor/_payload.py index a7cac39b..274395f3 100644 --- a/src/predict_rlm/backends/supervisor/_payload.py +++ b/src/predict_rlm/backends/supervisor/_payload.py @@ -55,6 +55,10 @@ _KERNEL_PROCESS: multiprocessing.Process | None = None _KERNEL_REQUEST_QUEUE: multiprocessing.Queue | None = None _KERNEL_RESULT_QUEUE: multiprocessing.Queue | None = None +# Interrupt requests are handled out of band from the serial request queue so +# the websocket receiver can abort a cell while the run loop waits on execute. +_INTERRUPT_REQUESTED = False +_EXECUTION_ACTIVE = False _DEFAULT_TIMEOUT_INTERRUPT_GRACE_SECONDS = 0.5 _INTERNAL_GLOBAL_NAMES = { "SUBMIT", @@ -817,6 +821,26 @@ def _pickle_snapshot_state( } +def _build_interrupt_result( + timeout_seconds: float | None, + oob_interrupted: bool, + stdout: str, + stderr: str, + snapshot: dict[str, Any], + reason: str, +) -> dict[str, Any]: + """Build the recoverable timeout-style result for an interrupted cell.""" + timeout_info: dict[str, Any] = {"seconds": timeout_seconds} + if oob_interrupted: + timeout_info["interrupted"] = True + return { + "timeout": timeout_info, + "stdout": stdout, + "stderr": stderr, + "state": _pickle_snapshot_state(snapshot, reason), + } + + def _is_user_global(name: str, globals_dict: dict[str, Any]) -> bool: if name.startswith("__") and name.endswith("__"): return False @@ -1334,6 +1358,27 @@ def _terminate_runner(process: multiprocessing.Process) -> None: process.join(timeout=0.5) +def _request_interrupt() -> bool: + """Latch an interrupt request and return whether a cell is running.""" + global _INTERRUPT_REQUESTED + _INTERRUPT_REQUESTED = True + return _EXECUTION_ACTIVE + + +def _consume_interrupt_request() -> bool: + """Atomically read-and-clear the latched interrupt flag.""" + global _INTERRUPT_REQUESTED + requested = _INTERRUPT_REQUESTED + _INTERRUPT_REQUESTED = False + return requested + + +async def _handle_interrupt_request(request: dict[str, Any]) -> dict[str, Any]: + """Handle an ``interrupt`` JSON-RPC request outside the serial queue.""" + running = _request_interrupt() + return _response(request.get("id"), {"running": running}) + + def _discard_kernel() -> None: global _KERNEL_PROCESS, _KERNEL_REQUEST_QUEUE, _KERNEL_RESULT_QUEUE process = _KERNEL_PROCESS @@ -1430,9 +1475,11 @@ async def _kernel_pickle_snapshot( async def _register_runtime_hooks_in_runner( - params: dict[str, Any], globals_dict: dict[str, Any] + params: dict[str, Any], + globals_dict: dict[str, Any], + host_tool_bridge: _HostToolBridge | None = None, ) -> dict[str, Any]: - process = _ensure_kernel(globals_dict) + process = _ensure_kernel(globals_dict, host_tool_bridge) assert _KERNEL_REQUEST_QUEUE is not None assert _KERNEL_RESULT_QUEUE is not None _KERNEL_REQUEST_QUEUE.put({"op": "register_runtime_hooks", "params": params}) @@ -1458,12 +1505,13 @@ async def _execute_code_in_runner_with_timeout( defer_final_output: bool = False, host_tool_bridge: _HostToolBridge | None = None, ) -> dict[str, Any]: + global _EXECUTION_ACTIVE process = _ensure_kernel(globals_dict, host_tool_bridge) assert _KERNEL_REQUEST_QUEUE is not None assert _KERNEL_RESULT_QUEUE is not None - pre_timeout_snapshot: dict[str, Any] | None = None - if timeout_seconds is not None: - pre_timeout_snapshot = await _kernel_pickle_snapshot(process) + pre_timeout_snapshot: dict[str, Any] | None = await _kernel_pickle_snapshot(process) + _consume_interrupt_request() + _EXECUTION_ACTIVE = True stdout_path = _capture_file_path("stdout") stderr_path = _capture_file_path("stderr") _KERNEL_REQUEST_QUEUE.put({ @@ -1478,6 +1526,7 @@ async def _execute_code_in_runner_with_timeout( ) interrupt_deadline: float | None = None interrupt_sent = False + oob_interrupted = False runner_message: dict[str, Any] | None = None try: while True: @@ -1492,6 +1541,24 @@ async def _execute_code_in_runner_with_timeout( if not process.is_alive(): process.join(timeout=0.5) break + if not interrupt_sent and _consume_interrupt_request(): + interrupt_sent = True + oob_interrupted = True + interrupt_deadline = now + timeout_interrupt_grace_seconds + interrupted = _signal_runner(process, signal.SIGINT) + _debug_event( + "sbx.python_runner.execute.interrupt", + code_hash=_code_hash(code), + code_len=len(code), + reason="out_of_band_interrupt", + interrupt_sent=interrupted, + interrupt_grace_seconds=timeout_interrupt_grace_seconds, + child_pid=process.pid, + child_exitcode=process.exitcode, + ) + if interrupted and timeout_interrupt_grace_seconds > 0: + await asyncio.sleep(0.01) + continue if deadline is not None and not interrupt_sent and now >= deadline: interrupt_sent = True interrupt_deadline = now + timeout_interrupt_grace_seconds @@ -1537,12 +1604,9 @@ async def _execute_code_in_runner_with_timeout( child_pid=process.pid, child_exitcode=process.exitcode, ) - return { - "timeout": {"seconds": timeout_seconds}, - "stdout": stdout, - "stderr": stderr, - "state": _pickle_snapshot_state(snapshot, reason), - } + return _build_interrupt_result( + timeout_seconds, oob_interrupted, stdout, stderr, snapshot, reason + ) await asyncio.sleep(0.01) if host_tool_bridge is not None: await host_tool_bridge.drain_requests() @@ -1578,12 +1642,9 @@ async def _execute_code_in_runner_with_timeout( child_pid=process.pid, child_exitcode=exitcode, ) - return { - "timeout": {"seconds": timeout_seconds}, - "stdout": stdout, - "stderr": stderr, - "state": _pickle_snapshot_state(snapshot, reason), - } + return _build_interrupt_result( + timeout_seconds, oob_interrupted, stdout, stderr, snapshot, reason + ) _debug_event( "sbx.python_runner.execute", code_hash=_code_hash(code), @@ -1601,6 +1662,35 @@ async def _execute_code_in_runner_with_timeout( raise RuntimeError( f"execution runner exited without a result (exitcode={exitcode})" ) + if interrupt_sent and not runner_message.get("ok"): + stdout = _read_capture_file(stdout_path) + stderr = _read_capture_file(stderr_path) + reason = "kernel interrupted by SIGINT" + snapshot = pre_timeout_snapshot or { + "globals": {}, + "restored_globals": [], + "lost_globals": [], + } + _reset_globals_from_pickle_snapshot(globals_dict, snapshot) + _debug_event( + "sbx.python_runner.execute", + code_hash=_code_hash(code), + code_len=len(code), + timeout=True, + interrupted=oob_interrupted, + state_preserved=False, + state_source="pickle_snapshot", + state_loss_reason=reason, + restored_globals=snapshot.get("restored_globals", []), + lost_globals=snapshot.get("lost_globals", []), + stdout_len=len(stdout), + stderr_len=len(stderr), + child_pid=process.pid, + child_exitcode=process.exitcode, + ) + return _build_interrupt_result( + timeout_seconds, oob_interrupted, stdout, stderr, snapshot, reason + ) if not runner_message.get("ok"): _debug_event( "sbx.python_runner.execute", @@ -1645,6 +1735,7 @@ async def _execute_code_in_runner_with_timeout( ) return result finally: + _EXECUTION_ACTIVE = False _unlink_capture_files(stdout_path, stderr_path) @@ -1728,7 +1819,10 @@ async def _handle_request( return _response(request_id, _register_tools(params, globals_dict)) if method == "register_runtime_hooks": return _response( - request_id, await _register_runtime_hooks_in_runner(params, globals_dict) + request_id, + await _register_runtime_hooks_in_runner( + params, globals_dict, host_tool_bridge + ), ) if method == "mount_file": return _response(request_id, _mount_file(params)) @@ -1738,6 +1832,8 @@ async def _handle_request( return _response(request_id, _list_dir(params)) if method == "sync_file": return _response(request_id, _sync_file(params)) + if method == "interrupt": + return await _handle_interrupt_request(request) if method == "shutdown": return _response(request_id, {"shutdown": True}) raise ValueError(f"Unknown method: {method}") @@ -1830,7 +1926,10 @@ async def _receive_messages(self) -> None: continue if not isinstance(message, dict): continue - if message.get("method"): + if message.get("method") == "interrupt": + response = await _handle_interrupt_request(message) + await self.connection.send(json.dumps(response, default=str)) + elif message.get("method"): await self.requests.put(message) elif "id" in message: self.host_tool_bridge.deliver_response(message) diff --git a/tests/test_sbx_interpreter.py b/tests/test_sbx_interpreter.py index 2496578d..54095f95 100644 --- a/tests/test_sbx_interpreter.py +++ b/tests/test_sbx_interpreter.py @@ -7,6 +7,7 @@ import json import logging import os +import secrets import select import shutil import socket @@ -22,6 +23,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pydantic import ValidationError pytest.importorskip("websockets") # SBX/supervisor backend requires the [sbx] extra @@ -40,6 +42,7 @@ _pickleable_globals_snapshot, ) from predict_rlm.files import SyncedFile # noqa: E402 +from predict_rlm.workspace import DirectWorkspaceMount # noqa: E402 PAYLOAD_PATH = Path(__file__).parents[1] / "src" / "predict_rlm" / "backends" / "supervisor" / "_payload.py" @@ -1692,9 +1695,13 @@ def make_interpreter( tools: dict | None = None, path: str | None = None, url_path: str | None = None, + port: int | None = None, + name: str = "local-websocket-test", + reuse: bool = False, + staging_root: Path | None = None, startup_timeout: float = 3, ) -> SbxBackend: - port = _free_local_port() + port = port or _free_local_port() websocket_path = path or f"/predict-rlm-test-{os.getpid()}-{time.time_ns()}" command = [ sys.executable, @@ -1711,7 +1718,8 @@ def make_interpreter( ] return SbxBackend( config=SbxConfig( - name="local-websocket-test", + name=name, + reuse=reuse, exec_timeout=5, websocket_startup_timeout=startup_timeout, websocket_max_message_bytes=32 * 1024 * 1024, @@ -1720,7 +1728,7 @@ def make_interpreter( preinstall_packages=False, _websocket_supervisor_command=command, _websocket_url=f"ws://127.0.0.1:{port}{url_path or websocket_path}", - _staging_root=tmp_path / "ws-staging", + _staging_root=staging_root or tmp_path / "ws-staging", ) def test_websocket_execute_and_state_persistence(self, tmp_path: Path): @@ -1819,6 +1827,63 @@ def predict(signature: str, **kwargs) -> dict: assert output == "4\n" assert seen_lengths == [950000] + def test_reusable_named_websocket_supervisors_run_concurrently( + self, tmp_path: Path + ): + staging_root = tmp_path / "shared-staging" + first = self.make_interpreter( + tmp_path, + name="shared-websocket-test", + reuse=True, + staging_root=staging_root, + path="/predict-rlm-first", + ) + second = self.make_interpreter( + tmp_path, + name="shared-websocket-test", + reuse=True, + staging_root=staging_root, + path="/predict-rlm-second", + ) + barrier = threading.Barrier(3) + errors: list[BaseException] = [] + outputs: dict[str, str] = {} + + def execute(interpreter: SbxBackend, label: str) -> None: + try: + barrier.wait(timeout=2) + outputs[label] = interpreter.execute( + f"owner = {label!r}\n" + "import time\n" + "time.sleep(0.2)\n" + "print(owner)" + ) + except BaseException as exc: + errors.append(exc) + + try: + first.prewarm() + second.prewarm() + + threads = [ + threading.Thread(target=execute, args=(first, "first")), + threading.Thread(target=execute, args=(second, "second")), + ] + for thread in threads: + thread.start() + barrier.wait(timeout=2) + for thread in threads: + thread.join(timeout=3) + + assert all(not thread.is_alive() for thread in threads) + assert errors == [] + assert outputs == {"first": "first\n", "second": "second\n"} + assert first.execute("print(owner)") == "first\n" + assert second.execute("print(owner)") == "second\n" + finally: + first.shutdown() + second.shutdown() + def test_predict_forwards_nested_pydantic_schemas_for_custom_types(self, tmp_path: Path): """predict() with a custom output type that nests sibling models. @@ -2105,6 +2170,166 @@ def test_websocket_auth_path_failure_is_reported(self, tmp_path: Path): interpreter.shutdown() +class TestSbxBackendInterrupt(TestSbxBackendLocalWebSocketRunner): + """On-demand interrupt and cancellation-safe async execution.""" + + @pytest.mark.local + def test_interrupt_unblocks_long_running_cell(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + try: + running_flag: dict[str, bool] = {} + + def fire_interrupt(): + time.sleep(1.0) + running_flag["was_running"] = interpreter.interrupt(timeout=10.0) + + thread = threading.Thread(target=fire_interrupt) + thread.start() + start = time.monotonic() + result = interpreter.execute("import time\ntime.sleep(120)\nprint('done')") + elapsed = time.monotonic() - start + thread.join(timeout=5) + + assert elapsed < 30, f"interrupt did not unblock promptly: {elapsed:.1f}s" + assert running_flag.get("was_running") is True + assert "done" not in str(result) + + assert interpreter.execute("print('alive')") == "alive\n" + finally: + interpreter.shutdown() + + @pytest.mark.local + def test_interrupt_preserves_warm_state(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + try: + assert interpreter.execute("kept = 99\nprint(kept)") == "99\n" + + def fire_interrupt(): + time.sleep(1.0) + interpreter.interrupt(timeout=10.0) + + thread = threading.Thread(target=fire_interrupt) + thread.start() + interpreter.execute("import time\ntime.sleep(120)\nprint('done')") + thread.join(timeout=5) + + assert interpreter.execute("print(kept)") == "99\n" + finally: + interpreter.shutdown() + + @pytest.mark.local + def test_interrupt_returns_only_after_cell_releases_gate(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + gate = interpreter._execution_gate + try: + interpreter.execute("warm = 1") + + def run_cell() -> None: + interpreter.execute("import time\ntime.sleep(120)\nprint('done')") + + worker = threading.Thread(target=run_cell) + worker.start() + while not gate.is_running(): + time.sleep(0.01) + time.sleep(0.5) + + was_running = interpreter.interrupt(timeout=10.0) + + assert was_running is True + assert ( + gate.is_running() is False + ), "interrupt returned before the interrupted cell released the gate" + + worker.join(timeout=5) + assert not worker.is_alive() + assert interpreter.execute("print(warm)") == "1\n" + finally: + interpreter.shutdown() + + @pytest.mark.local + def test_interrupt_returns_false_when_idle(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + try: + interpreter.execute("print('warm')") + assert interpreter.interrupt(timeout=5.0) is False + finally: + interpreter.shutdown() + + @pytest.mark.local + def test_aexecute_cancellation_is_prompt_and_keeps_sandbox_warm( + self, tmp_path: Path + ): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + + async def scenario() -> float: + interpreter.execute("seed = 5") + task = asyncio.ensure_future( + interpreter.aexecute("import time\ntime.sleep(120)\nprint('done')") + ) + await asyncio.sleep(1.0) + start = time.monotonic() + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + return time.monotonic() - start + + try: + elapsed = asyncio.run(scenario()) + assert elapsed < 30, f"cancellation did not unwind promptly: {elapsed:.1f}s" + assert interpreter.execute("print(seed)") == "5\n" + finally: + interpreter.shutdown() + + +class TestSupervisorPayloadInterruptMethod: + @pytest.mark.local + def test_interrupt_method_acks_running_false_when_idle(self, tmp_path: Path): + import predict_rlm.backends.supervisor._payload as payload + + payload._consume_interrupt_request() + result = asyncio.run( + payload._handle_interrupt_request({"id": 1, "method": "interrupt"}) + ) + assert result["result"]["running"] is False + payload._consume_interrupt_request() + + +class TestSbxBackendLocalSupervisorInterrupt(TestSbxBackendLocalWebSocketRunner): + @pytest.mark.local + def test_interrupt_method_trips_interrupt_path_while_running(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path, startup_timeout=5) + try: + interpreter.execute("flag = 1") + + def fire_interrupt(): + time.sleep(1.0) + interpreter.interrupt(timeout=10.0) + + thread = threading.Thread(target=fire_interrupt) + thread.start() + start = time.monotonic() + interpreter.execute("import time\ntime.sleep(120)") + elapsed = time.monotonic() - start + thread.join(timeout=5) + assert elapsed < 30 + assert interpreter.execute("print(flag)") == "1\n" + finally: + interpreter.shutdown() + + +class TestSbxSupervisorSignalIsolation(TestSbxBackendLocalWebSocketRunner): + @pytest.mark.local + def test_supervisor_runs_in_its_own_process_group(self, tmp_path: Path): + interpreter = self.make_interpreter(tmp_path) + try: + interpreter.execute("x = 1") + proc = interpreter._proc + assert proc is not None and proc.poll() is None + assert os.getpgid(proc.pid) != os.getpgid(0) + finally: + interpreter.shutdown() + + class TestSbxCommandConstruction: def test_default_template_uses_explicit_non_docker_shell_template( self, monkeypatch, tmp_path: Path @@ -2235,7 +2460,7 @@ def fake_popen(command, **kwargs): monkeypatch.setattr( interpreter, "_publish_websocket_port", - lambda: "ws://127.0.0.1:49152/test", + lambda port=None: "ws://127.0.0.1:49152/test", ) interpreter._start_sbx_websocket_supervisor() @@ -2253,6 +2478,51 @@ def fake_popen(command, **kwargs): assert interpreter._proc is not None assert not any(cmd[:3] == ["sbx", "exec", "-d"] for cmd in run_commands) + def test_websocket_supervisor_uses_dynamic_port_by_default( + self, monkeypatch, tmp_path: Path + ): + popen_commands: list[list[str]] = [] + published_ports: list[int | None] = [] + + class FakeProcess: + stdout = None + stderr = None + stdin = None + pid = 12345 + + def poll(self): + return None + + def fake_run(command, **kwargs): + return subprocess.CompletedProcess(command, 0, stdout="created-name\n", stderr="") + + def fake_popen(command, **kwargs): + popen_commands.append(command) + return FakeProcess() + + monkeypatch.setattr(shutil, "which", lambda name: "/usr/local/bin/sbx") + monkeypatch.setattr(subprocess, "run", fake_run) + monkeypatch.setattr(subprocess, "Popen", fake_popen) + monkeypatch.setattr(secrets, "randbelow", lambda upper: 12345) + interpreter = SbxBackend( + config=SbxConfig(name="created-name"), + preinstall_packages=False, + _staging_root=tmp_path / "staging", + ) + + def fake_publish(port=None): + published_ports.append(port) + return "ws://127.0.0.1:49152/test" + + monkeypatch.setattr(interpreter, "_publish_websocket_port", fake_publish) + + interpreter._start_sbx_websocket_supervisor() + + supervisor_exec = popen_commands[0] + assert supervisor_exec[supervisor_exec.index("--websocket-port") + 1] == "32345" + assert published_ports == [32345] + assert interpreter._active_websocket_port == 32345 + def test_websocket_recovery_restarts_detached_supervisor_after_kill( self, monkeypatch, tmp_path: Path ): @@ -3214,7 +3484,7 @@ def __getitem__(self, key: str) -> str: ]) pool = SbxPool( size=1, - config=SbxConfig(name=f"predict-rlm-test-predict-timeout-{os.getpid()}"), + config=SbxConfig(name=f"predict-rlm-test-predict-timeout-{os.getpid()}", exec_timeout=12.0), preinstall_packages=False, ) rlm = PredictRLM( @@ -3245,10 +3515,450 @@ def __getitem__(self, key: str) -> str: def test_predict_rlm_recovers_after_user_exceptions_and_tools_still_work(self): pool = SbxPool( size=1, - config=SbxConfig(name=f"predict-rlm-test-user-exceptions-{os.getpid()}"), + config=SbxConfig(name=f"predict-rlm-test-user-exceptions-{os.getpid()}", exec_timeout=12.0), preinstall_packages=False, ) try: assert_predict_rlm_recovers_after_user_exceptions_and_tools_still_work(pool) finally: pool.shutdown() + + +class TestSbxBackendReattachConfig: + def test_reuse_requires_name(self): + with pytest.raises(ValidationError, match="reuse=True"): + SbxConfig(reuse=True) + + def test_reuse_implies_persist_and_no_remove(self): + config = SbxConfig(name="hot-box", reuse=True) + assert config.reuse is True + assert config.persist is True + assert config.remove_on_shutdown is False + + def test_reuse_false_is_unchanged_default(self): + config = SbxConfig() + assert config.reuse is False + assert config.persist is False + assert config.remove_on_shutdown is True + assert config.stop_on_shutdown is False + + +class TestSbxBackendReattachStagingRoot: + def test_reuse_staging_root_is_deterministic_from_name(self, tmp_path: Path): + with patch( + "predict_rlm.backends.sbx.backend.Path.cwd", return_value=tmp_path + ): + backend_a = SbxBackend(config=SbxConfig(name="hot-box", reuse=True)) + backend_b = SbxBackend(config=SbxConfig(name="hot-box", reuse=True)) + assert backend_a._staging_root == backend_b._staging_root + assert backend_a._staging_root.name == "hot-box" + + def test_reuse_staging_root_not_marked_for_cleanup(self, tmp_path: Path): + from predict_rlm.backends.sbx import backend as backend_mod + + with patch( + "predict_rlm.backends.sbx.backend.Path.cwd", return_value=tmp_path + ): + backend = SbxBackend(config=SbxConfig(name="hot-box", reuse=True)) + assert ( + str(backend._staging_root) + not in backend_mod._owned_staging_roots_pending_cleanup + ) + + def test_ephemeral_staging_root_is_unique_uuid(self, tmp_path: Path): + with patch( + "predict_rlm.backends.sbx.backend.Path.cwd", return_value=tmp_path + ): + backend_a = SbxBackend(config=SbxConfig()) + backend_b = SbxBackend(config=SbxConfig()) + assert backend_a._staging_root != backend_b._staging_root + + def test_reuse_relocated_staging_root_is_deterministic_across_sessions( + self, tmp_path: Path + ): + mounts = [DirectWorkspaceMount(host_path=str(tmp_path), sandbox_path="/work")] + + def _make() -> SbxBackend: + with patch( + "predict_rlm.backends.sbx.backend.Path.cwd", return_value=tmp_path + ): + return SbxBackend( + config=SbxConfig(name="hot-box", reuse=True), + direct_workspace_mounts=mounts, + ) + + backend_a = _make() + backend_b = _make() + try: + assert tmp_path not in backend_a._staging_root.parents + assert backend_a._staging_root == backend_b._staging_root + assert backend_a._staging_root.name == "predict-rlm-sbx-hot-box" + finally: + for backend in (backend_a, backend_b): + shutil.rmtree(backend._staging_root, ignore_errors=True) + + def test_ephemeral_relocated_staging_root_stays_unique(self, tmp_path: Path): + mounts = [DirectWorkspaceMount(host_path=str(tmp_path), sandbox_path="/work")] + with patch( + "predict_rlm.backends.sbx.backend.Path.cwd", return_value=tmp_path + ): + backend_a = SbxBackend(config=SbxConfig(), direct_workspace_mounts=mounts) + backend_b = SbxBackend(config=SbxConfig(), direct_workspace_mounts=mounts) + try: + assert tmp_path not in backend_a._staging_root.parents + assert backend_a._staging_root != backend_b._staging_root + finally: + for backend in (backend_a, backend_b): + shutil.rmtree(backend._staging_root, ignore_errors=True) + + +def _reattach_backend(tmp_path: Path, *, name: str = "hot-box") -> SbxBackend: + return SbxBackend( + config=SbxConfig(name=name, reuse=True), + preinstall_packages=False, + _staging_root=tmp_path / "staging", + ) + + +class TestSbxBackendReattachDetection: + def _patches(self, backend: SbxBackend, *, ls_output: str): + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + if cmd[:2] == ["sbx", "ls"]: + return SimpleNamespace(returncode=0, stdout=ls_output, stderr="") + return SimpleNamespace(returncode=0, stdout="", stderr="") + + cm = [ + patch( + "predict_rlm.backends.sbx.backend.shutil.which", + return_value="/usr/bin/sbx", + ), + patch( + "predict_rlm.backends.sbx.backend.subprocess.run", + side_effect=fake_run, + ), + patch.object( + SbxBackend, "_prepare_supervisor_script", return_value=Path("/sup.py") + ), + ] + return runs, cm + + def test_running_named_sandbox_reattaches_without_create_or_bootstrap( + self, tmp_path: Path + ): + backend = _reattach_backend(tmp_path) + runs, cms = self._patches(backend, ls_output="hot-box running\n") + with ( + cms[0], + cms[1], + cms[2], + patch.object(SbxBackend, "_apply_network_policy") as net, + patch.object(SbxBackend, "_bootstrap_packages") as boot, + patch.object(SbxBackend, "_setup_direct_workspace_aliases_in_sandbox"), + patch.object(SbxBackend, "_sbx_sandbox_healthy", return_value=True), + ): + backend._start_sbx_and_prepare_supervisor() + assert backend._sandbox_name == "hot-box" + assert not any(r[:2] == ["sbx", "create"] for r in runs) + net.assert_not_called() + boot.assert_not_called() + + def test_stopped_named_sandbox_is_started_then_reattaches(self, tmp_path: Path): + backend = _reattach_backend(tmp_path) + runs, cms = self._patches(backend, ls_output="hot-box stopped\n") + with ( + cms[0], + cms[1], + cms[2], + patch.object(SbxBackend, "_apply_network_policy") as net, + patch.object(SbxBackend, "_bootstrap_packages") as boot, + patch.object(SbxBackend, "_setup_direct_workspace_aliases_in_sandbox"), + patch.object(SbxBackend, "_sbx_sandbox_healthy", return_value=True), + ): + backend._start_sbx_and_prepare_supervisor() + assert backend._sandbox_name == "hot-box" + assert any( + r[:2] == ["sbx", "start"] and "hot-box" in r for r in runs + ), runs + assert not any(r[:2] == ["sbx", "create"] for r in runs) + net.assert_not_called() + boot.assert_not_called() + + def test_missing_named_sandbox_falls_through_to_create(self, tmp_path: Path): + backend = _reattach_backend(tmp_path) + runs, cms = self._patches(backend, ls_output="other-box running\n") + with ( + cms[0], + cms[1], + cms[2], + patch.object(SbxBackend, "_apply_network_policy") as net, + patch.object(SbxBackend, "_bootstrap_packages") as boot, + patch.object(SbxBackend, "_setup_direct_workspace_aliases_in_sandbox"), + patch.object(SbxBackend, "_sbx_sandbox_healthy", return_value=True), + ): + backend._start_sbx_and_prepare_supervisor() + assert backend._sandbox_name == "hot-box" + assert any(r[:2] == ["sbx", "create"] for r in runs), runs + net.assert_called_once() + boot.assert_called_once() + + def test_running_but_unhealthy_recreates(self, tmp_path: Path): + backend = _reattach_backend(tmp_path) + runs, _ = self._patches(backend, ls_output="hot-box running\n") + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + if cmd[:2] == ["sbx", "ls"]: + return SimpleNamespace( + returncode=0, stdout="hot-box running\n", stderr="" + ) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + runs.clear() + with ( + patch( + "predict_rlm.backends.sbx.backend.shutil.which", + return_value="/usr/bin/sbx", + ), + patch( + "predict_rlm.backends.sbx.backend.subprocess.run", + side_effect=fake_run, + ), + patch.object( + SbxBackend, "_prepare_supervisor_script", return_value=Path("/sup.py") + ), + patch.object(SbxBackend, "_apply_network_policy") as net, + patch.object(SbxBackend, "_bootstrap_packages") as boot, + patch.object(SbxBackend, "_setup_direct_workspace_aliases_in_sandbox"), + patch.object(SbxBackend, "_sbx_sandbox_healthy", return_value=False), + ): + backend._start_sbx_and_prepare_supervisor() + assert any( + r[:2] == ["sbx", "rm"] and "hot-box" in r for r in runs + ), runs + assert any(r[:2] == ["sbx", "create"] for r in runs), runs + net.assert_called_once() + boot.assert_called_once() + + +class TestSbxBackendReattachShutdown: + def test_reuse_shutdown_does_not_rm_or_delete_staging(self, tmp_path: Path): + backend = _reattach_backend(tmp_path) + backend._sandbox_name = "hot-box" + staging = backend._staging_root + assert staging.exists() + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + with patch( + "predict_rlm.backends.sbx.backend.subprocess.run", side_effect=fake_run + ): + backend.shutdown() + assert not any(r[:2] == ["sbx", "rm"] for r in runs), runs + assert staging.exists() + + def test_reuse_stop_on_shutdown_stops_container(self, tmp_path: Path): + backend = SbxBackend( + config=SbxConfig(name="hot-box", reuse=True, stop_on_shutdown=True), + preinstall_packages=False, + _staging_root=tmp_path / "staging", + ) + backend._sandbox_name = "hot-box" + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + with patch( + "predict_rlm.backends.sbx.backend.subprocess.run", side_effect=fake_run + ): + backend.shutdown() + assert any( + r[:2] == ["sbx", "stop"] and "hot-box" in r for r in runs + ), runs + assert not any(r[:2] == ["sbx", "rm"] for r in runs), runs + + +class TestSbxBackendDestroy: + def test_destroy_removes_sandbox_and_staging_root(self, tmp_path: Path): + backend = _reattach_backend(tmp_path) + backend._sandbox_name = "hot-box" + staging = backend._staging_root + assert staging.exists() + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + with patch( + "predict_rlm.backends.sbx.backend.subprocess.run", side_effect=fake_run + ): + backend.destroy() + assert any( + r[:3] == ["sbx", "rm", "--force"] and "hot-box" in r for r in runs + ), runs + assert not staging.exists() + + def test_remove_classmethod_force_removes_named_sandbox(self): + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + return SimpleNamespace(returncode=0, stdout="", stderr="") + + with patch( + "predict_rlm.backends.sbx.backend.subprocess.run", side_effect=fake_run + ): + SbxBackend.remove("hot-box") + assert any( + r[:3] == ["sbx", "rm", "--force"] and "hot-box" in r for r in runs + ), runs + + +class TestSbxBackendReattachRegression: + def test_default_path_still_creates_without_ls_probe(self, tmp_path: Path): + backend = SbxBackend( + config=SbxConfig(), + preinstall_packages=False, + _staging_root=tmp_path / "staging", + ) + runs: list[list[str]] = [] + + def fake_run(cmd, *args, **kwargs): + runs.append(list(cmd)) + return SimpleNamespace(returncode=0, stdout="auto-name\n", stderr="") + + with ( + patch( + "predict_rlm.backends.sbx.backend.shutil.which", + return_value="/usr/bin/sbx", + ), + patch( + "predict_rlm.backends.sbx.backend.subprocess.run", side_effect=fake_run + ), + patch.object( + SbxBackend, "_prepare_supervisor_script", return_value=Path("/sup.py") + ), + patch.object(SbxBackend, "_apply_network_policy") as net, + patch.object(SbxBackend, "_bootstrap_packages") as boot, + patch.object(SbxBackend, "_setup_direct_workspace_aliases_in_sandbox"), + ): + backend._start_sbx_and_prepare_supervisor() + assert not any(r[:2] == ["sbx", "ls"] for r in runs), runs + assert any(r[:2] == ["sbx", "create"] for r in runs), runs + net.assert_called_once() + boot.assert_called_once() + + +@pytest.mark.integration +@pytest.mark.skipif( + not _real_sbx_available(), + reason="real Docker Sandboxes tests require PREDICT_RLM_RUN_SBX_TESTS=1, sbx CLI, and sbx login", +) +class TestSbxBackendRealSbxReattach: + def _list_names(self) -> list[str]: + result = subprocess.run( + ["sbx", "ls"], capture_output=True, text=True, check=False, timeout=15 + ) + return [line.split()[0] for line in result.stdout.splitlines() if line.split()] + + def test_persist_reattach_destroy_lifecycle(self): + name = f"predict-rlm-reattach-{os.getpid()}" + config = SbxConfig(name=name, reuse=True) + marker = f"state-{os.getpid()}" + + first = SbxBackend(config=config, preinstall_packages=False, debug=True) + try: + first.prewarm() + first.execute( + "from pathlib import Path\n" + f"Path('/sandbox/persisted.txt').write_text({marker!r})\n" + "print('wrote')" + ) + first.shutdown() + assert name in self._list_names() + + second = SbxBackend(config=config, preinstall_packages=False, debug=True) + events: list[str] = [] + orig_log = second._log_lifecycle + + def spy_log(event, **fields): + events.append(event) + return orig_log(event, **fields) + + with ( + patch.object(second, "_log_lifecycle", side_effect=spy_log), + patch.object( + SbxBackend, + "_bootstrap_packages", + side_effect=AssertionError("bootstrap must not run on reattach"), + ), + ): + second.prewarm() + out = second.execute( + "from pathlib import Path\n" + "print(Path('/sandbox/persisted.txt').read_text())" + ) + assert out.strip() == marker + assert any(e.startswith("sbx.reattach") for e in events), events + assert not any(e == "sbx.create.start" for e in events), events + second.shutdown() + assert name in self._list_names() + + second.destroy() + assert name not in self._list_names() + + third = SbxBackend(config=config, preinstall_packages=False, debug=True) + try: + third.prewarm() + fresh = third.execute( + "from pathlib import Path\n" + "print(Path('/sandbox/persisted.txt').exists())" + ) + assert fresh.strip() == "False" + finally: + third.destroy() + finally: + subprocess.run( + ["sbx", "rm", "--force", name], + capture_output=True, + text=True, + check=False, + ) + + def test_reattach_after_interpreter_error_recovers(self): + name = f"predict-rlm-recover-{os.getpid()}" + config = SbxConfig(name=name, reuse=True) + + first = SbxBackend(config=config, preinstall_packages=False, debug=True) + try: + first.prewarm() + first.execute("keep = 7\nprint('ready')") + with pytest.raises(CodeInterpreterError, match="ValueError"): + first.execute("raise ValueError('boom')") + assert first.execute("print(keep + 1)").strip() == "8" + first.shutdown() + assert name in self._list_names() + + second = SbxBackend(config=config, preinstall_packages=False, debug=True) + second.prewarm() + assert second.execute("print('recovered')").strip() == "recovered" + with pytest.raises(CodeInterpreterError, match="ValueError"): + second.execute("raise ValueError('again')") + assert second.execute("print(6 * 7)").strip() == "42" + second.destroy() + assert name not in self._list_names() + finally: + subprocess.run( + ["sbx", "rm", "--force", name], + capture_output=True, + text=True, + check=False, + )