From 050776924ae8a7c6d7c5ebf1b4652cb0ce02bc67 Mon Sep 17 00:00:00 2001 From: Kingston Date: Wed, 6 May 2026 14:43:32 +0800 Subject: [PATCH 1/4] Add first-class pi05_libero support to hosting layer Adds LIBERO observation factory + image specs to warmup.py, makes prepare-checkpoint's required asset id configurable (so libero's physical-intelligence/libero norm stats can be verified), threads a matching Terraform variable through the cloud-init bootstrap, and adds an --embodiment flag to the QUIC/WS smoke tests so they can target a libero deployment end-to-end. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../regional_inference_instance/main.tf | 1 + .../templates/user_data.yaml.tftpl | 3 +- .../regional_inference_instance/variables.tf | 6 +++ infra/regional-instance/main.tf | 1 + infra/regional-instance/variables.tf | 6 +++ main.py | 23 ++++++++- src/hosting/prepare_checkpoint.py | 11 +++-- src/hosting/warmup.py | 49 ++++++++++++++++++- tests/helpers.py | 37 +++++++++++++- tests/test_quic.py | 8 +-- tests/test_ws.py | 10 ++-- 11 files changed, 139 insertions(+), 16 deletions(-) diff --git a/infra/modules/regional_inference_instance/main.tf b/infra/modules/regional_inference_instance/main.tf index b47e0b7..2c55218 100644 --- a/infra/modules/regional_inference_instance/main.tf +++ b/infra/modules/regional_inference_instance/main.tf @@ -230,6 +230,7 @@ resource "aws_instance" "inference" { checkpoint_prep_model_id = var.checkpoint_prep_model_id checkpoint_prep_openpi_assets_uri = var.checkpoint_prep_openpi_assets_uri checkpoint_prep_output_dir = var.checkpoint_prep_output_dir + checkpoint_prep_required_asset_id = var.checkpoint_prep_required_asset_id container_name = var.container_name ecr_region = var.ecr_region ecr_registry_host = local.ecr_registry_host diff --git a/infra/modules/regional_inference_instance/templates/user_data.yaml.tftpl b/infra/modules/regional_inference_instance/templates/user_data.yaml.tftpl index a3e763c..f3648af 100644 --- a/infra/modules/regional_inference_instance/templates/user_data.yaml.tftpl +++ b/infra/modules/regional_inference_instance/templates/user_data.yaml.tftpl @@ -71,7 +71,8 @@ write_files: python main.py prepare-checkpoint \ --model-id ${checkpoint_prep_model_id} \ --openpi-assets-uri ${checkpoint_prep_openpi_assets_uri} \ - --output-dir ${checkpoint_prep_output_dir} + --output-dir ${checkpoint_prep_output_dir} \ + --required-asset-id ${checkpoint_prep_required_asset_id} %{ endif ~} %{ if prepare_planner_checkpoint ~} diff --git a/infra/modules/regional_inference_instance/variables.tf b/infra/modules/regional_inference_instance/variables.tf index 1bd5e43..0b20959 100644 --- a/infra/modules/regional_inference_instance/variables.tf +++ b/infra/modules/regional_inference_instance/variables.tf @@ -178,6 +178,12 @@ variable "checkpoint_prep_output_dir" { default = "/cache/models/pi05_base_openpi" } +variable "checkpoint_prep_required_asset_id" { + description = "Asset id whose norm_stats.json the Docker checkpoint preparation step must produce. Use 'trossen' for ALOHA, 'droid' for pi05_droid, 'physical-intelligence/libero' for pi05_libero." + type = string + default = "trossen" +} + # -- Planner checkpoint prep (planner slot) ----------------------------------- # # Optional one-shot step that downloads + extracts a JAX subtask planner tar diff --git a/infra/regional-instance/main.tf b/infra/regional-instance/main.tf index 5796de3..95ea290 100644 --- a/infra/regional-instance/main.tf +++ b/infra/regional-instance/main.tf @@ -38,6 +38,7 @@ module "regional_inference_instance" { checkpoint_prep_model_id = var.checkpoint_prep_model_id checkpoint_prep_openpi_assets_uri = var.checkpoint_prep_openpi_assets_uri checkpoint_prep_output_dir = var.checkpoint_prep_output_dir + checkpoint_prep_required_asset_id = var.checkpoint_prep_required_asset_id container_name = var.container_name deployment_name = var.deployment_name docker_image_tag = var.docker_image_tag diff --git a/infra/regional-instance/variables.tf b/infra/regional-instance/variables.tf index 5a2f2aa..90efbad 100644 --- a/infra/regional-instance/variables.tf +++ b/infra/regional-instance/variables.tf @@ -183,6 +183,12 @@ variable "checkpoint_prep_output_dir" { default = "/cache/models/pi05_base_openpi" } +variable "checkpoint_prep_required_asset_id" { + description = "Asset id whose norm_stats.json the Docker checkpoint preparation step must produce. Use 'trossen' for ALOHA, 'droid' for pi05_droid, 'physical-intelligence/libero' for pi05_libero." + type = string + default = "trossen" +} + variable "openpi_pytorch_compile_mode" { description = "Value for OPENPI_PYTORCH_COMPILE_MODE inside the inference container" type = string diff --git a/main.py b/main.py index 82149cc..ab059ad 100644 --- a/main.py +++ b/main.py @@ -43,6 +43,16 @@ def prepare_checkpoint( bool, typer.Option(help="Re-download upstream files and rebuild the prepared checkpoint."), ] = False, + required_asset_id: Annotated[ + str, + typer.Option( + help=( + "Asset id whose norm_stats.json must exist in the prepared " + "checkpoint's assets/ directory. Defaults to 'trossen' (ALOHA). " + "Use 'droid' for pi05_droid or 'physical-intelligence/libero' for pi05_libero." + ), + ), + ] = "trossen", ) -> None: """Prepare a local OpenPI-compatible checkpoint from upstream sources.""" from hosting.prepare_checkpoint import prepare_openpi_compatible_checkpoint @@ -52,6 +62,7 @@ def prepare_checkpoint( openpi_assets_uri=openpi_assets_uri, output_dir=output_dir, force_download=force_download, + required_asset_id=required_asset_id, ) @@ -113,23 +124,30 @@ def serve( InferenceModeChoice = Literal["default", "action_only", "subtask_only"] +EmbodimentChoice = Literal["aloha", "droid", "libero"] _MODE_OPTION_HELP = ( "Inference mode sent in the observation. Lets benchmarks target one phase " "of a combined-mode server: 'action_only' skips the planner, 'subtask_only' " "skips the action policy. Omit for the server default." ) +_EMBODIMENT_OPTION_HELP = ( + "Observation shape to send. Must match the embodiment the server's action " + "checkpoint was trained on (aloha for pi05_aloha, droid for pi05_droid, " + "libero for pi05_libero)." +) @test_app.command(name="ws") def test_websocket( url: Annotated[str, typer.Argument(help="WebSocket URL (ws://host:port or wss://host).")], mode: Annotated[InferenceModeChoice | None, typer.Option(help=_MODE_OPTION_HELP)] = None, + embodiment: Annotated[EmbodimentChoice, typer.Option(help=_EMBODIMENT_OPTION_HELP)] = "aloha", ) -> None: """Smoke test against a WebSocket server (EC2, Docker, or Modal ASGI).""" from tests.test_ws import run - run(url, mode=mode) + run(url, mode=mode, embodiment=embodiment) @test_app.command(name="quic") @@ -140,11 +158,12 @@ def test_quic( int, typer.Option(help="Server WebSocket/TCP port (for health check).") ] = 8000, mode: Annotated[InferenceModeChoice | None, typer.Option(help=_MODE_OPTION_HELP)] = None, + embodiment: Annotated[EmbodimentChoice, typer.Option(help=_EMBODIMENT_OPTION_HELP)] = "aloha", ) -> None: """Smoke test against a direct QUIC server (EC2 or Docker).""" from tests.test_quic import run - run(host, quic_port=quic_port, ws_port=ws_port, mode=mode) + run(host, quic_port=quic_port, ws_port=ws_port, mode=mode, embodiment=embodiment) @test_app.command(name="modal-tunnel") diff --git a/src/hosting/prepare_checkpoint.py b/src/hosting/prepare_checkpoint.py index 4f2ae82..9848a99 100644 --- a/src/hosting/prepare_checkpoint.py +++ b/src/hosting/prepare_checkpoint.py @@ -29,13 +29,15 @@ def get_default_output_dir() -> Path: return openpi_download.get_cache_dir() / "pi05_base_openpi" -def _assert_prepared_checkpoint_directory_is_complete(output_dir: Path) -> None: +def _assert_prepared_checkpoint_directory_is_complete( + output_dir: Path, required_asset_id: str +) -> None: missing_paths = [ str(output_dir / checkpoint_filename) for checkpoint_filename in _REQUIRED_CHECKPOINT_FILENAMES if not (output_dir / checkpoint_filename).exists() ] - required_norm_stats_path = output_dir / "assets" / DEFAULT_REQUIRED_ASSET_ID / "norm_stats.json" + required_norm_stats_path = output_dir / "assets" / required_asset_id / "norm_stats.json" if not required_norm_stats_path.exists(): missing_paths.append(str(required_norm_stats_path)) @@ -53,10 +55,11 @@ def prepare_openpi_compatible_checkpoint( openpi_assets_uri: str = DEFAULT_OPENPI_ASSETS_URI, output_dir: Path | None = None, force_download: bool = False, + required_asset_id: str = DEFAULT_REQUIRED_ASSET_ID, ) -> Path: resolved_output_dir = (output_dir or get_default_output_dir()).resolve() if resolved_output_dir.exists() and not force_download: - _assert_prepared_checkpoint_directory_is_complete(resolved_output_dir) + _assert_prepared_checkpoint_directory_is_complete(resolved_output_dir, required_asset_id) print(f"Prepared checkpoint already exists at {resolved_output_dir}") return resolved_output_dir @@ -98,7 +101,7 @@ def prepare_openpi_compatible_checkpoint( encoding="utf-8", ) - _assert_prepared_checkpoint_directory_is_complete(temporary_output_dir) + _assert_prepared_checkpoint_directory_is_complete(temporary_output_dir, required_asset_id) if resolved_output_dir.exists(): shutil.rmtree(resolved_output_dir) diff --git a/src/hosting/warmup.py b/src/hosting/warmup.py index b1443d9..ffec3b0 100644 --- a/src/hosting/warmup.py +++ b/src/hosting/warmup.py @@ -41,7 +41,16 @@ class DroidWarmupObservationSpec: prompt: str = "warmup" -WarmupObservationSpec = AlohaWarmupObservationSpec | DroidWarmupObservationSpec +@dataclass(frozen=True) +class LiberoWarmupObservationSpec: + """Warmup input spec for LIBERO-style embodiments.""" + + prompt: str = "warmup" + + +WarmupObservationSpec = ( + AlohaWarmupObservationSpec | DroidWarmupObservationSpec | LiberoWarmupObservationSpec +) def make_aloha_observation(prompt: str) -> dict[str, Any]: @@ -77,6 +86,24 @@ def make_droid_observation(prompt: str) -> dict[str, Any]: } +def make_libero_observation(prompt: str) -> dict[str, Any]: + """Create a dummy LIBERO observation matching the expected model input shape.""" + return { + "observation/state": np.random.rand(8), + "observation/image": np.random.randint( + 256, + size=(224, 224, 3), + dtype=np.uint8, + ), + "observation/wrist_image": np.random.randint( + 256, + size=(224, 224, 3), + dtype=np.uint8, + ), + "prompt": prompt, + } + + def get_warmup_observation_spec(train_config: Any) -> WarmupObservationSpec: """Derive the correct warmup input spec from the parsed OpenPI train config.""" data_config = getattr(train_config, "data", None) @@ -86,6 +113,8 @@ def get_warmup_observation_spec(train_config: Any) -> WarmupObservationSpec: if data_config_type_name == "LeRobotAlohaDataConfig": return AlohaWarmupObservationSpec() + if data_config_type_name == "LeRobotLiberoDataConfig": + return LiberoWarmupObservationSpec() if data_config_type_name == "SimpleDataConfig" and asset_id == "droid": return DroidWarmupObservationSpec() @@ -93,6 +122,8 @@ def get_warmup_observation_spec(train_config: Any) -> WarmupObservationSpec: return AlohaWarmupObservationSpec() if asset_id == "droid": return DroidWarmupObservationSpec() + if asset_id == "physical-intelligence/libero": + return LiberoWarmupObservationSpec() raise ValueError( "No warmup observation generator is registered for " @@ -109,6 +140,8 @@ def make_warmup_observation(train_config: Any) -> dict[str, Any]: return make_aloha_observation(prompt=prompt) case DroidWarmupObservationSpec(prompt=prompt): return make_droid_observation(prompt=prompt) + case LiberoWarmupObservationSpec(prompt=prompt): + return make_libero_observation(prompt=prompt) # Image specs the server advertises in metadata so the client transport @@ -130,6 +163,18 @@ def make_warmup_observation(train_config: Any) -> dict[str, Any]: dtype="uint8", ), ] +_LIBERO_IMAGE_SPECS: list[ImageSpec] = [ + ImageSpec( + path=["observation/image"], + target_shape=[224, 224, 3], + dtype="uint8", + ), + ImageSpec( + path=["observation/wrist_image"], + target_shape=[224, 224, 3], + dtype="uint8", + ), +] def make_image_specs(train_config: Any) -> list[ImageSpec]: @@ -143,6 +188,8 @@ def make_image_specs(train_config: Any) -> list[ImageSpec]: return _ALOHA_IMAGE_SPECS case DroidWarmupObservationSpec(): return _DROID_IMAGE_SPECS + case LiberoWarmupObservationSpec(): + return _LIBERO_IMAGE_SPECS def get_action_horizon(train_config: Any) -> int | None: diff --git a/tests/helpers.py b/tests/helpers.py index 2512be6..ae8f266 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,15 +1,21 @@ """Shared utilities for smoke test scripts.""" import time -from typing import Final +from typing import Final, Literal import httpx -from hosting.warmup import make_aloha_observation +from hosting.warmup import ( + make_aloha_observation, + make_droid_observation, + make_libero_observation, +) SERVER_READINESS_TIMEOUT_SECONDS: Final[float] = 300.0 SERVER_READINESS_POLL_INTERVAL_SECONDS: Final[float] = 5.0 +EmbodimentChoice = Literal["aloha", "droid", "libero"] + def random_observation_aloha(mode: str | None = None) -> dict: """Generate a random ALOHA observation for smoke testing. @@ -23,6 +29,33 @@ def random_observation_aloha(mode: str | None = None) -> dict: return observation +def random_observation_droid(mode: str | None = None) -> dict: + """Generate a random DROID observation for smoke testing.""" + observation = make_droid_observation(prompt="do something") + if mode is not None: + observation["mode"] = mode + return observation + + +def random_observation_libero(mode: str | None = None) -> dict: + """Generate a random LIBERO observation for smoke testing.""" + observation = make_libero_observation(prompt="do something") + if mode is not None: + observation["mode"] = mode + return observation + + +def random_observation(embodiment: EmbodimentChoice, mode: str | None = None) -> dict: + """Dispatch by embodiment to the matching observation factory.""" + match embodiment: + case "aloha": + return random_observation_aloha(mode=mode) + case "droid": + return random_observation_droid(mode=mode) + case "libero": + return random_observation_libero(mode=mode) + + def wait_for_server( health_url: str, timeout_seconds: float = SERVER_READINESS_TIMEOUT_SECONDS, diff --git a/tests/test_quic.py b/tests/test_quic.py index b8d3062..8d0381e 100644 --- a/tests/test_quic.py +++ b/tests/test_quic.py @@ -3,9 +3,10 @@ Connects directly to a QUIC server and runs benchmark inferences. """ +from openpi_flash_client import FlashTransportPolicy + from hosting.benchmark import run_benchmark -from hosting.flash_transport_policy import FlashTransportPolicy -from tests.helpers import random_observation_aloha, wait_for_server +from tests.helpers import EmbodimentChoice, random_observation, wait_for_server DEFAULT_QUIC_PORT = 5555 DEFAULT_WS_PORT = 8000 @@ -16,6 +17,7 @@ def run( quic_port: int = DEFAULT_QUIC_PORT, ws_port: int = DEFAULT_WS_PORT, mode: str | None = None, + embodiment: EmbodimentChoice = "aloha", ) -> None: health_url = f"http://{host}:{ws_port}/healthz" wait_for_server(health_url) @@ -24,7 +26,7 @@ def run( policy = FlashTransportPolicy(host=host, port=quic_port) print(f"Server metadata: {policy.get_server_metadata()}") - result = run_benchmark(policy, lambda: random_observation_aloha(mode=mode)) + result = run_benchmark(policy, lambda: random_observation(embodiment, mode=mode)) result.print_summary() policy.close() diff --git a/tests/test_ws.py b/tests/test_ws.py index 621cf78..5ca9ea3 100644 --- a/tests/test_ws.py +++ b/tests/test_ws.py @@ -9,10 +9,14 @@ from openpi_client import websocket_client_policy as _websocket_client_policy from hosting.benchmark import InferablePolicy, run_benchmark -from tests.helpers import random_observation_aloha, wait_for_server +from tests.helpers import EmbodimentChoice, random_observation, wait_for_server -def run(url: str, mode: str | None = None) -> None: +def run( + url: str, + mode: str | None = None, + embodiment: EmbodimentChoice = "aloha", +) -> None: parsed_server_url = urllib.parse.urlparse(url) if parsed_server_url.scheme not in {"ws", "wss"} or parsed_server_url.hostname is None: raise ValueError("Expected a WebSocket URL like ws://host[:port] or wss://host[:port].") @@ -35,6 +39,6 @@ def run(url: str, mode: str | None = None) -> None: result = run_benchmark( cast(InferablePolicy, policy), - lambda: random_observation_aloha(mode=mode), + lambda: random_observation(embodiment, mode=mode), ) result.print_summary() From 50076dbc375d7076e59231e6cda5ee5ae3c395ca Mon Sep 17 00:00:00 2001 From: Kingston Date: Wed, 6 May 2026 23:20:41 +0800 Subject: [PATCH 2/4] Tighten libero dispatch surface Drops the dead asset_id fallback block in get_warmup_observation_spec (every registered openpi config matches the class-name tier; the fallback never fired in production), inlines the per-embodiment random_observation_* helpers into the dispatcher, and dedupes the EmbodimentChoice literal so the CLI imports the canonical type from tests.helpers instead of redeclaring it. Co-Authored-By: Claude Opus 4.7 (1M context) --- main.py | 3 ++- src/hosting/warmup.py | 7 ------- tests/helpers.py | 27 +++++++-------------------- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/main.py b/main.py index ab059ad..9bad3da 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,8 @@ import typer +from tests.helpers import EmbodimentChoice + app = typer.Typer( name="openpi-flash", help="Real-time inference engine for openpi. Serves policy models over QUIC and WebSocket.", @@ -124,7 +126,6 @@ def serve( InferenceModeChoice = Literal["default", "action_only", "subtask_only"] -EmbodimentChoice = Literal["aloha", "droid", "libero"] _MODE_OPTION_HELP = ( "Inference mode sent in the observation. Lets benchmarks target one phase " diff --git a/src/hosting/warmup.py b/src/hosting/warmup.py index ffec3b0..b8cdcd7 100644 --- a/src/hosting/warmup.py +++ b/src/hosting/warmup.py @@ -118,13 +118,6 @@ def get_warmup_observation_spec(train_config: Any) -> WarmupObservationSpec: if data_config_type_name == "SimpleDataConfig" and asset_id == "droid": return DroidWarmupObservationSpec() - if asset_id == "trossen": - return AlohaWarmupObservationSpec() - if asset_id == "droid": - return DroidWarmupObservationSpec() - if asset_id == "physical-intelligence/libero": - return LiberoWarmupObservationSpec() - raise ValueError( "No warmup observation generator is registered for " f"config={config_name!r} data_config_type={data_config_type_name!r} asset_id={asset_id!r}." diff --git a/tests/helpers.py b/tests/helpers.py index ae8f266..133d819 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -29,31 +29,18 @@ def random_observation_aloha(mode: str | None = None) -> dict: return observation -def random_observation_droid(mode: str | None = None) -> dict: - """Generate a random DROID observation for smoke testing.""" - observation = make_droid_observation(prompt="do something") - if mode is not None: - observation["mode"] = mode - return observation - - -def random_observation_libero(mode: str | None = None) -> dict: - """Generate a random LIBERO observation for smoke testing.""" - observation = make_libero_observation(prompt="do something") - if mode is not None: - observation["mode"] = mode - return observation - - def random_observation(embodiment: EmbodimentChoice, mode: str | None = None) -> dict: - """Dispatch by embodiment to the matching observation factory.""" + """Generate a random observation for smoke testing the given embodiment.""" match embodiment: case "aloha": - return random_observation_aloha(mode=mode) + observation = make_aloha_observation(prompt="do something") case "droid": - return random_observation_droid(mode=mode) + observation = make_droid_observation(prompt="do something") case "libero": - return random_observation_libero(mode=mode) + observation = make_libero_observation(prompt="do something") + if mode is not None: + observation["mode"] = mode + return observation def wait_for_server( From a98802cd18d340c4de9286f49a425b4f56d98104 Mon Sep 17 00:00:00 2001 From: Kingston Date: Tue, 12 May 2026 16:40:41 +0800 Subject: [PATCH 3/4] Split flash-transport Python into local packages and add uv supply-chain guard - Extract hosting/flash_transport_* into two editable local packages under packages/: openpi-flash-transport (transport + codec, numpy-only) and openpi-flash-client (FlashTransportPolicy + openpi-client adapter). - Reduce src/hosting/.py to thin compatibility re-exports so existing callers continue to import from hosting.* unchanged. - Migrate direct internal consumers (serve.py, local_policy_socket_server, tests/*) to the new package paths. - Harden FlashTransportPolicy: close-on-init-failure, idempotent close(), __del__ finalizer, no hard import of quic_portal at module scope. - Add tool.uv.exclude-newer = "7 days" across the three pyproject files to refuse PyPI releases younger than 7 days; relock under the new constraint. Co-Authored-By: Claude Opus 4.7 (1M context) --- packages/openpi-flash-client/pyproject.toml | 51 ++++ .../src/openpi_flash_client/__init__.py | 5 + .../flash_transport_policy.py | 205 ++++++++++++++++ .../openpi-flash-transport/pyproject.toml | 46 ++++ .../src/openpi_flash_transport/__init__.py | 30 +++ .../flash_transport_binary.py | 112 +++++++++ .../src/openpi_flash_transport/local_frame.py | 227 +++++++++++++++++ .../local_transport_protocol.py | 64 +++++ pyproject.toml | 7 + src/hosting/flash_transport_binary.py | 128 ++-------- src/hosting/flash_transport_policy.py | 172 +------------ src/hosting/local_frame.py | 228 +----------------- src/hosting/local_policy_socket_server.py | 7 +- src/hosting/local_transport_protocol.py | 79 ++---- src/hosting/serve.py | 12 +- tests/test_arrow_wire.py | 3 +- tests/test_flash_transport_cli_drift.py | 5 +- tests/test_local_frame.py | 3 +- uv.lock | 71 +++--- 19 files changed, 843 insertions(+), 612 deletions(-) create mode 100644 packages/openpi-flash-client/pyproject.toml create mode 100644 packages/openpi-flash-client/src/openpi_flash_client/__init__.py create mode 100644 packages/openpi-flash-client/src/openpi_flash_client/flash_transport_policy.py create mode 100644 packages/openpi-flash-transport/pyproject.toml create mode 100644 packages/openpi-flash-transport/src/openpi_flash_transport/__init__.py create mode 100644 packages/openpi-flash-transport/src/openpi_flash_transport/flash_transport_binary.py create mode 100644 packages/openpi-flash-transport/src/openpi_flash_transport/local_frame.py create mode 100644 packages/openpi-flash-transport/src/openpi_flash_transport/local_transport_protocol.py diff --git a/packages/openpi-flash-client/pyproject.toml b/packages/openpi-flash-client/pyproject.toml new file mode 100644 index 0000000..0aaee21 --- /dev/null +++ b/packages/openpi-flash-client/pyproject.toml @@ -0,0 +1,51 @@ +[project] +name = "openpi-flash-client" +version = "0.1.0" +description = "Client policy adapter for openpi-flash transport servers." +requires-python = ">=3.11,<3.13" +dependencies = [ + "openpi-client", + "openpi-flash-transport", +] + +[tool.uv] +# Supply-chain hardening: refuse PyPI releases younger than 7 days. +# See: https://docs.astral.sh/uv/concepts/resolution/#exclude-newer +exclude-newer = "7 days" + +[tool.uv.sources] +openpi-client = { path = "../../../openpi/packages/openpi-client", editable = true } +openpi-flash-transport = { path = "../openpi-flash-transport", editable = true } + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/openpi_flash_client"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "B", + "C4", + "UP", + "ANN", + "PTH", + "RET", + "SIM", + "TID", + "RUF", +] +ignore = ["E501", "B008", "C901", "RET504", "ANN401"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/packages/openpi-flash-client/src/openpi_flash_client/__init__.py b/packages/openpi-flash-client/src/openpi_flash_client/__init__.py new file mode 100644 index 0000000..c159a74 --- /dev/null +++ b/packages/openpi-flash-client/src/openpi_flash_client/__init__.py @@ -0,0 +1,5 @@ +"""Client policy adapters for openpi-flash.""" + +from openpi_flash_client.flash_transport_policy import FlashTransportPolicy + +__all__ = ["FlashTransportPolicy"] diff --git a/packages/openpi-flash-client/src/openpi_flash_client/flash_transport_policy.py b/packages/openpi-flash-client/src/openpi_flash_client/flash_transport_policy.py new file mode 100644 index 0000000..487b0cf --- /dev/null +++ b/packages/openpi-flash-client/src/openpi_flash_client/flash_transport_policy.py @@ -0,0 +1,205 @@ +"""Client policy backed by a local ``openpi-flash-transport`` subprocess. + +This preserves the normal Python ``BasePolicy`` interface used by openpi +clients while moving QUIC transport, Arrow IPC codec, image preprocessing, +action chunking, and server-timing instrumentation into the transport +binary. + +The QUIC path speaks Arrow IPC Streaming Format on the wire (the transport +binary owns the codec translation); ``openpi-client``'s WebSocket path is +the supported pure-Python msgpack alternative for customers who don't want +the transport binary as a dependency. +""" + +from __future__ import annotations + +import contextlib +import pathlib +import socket +import subprocess +import time +import uuid + +from openpi_client import base_policy as _base_policy +from openpi_client import msgpack_numpy +from openpi_flash_transport.flash_transport_binary import ( + BINARY_NAME, + ClientArgs, + resolve_binary_path, +) +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame +from openpi_flash_transport.local_transport_protocol import ( + TransportRequestType, + TransportResponseType, + recv_framed_message, + send_framed_message, +) +from typing_extensions import override + +DEFAULT_TRANSPORT_STARTUP_TIMEOUT_SECONDS = 30.0 +DEFAULT_TRANSPORT_POLL_INTERVAL_SECONDS = 0.1 + +# Unix sockets must fit in sun_path (104 bytes on macOS, 108 on Linux), so we +# can't use tempfile.gettempdir() here — macOS's default $TMPDIR is a long +# /var/folders/... path that overflows once the UUID filename is appended. +_UNIX_SOCKET_DIR = pathlib.Path("/tmp") + + +class FlashTransportPolicy(_base_policy.BasePolicy): + """Connects to a direct QUIC server through a local ``openpi-flash-transport`` subprocess.""" + + def __init__( + self, + host: str, + port: int = 5555, + local_port: int = 5556, + transport_options: object | None = None, + ) -> None: + if transport_options is not None: + raise ValueError(f"Custom transport_options are not supported by {BINARY_NAME} yet") + + self._closed = False + self._socket_path = _UNIX_SOCKET_DIR / f"{BINARY_NAME}-client-{uuid.uuid4().hex}.sock" + self._transport_process: subprocess.Popen[str] | None = None + self._transport_socket: socket.socket | None = None + try: + self._transport_process = self._spawn_transport_process( + host=host, + port=port, + local_port=local_port, + socket_path=self._socket_path, + ) + self._transport_socket = self._connect_to_transport_socket(self._socket_path) + self._server_metadata = self._request_metadata() + except BaseException: + self.close() + raise + + def _spawn_transport_process( + self, + *, + host: str, + port: int, + local_port: int, + socket_path: pathlib.Path, + ) -> subprocess.Popen[str]: + binary_path = resolve_binary_path() + args = ClientArgs( + server_host=host, + local_socket_path=socket_path, + server_port=port, + local_port=local_port, + ) + command = [str(binary_path), *args.to_argv()] + return subprocess.Popen(command, text=True) + + def _connect_to_transport_socket(self, socket_path: pathlib.Path) -> socket.socket: + transport_process = self._require_transport_process() + wait_deadline = time.monotonic() + DEFAULT_TRANSPORT_STARTUP_TIMEOUT_SECONDS + while time.monotonic() < wait_deadline: + if transport_process.poll() is not None: + raise RuntimeError( + f"{BINARY_NAME} client exited before opening its local socket " + f"(exit_code={transport_process.returncode})" + ) + + if socket_path.exists(): + transport_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + transport_socket.connect(str(socket_path)) + return transport_socket + except OSError: + transport_socket.close() + + time.sleep(DEFAULT_TRANSPORT_POLL_INTERVAL_SECONDS) + + raise TimeoutError(f"Timed out waiting for {BINARY_NAME} socket at {socket_path}") + + def _request(self, request_type: TransportRequestType, payload: bytes = b"") -> bytes: + transport_socket = self._require_transport_socket() + send_framed_message(transport_socket, bytes([request_type]) + payload) + framed_response = recv_framed_message(transport_socket) + if framed_response is None: + raise ConnectionError(f"{BINARY_NAME} disconnected unexpectedly") + if not framed_response: + raise RuntimeError(f"Received empty response from {BINARY_NAME}") + + response_type = TransportResponseType(framed_response[0]) + response_body = framed_response[1:] + + if response_type == TransportResponseType.ERROR: + raise RuntimeError(f"Error from {BINARY_NAME}:\n{response_body.decode('utf-8')}") + if ( + request_type == TransportRequestType.METADATA + and response_type != TransportResponseType.METADATA + ): + raise RuntimeError(f"Unexpected metadata response type: {response_type!r}") + if ( + request_type == TransportRequestType.INFER + and response_type != TransportResponseType.INFER + ): + raise RuntimeError(f"Unexpected inference response type: {response_type!r}") + if ( + request_type == TransportRequestType.RESET + and response_type != TransportResponseType.RESET + ): + raise RuntimeError(f"Unexpected reset response type: {response_type!r}") + + return response_body + + def _request_metadata(self) -> dict: + # Metadata stays msgpack_numpy end to end — it's openpi's blob, + # forwarded verbatim through the handshake. + return msgpack_numpy.unpackb(self._request(TransportRequestType.METADATA)) + + def get_server_metadata(self) -> dict: + return self._server_metadata + + @override + def infer(self, obs: dict) -> dict: + frame = pack_local_frame(obs) + response_body = self._request(TransportRequestType.INFER, frame) + return unpack_local_frame(response_body) + + @override + def reset(self) -> None: + self._request(TransportRequestType.RESET) + + def _require_transport_process(self) -> subprocess.Popen[str]: + if self._transport_process is None: + raise RuntimeError(f"{BINARY_NAME} transport process has not been started") + return self._transport_process + + def _require_transport_socket(self) -> socket.socket: + if self._transport_socket is None: + raise RuntimeError(f"{BINARY_NAME} transport socket has not been connected") + return self._transport_socket + + def close(self) -> None: + """Close the local socket and stop the subprocess.""" + if self._closed: + return + self._closed = True + + transport_socket = self._transport_socket + self._transport_socket = None + if transport_socket is not None: + with contextlib.suppress(OSError): + transport_socket.close() + + transport_process = self._transport_process + self._transport_process = None + if transport_process is not None and transport_process.poll() is None: + transport_process.terminate() + try: + transport_process.wait(timeout=5) + except subprocess.TimeoutExpired: + transport_process.kill() + transport_process.wait(timeout=5) + + with contextlib.suppress(FileNotFoundError): + self._socket_path.unlink() + + def __del__(self) -> None: + with contextlib.suppress(Exception): + self.close() diff --git a/packages/openpi-flash-transport/pyproject.toml b/packages/openpi-flash-transport/pyproject.toml new file mode 100644 index 0000000..88a1b94 --- /dev/null +++ b/packages/openpi-flash-transport/pyproject.toml @@ -0,0 +1,46 @@ +[project] +name = "openpi-flash-transport" +version = "0.1.0" +description = "Shared local transport protocol helpers for openpi-flash." +requires-python = ">=3.11,<3.13" +dependencies = [ + "numpy>=1.22.4", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["src/openpi_flash_transport"] + +[tool.uv] +# Supply-chain hardening: refuse PyPI releases younger than 7 days. +# See: https://docs.astral.sh/uv/concepts/resolution/#exclude-newer +exclude-newer = "7 days" + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = [ + "E", + "W", + "F", + "I", + "B", + "C4", + "UP", + "ANN", + "PTH", + "RET", + "SIM", + "TID", + "RUF", +] +ignore = ["E501", "B008", "C901", "RET504", "ANN401"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" diff --git a/packages/openpi-flash-transport/src/openpi_flash_transport/__init__.py b/packages/openpi-flash-transport/src/openpi_flash_transport/__init__.py new file mode 100644 index 0000000..eb0ffe9 --- /dev/null +++ b/packages/openpi-flash-transport/src/openpi_flash_transport/__init__.py @@ -0,0 +1,30 @@ +"""Shared transport helpers for openpi-flash clients and servers.""" + +from openpi_flash_transport.flash_transport_binary import ( + BINARY_NAME, + ENV_OVERRIDE, + ClientArgs, + ServerArgs, + resolve_binary_path, +) +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame +from openpi_flash_transport.local_transport_protocol import ( + TransportRequestType, + TransportResponseType, + recv_framed_message, + send_framed_message, +) + +__all__ = [ + "BINARY_NAME", + "ENV_OVERRIDE", + "ClientArgs", + "ServerArgs", + "TransportRequestType", + "TransportResponseType", + "pack_local_frame", + "recv_framed_message", + "resolve_binary_path", + "send_framed_message", + "unpack_local_frame", +] diff --git a/packages/openpi-flash-transport/src/openpi_flash_transport/flash_transport_binary.py b/packages/openpi-flash-transport/src/openpi_flash_transport/flash_transport_binary.py new file mode 100644 index 0000000..4feaf23 --- /dev/null +++ b/packages/openpi-flash-transport/src/openpi_flash_transport/flash_transport_binary.py @@ -0,0 +1,112 @@ +"""Locate and invoke the openpi-flash-transport binary. + +Used by both the backend (``hosting.serve``) and the client policy +(``openpi_flash_client.flash_transport_policy``). Resolution order for the +binary: + +1. ``OPENPI_FLASH_TRANSPORT_BINARY`` env var override. +2. The standard Docker install path at ``/usr/local/bin/``. +3. Local cargo build output (``flash-transport/target/{debug,release}/...``) + — useful for the client-side developer loop where it isn't installed + globally. + +Also holds ``ServerArgs`` / ``ClientArgs`` — typed mirrors of the Rust +``clap`` structs in ``flash-transport/src/main.rs``. Python callers +construct one of these dataclasses instead of hand-building argv strings, +so a Rust flag rename becomes a type error on the Python side. +""" + +from __future__ import annotations + +import os +import pathlib +from dataclasses import dataclass, fields +from typing import Any + +BINARY_NAME = "openpi-flash-transport" +DEFAULT_BINARY_PATH = pathlib.Path(f"/usr/local/bin/{BINARY_NAME}") +ENV_OVERRIDE = "OPENPI_FLASH_TRANSPORT_BINARY" + +# Defaults shared by both subcommands. Kept in sync with +# ``flash-transport/src/main.rs`` (see ``ServerArgs`` / ``ClientArgs``). +_DEFAULT_MAX_IDLE_TIMEOUT_SECS = 10 +_DEFAULT_KEEP_ALIVE_INTERVAL_SECS = 2 +_DEFAULT_INITIAL_WINDOW_BYTES = 1024 * 1024 +_DEFAULT_QUIC_PORT = 5555 +_DEFAULT_LOCAL_CLIENT_PORT = 5556 + + +def _hosting_repo_root() -> pathlib.Path: + module_path = pathlib.Path(__file__).resolve() + for parent_path in module_path.parents: + if (parent_path / "flash-transport").is_dir(): + return parent_path + return module_path.parents[4] + + +def _iter_binary_candidates() -> list[pathlib.Path]: + candidates: list[pathlib.Path] = [] + if configured := os.environ.get(ENV_OVERRIDE): + candidates.append(pathlib.Path(configured)) + candidates.append(DEFAULT_BINARY_PATH) + repo_root = _hosting_repo_root() + candidates.append(repo_root / "flash-transport" / "target" / "debug" / BINARY_NAME) + candidates.append(repo_root / "flash-transport" / "target" / "release" / BINARY_NAME) + return candidates + + +def resolve_binary_path() -> pathlib.Path: + """Return the first existing candidate path, or raise ``FileNotFoundError``.""" + for candidate in _iter_binary_candidates(): + if candidate.exists(): + return candidate + + searched = "\n".join(f" - {candidate}" for candidate in _iter_binary_candidates()) + raise FileNotFoundError( + f"{BINARY_NAME} binary not found. Searched:\n" + f"{searched}\n" + f"Set {ENV_OVERRIDE} to override the path." + ) + + +def _args_to_argv(subcommand: str, args: Any) -> list[str]: + """Turn a dataclass of CLI args into a ``clap``-compatible argv list. + + Each field becomes ``--kebab-case-name value``. Fields are emitted in + declaration order. + """ + argv: list[str] = [subcommand] + for field in fields(args): + flag = "--" + field.name.replace("_", "-") + argv.extend([flag, str(getattr(args, field.name))]) + return argv + + +@dataclass(frozen=True) +class ServerArgs: + """Typed mirror of ``openpi-flash-transport server`` CLI flags.""" + + backend_socket_path: pathlib.Path + listen_port: int = _DEFAULT_QUIC_PORT + max_idle_timeout_secs: int = _DEFAULT_MAX_IDLE_TIMEOUT_SECS + keep_alive_interval_secs: int = _DEFAULT_KEEP_ALIVE_INTERVAL_SECS + initial_window_bytes: int = _DEFAULT_INITIAL_WINDOW_BYTES + + def to_argv(self) -> list[str]: + return _args_to_argv("server", self) + + +@dataclass(frozen=True) +class ClientArgs: + """Typed mirror of ``openpi-flash-transport client`` CLI flags.""" + + server_host: str + local_socket_path: pathlib.Path + server_port: int = _DEFAULT_QUIC_PORT + local_port: int = _DEFAULT_LOCAL_CLIENT_PORT + max_idle_timeout_secs: int = _DEFAULT_MAX_IDLE_TIMEOUT_SECS + keep_alive_interval_secs: int = _DEFAULT_KEEP_ALIVE_INTERVAL_SECS + initial_window_bytes: int = _DEFAULT_INITIAL_WINDOW_BYTES + + def to_argv(self) -> list[str]: + return _args_to_argv("client", self) diff --git a/packages/openpi-flash-transport/src/openpi_flash_transport/local_frame.py b/packages/openpi-flash-transport/src/openpi_flash_transport/local_frame.py new file mode 100644 index 0000000..b4f9552 --- /dev/null +++ b/packages/openpi-flash-transport/src/openpi_flash_transport/local_frame.py @@ -0,0 +1,227 @@ +"""Python codec for the LocalFrame binary format used over the local Unix socket. + +See ``docs/arrow-wire.md`` for the wire format. Provides the thin writer +and reader used on the Python side; the Rust transport translates these +frames to/from Arrow IPC Streaming Format for the QUIC wire. Mirrors +``flash-transport/src/local_format.rs``. + +The format intentionally avoids any serialization framework so encoding is +roughly ``ndarray.tobytes()`` + a handful of ``struct.pack`` calls. On the +decode side, ``np.frombuffer`` is used so tensor data is a view over the +received bytes rather than a fresh copy. +""" + +from __future__ import annotations + +import json +import struct +from collections.abc import Iterator +from typing import Any, Final + +import numpy as np + +# Mirrors the Rust `DtypeCode` enum in `flash-transport/src/local_format.rs`. +DTYPE_CODE_UINT8: Final[int] = 0x01 +DTYPE_CODE_INT8: Final[int] = 0x02 +DTYPE_CODE_UINT16: Final[int] = 0x03 +DTYPE_CODE_INT16: Final[int] = 0x04 +DTYPE_CODE_UINT32: Final[int] = 0x05 +DTYPE_CODE_INT32: Final[int] = 0x06 +DTYPE_CODE_UINT64: Final[int] = 0x07 +DTYPE_CODE_INT64: Final[int] = 0x08 +DTYPE_CODE_FLOAT16: Final[int] = 0x09 +DTYPE_CODE_FLOAT32: Final[int] = 0x0A +DTYPE_CODE_FLOAT64: Final[int] = 0x0B +DTYPE_CODE_BOOL: Final[int] = 0x0C + +_DTYPE_TO_CODE: Final[dict[np.dtype, int]] = { + np.dtype(np.uint8): DTYPE_CODE_UINT8, + np.dtype(np.int8): DTYPE_CODE_INT8, + np.dtype(np.uint16): DTYPE_CODE_UINT16, + np.dtype(np.int16): DTYPE_CODE_INT16, + np.dtype(np.uint32): DTYPE_CODE_UINT32, + np.dtype(np.int32): DTYPE_CODE_INT32, + np.dtype(np.uint64): DTYPE_CODE_UINT64, + np.dtype(np.int64): DTYPE_CODE_INT64, + np.dtype(np.float16): DTYPE_CODE_FLOAT16, + np.dtype(np.float32): DTYPE_CODE_FLOAT32, + np.dtype(np.float64): DTYPE_CODE_FLOAT64, + np.dtype(np.bool_): DTYPE_CODE_BOOL, +} + +_CODE_TO_DTYPE: Final[dict[int, np.dtype]] = {code: dtype for dtype, code in _DTYPE_TO_CODE.items()} + + +def pack_local_frame(payload: dict[str, Any], *, schema_id: str = "unknown") -> bytes: + """Serialize ``payload`` into the local frame binary format. + + Nested dicts are supported. Numpy arrays become array entries keyed by + their dict path; any other value goes into the scalar JSON trailer. + """ + arrays: list[tuple[list[str], np.ndarray]] = [] + scalars: dict[str, Any] = {} + + for path, value in _walk_payload(payload): + if isinstance(value, np.ndarray): + arrays.append((path, value)) + elif isinstance(value, (np.integer, np.floating, np.bool_)): + _insert_scalar(scalars, path, value.item()) + else: + _insert_scalar(scalars, path, value) + + schema_id_bytes = schema_id.encode("utf-8") + if len(schema_id_bytes) > 255: + raise ValueError(f"schema_id too long: {len(schema_id_bytes)} bytes (max 255)") + if len(arrays) > 0xFFFF: + raise ValueError(f"too many arrays: {len(arrays)} (max 65535)") + + parts: list[bytes] = [bytes([len(schema_id_bytes)]), schema_id_bytes] + parts.append(struct.pack(">H", len(arrays))) + + for path, array in arrays: + parts.append(_encode_array_entry(path, array)) + + scalar_json = json.dumps(scalars, separators=(",", ":"), ensure_ascii=False).encode("utf-8") + if len(scalar_json) > 0xFFFFFFFF: + raise ValueError(f"scalar_json too long: {len(scalar_json)} bytes") + parts.append(struct.pack(">I", len(scalar_json))) + parts.append(scalar_json) + + return b"".join(parts) + + +def unpack_local_frame(frame: bytes) -> dict[str, Any]: + """Deserialize a local frame into the nested dict the caller sent. + + Numpy arrays are reconstructed via ``np.frombuffer``, so the returned + arrays are views over ``frame``. Callers must either consume the arrays + before ``frame`` is freed or copy them with ``.copy()``. + """ + frame_memoryview = memoryview(frame) + offset = 0 + + schema_id_len = frame_memoryview[offset] + offset += 1 + _ = frame_memoryview[offset : offset + schema_id_len].tobytes().decode("utf-8") + offset += schema_id_len + + (num_arrays,) = struct.unpack_from(">H", frame_memoryview, offset) + offset += 2 + + result: dict[str, Any] = {} + for _ in range(num_arrays): + path_depth = frame_memoryview[offset] + offset += 1 + if path_depth == 0: + raise ValueError("Array path_depth must be >= 1") + path = [] + for _ in range(path_depth): + component_len = frame_memoryview[offset] + offset += 1 + component = frame_memoryview[offset : offset + component_len].tobytes().decode("utf-8") + offset += component_len + path.append(component) + + dtype_code = frame_memoryview[offset] + offset += 1 + if dtype_code not in _CODE_TO_DTYPE: + raise ValueError(f"Unknown dtype code: 0x{dtype_code:02x}") + dtype = _CODE_TO_DTYPE[dtype_code] + + ndim = frame_memoryview[offset] + offset += 1 + shape = struct.unpack_from(f">{ndim}I", frame_memoryview, offset) + offset += 4 * ndim + + (data_len,) = struct.unpack_from(">Q", frame_memoryview, offset) + offset += 8 + array = np.frombuffer(frame, dtype=dtype, count=int(np.prod(shape)), offset=offset).reshape( + shape + ) + offset += data_len + + _insert_array(result, path, array) + + (scalar_json_len,) = struct.unpack_from(">I", frame_memoryview, offset) + offset += 4 + scalar_json = frame_memoryview[offset : offset + scalar_json_len].tobytes() + offset += scalar_json_len + if offset != len(frame): + raise ValueError(f"Trailing {len(frame) - offset} bytes in local frame") + + scalars = json.loads(scalar_json.decode("utf-8")) if scalar_json else {} + _merge_scalars(result, scalars) + return result + + +def _encode_array_entry(path: list[str], array: np.ndarray) -> bytes: + if len(path) == 0 or len(path) > 255: + raise ValueError(f"path must have 1..255 components, got {len(path)}") + if array.dtype not in _DTYPE_TO_CODE: + raise ValueError(f"Unsupported numpy dtype: {array.dtype}") + if array.ndim > 255: + raise ValueError(f"array ndim too large: {array.ndim}") + for dim in array.shape: + if dim < 0 or dim > 0xFFFFFFFF: + raise ValueError(f"array shape dimension out of u32 range: {dim}") + + array_contiguous = np.ascontiguousarray(array) + data_bytes = array_contiguous.tobytes() + + parts: list[bytes] = [bytes([len(path)])] + for component in path: + component_bytes = component.encode("utf-8") + if len(component_bytes) > 255: + raise ValueError(f"path component too long: {component!r}") + parts.append(bytes([len(component_bytes)])) + parts.append(component_bytes) + + parts.append(bytes([_DTYPE_TO_CODE[array_contiguous.dtype]])) + parts.append(bytes([array_contiguous.ndim])) + parts.append(struct.pack(f">{array_contiguous.ndim}I", *array_contiguous.shape)) + parts.append(struct.pack(">Q", len(data_bytes))) + parts.append(data_bytes) + + return b"".join(parts) + + +def _walk_payload( + payload: dict[str, Any], path: list[str] | None = None +) -> Iterator[tuple[list[str], Any]]: + path = path or [] + for key, value in payload.items(): + if not isinstance(key, str): + raise TypeError(f"Observation dict keys must be str, got {type(key).__name__}") + new_path = [*path, key] + if isinstance(value, dict): + yield from _walk_payload(value, new_path) + else: + yield new_path, value + + +def _insert_scalar(scalars: dict[str, Any], path: list[str], value: Any) -> None: + cursor = scalars + for component in path[:-1]: + cursor = cursor.setdefault(component, {}) + if not isinstance(cursor, dict): + raise ValueError(f"Scalar path {path} collides with an existing non-dict value") + cursor[path[-1]] = value + + +def _insert_array(result: dict[str, Any], path: list[str], array: np.ndarray) -> None: + cursor = result + for component in path[:-1]: + cursor = cursor.setdefault(component, {}) + if not isinstance(cursor, dict): + raise ValueError(f"Array path {path} collides with an existing non-dict value") + cursor[path[-1]] = array + + +def _merge_scalars(result: dict[str, Any], scalars: dict[str, Any]) -> None: + for key, value in scalars.items(): + if isinstance(value, dict) and isinstance(result.get(key), dict): + _merge_scalars(result[key], value) + elif key in result and not isinstance(value, dict): + raise ValueError(f"Scalar key {key!r} collides with an existing array") + else: + result[key] = value diff --git a/packages/openpi-flash-transport/src/openpi_flash_transport/local_transport_protocol.py b/packages/openpi-flash-transport/src/openpi_flash_transport/local_transport_protocol.py new file mode 100644 index 0000000..cce0727 --- /dev/null +++ b/packages/openpi-flash-transport/src/openpi_flash_transport/local_transport_protocol.py @@ -0,0 +1,64 @@ +"""Shared framed protocol helpers for the local transport Unix socket. + +Both the server-side and client-side transport processes speak the same +length-prefixed protocol over a Unix domain socket to their Python peers. +This module holds the shared message type constants and framing helpers +for the Python side. +""" + +from __future__ import annotations + +import socket +import struct +from enum import IntEnum + + +class TransportRequestType(IntEnum): + """Request message types sent to the transport over the local socket.""" + + METADATA = 0x01 + INFER = 0x02 + RESET = 0x03 + + +class TransportResponseType(IntEnum): + """Response message types returned by the transport over the local socket.""" + + METADATA = 0x11 + INFER = 0x12 + ERROR = 0x13 + RESET = 0x14 + + +def recv_exactly(stream_socket: socket.socket, num_bytes: int) -> bytes | None: + """Read exactly ``num_bytes`` from a stream socket or return ``None`` on EOF.""" + received_chunks = bytearray() + while len(received_chunks) < num_bytes: + chunk = stream_socket.recv(num_bytes - len(received_chunks)) + if not chunk: + return None + received_chunks.extend(chunk) + return bytes(received_chunks) + + +def recv_framed_message(stream_socket: socket.socket) -> bytes | None: + """Receive one length-prefixed message from a stream socket.""" + raw_length_prefix = recv_exactly(stream_socket, 4) + if raw_length_prefix is None: + return None + + message_length = struct.unpack(">I", raw_length_prefix)[0] + if message_length == 0: + return b"" + + payload = recv_exactly(stream_socket, message_length) + if payload is None: + raise ConnectionError("Unexpected EOF while reading framed transport message") + return payload + + +def send_framed_message(stream_socket: socket.socket, payload: bytes) -> None: + """Send one length-prefixed message over a stream socket.""" + stream_socket.sendall(struct.pack(">I", len(payload))) + if payload: + stream_socket.sendall(payload) diff --git a/pyproject.toml b/pyproject.toml index f44ffc2..2ec09ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,10 +22,14 @@ dependencies = [ "typer>=0.24.1", "google-genai>=1.55.0", "openai>=1.50.0", + "openpi-flash-transport", "tenacity>=9.1.4", ] [tool.uv] +# Supply-chain hardening: refuse PyPI releases younger than 7 days. +# See: https://docs.astral.sh/uv/concepts/resolution/#exclude-newer +exclude-newer = "7 days" override-dependencies = [ "jax[cuda12]==0.5.3; sys_platform == 'linux'", "jax==0.5.3; sys_platform != 'linux'", @@ -34,6 +38,8 @@ override-dependencies = [ [tool.uv.sources] openpi = { path = "../openpi", editable = true } openpi-client = { path = "../openpi/packages/openpi-client", editable = true } +openpi-flash-client = { path = "packages/openpi-flash-client", editable = true } +openpi-flash-transport = { path = "packages/openpi-flash-transport", editable = true } quic-portal = { git = "https://github.com/Hebbian-Robotics/quic-portal.git" } [build-system] @@ -42,6 +48,7 @@ build-backend = "hatchling.build" [dependency-groups] dev = [ + "openpi-flash-client", "ruff>=0.15.11", # "ty>=0.0.32", "pytest>=9.0.3", diff --git a/src/hosting/flash_transport_binary.py b/src/hosting/flash_transport_binary.py index 17a5a71..1fe9fac 100644 --- a/src/hosting/flash_transport_binary.py +++ b/src/hosting/flash_transport_binary.py @@ -1,107 +1,21 @@ -"""Locate and invoke the openpi-flash-transport binary. - -Used by both the backend (``serve.py``) and the client policy -(``flash_transport_policy.py``). Resolution order for the binary: - -1. ``OPENPI_FLASH_TRANSPORT_BINARY`` env var override. -2. The standard Docker install path at ``/usr/local/bin/``. -3. Local cargo build output (``flash-transport/target/{debug,release}/...``) - — useful for the client-side developer loop where it isn't installed - globally. - -Also holds ``ServerArgs`` / ``ClientArgs`` — typed mirrors of the Rust -``clap`` structs in ``flash-transport/src/main.rs``. Python callers -construct one of these dataclasses instead of hand-building argv strings, -so a Rust flag rename becomes a type error on the Python side. -""" - -from __future__ import annotations - -import os -import pathlib -from dataclasses import dataclass, fields -from typing import Any - -BINARY_NAME = "openpi-flash-transport" -DEFAULT_BINARY_PATH = pathlib.Path(f"/usr/local/bin/{BINARY_NAME}") -ENV_OVERRIDE = "OPENPI_FLASH_TRANSPORT_BINARY" - -# Defaults shared by both subcommands. Kept in sync with -# ``flash-transport/src/main.rs`` (see ``ServerArgs`` / ``ClientArgs``). -_DEFAULT_MAX_IDLE_TIMEOUT_SECS = 10 -_DEFAULT_KEEP_ALIVE_INTERVAL_SECS = 2 -_DEFAULT_INITIAL_WINDOW_BYTES = 1024 * 1024 -_DEFAULT_QUIC_PORT = 5555 -_DEFAULT_LOCAL_CLIENT_PORT = 5556 - - -def _hosting_repo_root() -> pathlib.Path: - return pathlib.Path(__file__).resolve().parents[2] - - -def _iter_binary_candidates() -> list[pathlib.Path]: - candidates: list[pathlib.Path] = [] - if configured := os.environ.get(ENV_OVERRIDE): - candidates.append(pathlib.Path(configured)) - candidates.append(DEFAULT_BINARY_PATH) - repo_root = _hosting_repo_root() - candidates.append(repo_root / "flash-transport" / "target" / "debug" / BINARY_NAME) - candidates.append(repo_root / "flash-transport" / "target" / "release" / BINARY_NAME) - return candidates - - -def resolve_binary_path() -> pathlib.Path: - """Return the first existing candidate path, or raise ``FileNotFoundError``.""" - for candidate in _iter_binary_candidates(): - if candidate.exists(): - return candidate - - searched = "\n".join(f" - {candidate}" for candidate in _iter_binary_candidates()) - raise FileNotFoundError( - f"{BINARY_NAME} binary not found. Searched:\n" - f"{searched}\n" - f"Set {ENV_OVERRIDE} to override the path." - ) - - -def _args_to_argv(subcommand: str, args: Any) -> list[str]: - """Turn a dataclass of CLI args into a ``clap``-compatible argv list. - - Each field becomes ``--kebab-case-name value``. Fields are emitted in - declaration order. - """ - argv: list[str] = [subcommand] - for field in fields(args): - flag = "--" + field.name.replace("_", "-") - argv.extend([flag, str(getattr(args, field.name))]) - return argv - - -@dataclass(frozen=True) -class ServerArgs: - """Typed mirror of ``openpi-flash-transport server`` CLI flags.""" - - backend_socket_path: pathlib.Path - listen_port: int = _DEFAULT_QUIC_PORT - max_idle_timeout_secs: int = _DEFAULT_MAX_IDLE_TIMEOUT_SECS - keep_alive_interval_secs: int = _DEFAULT_KEEP_ALIVE_INTERVAL_SECS - initial_window_bytes: int = _DEFAULT_INITIAL_WINDOW_BYTES - - def to_argv(self) -> list[str]: - return _args_to_argv("server", self) - - -@dataclass(frozen=True) -class ClientArgs: - """Typed mirror of ``openpi-flash-transport client`` CLI flags.""" - - server_host: str - local_socket_path: pathlib.Path - server_port: int = _DEFAULT_QUIC_PORT - local_port: int = _DEFAULT_LOCAL_CLIENT_PORT - max_idle_timeout_secs: int = _DEFAULT_MAX_IDLE_TIMEOUT_SECS - keep_alive_interval_secs: int = _DEFAULT_KEEP_ALIVE_INTERVAL_SECS - initial_window_bytes: int = _DEFAULT_INITIAL_WINDOW_BYTES - - def to_argv(self) -> list[str]: - return _args_to_argv("client", self) +"""Compatibility exports for the split openpi-flash transport package.""" + +from openpi_flash_transport.flash_transport_binary import ( + BINARY_NAME, + DEFAULT_BINARY_PATH, + ENV_OVERRIDE, + ClientArgs, + ServerArgs, + _iter_binary_candidates, + resolve_binary_path, +) + +__all__ = [ + "BINARY_NAME", + "DEFAULT_BINARY_PATH", + "ENV_OVERRIDE", + "ClientArgs", + "ServerArgs", + "_iter_binary_candidates", + "resolve_binary_path", +] diff --git a/src/hosting/flash_transport_policy.py b/src/hosting/flash_transport_policy.py index 34b6dfd..b906924 100644 --- a/src/hosting/flash_transport_policy.py +++ b/src/hosting/flash_transport_policy.py @@ -1,171 +1,5 @@ -"""Client policy backed by a local ``openpi-flash-transport`` subprocess. +"""Compatibility exports for the split openpi-flash client package.""" -This preserves the normal Python ``BasePolicy`` interface used by openpi -clients while moving QUIC transport, Arrow IPC codec, image preprocessing, -action chunking, and server-timing instrumentation into the transport -binary. +from openpi_flash_client.flash_transport_policy import FlashTransportPolicy -The QUIC path speaks Arrow IPC Streaming Format on the wire (the transport -binary owns the codec translation); ``openpi-client``'s WebSocket path is -the supported pure-Python msgpack alternative for customers who don't want -the transport binary as a dependency. -""" - -from __future__ import annotations - -import contextlib -import pathlib -import socket -import subprocess -import time -import uuid - -from openpi_client import base_policy as _base_policy -from openpi_client import msgpack_numpy -from quic_portal import QuicTransportOptions -from typing_extensions import override - -from hosting.flash_transport_binary import BINARY_NAME, ClientArgs, resolve_binary_path -from hosting.local_frame import pack_local_frame, unpack_local_frame -from hosting.local_transport_protocol import ( - TransportRequestType, - TransportResponseType, - recv_framed_message, - send_framed_message, -) - -DEFAULT_TRANSPORT_STARTUP_TIMEOUT_SECONDS = 30.0 -DEFAULT_TRANSPORT_POLL_INTERVAL_SECONDS = 0.1 - -# Unix sockets must fit in sun_path (104 bytes on macOS, 108 on Linux), so we -# can't use tempfile.gettempdir() here — macOS's default $TMPDIR is a long -# /var/folders/... path that overflows once the UUID filename is appended. -_UNIX_SOCKET_DIR = pathlib.Path("/tmp") - - -class FlashTransportPolicy(_base_policy.BasePolicy): - """Connects to a direct QUIC server through a local ``openpi-flash-transport`` subprocess.""" - - def __init__( - self, - host: str, - port: int = 5555, - local_port: int = 5556, - transport_options: QuicTransportOptions | None = None, - ) -> None: - if transport_options is not None: - raise ValueError(f"Custom transport_options are not supported by {BINARY_NAME} yet") - - self._socket_path = _UNIX_SOCKET_DIR / f"{BINARY_NAME}-client-{uuid.uuid4().hex}.sock" - self._transport_process = self._spawn_transport_process( - host=host, - port=port, - local_port=local_port, - socket_path=self._socket_path, - ) - self._transport_socket = self._connect_to_transport_socket(self._socket_path) - self._server_metadata = self._request_metadata() - - def _spawn_transport_process( - self, - *, - host: str, - port: int, - local_port: int, - socket_path: pathlib.Path, - ) -> subprocess.Popen[str]: - binary_path = resolve_binary_path() - args = ClientArgs( - server_host=host, - local_socket_path=socket_path, - server_port=port, - local_port=local_port, - ) - command = [str(binary_path), *args.to_argv()] - return subprocess.Popen(command, text=True) - - def _connect_to_transport_socket(self, socket_path: pathlib.Path) -> socket.socket: - wait_deadline = time.monotonic() + DEFAULT_TRANSPORT_STARTUP_TIMEOUT_SECONDS - while time.monotonic() < wait_deadline: - if self._transport_process.poll() is not None: - raise RuntimeError( - f"{BINARY_NAME} client exited before opening its local socket " - f"(exit_code={self._transport_process.returncode})" - ) - - if socket_path.exists(): - transport_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - try: - transport_socket.connect(str(socket_path)) - return transport_socket - except OSError: - transport_socket.close() - - time.sleep(DEFAULT_TRANSPORT_POLL_INTERVAL_SECONDS) - - raise TimeoutError(f"Timed out waiting for {BINARY_NAME} socket at {socket_path}") - - def _request(self, request_type: TransportRequestType, payload: bytes = b"") -> bytes: - send_framed_message(self._transport_socket, bytes([request_type]) + payload) - framed_response = recv_framed_message(self._transport_socket) - if framed_response is None: - raise ConnectionError(f"{BINARY_NAME} disconnected unexpectedly") - if not framed_response: - raise RuntimeError(f"Received empty response from {BINARY_NAME}") - - response_type = TransportResponseType(framed_response[0]) - response_body = framed_response[1:] - - if response_type == TransportResponseType.ERROR: - raise RuntimeError(f"Error from {BINARY_NAME}:\n{response_body.decode('utf-8')}") - if ( - request_type == TransportRequestType.METADATA - and response_type != TransportResponseType.METADATA - ): - raise RuntimeError(f"Unexpected metadata response type: {response_type!r}") - if ( - request_type == TransportRequestType.INFER - and response_type != TransportResponseType.INFER - ): - raise RuntimeError(f"Unexpected inference response type: {response_type!r}") - if ( - request_type == TransportRequestType.RESET - and response_type != TransportResponseType.RESET - ): - raise RuntimeError(f"Unexpected reset response type: {response_type!r}") - - return response_body - - def _request_metadata(self) -> dict: - # Metadata stays msgpack_numpy end to end — it's openpi's blob, - # forwarded verbatim through the handshake. - return msgpack_numpy.unpackb(self._request(TransportRequestType.METADATA)) - - def get_server_metadata(self) -> dict: - return self._server_metadata - - @override - def infer(self, obs: dict) -> dict: - frame = pack_local_frame(obs) - response_body = self._request(TransportRequestType.INFER, frame) - return unpack_local_frame(response_body) - - @override - def reset(self) -> None: - self._request(TransportRequestType.RESET) - - def close(self) -> None: - """Close the local socket and stop the subprocess.""" - with contextlib.suppress(OSError): - self._transport_socket.close() - - if self._transport_process.poll() is None: - self._transport_process.terminate() - try: - self._transport_process.wait(timeout=5) - except subprocess.TimeoutExpired: - self._transport_process.kill() - self._transport_process.wait(timeout=5) - - with contextlib.suppress(FileNotFoundError): - self._socket_path.unlink() +__all__ = ["FlashTransportPolicy"] diff --git a/src/hosting/local_frame.py b/src/hosting/local_frame.py index b4f9552..dfae3ed 100644 --- a/src/hosting/local_frame.py +++ b/src/hosting/local_frame.py @@ -1,227 +1,5 @@ -"""Python codec for the LocalFrame binary format used over the local Unix socket. +"""Compatibility exports for the split openpi-flash transport package.""" -See ``docs/arrow-wire.md`` for the wire format. Provides the thin writer -and reader used on the Python side; the Rust transport translates these -frames to/from Arrow IPC Streaming Format for the QUIC wire. Mirrors -``flash-transport/src/local_format.rs``. +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame -The format intentionally avoids any serialization framework so encoding is -roughly ``ndarray.tobytes()`` + a handful of ``struct.pack`` calls. On the -decode side, ``np.frombuffer`` is used so tensor data is a view over the -received bytes rather than a fresh copy. -""" - -from __future__ import annotations - -import json -import struct -from collections.abc import Iterator -from typing import Any, Final - -import numpy as np - -# Mirrors the Rust `DtypeCode` enum in `flash-transport/src/local_format.rs`. -DTYPE_CODE_UINT8: Final[int] = 0x01 -DTYPE_CODE_INT8: Final[int] = 0x02 -DTYPE_CODE_UINT16: Final[int] = 0x03 -DTYPE_CODE_INT16: Final[int] = 0x04 -DTYPE_CODE_UINT32: Final[int] = 0x05 -DTYPE_CODE_INT32: Final[int] = 0x06 -DTYPE_CODE_UINT64: Final[int] = 0x07 -DTYPE_CODE_INT64: Final[int] = 0x08 -DTYPE_CODE_FLOAT16: Final[int] = 0x09 -DTYPE_CODE_FLOAT32: Final[int] = 0x0A -DTYPE_CODE_FLOAT64: Final[int] = 0x0B -DTYPE_CODE_BOOL: Final[int] = 0x0C - -_DTYPE_TO_CODE: Final[dict[np.dtype, int]] = { - np.dtype(np.uint8): DTYPE_CODE_UINT8, - np.dtype(np.int8): DTYPE_CODE_INT8, - np.dtype(np.uint16): DTYPE_CODE_UINT16, - np.dtype(np.int16): DTYPE_CODE_INT16, - np.dtype(np.uint32): DTYPE_CODE_UINT32, - np.dtype(np.int32): DTYPE_CODE_INT32, - np.dtype(np.uint64): DTYPE_CODE_UINT64, - np.dtype(np.int64): DTYPE_CODE_INT64, - np.dtype(np.float16): DTYPE_CODE_FLOAT16, - np.dtype(np.float32): DTYPE_CODE_FLOAT32, - np.dtype(np.float64): DTYPE_CODE_FLOAT64, - np.dtype(np.bool_): DTYPE_CODE_BOOL, -} - -_CODE_TO_DTYPE: Final[dict[int, np.dtype]] = {code: dtype for dtype, code in _DTYPE_TO_CODE.items()} - - -def pack_local_frame(payload: dict[str, Any], *, schema_id: str = "unknown") -> bytes: - """Serialize ``payload`` into the local frame binary format. - - Nested dicts are supported. Numpy arrays become array entries keyed by - their dict path; any other value goes into the scalar JSON trailer. - """ - arrays: list[tuple[list[str], np.ndarray]] = [] - scalars: dict[str, Any] = {} - - for path, value in _walk_payload(payload): - if isinstance(value, np.ndarray): - arrays.append((path, value)) - elif isinstance(value, (np.integer, np.floating, np.bool_)): - _insert_scalar(scalars, path, value.item()) - else: - _insert_scalar(scalars, path, value) - - schema_id_bytes = schema_id.encode("utf-8") - if len(schema_id_bytes) > 255: - raise ValueError(f"schema_id too long: {len(schema_id_bytes)} bytes (max 255)") - if len(arrays) > 0xFFFF: - raise ValueError(f"too many arrays: {len(arrays)} (max 65535)") - - parts: list[bytes] = [bytes([len(schema_id_bytes)]), schema_id_bytes] - parts.append(struct.pack(">H", len(arrays))) - - for path, array in arrays: - parts.append(_encode_array_entry(path, array)) - - scalar_json = json.dumps(scalars, separators=(",", ":"), ensure_ascii=False).encode("utf-8") - if len(scalar_json) > 0xFFFFFFFF: - raise ValueError(f"scalar_json too long: {len(scalar_json)} bytes") - parts.append(struct.pack(">I", len(scalar_json))) - parts.append(scalar_json) - - return b"".join(parts) - - -def unpack_local_frame(frame: bytes) -> dict[str, Any]: - """Deserialize a local frame into the nested dict the caller sent. - - Numpy arrays are reconstructed via ``np.frombuffer``, so the returned - arrays are views over ``frame``. Callers must either consume the arrays - before ``frame`` is freed or copy them with ``.copy()``. - """ - frame_memoryview = memoryview(frame) - offset = 0 - - schema_id_len = frame_memoryview[offset] - offset += 1 - _ = frame_memoryview[offset : offset + schema_id_len].tobytes().decode("utf-8") - offset += schema_id_len - - (num_arrays,) = struct.unpack_from(">H", frame_memoryview, offset) - offset += 2 - - result: dict[str, Any] = {} - for _ in range(num_arrays): - path_depth = frame_memoryview[offset] - offset += 1 - if path_depth == 0: - raise ValueError("Array path_depth must be >= 1") - path = [] - for _ in range(path_depth): - component_len = frame_memoryview[offset] - offset += 1 - component = frame_memoryview[offset : offset + component_len].tobytes().decode("utf-8") - offset += component_len - path.append(component) - - dtype_code = frame_memoryview[offset] - offset += 1 - if dtype_code not in _CODE_TO_DTYPE: - raise ValueError(f"Unknown dtype code: 0x{dtype_code:02x}") - dtype = _CODE_TO_DTYPE[dtype_code] - - ndim = frame_memoryview[offset] - offset += 1 - shape = struct.unpack_from(f">{ndim}I", frame_memoryview, offset) - offset += 4 * ndim - - (data_len,) = struct.unpack_from(">Q", frame_memoryview, offset) - offset += 8 - array = np.frombuffer(frame, dtype=dtype, count=int(np.prod(shape)), offset=offset).reshape( - shape - ) - offset += data_len - - _insert_array(result, path, array) - - (scalar_json_len,) = struct.unpack_from(">I", frame_memoryview, offset) - offset += 4 - scalar_json = frame_memoryview[offset : offset + scalar_json_len].tobytes() - offset += scalar_json_len - if offset != len(frame): - raise ValueError(f"Trailing {len(frame) - offset} bytes in local frame") - - scalars = json.loads(scalar_json.decode("utf-8")) if scalar_json else {} - _merge_scalars(result, scalars) - return result - - -def _encode_array_entry(path: list[str], array: np.ndarray) -> bytes: - if len(path) == 0 or len(path) > 255: - raise ValueError(f"path must have 1..255 components, got {len(path)}") - if array.dtype not in _DTYPE_TO_CODE: - raise ValueError(f"Unsupported numpy dtype: {array.dtype}") - if array.ndim > 255: - raise ValueError(f"array ndim too large: {array.ndim}") - for dim in array.shape: - if dim < 0 or dim > 0xFFFFFFFF: - raise ValueError(f"array shape dimension out of u32 range: {dim}") - - array_contiguous = np.ascontiguousarray(array) - data_bytes = array_contiguous.tobytes() - - parts: list[bytes] = [bytes([len(path)])] - for component in path: - component_bytes = component.encode("utf-8") - if len(component_bytes) > 255: - raise ValueError(f"path component too long: {component!r}") - parts.append(bytes([len(component_bytes)])) - parts.append(component_bytes) - - parts.append(bytes([_DTYPE_TO_CODE[array_contiguous.dtype]])) - parts.append(bytes([array_contiguous.ndim])) - parts.append(struct.pack(f">{array_contiguous.ndim}I", *array_contiguous.shape)) - parts.append(struct.pack(">Q", len(data_bytes))) - parts.append(data_bytes) - - return b"".join(parts) - - -def _walk_payload( - payload: dict[str, Any], path: list[str] | None = None -) -> Iterator[tuple[list[str], Any]]: - path = path or [] - for key, value in payload.items(): - if not isinstance(key, str): - raise TypeError(f"Observation dict keys must be str, got {type(key).__name__}") - new_path = [*path, key] - if isinstance(value, dict): - yield from _walk_payload(value, new_path) - else: - yield new_path, value - - -def _insert_scalar(scalars: dict[str, Any], path: list[str], value: Any) -> None: - cursor = scalars - for component in path[:-1]: - cursor = cursor.setdefault(component, {}) - if not isinstance(cursor, dict): - raise ValueError(f"Scalar path {path} collides with an existing non-dict value") - cursor[path[-1]] = value - - -def _insert_array(result: dict[str, Any], path: list[str], array: np.ndarray) -> None: - cursor = result - for component in path[:-1]: - cursor = cursor.setdefault(component, {}) - if not isinstance(cursor, dict): - raise ValueError(f"Array path {path} collides with an existing non-dict value") - cursor[path[-1]] = array - - -def _merge_scalars(result: dict[str, Any], scalars: dict[str, Any]) -> None: - for key, value in scalars.items(): - if isinstance(value, dict) and isinstance(result.get(key), dict): - _merge_scalars(result[key], value) - elif key in result and not isinstance(value, dict): - raise ValueError(f"Scalar key {key!r} collides with an existing array") - else: - result[key] = value +__all__ = ["pack_local_frame", "unpack_local_frame"] diff --git a/src/hosting/local_policy_socket_server.py b/src/hosting/local_policy_socket_server.py index e0c8361..b929a79 100644 --- a/src/hosting/local_policy_socket_server.py +++ b/src/hosting/local_policy_socket_server.py @@ -16,10 +16,9 @@ from openpi_client import base_policy as _base_policy from openpi_client import msgpack_numpy - -from hosting.flash_transport_binary import BINARY_NAME -from hosting.local_frame import pack_local_frame, unpack_local_frame -from hosting.local_transport_protocol import ( +from openpi_flash_transport.flash_transport_binary import BINARY_NAME +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame +from openpi_flash_transport.local_transport_protocol import ( TransportRequestType, TransportResponseType, recv_framed_message, diff --git a/src/hosting/local_transport_protocol.py b/src/hosting/local_transport_protocol.py index cce0727..1eef706 100644 --- a/src/hosting/local_transport_protocol.py +++ b/src/hosting/local_transport_protocol.py @@ -1,64 +1,15 @@ -"""Shared framed protocol helpers for the local transport Unix socket. - -Both the server-side and client-side transport processes speak the same -length-prefixed protocol over a Unix domain socket to their Python peers. -This module holds the shared message type constants and framing helpers -for the Python side. -""" - -from __future__ import annotations - -import socket -import struct -from enum import IntEnum - - -class TransportRequestType(IntEnum): - """Request message types sent to the transport over the local socket.""" - - METADATA = 0x01 - INFER = 0x02 - RESET = 0x03 - - -class TransportResponseType(IntEnum): - """Response message types returned by the transport over the local socket.""" - - METADATA = 0x11 - INFER = 0x12 - ERROR = 0x13 - RESET = 0x14 - - -def recv_exactly(stream_socket: socket.socket, num_bytes: int) -> bytes | None: - """Read exactly ``num_bytes`` from a stream socket or return ``None`` on EOF.""" - received_chunks = bytearray() - while len(received_chunks) < num_bytes: - chunk = stream_socket.recv(num_bytes - len(received_chunks)) - if not chunk: - return None - received_chunks.extend(chunk) - return bytes(received_chunks) - - -def recv_framed_message(stream_socket: socket.socket) -> bytes | None: - """Receive one length-prefixed message from a stream socket.""" - raw_length_prefix = recv_exactly(stream_socket, 4) - if raw_length_prefix is None: - return None - - message_length = struct.unpack(">I", raw_length_prefix)[0] - if message_length == 0: - return b"" - - payload = recv_exactly(stream_socket, message_length) - if payload is None: - raise ConnectionError("Unexpected EOF while reading framed transport message") - return payload - - -def send_framed_message(stream_socket: socket.socket, payload: bytes) -> None: - """Send one length-prefixed message over a stream socket.""" - stream_socket.sendall(struct.pack(">I", len(payload))) - if payload: - stream_socket.sendall(payload) +"""Compatibility exports for the split openpi-flash transport package.""" + +from openpi_flash_transport.local_transport_protocol import ( + TransportRequestType, + TransportResponseType, + recv_framed_message, + send_framed_message, +) + +__all__ = [ + "TransportRequestType", + "TransportResponseType", + "recv_framed_message", + "send_framed_message", +] diff --git a/src/hosting/serve.py b/src/hosting/serve.py index 86d2c12..e810a6d 100644 --- a/src/hosting/serve.py +++ b/src/hosting/serve.py @@ -42,6 +42,12 @@ from openpi.shared import download as _download from openpi.training import config as _config from openpi_client import base_policy as _base_policy +from openpi_flash_transport.flash_transport_binary import ( + BINARY_NAME, + ENV_OVERRIDE, + ServerArgs, + resolve_binary_path, +) from hosting.admin_server import RuntimeConfig, start_admin_server from hosting.compile_mode import get_serving_pytorch_compile_mode @@ -52,12 +58,6 @@ SlotTransportConfig, load_config, ) -from hosting.flash_transport_binary import ( - BINARY_NAME, - ENV_OVERRIDE, - ServerArgs, - resolve_binary_path, -) from hosting.local_policy_socket_server import LocalPolicySocketServer from hosting.warmup import get_action_horizon, make_image_specs, make_warmup_observation diff --git a/tests/test_arrow_wire.py b/tests/test_arrow_wire.py index 742d975..cb257f4 100644 --- a/tests/test_arrow_wire.py +++ b/tests/test_arrow_wire.py @@ -19,8 +19,7 @@ import numpy as np import pytest - -from hosting.local_frame import pack_local_frame, unpack_local_frame +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame def _hosting_repo_root() -> pathlib.Path: diff --git a/tests/test_flash_transport_cli_drift.py b/tests/test_flash_transport_cli_drift.py index e1f2361..72a2992 100644 --- a/tests/test_flash_transport_cli_drift.py +++ b/tests/test_flash_transport_cli_drift.py @@ -3,7 +3,7 @@ Spawns the ``openpi-flash-transport`` binary with `` --help``, parses the ``--flag-name`` tokens out of the help text, and asserts the -Python dataclasses in ``hosting.flash_transport_binary`` have matching +Python dataclasses in ``openpi_flash_transport.flash_transport_binary`` have matching field names. Skipped when the binary isn't built (e.g. fresh checkout without cargo). @@ -18,8 +18,7 @@ from typing import Any import pytest - -from hosting.flash_transport_binary import ( +from openpi_flash_transport.flash_transport_binary import ( ClientArgs, ServerArgs, _iter_binary_candidates, diff --git a/tests/test_local_frame.py b/tests/test_local_frame.py index ea096a2..7370822 100644 --- a/tests/test_local_frame.py +++ b/tests/test_local_frame.py @@ -4,8 +4,7 @@ import numpy as np import pytest - -from hosting.local_frame import pack_local_frame, unpack_local_frame +from openpi_flash_transport.local_frame import pack_local_frame, unpack_local_frame def test_round_trip_droid_observation() -> None: diff --git a/uv.lock b/uv.lock index 72c0d5e..1f6100e 100644 --- a/uv.lock +++ b/uv.lock @@ -9,6 +9,10 @@ resolution-markers = [ "sys_platform == 'emscripten'", ] +[options] +exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer-span = "P7D" + [manifest] overrides = [ { name = "jax", marker = "sys_platform != 'linux'", specifier = "==0.5.3" }, @@ -524,7 +528,7 @@ wheels = [ [[package]] name = "dm-tree" -version = "0.1.10" +version = "0.1.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, @@ -532,12 +536,12 @@ dependencies = [ { name = "numpy" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5a/66/a3ec619d22b6baffa5ab853e8dc6ec9d0c837127948af59bb15b988d7312/dm_tree-0.1.10.tar.gz", hash = "sha256:22f37b599e01cc3402a17f79c257a802aebd8d326de05b54657650845956208a", size = 35748, upload-time = "2026-03-31T17:35:39.03Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a6/83/ce29720ccf934c6cfa9b9c95ebbe96558386e66886626066632b5e44afed/dm_tree-0.1.9.tar.gz", hash = "sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b", size = 35623, upload-time = "2025-01-30T20:45:37.13Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/dc/b01d0f70cde99b306731216a98287ba5926a50f27222f2ada0b99ad0911f/dm_tree-0.1.10-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8218af7b99701bb8b03001c82961dc2cf81d7a734958206d2ea1ede8fbbe2b5f", size = 314603, upload-time = "2026-03-31T17:35:10.052Z" }, - { url = "https://files.pythonhosted.org/packages/40/72/3bafa58492862360113c1cccb26747c7863d417271e1572bacb3c281162f/dm_tree-0.1.10-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cacef6180fcfef30bab2cac5164e753e2f7a2e60e5da0feb81f2d318416f8d98", size = 182657, upload-time = "2026-03-31T17:35:11.462Z" }, - { url = "https://files.pythonhosted.org/packages/78/10/587a2cdc05995069aa63b659d884eb3e58a3c86a5b4a00acdb7a316bddf3/dm_tree-0.1.10-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0e8907bb6809dc195be3af077e382126eaebe06c00f835d09ae26e36d2165ff", size = 185008, upload-time = "2026-03-31T17:35:12.838Z" }, - { url = "https://files.pythonhosted.org/packages/60/0e/08d938d84cbf791dde009b3d3a6637f27a0004235e700641a0ac038daac5/dm_tree-0.1.10-cp311-cp311-win_amd64.whl", hash = "sha256:a1c82dd4726a16ac6b6f7a77a5fb097ee396fd349ae301407eb5736f15b8fa16", size = 111472, upload-time = "2026-03-31T17:35:14.035Z" }, + { url = "https://files.pythonhosted.org/packages/ac/b6/2d2de9f8901ccc5b6f34aea678e732816853015b9d756c86efcec189bf4b/dm_tree-0.1.9-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004", size = 173561, upload-time = "2025-03-31T08:35:40.042Z" }, + { url = "https://files.pythonhosted.org/packages/3e/07/57459f32cf5683c25b596ab58f42a3305f91876c2f03d2fa6e9d0df75fcb/dm_tree-0.1.9-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c", size = 146926, upload-time = "2025-01-30T20:45:20.622Z" }, + { url = "https://files.pythonhosted.org/packages/e8/46/939fbf81177c7cb3b1e5ddebd696237b3be9520769cce882f064de497103/dm_tree-0.1.9-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7", size = 152851, upload-time = "2025-01-30T20:45:23.032Z" }, + { url = "https://files.pythonhosted.org/packages/35/3e/a46933e0157b0ac87619a754ce1a796b2afc6386fca7c11f95c010f40745/dm_tree-0.1.9-cp311-cp311-win_amd64.whl", hash = "sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf", size = 101522, upload-time = "2025-01-30T20:45:24.433Z" }, ] [[package]] @@ -2314,7 +2318,6 @@ dependencies = [ { name = "msgpack" }, { name = "numpy" }, { name = "pillow" }, - { name = "tree" }, { name = "websockets" }, ] @@ -2324,7 +2327,6 @@ requires-dist = [ { name = "msgpack", specifier = ">=1.0.5" }, { name = "numpy", specifier = ">=1.22.4,<2.0.0" }, { name = "pillow", specifier = ">=9.0.0" }, - { name = "tree", specifier = ">=0.2.4" }, { name = "websockets", specifier = ">=11.0" }, ] @@ -2345,6 +2347,7 @@ dependencies = [ { name = "openai" }, { name = "openpi" }, { name = "openpi-client" }, + { name = "openpi-flash-transport" }, { name = "pydantic" }, { name = "pydantic-settings" }, { name = "python-dotenv" }, @@ -2357,6 +2360,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "openpi-flash-client" }, { name = "pytest" }, { name = "ruff" }, { name = "ty" }, @@ -2373,6 +2377,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.50.0" }, { name = "openpi", editable = "../openpi" }, { name = "openpi-client", editable = "../openpi/packages/openpi-client" }, + { name = "openpi-flash-transport", editable = "packages/openpi-flash-transport" }, { name = "pydantic", specifier = ">=2.13.3" }, { name = "pydantic-settings", specifier = ">=2.14.0" }, { name = "python-dotenv", specifier = ">=1.2.2" }, @@ -2385,11 +2390,38 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "openpi-flash-client", editable = "packages/openpi-flash-client" }, { name = "pytest", specifier = ">=9.0.3" }, { name = "ruff", specifier = ">=0.15.11" }, { name = "ty", specifier = ">=0.0.32" }, ] +[[package]] +name = "openpi-flash-client" +version = "0.1.0" +source = { editable = "packages/openpi-flash-client" } +dependencies = [ + { name = "openpi-client" }, + { name = "openpi-flash-transport" }, +] + +[package.metadata] +requires-dist = [ + { name = "openpi-client", editable = "../openpi/packages/openpi-client" }, + { name = "openpi-flash-transport", editable = "packages/openpi-flash-transport" }, +] + +[[package]] +name = "openpi-flash-transport" +version = "0.1.0" +source = { editable = "packages/openpi-flash-transport" } +dependencies = [ + { name = "numpy" }, +] + +[package.metadata] +requires-dist = [{ name = "numpy", specifier = ">=1.22.4" }] + [[package]] name = "opt-einsum" version = "3.4.0" @@ -3319,15 +3351,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] -[[package]] -name = "svgwrite" -version = "1.4.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/16/c1/263d4e93b543390d86d8eb4fc23d9ce8a8d6efd146f9427364109004fa9b/svgwrite-1.4.3.zip", hash = "sha256:a8fbdfd4443302a6619a7f76bc937fc683daf2628d9b737c891ec08b8ce524c3", size = 189516, upload-time = "2022-07-14T14:05:26.107Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/84/15/640e399579024a6875918839454025bb1d5f850bb70d96a11eabb644d11c/svgwrite-1.4.3-py3-none-any.whl", hash = "sha256:bb6b2b5450f1edbfa597d924f9ac2dd099e625562e492021d7dd614f65f8a22d", size = 67122, upload-time = "2022-07-14T14:05:24.459Z" }, -] - [[package]] name = "sympy" version = "1.14.0" @@ -3535,18 +3558,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/96/88/beb33a79a382fcd2aed0be5222bdc47f41e4bfe7aaa90ae1374f1d8ea2af/transformers-4.53.2-py3-none-any.whl", hash = "sha256:db8f4819bb34f000029c73c3c557e7d06fc1b8e612ec142eecdae3947a9c78bf", size = 10826609, upload-time = "2025-07-11T12:39:05.461Z" }, ] -[[package]] -name = "tree" -version = "0.2.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click" }, - { name = "pillow" }, - { name = "setuptools" }, - { name = "svgwrite" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/29/3f/63cbed2909786f0e5ac30a4ae5791ad597c6b5fec7167e161c55bba511ce/Tree-0.2.4.tar.gz", hash = "sha256:f84d8ec9bf50dd69f551da78925a23d110864e7706551f590cdade27646f7883", size = 6489, upload-time = "2018-07-03T20:49:29.918Z" } - [[package]] name = "treescope" version = "0.1.10" @@ -3564,7 +3575,7 @@ name = "triton" version = "3.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform == 'linux'" }, + { name = "setuptools", marker = "platform_machine != 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/21/2f/3e56ea7b58f80ff68899b1dbe810ff257c9d177d288c6b0f55bf2fe4eb50/triton-3.3.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b31e3aa26f8cb3cc5bf4e187bf737cbacf17311e1112b781d4a059353dfd731b", size = 155689937, upload-time = "2025-05-29T23:39:44.182Z" }, From 01a3c3b89a18403b58ba4f5cb58eb80339243c0d Mon Sep 17 00:00:00 2001 From: Kingston Date: Tue, 12 May 2026 16:49:02 +0800 Subject: [PATCH 4/4] Wire flash-transport packages into Dockerfile The base stage's uv sync was bind-mounting only hosting/uv.lock, pyproject.toml, and the sibling openpi paths. After the split into packages/openpi-flash-transport and packages/openpi-flash-client, uv sync --frozen could no longer locate those editable sources and the CI base image build failed with: Distribution not found at: file:///build/hosting/packages/openpi-flash-transport Add bind mounts for each new package's pyproject.toml and src/ so the lockfile resolves, then COPY their src/ into /app at runtime and append the paths to PYTHONPATH so imports work past the bind-mount layer. Co-Authored-By: Claude Opus 4.7 (1M context) --- Dockerfile | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index cac2fa5..38f9745 100644 --- a/Dockerfile +++ b/Dockerfile @@ -53,6 +53,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ --mount=type=bind,source=openpi/src,target=/build/openpi/src \ --mount=type=bind,source=openpi/packages/openpi-client/pyproject.toml,target=/build/openpi/packages/openpi-client/pyproject.toml \ --mount=type=bind,source=openpi/packages/openpi-client/src,target=/build/openpi/packages/openpi-client/src \ + --mount=type=bind,source=hosting/packages/openpi-flash-transport/pyproject.toml,target=/build/hosting/packages/openpi-flash-transport/pyproject.toml \ + --mount=type=bind,source=hosting/packages/openpi-flash-transport/src,target=/build/hosting/packages/openpi-flash-transport/src \ + --mount=type=bind,source=hosting/packages/openpi-flash-client/pyproject.toml,target=/build/hosting/packages/openpi-flash-client/pyproject.toml \ + --mount=type=bind,source=hosting/packages/openpi-flash-client/src,target=/build/hosting/packages/openpi-flash-client/src \ GIT_LFS_SKIP_SMUDGE=1 uv sync --project /build/hosting --frozen --no-install-project --no-dev # Copy transformers_replace files (required for PyTorch models). @@ -63,9 +67,11 @@ RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | x # Copy application code. COPY openpi/src /app/openpi-src COPY openpi/packages/openpi-client/src /app/openpi-client-src +COPY hosting/packages/openpi-flash-transport/src /app/openpi-flash-transport-src +COPY hosting/packages/openpi-flash-client/src /app/openpi-flash-client-src COPY hosting/src /app/hosting-src COPY hosting/main.py /app/main.py -ENV PYTHONPATH="/app/openpi-src:/app/openpi-client-src:/app/hosting-src" +ENV PYTHONPATH="/app/openpi-src:/app/openpi-client-src:/app/openpi-flash-transport-src:/app/openpi-flash-client-src:/app/hosting-src" # PyTorch inductor cache — persists within container lifetime (use a volume # mount at /cache for persistence across restarts).