diff --git a/CHANGELOG.md b/CHANGELOG.md index b02c9da1..d276f3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,89 @@ All notable behavioural changes to `strands-robots` are logged here. Follows [Keep a Changelog](https://keepachangelog.com/) conventions. +## Unreleased - #90 (gr00t_inference validation hardening) + +### Added + +- ``validate_inputs()`` centralises all parameter validation with action-aware + scoping: read-only actions (``find_containers``, ``list``, ``status``, + ``stop``) only validate ``port``/``host``/``protocol``; mutating actions + (``start``, ``restart``, ``lifecycle``) validate the full parameter surface. +- ``host`` parameter now accepts both IP addresses and RFC-952 hostnames + (e.g. ``localhost``, ``host.docker.internal``). All-numeric strings that + fail ``ipaddress.ip_address()`` (e.g. ``127.0.01``, ``999.999.999.999``) + are rejected as obvious IP typos. +- ``protocol`` validation moved into ``validate_inputs()`` (previously + hand-rolled outside the helper, breaking the single-entry-point contract). +- ``pgrep`` patterns now match both ``--port N`` and ``--port=N`` forms, + preventing silent miss when the service is started with the ``=`` syntax. +- Integration tests that invoke ``gr00t_inference()`` end-to-end and assert + that invalid inputs are caught (pins the ``try/except ValueError`` wiring). +- End-to-end regression test for ``_stop_service`` cross-port-kill scenario: + verifies that a process on port 8000 is NOT killed when stopping port 80. +- Exception clauses narrowed throughout: ``_is_gr00t_process`` / ``_is_gr00t_host_process`` + use ``(OSError, subprocess.SubprocessError, UnicodeDecodeError)``; + ``_list_running_services``/``_is_service_running`` use ``OSError``; + ``_stop_service`` uses ``(OSError, subprocess.SubprocessError)``; + ``_start_service`` uses ``(OSError, RuntimeError)``. + Only ``_download_checkpoint`` retains ``except Exception`` (``# noqa: BLE001``) + because huggingface_hub raises varied, opaque exception types. +- ``action`` parameter validated against a complete allowlist of 10 valid + actions; unknown actions get a clear error with the valid set listed. +- ``image_name``, ``volumes``, and ``container_command`` parameters are now + validated (Docker image reference, path traversal, shell metacharacters). +- ``pgrep`` pattern factored into ``_PGREP_INFERENCE_PORT_FMT`` module-level + constant — single source of truth across all 4 usage sites. +- ``_PGREP_INFERENCE_PORT_FMT`` and ``_is_gr00t_*_process`` now match both + N1.5/N1.6 (``inference_service.py``) and N1.7 (``gr00t.eval.run_gr00t_server``) + entry-points. Closes the N1.7 stop/status identification gap. +- All exception types in process-probe helpers now log (WARNING for + PermissionError, DEBUG for other OSError/SubprocessError/UnicodeDecodeError). + +### Changed + +- ``validate_inputs()`` parameters are now all required (no defaults). + ``gr00t_inference()`` is the single source of truth for default values; + the validator no longer duplicates them (prevents silent drift). +- ``_DOCKER_IMAGE_RE`` extended to support private-registry references with + port numbers (e.g. ``localhost:5000/myorg/img:tag``). +- ``_is_gr00t_process`` / ``_is_gr00t_host_process`` now require ``--port`` + in the process cmdline to match — prevents false-killing editors or + log-tailers that happen to touch ``inference_service.py``. +- ``PermissionError`` in process probes now logs at WARNING level instead + of being silently swallowed. +- **BREAKING** Default ``host`` changed from ``0.0.0.0`` to ``127.0.0.1`` + (loopback-only, per AGENTS.md). Container actions (``start``/``restart``/ + ``lifecycle``) auto-flip to ``0.0.0.0`` internally since Docker's + ``-p {port}:{port}`` publish requires bind-all inside the container. + **Migration:** if your downstream connects from another host, pass + ``host="0.0.0.0"`` explicitly. +- Host-system fallback (``pgrep``) is documented as Linux-only. Non-Linux + platforms will see "No service running" rather than silently succeeding. + +### Fixed + +- Duplicate ``torch_mock.manual_seed`` assignment in ``tests/mocks/torch_mock.py``. +- Option-injection guard: ``repo_url``, ``repo_tag``, ``policy_name`` starting + with ``-`` are rejected (prevents git/docker flag injection via subprocess argv). +- Host-system fallback (pgrep) now returns a clear error on non-Linux platforms + instead of silently reporting success. +- All-numeric hostname guard narrowed to multi-label patterns only — single-label + numerics (e.g. ``123``) are valid per RFC-1123. + +### Notes + +- Host validation is **broader** than before for hostnames (RFC-952 names like + ``localhost`` and ``host.docker.internal`` now pass) but **stricter** for + IP-like typos (all-numeric labels like ``127.0.01`` are rejected). +- Validation scope covers ``port``, ``host``, ``protocol``, ``data_config``, + ``embodiment_tag``, ``container_name``, TRT dtypes, ``checkpoint_path``, + ``trt_engine_path``, ``image_name``, ``volumes``, and ``container_command``. + Parameters ``repo_url``, ``repo_tag``, ``hf_repo``, ``policy_name`` are + NOT validated here — they flow into argv-style subprocess calls which + are not shell-injection-vulnerable. + + ## Unreleased - #178 (LiberoOffScreenRenderEngine retired) ### Removed: ``LiberoOffScreenRenderEngine`` simulation backend (BREAKING) diff --git a/strands_robots/tools/gr00t_inference.py b/strands_robots/tools/gr00t_inference.py index 8bfd9c6f..0a03b2df 100644 --- a/strands_robots/tools/gr00t_inference.py +++ b/strands_robots/tools/gr00t_inference.py @@ -11,9 +11,12 @@ from a single prompt - see #148 for the motivation. """ +import ipaddress import os +import re import socket import subprocess +import sys import time from pathlib import Path from typing import Any @@ -41,6 +44,256 @@ def _checkpoints_dir() -> Path: return get_base_dir() / "checkpoints" +# ───────────────────────────────────────────────────────────────────── +# Input validation helpers +# ───────────────────────────────────────────────────────────────────── + +# Docker image reference pattern — supports registry:port/path:tag and @sha256:digest. +# Examples: "gr00t:latest", "nvcr.io/nvidia/gr00t:n1.7", "localhost:5000/myorg/img:tag" +# "nvcr.io/nvidia/gr00t@sha256:abcdef..." (digest-pinned, supply-chain recommended) +_DOCKER_IMAGE_RE = re.compile( + r"^[a-zA-Z0-9]" # must start with alnum + r"(?:[a-zA-Z0-9._\-]*[a-zA-Z0-9])?" # optional middle chars (host/path prefix) + r"(?::[0-9]{1,5})?" # optional registry port (:5000) + r"(?:/[a-zA-Z0-9][a-zA-Z0-9._\-]*)*" # path components (/org/img) + r"(?::[a-zA-Z0-9][a-zA-Z0-9._\-]*" # option A: :tag + r"|@sha256:[a-f0-9]{64})?$" # option B: @sha256:digest (mutually exclusive with tag) +) + +# Characters that cause harm in subprocess argv or shell interpolation. +# Narrowed per AGENTS.md review-learnings: quotes/bangs/parens/brackets +# appear in legitimate filesystem paths and all subprocess calls here are +# argv-style (no shell=True), so they pose no injection risk in path values. +# Backslash (\) is also legal on Linux (only / and NUL are forbidden by POSIX) +# and carries no special meaning in argv-style subprocess calls. +_SHELL_META = re.compile(r"[;&|`$<>\n\r\x00]") + +# Strict patterns for enumerable parameters. +_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$") +_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$") +_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$") +# RFC-952 hostname pattern for host validation. +_HOSTNAME_RE = re.compile( + r"^[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?" + r"(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9\-]{0,61}[a-zA-Z0-9])?)*$" +) +# Reject multi-label all-numeric strings — prevents typos like "127.0.01" +# which pass _HOSTNAME_RE but are clearly malformed IP attempts, not hostnames. +# Single-label numerics (e.g. "123") are valid per RFC-1123. +_ALL_NUMERIC_RE = re.compile(r"^[0-9]+(?:\.[0-9]+)+$") + +# Factored pgrep pattern — single source of truth for both docker-exec and +# host-fallback discovery paths. ERE syntax (procps-ng on Linux). +# NOTE: _PGREP_INFERENCE_PORT_FMT uses `( |$)` (ERE, space-only boundary) while +# _PYTHON_PORT_RE_FMT uses `(?:\s|$)` (Python re, any-whitespace boundary). +# This is intentional: pgrep is constrained to ERE, and cmdlines are space-separated +# in practice (procps-ng converts NUL → space when reading /proc/*/cmdline). +# Matches both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) +_PGREP_INFERENCE_PORT_FMT = r"(inference_service\.py|gr00t\.eval\.run_gr00t_server).*--port[= ]{port}( |$)" +# Python-side equivalent for re.search — uses (?:\s|$) instead of ( |$) +# because Python re is always ERE-ish and \s is more precise. +_PYTHON_PORT_RE_FMT = r"--port[= ]{port}(?:\s|$)" + +# Allowlists for TensorRT dtype parameters. +_VALID_VIT_DTYPES = {"fp16", "fp8"} +_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"} +_VALID_DIT_DTYPES = {"fp16", "fp8"} + +# Complete allowlist of valid actions for the tool. +_VALID_ACTIONS = frozenset( + { + "find_containers", + "list", + "status", + "stop", + "start", + "restart", + "build_image", + "download_checkpoint", + "start_container", + "lifecycle", + } +) + + +def _validate_path(value: str, label: str, *, reject_colon: bool = False) -> None: + """Reject paths containing shell metacharacters, null bytes, or traversal sequences. + + Args: + value: The path string to validate. + label: Human-readable label for error messages. + reject_colon: When True, reject ':' in the value. Required for Docker + volume mount paths where ':' would be re-interpreted as + host:container:options separator by docker -v. + """ + if "\x00" in value: + raise ValueError(f"{label} must not contain null bytes") + if value.startswith("-"): + raise ValueError(f"{label} must not start with '-' (got {value!r})") + if any(part == ".." for part in re.split(r"[/\\]", value)): + raise ValueError(f"{label} must not contain '..' path traversal components") + if _SHELL_META.search(value): + raise ValueError(f"{label} contains disallowed characters: {value!r}") + if reject_colon and ":" in value: + raise ValueError( + f"{label} must not contain ':' (docker -v interprets it as host:container:options separator; got {value!r})" + ) + + +def validate_inputs( + *, + action: str, + data_config: str, + embodiment_tag: str, + port: int, + host: str, + vit_dtype: str, + llm_dtype: str, + dit_dtype: str, + checkpoint_path: str | None, + trt_engine_path: str, + container_name: str | None, + protocol: str, + image_name: str | None = None, + volumes: dict[str, str] | None = None, + container_command: str | None = None, + repo_url: str | None = None, + repo_tag: str | None = None, + policy_name: str | None = None, +) -> None: + """Validate all user-supplied parameters in one place. + + Raises ValueError for any invalid input. Callers exposing this through + an AgentTool MUST wrap in try/except and convert to the structured error + dict (``{"status": "error", "message": str(e)}``). + + This centralises validation so that the main tool function stays focused + on orchestration and each check is independently testable via this + single entry-point. + + Validation is scoped to the action: actions whose only user-controlled + surface is port/host/protocol (find_containers, list, status, stop) + skip full parameter validation; mutating actions (start, restart, + lifecycle, build_image, download_checkpoint, start_container) validate + the full parameter surface. + """ + # Action allowlist — reject unknown actions early with a clear error + if action not in _VALID_ACTIONS: + raise ValueError(f"Unknown action {action!r}. Valid actions: {sorted(_VALID_ACTIONS)}") + + # Protocol — always validated regardless of action + valid_protocols = ("n1.5", "n1.6", "n1.7") + if protocol not in valid_protocols: + raise ValueError(f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}") + # Port range — always validated. Type-check first so callers get ValueError, not TypeError. + if not isinstance(port, int): + raise ValueError(f"port must be an integer, got {type(port).__name__}: {port!r}") + if not (1 <= port <= 65535): + raise ValueError(f"port must be between 1 and 65535, got {port}") + + # Host address validation — always validated (accept IPs and RFC-952 hostnames) + if not isinstance(host, str): + raise ValueError(f"host must be a string, got {type(host).__name__}: {host!r}") + # RFC 1035 §2.3.4: total hostname must not exceed 253 octets. + if len(host) > 253: + raise ValueError(f"host exceeds RFC 1035 maximum length of 253 chars (got {len(host)} chars)") + try: + ipaddress.ip_address(host) + except ValueError: + # Reject all-numeric labels (e.g. "127.0.01") — these are clearly IP typos + # not legitimate hostnames. Real hostnames must have at least one alpha label. + # Rejects all-numeric multi-label strings including Linux IPv4 short-forms + # like "127.1" — use canonical dotted-quad for clarity in agent-driven contexts. + if _ALL_NUMERIC_RE.match(host) or not _HOSTNAME_RE.match(host): + raise ValueError( + f"host must be a valid IP address or hostname (got {host!r}). " + f"Use '127.0.0.1' for loopback, '0.0.0.0' for all interfaces, " + f"or a valid hostname like 'localhost'." + ) from None + + # Port-only actions (find_containers, list, status, stop) only need + # port/host/protocol validation — the other params are unused by dispatch. + _port_only_actions = ("find_containers", "list", "status", "stop") + if action in _port_only_actions: + return + + # Image/download actions only consume image_name, paths, and volumes — not + # inference-time params (data_config, embodiment_tag, dtypes). + _image_only_actions = ("build_image", "download_checkpoint", "start_container") + if action in _image_only_actions: + # Validate container_name (used by start_container, interpolated into docker run --name) + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + # Validate image_name, volumes, container_command (relevant to these actions) + if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + # Option-injection guard for params used by these actions + for param_name, param_value in [("repo_url", repo_url), ("repo_tag", repo_tag), ("policy_name", policy_name)]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + return + + # ── Full validation for inference-mutating actions (start, restart, lifecycle) ── + + # Enumerable string parameters + if not _DATA_CONFIG_RE.match(data_config): + raise ValueError( + f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). " + f"See the tool docstring for the full list of accepted configs." + ) + if not _EMBODIMENT_TAG_RE.match(embodiment_tag): + raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})") + + # Docker container name + if container_name is not None and not _CONTAINER_NAME_RE.match(container_name): + raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})") + + # Filesystem paths — reject shell metacharacters and traversal + if checkpoint_path is not None: + _validate_path(checkpoint_path, "checkpoint_path") + _validate_path(trt_engine_path, "trt_engine_path") + + # TensorRT dtype allowlists + if vit_dtype not in _VALID_VIT_DTYPES: + raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}") + if llm_dtype not in _VALID_LLM_DTYPES: + raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}") + if dit_dtype not in _VALID_DIT_DTYPES: + raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}") + + # Docker image reference (if provided via kwargs) + if image_name is not None and not _DOCKER_IMAGE_RE.match(image_name): + raise ValueError(f"image_name must be a valid Docker image reference (got {image_name!r})") + + # Volume paths validation + if volumes is not None: + for vol_host, vol_container in volumes.items(): + _validate_path(vol_host, "volumes key (host path)", reject_colon=True) + _validate_path(vol_container, "volumes value (container path)", reject_colon=True) + + # Container command — reject shell metacharacters + if container_command is not None and _SHELL_META.search(container_command): + raise ValueError(f"container_command contains disallowed characters: {container_command!r}") + + # Option-injection guard: reject LLM-controlled values starting with '-' + # which could be parsed as flags by git/docker/pgrep in subprocess argv. + for param_name, param_value in [ + ("repo_url", repo_url), + ("repo_tag", repo_tag), + ("policy_name", policy_name), + ]: + if param_value is not None and param_value.startswith("-"): + raise ValueError(f"{param_name} must not start with '-' (got {param_value!r})") + + @tool def gr00t_inference( action: str, @@ -50,7 +303,7 @@ def gr00t_inference( data_config: str = "fourier_gr1_arms_only", embodiment_tag: str = "gr1", denoising_steps: int = 4, - host: str = "0.0.0.0", + host: str | None = None, container_name: str | None = None, timeout: int = 60, use_tensorrt: bool = False, @@ -192,7 +445,9 @@ def gr00t_inference( ``libero_sim``). denoising_steps: Number of denoising steps for action generation (default: 4). N1.5/N1.6 only - the N1.7 server reads this from the checkpoint. - host: Host address to bind the service to (default: ``0.0.0.0``). + host: Host address to bind the service to (default: ``127.0.0.1`` + loopback only). Container actions auto-flip to ``0.0.0.0`` internally + since Docker -p port-publish requires bind-all inside the container. container_name: Specific Docker container name. Auto-detected if omitted. timeout: Seconds to wait for service startup (default: 60). use_tensorrt: Enable TensorRT acceleration (default: False). @@ -297,14 +552,37 @@ def gr00t_inference( if api_token is None: api_token = os.environ.get("GROOT_API_TOKEN") - # Validate protocol up-front so users get a friendly error rather than - # an opaque docker-exec failure inside _start_service. - valid_protocols = ("n1.5", "n1.6", "n1.7") - if protocol not in valid_protocols: - return { - "status": "error", - "message": f"Unknown protocol {protocol!r}. Valid: {list(valid_protocols)}", - } + # Sentinel default: None means "user did not pass host=". + # Default to 127.0.0.1 (loopback, per AGENTS.md § LLM Input Safety). + # _start_service auto-flips to 0.0.0.0 ONLY when host was not explicitly set. + _host_was_explicit = host is not None + if host is None: + host = "127.0.0.1" + + # ── Validate all inputs in one call (scoped per action) ───────── + try: + validate_inputs( + action=action, + data_config=data_config, + embodiment_tag=embodiment_tag, + port=port, + host=host, + vit_dtype=vit_dtype, + llm_dtype=llm_dtype, + dit_dtype=dit_dtype, + checkpoint_path=checkpoint_path, + trt_engine_path=trt_engine_path, + container_name=container_name, + protocol=protocol, + image_name=image_name, + volumes=volumes, + container_command=container_command, + repo_url=repo_url, + repo_tag=repo_tag, + policy_name=policy_name, + ) + except ValueError as e: + return {"status": "error", "message": str(e)} if action == "find_containers": return _find_gr00t_containers() @@ -378,6 +656,7 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) elif action == "start": if checkpoint_path is None: @@ -404,6 +683,7 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) elif action == "restart": if checkpoint_path is None: @@ -430,9 +710,11 @@ def gr00t_inference( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=_host_was_explicit, ) - else: - return {"status": "error", "message": f"Unknown action: {action}"} + + # Unreachable: validate_inputs() rejects unknown actions before dispatch. + return {"status": "error", "message": f"Unknown action: {action}"} # pragma: no cover def _find_gr00t_containers() -> dict[str, Any]: @@ -479,7 +761,7 @@ def _list_running_services() -> dict[str, Any]: return {"status": "success", "services": services, "message": f"Found {len(services)} running services"} - except Exception as e: + except OSError as e: return {"status": "error", "message": f"Failed to list services: {e}"} @@ -491,7 +773,7 @@ def _is_service_running(port: int) -> bool: result = sock.connect_ex(("localhost", port)) sock.close() return result == 0 - except Exception: + except OSError: return False @@ -509,6 +791,88 @@ def _check_service_status(port: int) -> dict[str, Any]: } +def _is_gr00t_process(container_name: str, pid: str, *, port: int | None = None) -> bool: + """Verify that a PID inside a container belongs to a GR00T inference process. + + This prevents accidentally killing unrelated processes that happen to + be listening on the same port. + + Args: + container_name: Docker container name to inspect. + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + result = subprocess.run( + ["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"], + capture_output=True, + text=True, + check=False, + ) + if result.returncode == 0: + cmdline = result.stdout.replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + # Verify the process is serving on the requested port + # Use word-boundary regex to avoid partial matches (e.g. port 80 vs 8000) + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, subprocess.SubprocessError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied probing container process %s -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe container process %s: %s", pid, exc) + return False + + +def _is_gr00t_host_process(pid: str, *, port: int | None = None) -> bool: + """Verify that a host PID belongs to a GR00T inference process. + + Reads /proc//cmdline directly (no Docker) to confirm the process + is a GR00T inference service, optionally bound to a specific port. + + Note: This function reads from /proc and is Linux-only. + + Args: + pid: Process ID to check. + port: If provided, also verify the process is bound to this port. + """ + try: + cmdline_path = Path(f"/proc/{pid}/cmdline") + if cmdline_path.exists(): + cmdline = cmdline_path.read_text().replace("\x00", " ") + # Require both a Python interpreter AND inference_service.py in cmdline + # to avoid false-matching unrelated processes (e.g. vim editing a gr00t file) + # Match both N1.5/N1.6 (inference_service.py) and N1.7 (gr00t.eval.run_gr00t_server) + is_gr00t = ( + ("inference_service.py" in cmdline or "gr00t.eval.run_gr00t_server" in cmdline) + and ("python" in cmdline.lower() or "gr00t" in cmdline.lower()) + and "--port" in cmdline # Must have a --port flag to be a running service + ) + if is_gr00t and port is not None: + return bool(re.search(_PYTHON_PORT_RE_FMT.format(port=port), cmdline)) + return is_gr00t + except (OSError, UnicodeDecodeError) as exc: + import logging + + _logger = logging.getLogger(__name__) + if isinstance(exc, PermissionError): + _logger.warning("Permission denied reading /proc/%s/cmdline -- treating as non-GR00T", pid) + else: + _logger.debug("Failed to probe host process %s: %s", pid, exc) + return False + + def _stop_service(port: int) -> dict[str, Any]: """Stop GR00T inference service running on specific port.""" try: @@ -520,7 +884,14 @@ def _stop_service(port: int) -> dict[str, Any]: container_name = container["name"] try: result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -529,13 +900,21 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True) time.sleep(2) result = subprocess.run( - ["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"], + [ + "docker", + "exec", + container_name, + "pgrep", + "-f", + _PGREP_INFERENCE_PORT_FMT.format(port=port), + ], capture_output=True, text=True, check=False, @@ -544,7 +923,8 @@ def _stop_service(port: int) -> dict[str, Any]: if result.returncode == 0 and result.stdout.strip(): pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_process(container_name, pid, port=port): subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True) return { @@ -557,30 +937,54 @@ def _stop_service(port: int) -> dict[str, Any]: except subprocess.CalledProcessError: continue - # Fallback: try host system - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + # Fallback: try host system — verify via /proc//cmdline + # This path is Linux-only (ERE pgrep + /proc filesystem). + if sys.platform != "linux": + return { + "status": "error", + "message": ( + "No GR00T containers found. Host-fallback stop requires Linux " + "(pgrep + /proc). Run inside a Docker container or use action='find_containers' first." + ), + } + + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-TERM", pid], check=True) time.sleep(2) - result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True) + result = subprocess.run( + # NOTE: ( |$) is ERE syntax; pgrep on Linux (procps-ng) defaults to ERE. + # This pattern is Linux-only; BSD pgrep may not match correctly. + ["pgrep", "-f", _PGREP_INFERENCE_PORT_FMT.format(port=port)], + capture_output=True, + text=True, + ) if result.returncode == 0: pids = result.stdout.strip().split("\n") for pid in pids: - if pid: + pid = pid.strip() + if pid and _is_gr00t_host_process(pid, port=port): subprocess.run(["kill", "-KILL", pid], check=True) return {"status": "success", "port": port, "message": f"Service on port {port} stopped"} else: return {"status": "success", "port": port, "message": f"No service running on port {port}"} - except Exception as e: + except (OSError, subprocess.SubprocessError) as e: return {"status": "error", "message": f"Failed to stop service: {e}"} @@ -713,9 +1117,22 @@ def _start_service( api_token: str | None, protocol: str = "n1.5", use_sim_policy_wrapper: bool = False, + host_was_explicit: bool = False, ) -> dict[str, Any]: """Start GR00T inference service using Isaac-GR00T's native inference service.""" try: + # Auto-flip host for container actions: Docker's -p port-publish requires the + # service to bind all interfaces inside the container. Only auto-flip if the + # user accepted the default (sentinel was None → resolved to 127.0.0.1). + # Users who explicitly pass host="127.0.0.1" get it honoured (e.g. --network=host). + if host == "127.0.0.1" and not host_was_explicit: + import logging as _logging + + _logging.getLogger(__name__).warning( + "Auto-flipping host from 127.0.0.1 to 0.0.0.0 for container " + "port-publish (-p). Pass host='127.0.0.1' explicitly to keep loopback." + ) + host = "0.0.0.0" # Find container if not specified if container_name is None: containers = _find_gr00t_containers() @@ -791,7 +1208,7 @@ def _start_service( except subprocess.CalledProcessError as e: return {"status": "error", "message": f"Failed to start service: {e.stderr or e}"} - except Exception as e: + except (OSError, RuntimeError) as e: return {"status": "error", "message": f"Unexpected error: {e}"} @@ -1187,6 +1604,7 @@ def _lifecycle( api_token: str | None, protocol: str, use_sim_policy_wrapper: bool, + host_was_explicit: bool = False, ) -> dict[str, Any]: """Orchestrate the four-step setup or tear down a previously-started container. @@ -1308,6 +1726,7 @@ def _lifecycle( api_token=api_token, protocol=protocol, use_sim_policy_wrapper=use_sim_policy_wrapper, + host_was_explicit=host_was_explicit, ) steps.append({"step": "start", "result": start_result}) @@ -1320,7 +1739,7 @@ def _lifecycle( if __name__ == "__main__": - print("🐳 GR00T Inference Service Manager (Isaac-GR00T Native)") + print("GR00T Inference Service Manager (Isaac-GR00T Native)") print("Supports ZMQ, HTTP, and TensorRT inference modes") print() print("Examples:") diff --git a/tests/policies/groot/test_gr00t_inference_validation.py b/tests/policies/groot/test_gr00t_inference_validation.py new file mode 100644 index 00000000..f79ea266 --- /dev/null +++ b/tests/policies/groot/test_gr00t_inference_validation.py @@ -0,0 +1,1625 @@ +"""Tests for gr00t_inference input validation. + +Covers the validate_inputs() function which centralises all parameter +validation for the gr00t_inference tool. +""" + +import pytest + +from strands_robots.tools.gr00t_inference import validate_inputs + +# Standard valid kwargs for validate_inputs — tests override individual fields. +# validate_inputs() no longer has defaults (gr00t_inference() is the single source +# of truth for defaults), so tests must supply all required params. +_VALID_KWARGS = { + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": None, + "trt_engine_path": "gr00t_engine", + "container_name": None, + "protocol": "n1.5", +} + + +class TestValidateInputs: + """Tests for the validate_inputs() public function.""" + + def test_valid_defaults(self): + """Default values must pass validation.""" + validate_inputs(**_VALID_KWARGS) + + def test_valid_with_all_optional(self): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100_dualcam", + "embodiment_tag": "so100", + "port": 8000, + "vit_dtype": "fp16", + "llm_dtype": "fp8", + "dit_dtype": "fp16", + "checkpoint_path": "/data/checkpoints/model", + "trt_engine_path": "/engines/cache", + "container_name": "gr00t-n17", + } + ) + + def test_invalid_data_config_uppercase(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "FourierGR1", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_data_config_shell_chars(self): + with pytest.raises(ValueError, match="data_config"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "foo;rm -rf /", + "embodiment_tag": "gr1", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_embodiment_tag(self): + with pytest.raises(ValueError, match="embodiment_tag"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "GR1-Sonic!", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_zero(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 0, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_port_too_high(self): + with pytest.raises(ValueError, match="port"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 70000, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_vit_dtype(self): + with pytest.raises(ValueError, match="vit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "bf16", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_llm_dtype(self): + with pytest.raises(ValueError, match="llm_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "int4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_dit_dtype(self): + with pytest.raises(ValueError, match="dit_dtype"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "bf16", + } + ) + + def test_checkpoint_path_traversal(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/../../../etc/passwd", + } + ) + + def test_checkpoint_path_null_byte(self): + with pytest.raises(ValueError, match="checkpoint_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "checkpoint_path": "/data/model\x00.bin", + } + ) + + def test_trt_engine_path_shell_injection(self): + with pytest.raises(ValueError, match="trt_engine_path"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "trt_engine_path": "engine;rm -rf /", + } + ) + + def test_invalid_container_name(self): + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": "-invalid-start", + } + ) + + def test_container_name_none_is_ok(self): + """container_name=None should not raise.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "so100", + "embodiment_tag": "so100", + "port": 5555, + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + "container_name": None, + } + ) + + +class TestIsGr00tProcess: + """Test the _is_gr00t_process helper verifies port binding.""" + + def test_rejects_wrong_port(self, monkeypatch): + """_is_gr00t_process should reject a GR00T process on a different port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Simulate cmdline: "python inference_service.py --port 8000" + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Asking for port 80 should return False even though it's a gr00t process + assert _is_gr00t_process("container", "123", port=80) is False + + def test_accepts_matching_port(self, monkeypatch): + """_is_gr00t_process should accept when port matches.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=8000) is True + + def test_no_port_check_when_none(self, monkeypatch): + """_is_gr00t_process without port param should not verify port.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port\x008000\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + # Without port, just checks if it's a gr00t process + assert _is_gr00t_process("container", "123") is True + + def test_accepts_equals_style_port(self, monkeypatch): + """_is_gr00t_process should accept --port=N style.""" + import subprocess as sp + + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + fake_result = sp.CompletedProcess( + args=[], returncode=0, stdout="python\x00inference_service.py\x00--port=5555\x00" + ) + monkeypatch.setattr(sp, "run", lambda *a, **kw: fake_result) + + assert _is_gr00t_process("container", "123", port=5555) is True + assert _is_gr00t_process("container", "123", port=6666) is False + + +class TestIsGr00tHostProcess: + """Test the _is_gr00t_host_process helper for host-system PID verification.""" + + def test_rejects_wrong_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject a process on a different port.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Create a fake /proc//cmdline + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + # Monkeypatch Path to point at our fake proc, with reachability check + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=80) is False + assert called.get("p") == "/proc/123/cmdline" # patch was reached + + def test_accepts_matching_port(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should accept when port matches.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "456" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("456", port=8000) is True + assert called.get("p") == "/proc/456/cmdline" # patch was reached + + def test_rejects_non_gr00t_process(self, tmp_path, monkeypatch): + """_is_gr00t_host_process should reject non-GR00T processes.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "789" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00some_other_service.py\x00--port\x008000\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("789", port=8000) is False + assert called.get("p") == "/proc/789/cmdline" # patch was reached + + def test_no_port_check_when_none(self, tmp_path, monkeypatch): + """_is_gr00t_host_process without port checks only process identity.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "321" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x009999\x00") + + from pathlib import Path as RealPath + + called = {} + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Without port kwarg, just checks identity + assert _is_gr00t_host_process("321") is True + assert called.get("p") == "/proc/321/cmdline" # patch was reached + + +class TestHostValidation: + """Tests for host address validation in validate_inputs().""" + + def test_valid_loopback(self): + """127.0.0.1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.0.1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_all_interfaces(self): + """0.0.0.0 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "0.0.0.0", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_ipv6_loopback(self): + """::1 is valid.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "::1", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_with_spaces(self): + """Host with spaces must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "foo bar", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_empty_labels(self): + """Host with empty labels (double dot) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "a..b", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_localhost(self): + """Valid hostnames like localhost are now accepted.""" + # Should not raise — localhost is a valid RFC-952 hostname + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "localhost", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_valid_hostname_docker_internal(self): + """Docker internal hostname is accepted.""" + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "host.docker.internal", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_special_chars(self): + """Hostnames with special characters are rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "--invalid-host", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + +class TestGr00tInferenceToolIntegration: + """Integration tests verifying validate_inputs is wired into the tool entry point. + + These tests invoke gr00t_inference() directly and assert that invalid inputs + are caught and returned as error dicts, NOT silently passed through. + This pins the try/except ValueError wiring so a future refactor that drops + the validation call surfaces as a test failure. + """ + + def test_shell_injection_in_data_config_returns_error(self): + """Shell metacharacters in data_config must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", data_config="foo;rm -rf /") + assert result["status"] == "error" + assert "data_config" in result["message"] + + def test_path_traversal_in_checkpoint_returns_error(self): + """Path traversal in checkpoint_path must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", checkpoint_path="/tmp/../../../etc/passwd") + assert result["status"] == "error" + assert "checkpoint_path" in result["message"] + + def test_invalid_host_returns_error(self): + """Invalid host address must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", host="--not-valid") + assert result["status"] == "error" + assert "host" in result["message"] + + def test_invalid_port_returns_error(self): + """Out-of-range port must return error dict.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="start", port=99999) + assert result["status"] == "error" + assert "port" in result["message"] + + +class TestStopServiceCrossPortKill: + """End-to-end regression test for the cross-port-kill bug. + + Verifies that _stop_service(port=80) does NOT kill a GR00T process + running on port 8000. This pins the _is_gr00t_process(port=...) guard + so a future refactor that removes it will surface as a test failure. + """ + + def test_stop_service_does_not_kill_wrong_port(self, monkeypatch): + """_stop_service(port=80) must NOT kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + call_log = [] + + def _fake_run(cmd, *args, **kwargs): + call_log.append(cmd) + + # Mock _find_gr00t_containers returning no containers (forces host fallback) + if "docker" in cmd and "ps" in cmd: + import subprocess + + result = subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return result + + # Mock pgrep finding PID 999 on the host + if cmd[0] == "pgrep": + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + + # Mock kill — record it + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + import subprocess + + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + + import subprocess + + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000, NOT port 80.""" + # The process is a GR00T process but on port 8000 + if port == 80: + return False # Not on port 80 + if port == 8000: + return True # Yes on port 8000 + return True # Generic check without port + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=80) + + # No process should have been killed (the only process is on port 8000) + assert not killed_pids, ( + f"_stop_service(port=80) killed PIDs {killed_pids} but should not have " + f"(the only GR00T process is on port 8000)" + ) + + def test_stop_service_kills_correct_port(self, monkeypatch): + """_stop_service(port=8000) MUST kill a process on port 8000.""" + from strands_robots.tools.gr00t_inference import _stop_service + + killed_pids = [] + + def _fake_run(cmd, *args, **kwargs): + import subprocess + + if cmd[0] == "pgrep": + return subprocess.CompletedProcess(cmd, 0, stdout="999\n", stderr="") + if cmd[0] == "kill": + killed_pids.append(cmd[-1]) + return subprocess.CompletedProcess(cmd, 0, stdout="", stderr="") + return subprocess.CompletedProcess(cmd, 1, stdout="", stderr="") + + def _fake_host_process(pid, *, port=None): + """Simulate a process running on port 8000.""" + if port == 8000: + return True + return False + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.subprocess.run", _fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_gr00t_host_process", _fake_host_process) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + + _stop_service(port=8000) + + # Process should have been killed + assert "999" in killed_pids, f"_stop_service(port=8000) should have killed PID 999 but killed {killed_pids}" + + +class TestActionScopedValidation: + """Tests verifying that validate_inputs scopes checks per action. + + Read-only actions (find_containers, list, status, stop) should only + validate port/host/protocol, not the full parameter surface like + data_config, embodiment_tag, etc. + """ + + def test_read_only_action_accepts_any_data_config(self): + """Read-only actions should not validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # This has hyphens/caps which WOULD fail for action="start" but passes for "list" + validate_inputs(**{**_VALID_KWARGS, "action": "list", "data_config": "Has-Hyphens-And-Caps"}) + + def test_read_only_action_still_validates_port(self): + """Read-only actions must still validate port.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="port must be between"): + validate_inputs(**{**_VALID_KWARGS, "action": "status", "port": 99999}) + + def test_read_only_action_still_validates_host(self): + """Read-only actions must still validate host.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="host must be a valid"): + validate_inputs(**{**_VALID_KWARGS, "action": "stop", "host": "--invalid"}) + + def test_read_only_action_still_validates_protocol(self): + """Read-only actions must still validate protocol.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="Unknown protocol"): + validate_inputs(**{**_VALID_KWARGS, "action": "list", "protocol": "invalid"}) + + def test_mutating_action_validates_data_config(self): + """Mutating actions must validate data_config.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="data_config"): + validate_inputs(**{**_VALID_KWARGS, "action": "start", "data_config": "foo;bar"}) + + def test_integration_read_only_action_skips_data_config_validation(self): + """gr00t_inference(action='list', data_config='invalid') must not error on data_config.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # action="list" should not validate data_config + # It will fail at runtime (no docker) but NOT on validation + result = gr00t_inference(action="list", data_config="invalid;stuff") + # Should NOT be a validation error about data_config + if result.get("status") == "error": + assert "data_config" not in result.get("message", "") + + +class TestHostNumericTypoRejection: + """Regression tests for all-numeric hostname typos. + + Verifies that "127.0.01" (typo for 127.0.0.1) and "999.999.999.999" + are rejected by validate_inputs. These strings pass _HOSTNAME_RE but + are caught by the _ALL_NUMERIC_RE guard introduced in review round-4. + """ + + def test_invalid_host_typo_dotted_numeric(self): + """127.0.01 (typo for 127.0.0.1) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "127.0.01", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_invalid_host_999_octets(self): + """999.999.999.999 (invalid IP, all-numeric) must be rejected.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs( + **{ + **_VALID_KWARGS, + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "port": 5555, + "host": "999.999.999.999", + "vit_dtype": "fp8", + "llm_dtype": "nvfp4", + "dit_dtype": "fp8", + } + ) + + def test_single_numeric_label_is_valid_hostname(self): + """A bare number like '8080' is a valid single-label hostname (RFC-1123).""" + # Single-label numerics are valid hostnames; only multi-label patterns + # like '127.0.01' (IP typos) are rejected. + validate_inputs( + **{ + **_VALID_KWARGS, + "host": "8080", + } + ) + + +class TestActionAllowlistValidation: + """Tests for the action allowlist in validate_inputs. + + Verifies that unknown actions are rejected with a clear error that + lists the valid options, rather than falling through to validation + of unrelated parameters. + """ + + def test_unknown_action_rejected(self): + """Typo'd action gets a clear error listing valid actions.""" + with pytest.raises(ValueError, match="Unknown action.*Valid actions"): + validate_inputs(**{**_VALID_KWARGS, "action": "strat"}) # typo for "start" + + def test_unknown_action_integration(self): + """gr00t_inference(action='typo') returns error about unknown action.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="typo") + assert result["status"] == "error" + assert "Unknown action" in result["message"] + + def test_all_valid_actions_accepted(self): + """All 10 valid actions pass action validation (may fail later).""" + from strands_robots.tools.gr00t_inference import _VALID_ACTIONS + + for action in _VALID_ACTIONS: + # Should not raise ValueError about unknown action + # (may raise about other params, but that's fine) + try: + validate_inputs(**{**_VALID_KWARGS, "action": action}) + except ValueError as e: + assert "Unknown action" not in str(e), f"Action {action!r} wrongly rejected" + + +class TestExpandedParamValidation: + """Tests for image_name, volumes, and container_command validation.""" + + def test_invalid_image_name_rejected(self): + """Docker image with shell chars must be rejected.""" + with pytest.raises(ValueError, match="image_name must be a valid Docker"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "gr00t:latest; rm -rf /", + } + ) + + def test_valid_image_name_accepted(self): + """Standard Docker image references must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "image_name": "nvcr.io/nvidia/gr00t:n1.7", + } + ) + + def test_volume_path_traversal_rejected(self): + """Volumes with path traversal must be rejected.""" + with pytest.raises(ValueError, match="volumes key"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "volumes": {"/../etc/passwd": "/data"}, + } + ) + + def test_container_command_shell_meta_rejected(self): + """Container command with shell metacharacters must be rejected.""" + with pytest.raises(ValueError, match="container_command contains disallowed"): + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null; rm -rf /", + } + ) + + def test_valid_container_command_accepted(self): + """Standard container commands must pass.""" + # Should not raise + validate_inputs( + **{ + **_VALID_KWARGS, + "action": "start", + "data_config": "fourier_gr1_arms_only", + "embodiment_tag": "gr1", + "container_command": "tail -f /dev/null", + } + ) + + +class TestHappyPathIntegration: + """Happy-path integration test for gr00t_inference. + + Verifies that valid inputs pass validation and proceed to runtime + (which will fail due to missing Docker, but NOT on validation). + """ + + def test_valid_list_action_passes_validation(self): + """gr00t_inference(action='list') with valid params does not error on validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="list") + # The error should be about runtime (no docker), NOT validation + if result.get("status") == "error": + msg = result.get("message", "") + # Must not be a validation error + assert "must be" not in msg or "port" not in msg + assert "Unknown action" not in msg + assert "data_config" not in msg + + def test_valid_status_action_passes_validation(self): + """gr00t_inference(action='status') with valid params proceeds past validation.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="status", port=5555, host="127.0.0.1") + # Should not be a validation error + if result.get("status") == "error": + msg = result.get("message", "") + assert "Unknown action" not in msg + assert "host must be" not in msg + assert "port must be" not in msg + + +class TestDockerImageRegistryPort: + """Tests that _DOCKER_IMAGE_RE supports private registries with port numbers.""" + + def test_registry_with_port_accepted(self): + """localhost:5000/myorg/img:tag must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "localhost:5000/myorg/img:tag"}) + + def test_registry_with_port_no_tag(self): + """registry.internal:5000/img must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "registry.internal:5000/img"}) + + def test_nvcr_standard_format(self): + """nvcr.io/nvidia/gr00t:n1.7 must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "nvcr.io/nvidia/gr00t:n1.7"}) + + def test_simple_image_tag(self): + """gr00t:latest must be accepted.""" + validate_inputs(**{**_VALID_KWARGS, "image_name": "gr00t:latest"}) + + +class TestProcessIdentificationRequiresPort: + """Tests that _is_gr00t_process requires --port in cmdline. + + Prevents false-matching unrelated processes like editors or log-tailers + that happen to have 'inference_service.py' and 'python' in their cmdline. + """ + + def test_process_without_port_flag_rejected(self, monkeypatch): + """A process with 'python inference_service.py' but no --port flag is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + # Mock docker exec to return a cmdline without --port + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --config test\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # Without --port in cmdline, should return False + assert _is_gr00t_process("container", "123", port=5555) is False + + def test_process_with_port_flag_accepted(self, monkeypatch): + """A process with --port 5555 in cmdline is a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "python inference_service.py --port 5555\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + assert _is_gr00t_process("container", "123", port=5555) is True + + def test_editor_on_inference_service_rejected(self, monkeypatch): + """vim editing inference_service.py under a python venv is not a match.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_process + + def fake_run(*args, **kwargs): + class Result: + returncode = 0 + stdout = "/opt/conda/envs/gr00t/bin/python vim /opt/gr00t/inference_service.py\x00" + + return Result() + + monkeypatch.setattr("subprocess.run", fake_run) + # No --port flag → rejected + assert _is_gr00t_process("container", "123", port=5555) is False + + +class TestOptionInjectionGuard: + """Test option-injection guard for argv-interpolated parameters.""" + + def test_repo_url_starting_with_dash_rejected(self): + """repo_url='--upload-pack=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_url="--upload-pack=touch /tmp/pwned", + ) + assert result["status"] == "error" + assert "repo_url" in result["message"] + assert "must not start with '-'" in result["message"] + + def test_repo_tag_starting_with_dash_rejected(self): + """repo_tag='--config=evil' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="build_image", + repo_tag="--config=core.fsmonitor=evil-cmd", + ) + assert result["status"] == "error" + assert "repo_tag" in result["message"] + + def test_policy_name_starting_with_dash_rejected(self): + """policy_name='--flag' must be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start", + checkpoint_path="/data/model", + policy_name="--malicious", + ) + assert result["status"] == "error" + assert "policy_name" in result["message"] + + def test_valid_repo_url_accepted(self, monkeypatch): + """Normal https:// URL must pass the guard.""" + from strands_robots.tools import gr00t_inference as gi_mod + + # Mock _build_image to avoid actual git/docker operations + monkeypatch.setattr( + gi_mod, + "_build_image", + lambda **kwargs: {"status": "success", "message": "mocked"}, + ) + + result = gi_mod.gr00t_inference( + action="build_image", + repo_url="https://github.com/NVIDIA/Isaac-GR00T", + repo_tag="n1.7-release", + ) + # Should not be an option-injection error + assert "must not start with '-'" not in result.get("message", "") + + +class TestHostAutoFlipForContainer: + """Test that container actions auto-flip 127.0.0.1 to 0.0.0.0.""" + + def test_default_host_is_loopback(self): + """Signature default must be 127.0.0.1 (AGENTS.md compliance).""" + import inspect + + from strands_robots.tools.gr00t_inference import gr00t_inference + + sig = inspect.signature(gr00t_inference) + # Sentinel default: None means "use 127.0.0.1" but distinguishes from explicit + assert sig.parameters["host"].default is None + + def test_start_service_auto_flips_loopback(self, monkeypatch): + """_start_service should auto-flip 127.0.0.1 to 0.0.0.0 for Docker.""" + from strands_robots.tools.gr00t_inference import _start_service + + captured_host = {} + + def fake_find(*args, **kwargs): + return { + "status": "success", + "containers": [{"name": "gr00t-test", "status": "Up 2 hours"}], + } + + def fake_build_cmd(**kwargs): + captured_host["host"] = kwargs.get("host") + return ["docker", "exec", "gr00t-test", "echo", "test"] + + monkeypatch.setattr("strands_robots.tools.gr00t_inference._find_gr00t_containers", fake_find) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._build_inference_command", fake_build_cmd) + + import subprocess + + def fake_run(*args, **kwargs): + return subprocess.CompletedProcess(args=args[0] if args else [], returncode=0, stdout="", stderr="") + + monkeypatch.setattr(subprocess, "run", fake_run) + monkeypatch.setattr("strands_robots.tools.gr00t_inference._is_service_running", lambda port: True) + + # Call with loopback default — should auto-flip + _start_service( + checkpoint_path="/data/model", + port=5555, + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + denoising_steps=4, + host="127.0.0.1", + container_name=None, + policy_name=None, + timeout=5, + use_tensorrt=False, + trt_engine_path="gr00t_engine", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + http_server=False, + api_token=None, + host_was_explicit=False, + ) + assert captured_host.get("host") == "0.0.0.0" + + +class TestSingleLabelNumericHostname: + """Verify single-label numeric hostnames (per RFC-1123) are accepted.""" + + def test_single_numeric_label_accepted(self): + """Single-label '123' is a valid hostname (RFC-1123).""" + validate_inputs(**{**_VALID_KWARGS, "host": "123"}) + + def test_multi_label_numeric_rejected(self): + """Multi-label '127.0.01' is rejected as an IP typo.""" + with pytest.raises(ValueError, match="host must be a valid IP address or hostname"): + validate_inputs(**{**_VALID_KWARGS, "host": "127.0.01"}) + + +class TestPlatformGuardForHostFallback: + """Test that host-fallback stop returns error on non-Linux platforms.""" + + def test_non_linux_platform_returns_error(self, monkeypatch): + """On non-Linux, _stop_service should error when no containers found.""" + import sys as _sys + + from strands_robots.tools.gr00t_inference import _stop_service + + # Mock no containers found + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._find_gr00t_containers", + lambda: {"status": "success", "containers": []}, + ) + monkeypatch.setattr(_sys, "platform", "darwin") + + result = _stop_service(5555) + assert result["status"] == "error" + assert "Linux" in result["message"] + + +class TestN17ProcessIdentification: + """Regression tests for N1.7 process identification — GH review thread. + + N1.7 services are started via `python -m gr00t.eval.run_gr00t_server` which + doesn't contain `inference_service.py` in cmdline. These tests ensure the + stop/status path can identify N1.7 services. + """ + + def test_n17_cmdline_detected_by_host_process_check(self, tmp_path, monkeypatch): + """_is_gr00t_host_process detects N1.7 server cmdline.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + # Simulate N1.7 cmdline: python -m gr00t.eval.run_gr00t_server --port 5555 + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x005555\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("999", port=5555) is True + assert called.get("p") == "/proc/999/cmdline" + + def test_n17_cmdline_wrong_port_rejected(self, tmp_path, monkeypatch): + """N1.7 server on wrong port is not killed.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "999" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00-m\x00gr00t.eval.run_gr00t_server\x00--port\x008000\x00") + + called = {} + from pathlib import Path as RealPath + + def _fake_path(p): + called["p"] = p + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + # Request port 80 — should not match 8000 + assert _is_gr00t_host_process("999", port=80) is False + assert called.get("p") == "/proc/999/cmdline" + + def test_n15_cmdline_still_detected(self, tmp_path, monkeypatch): + """N1.5/N1.6 cmdline (inference_service.py) still works after N1.7 support.""" + from strands_robots.tools.gr00t_inference import _is_gr00t_host_process + + proc_dir = tmp_path / "proc" / "123" + proc_dir.mkdir(parents=True) + cmdline_file = proc_dir / "cmdline" + cmdline_file.write_text("python\x00inference_service.py\x00--port\x005555\x00") + + from pathlib import Path as RealPath + + def _fake_path(p): + return RealPath(str(p).replace("/proc", str(tmp_path / "proc"))) + + monkeypatch.setattr("strands_robots.tools.gr00t_inference.Path", _fake_path) + + assert _is_gr00t_host_process("123", port=5555) is True + + +class TestExpandedParamValidationExtended: + """Extended tests for image_name, volumes, and container_command — covers happy paths.""" + + def test_valid_image_name(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="localhost:5000/myorg/img:tag", + ) + + def test_invalid_image_name_shell_meta(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="image_name"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + image_name="evil;rm -rf /", + ) + + def test_volumes_path_traversal_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="volumes"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"../../etc/passwd": "/data"}, + ) + + def test_container_command_shell_meta_rejected(self): + import pytest + + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_command"): + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null; rm -rf /", + ) + + def test_valid_container_command(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise - legitimate container commands without shell metas + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + container_command="tail -f /dev/null", + ) + + def test_valid_volumes(self): + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start", + data_config="fourier_gr1_arms_only", + embodiment_tag="gr1", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path="/tmp/ckpt", + trt_engine_path="gr00t_engine", + container_name="gr00t", + protocol="n1.5", + volumes={"/tmp/checkpoints": "/data/checkpoints"}, + ) + + +class TestHostAutoFlipSentinel: + """Regression tests for the sentinel-based host auto-flip logic. + + The auto-flip from 127.0.0.1 → 0.0.0.0 for Docker container actions + MUST only fire when the user accepted the default (i.e. did not pass + host= explicitly). Users who explicitly pass host="127.0.0.1" (e.g. + for --network=host deployments) must have their choice honoured. + """ + + def test_default_host_passes_not_explicit(self, monkeypatch): + """When host is NOT passed (None sentinel), host_was_explicit=False.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + # Call without host= (uses default None → 127.0.0.1, not explicit) + gr00t_inference(action="start", checkpoint_path="/data/model") + assert captured.get("host") == "127.0.0.1" + assert captured.get("host_was_explicit") is False, ( + "Default host (None sentinel) should pass host_was_explicit=False" + ) + + def test_explicit_loopback_passes_explicit_flag(self, monkeypatch): + """When user explicitly passes host='127.0.0.1', host_was_explicit=True.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + # Call WITH explicit host="127.0.0.1" — must pass host_was_explicit=True + gr00t_inference(action="start", checkpoint_path="/data/model", host="127.0.0.1") + assert captured.get("host") == "127.0.0.1" + assert captured.get("host_was_explicit") is True, "Explicit host='127.0.0.1' must pass host_was_explicit=True" + + def test_explicit_zero_passes_explicit_flag(self, monkeypatch): + """When user explicitly passes host='0.0.0.0', host_was_explicit=True.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "error", "message": "mocked"} + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + gr00t_inference(action="start", checkpoint_path="/data/model", host="0.0.0.0") + assert captured.get("host") == "0.0.0.0" + assert captured.get("host_was_explicit") is True + + +class TestReviewRound8Fixes: + """Regression tests for review round-8 fixes (2026-05-22 21:44 UTC). + + Covers: + - restart path forwarding host_was_explicit + - colon rejection in volume paths (docker -v mount-redirect) + - digest-pinned image references + - TypeError handling in validation wrapper + - dash-prefix rejection in volume paths + """ + + def test_restart_forwards_host_was_explicit(self, monkeypatch): + """action='restart' must forward host_was_explicit to _start_service.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "success", "message": "mocked"} + + def _mock_stop_service(port): + pass + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._stop_service", + _mock_stop_service, + ) + monkeypatch.setattr("time.sleep", lambda _: None) + + # Explicit host='127.0.0.1' on restart must pass host_was_explicit=True + gr00t_inference( + action="restart", + checkpoint_path="/data/model", + host="127.0.0.1", + ) + assert captured.get("host_was_explicit") is True, ( + "restart path must forward host_was_explicit=True for explicit host" + ) + + def test_restart_default_host_not_explicit(self, monkeypatch): + """action='restart' with default host must pass host_was_explicit=False.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + captured = {} + + def _mock_start_service(**kwargs): + captured.update(kwargs) + return {"status": "success", "message": "mocked"} + + def _mock_stop_service(port): + pass + + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._start_service", + _mock_start_service, + ) + monkeypatch.setattr( + "strands_robots.tools.gr00t_inference._stop_service", + _mock_stop_service, + ) + monkeypatch.setattr("time.sleep", lambda _: None) + + # Default host (not passed) on restart -> host_was_explicit=False + gr00t_inference(action="restart", checkpoint_path="/data/model") + assert captured.get("host_was_explicit") is False + + def test_volume_path_colon_rejected(self): + """Volume paths containing ':' must be rejected (docker -v mount-redirect).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/legit/dir:rw,nosuid": "/container/path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] or "colon" in result["message"].lower() + + def test_volume_value_colon_rejected(self): + """Container-side volume paths containing ':' must also be rejected.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"/host/path": "/container:path"}, + ) + assert result["status"] == "error" + assert ":" in result["message"] + + def test_volume_path_dash_prefix_rejected(self): + """Volume paths starting with '-' must be rejected (option injection).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference( + action="start_container", + image_name="gr00t:latest", + volumes={"--privileged=foo": "/bar"}, + ) + assert result["status"] == "error" + assert "'-'" in result["message"] or "start with" in result["message"] + + def test_digest_pinned_image_accepted(self): + """Digest-pinned image refs (registry/path@sha256:hex) must be accepted.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # Should NOT fail validation on image_name (may fail later on docker ops) + result = gr00t_inference( + action="start_container", + image_name="nvcr.io/nvidia/gr00t@sha256:" + "a" * 64, + ) + # If it fails, it should NOT be an image_name validation error + if result["status"] == "error": + assert "valid Docker image" not in result["message"] + + def test_type_error_returns_structured_error(self): + """TypeError from bad parameter types must return structured error, not raise.""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + # port="5555" (str instead of int) -> TypeError on `1 <= port <= 65535` + result = gr00t_inference(action="start", checkpoint_path="/data/model", port="5555") + assert result["status"] == "error" + # Must not propagate as unhandled exception - returns dict + + def test_end_to_end_bogus_action_returns_error_dict(self): + """Bogus action returns structured error dict (not raw exception).""" + from strands_robots.tools.gr00t_inference import gr00t_inference + + result = gr00t_inference(action="bogus_action") + assert isinstance(result, dict) + assert result["status"] == "error" + assert "Unknown action" in result["message"] or "bogus_action" in result["message"] + + +class TestImageOnlyBranchValidation: + """Tests for validation on image-only actions (build_image, download_checkpoint, start_container).""" + + def test_container_name_validated_on_start_container(self): + """container_name must be validated on start_container (image-only branch).""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="container_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="--privileged", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + ) + + def test_policy_name_dash_rejected_on_start_container(self): + """policy_name starting with '-' must be rejected on image-only actions.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + with pytest.raises(ValueError, match="policy_name"): + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name=None, + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name="--malicious", + ) + + def test_valid_container_name_accepted_on_start_container(self): + """Valid container_name passes on start_container.""" + from strands_robots.tools.gr00t_inference import validate_inputs + + # Should not raise + validate_inputs( + action="start_container", + data_config="so100", + embodiment_tag="so100", + port=5555, + host="127.0.0.1", + vit_dtype="fp8", + llm_dtype="nvfp4", + dit_dtype="fp8", + checkpoint_path=None, + trt_engine_path="/opt/engine", + container_name="my-gr00t-container", + protocol="n1.5", + image_name=None, + volumes=None, + container_command=None, + repo_url=None, + repo_tag=None, + policy_name=None, + )