Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,source=openpi/src,target=/build/openpi/src \
--mount=type=bind,source=openpi/packages/openpi-client/pyproject.toml,target=/build/openpi/packages/openpi-client/pyproject.toml \
--mount=type=bind,source=openpi/packages/openpi-client/src,target=/build/openpi/packages/openpi-client/src \
--mount=type=bind,source=hosting/packages/openpi-flash-transport/pyproject.toml,target=/build/hosting/packages/openpi-flash-transport/pyproject.toml \
--mount=type=bind,source=hosting/packages/openpi-flash-transport/src,target=/build/hosting/packages/openpi-flash-transport/src \
--mount=type=bind,source=hosting/packages/openpi-flash-client/pyproject.toml,target=/build/hosting/packages/openpi-flash-client/pyproject.toml \
--mount=type=bind,source=hosting/packages/openpi-flash-client/src,target=/build/hosting/packages/openpi-flash-client/src \
GIT_LFS_SKIP_SMUDGE=1 uv sync --project /build/hosting --frozen --no-install-project --no-dev

# Copy transformers_replace files (required for PyTorch models).
Expand All @@ -63,9 +67,11 @@ RUN /.venv/bin/python -c "import transformers; print(transformers.__file__)" | x
# Copy application code.
COPY openpi/src /app/openpi-src
COPY openpi/packages/openpi-client/src /app/openpi-client-src
COPY hosting/packages/openpi-flash-transport/src /app/openpi-flash-transport-src
COPY hosting/packages/openpi-flash-client/src /app/openpi-flash-client-src
COPY hosting/src /app/hosting-src
COPY hosting/main.py /app/main.py
ENV PYTHONPATH="/app/openpi-src:/app/openpi-client-src:/app/hosting-src"
ENV PYTHONPATH="/app/openpi-src:/app/openpi-client-src:/app/openpi-flash-transport-src:/app/openpi-flash-client-src:/app/hosting-src"

# PyTorch inductor cache — persists within container lifetime (use a volume
# mount at /cache for persistence across restarts).
Expand Down
1 change: 1 addition & 0 deletions infra/modules/regional_inference_instance/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ resource "aws_instance" "inference" {
checkpoint_prep_model_id = var.checkpoint_prep_model_id
checkpoint_prep_openpi_assets_uri = var.checkpoint_prep_openpi_assets_uri
checkpoint_prep_output_dir = var.checkpoint_prep_output_dir
checkpoint_prep_required_asset_id = var.checkpoint_prep_required_asset_id
container_name = var.container_name
ecr_region = var.ecr_region
ecr_registry_host = local.ecr_registry_host
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ write_files:
python main.py prepare-checkpoint \
--model-id ${checkpoint_prep_model_id} \
--openpi-assets-uri ${checkpoint_prep_openpi_assets_uri} \
--output-dir ${checkpoint_prep_output_dir}
--output-dir ${checkpoint_prep_output_dir} \
--required-asset-id ${checkpoint_prep_required_asset_id}
%{ endif ~}
%{ if prepare_planner_checkpoint ~}

Expand Down
6 changes: 6 additions & 0 deletions infra/modules/regional_inference_instance/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,12 @@ variable "checkpoint_prep_output_dir" {
default = "/cache/models/pi05_base_openpi"
}

variable "checkpoint_prep_required_asset_id" {
description = "Asset id whose norm_stats.json the Docker checkpoint preparation step must produce. Use 'trossen' for ALOHA, 'droid' for pi05_droid, 'physical-intelligence/libero' for pi05_libero."
type = string
default = "trossen"
}

# -- Planner checkpoint prep (planner slot) -----------------------------------
#
# Optional one-shot step that downloads + extracts a JAX subtask planner tar
Expand Down
1 change: 1 addition & 0 deletions infra/regional-instance/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ module "regional_inference_instance" {
checkpoint_prep_model_id = var.checkpoint_prep_model_id
checkpoint_prep_openpi_assets_uri = var.checkpoint_prep_openpi_assets_uri
checkpoint_prep_output_dir = var.checkpoint_prep_output_dir
checkpoint_prep_required_asset_id = var.checkpoint_prep_required_asset_id
container_name = var.container_name
deployment_name = var.deployment_name
docker_image_tag = var.docker_image_tag
Expand Down
6 changes: 6 additions & 0 deletions infra/regional-instance/variables.tf
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ variable "checkpoint_prep_output_dir" {
default = "/cache/models/pi05_base_openpi"
}

variable "checkpoint_prep_required_asset_id" {
description = "Asset id whose norm_stats.json the Docker checkpoint preparation step must produce. Use 'trossen' for ALOHA, 'droid' for pi05_droid, 'physical-intelligence/libero' for pi05_libero."
type = string
default = "trossen"
}

variable "openpi_pytorch_compile_mode" {
description = "Value for OPENPI_PYTORCH_COMPILE_MODE inside the inference container"
type = string
Expand Down
24 changes: 22 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import typer

from tests.helpers import EmbodimentChoice

app = typer.Typer(
name="openpi-flash",
help="Real-time inference engine for openpi. Serves policy models over QUIC and WebSocket.",
Expand Down Expand Up @@ -43,6 +45,16 @@ def prepare_checkpoint(
bool,
typer.Option(help="Re-download upstream files and rebuild the prepared checkpoint."),
] = False,
required_asset_id: Annotated[
str,
typer.Option(
help=(
"Asset id whose norm_stats.json must exist in the prepared "
"checkpoint's assets/ directory. Defaults to 'trossen' (ALOHA). "
"Use 'droid' for pi05_droid or 'physical-intelligence/libero' for pi05_libero."
),
),
] = "trossen",
) -> None:
"""Prepare a local OpenPI-compatible checkpoint from upstream sources."""
from hosting.prepare_checkpoint import prepare_openpi_compatible_checkpoint
Expand All @@ -52,6 +64,7 @@ def prepare_checkpoint(
openpi_assets_uri=openpi_assets_uri,
output_dir=output_dir,
force_download=force_download,
required_asset_id=required_asset_id,
)


Expand Down Expand Up @@ -119,17 +132,23 @@ def serve(
"of a combined-mode server: 'action_only' skips the planner, 'subtask_only' "
"skips the action policy. Omit for the server default."
)
_EMBODIMENT_OPTION_HELP = (
"Observation shape to send. Must match the embodiment the server's action "
"checkpoint was trained on (aloha for pi05_aloha, droid for pi05_droid, "
"libero for pi05_libero)."
)


@test_app.command(name="ws")
def test_websocket(
url: Annotated[str, typer.Argument(help="WebSocket URL (ws://host:port or wss://host).")],
mode: Annotated[InferenceModeChoice | None, typer.Option(help=_MODE_OPTION_HELP)] = None,
embodiment: Annotated[EmbodimentChoice, typer.Option(help=_EMBODIMENT_OPTION_HELP)] = "aloha",
) -> None:
"""Smoke test against a WebSocket server (EC2, Docker, or Modal ASGI)."""
from tests.test_ws import run

run(url, mode=mode)
run(url, mode=mode, embodiment=embodiment)


@test_app.command(name="quic")
Expand All @@ -140,11 +159,12 @@ def test_quic(
int, typer.Option(help="Server WebSocket/TCP port (for health check).")
] = 8000,
mode: Annotated[InferenceModeChoice | None, typer.Option(help=_MODE_OPTION_HELP)] = None,
embodiment: Annotated[EmbodimentChoice, typer.Option(help=_EMBODIMENT_OPTION_HELP)] = "aloha",
) -> None:
"""Smoke test against a direct QUIC server (EC2 or Docker)."""
from tests.test_quic import run

run(host, quic_port=quic_port, ws_port=ws_port, mode=mode)
run(host, quic_port=quic_port, ws_port=ws_port, mode=mode, embodiment=embodiment)


@test_app.command(name="modal-tunnel")
Expand Down
51 changes: 51 additions & 0 deletions packages/openpi-flash-client/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[project]
name = "openpi-flash-client"
version = "0.1.0"
description = "Client policy adapter for openpi-flash transport servers."
requires-python = ">=3.11,<3.13"
dependencies = [
"openpi-client",
"openpi-flash-transport",
]

[tool.uv]
# Supply-chain hardening: refuse PyPI releases younger than 7 days.
# See: https://docs.astral.sh/uv/concepts/resolution/#exclude-newer
exclude-newer = "7 days"

[tool.uv.sources]
openpi-client = { path = "../../../openpi/packages/openpi-client", editable = true }
openpi-flash-transport = { path = "../openpi-flash-transport", editable = true }

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src/openpi_flash_client"]

[tool.ruff]
line-length = 100
target-version = "py311"

[tool.ruff.lint]
select = [
"E",
"W",
"F",
"I",
"B",
"C4",
"UP",
"ANN",
"PTH",
"RET",
"SIM",
"TID",
"RUF",
]
ignore = ["E501", "B008", "C901", "RET504", "ANN401"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Client policy adapters for openpi-flash."""

from openpi_flash_client.flash_transport_policy import FlashTransportPolicy

__all__ = ["FlashTransportPolicy"]
Loading
Loading