feat(policies): add CosmosPredictPolicy for Cosmos-Predict2.5/robot/policy#163
feat(policies): add CosmosPredictPolicy for Cosmos-Predict2.5/robot/policy#163cagataycali wants to merge 6 commits into
Conversation
The CosmosPredictPolicy uses requests for server-mode health checks and inference. Without it in the [cosmos] optional dependency, tests fail with ModuleNotFoundError and mypy reports import-untyped. Fixes CI on PR strands-labs#163.
Review notes (automated review)Nice work — the structure mirrors Worth a second look
Nits
Confirmed (no action)
Happy to approve once the trust-gate comment / |
…olicy
Adds NVIDIA Cosmos-Predict2.5 as a policy provider, implementing
the Policy interface for action-chunk prediction via latent diffusion.
The model predicts 16-step, 7-DoF action chunks from camera observations
+ proprioception + language instruction using rectified flow denoising.
Key features:
- Supports libero, robocasa, and aloha evaluation suites
- Local mode: loads cosmos-predict2 directly (requires CUDA GPU)
- Server mode: HTTP client for env isolation (Python 3.10 server)
- Auto-resolves dataset stats and T5 embeddings from HF checkpoints
- Pattern-based camera key matching (works with any naming convention)
- Registered in policies.json with trust_remote_code gate
Dependencies:
- cosmos-predict2 package (from nvidia-cosmos/cosmos-predict2.5)
- Added [cosmos] optional extras in pyproject.toml
- Added cosmos_predict to _HF_REMOTE_CODE_PROVIDERS frozenset
Usage:
from strands_robots.policies import create_policy
policy = create_policy('cosmos_predict',
model_id='nvidia/Cosmos-Policy-LIBERO-Predict2-2B',
suite='libero')
Reference: arXiv:2511.00062, github.com/nvidia-cosmos/cosmos-predict2.5
The CosmosPredictPolicy uses requests for server-mode health checks and inference. Without it in the [cosmos] optional dependency, tests fail with ModuleNotFoundError and mypy reports import-untyped. Fixes CI on PR strands-labs#163.
When the [cosmos] extras pull in torch+transformers, importing lerobot policy classes can trigger a TypeError from draccus dataclass processing (non-default argument follows default). This is not a bug in our code but a version-specific interaction. Widen the expected exception set so CI passes regardless of which extras are installed.
… resolution
lerobot's internal dataclasses can raise TypeError ('non-default argument
follows default argument') when imported in newer versions. This causes
resolve_policy_class_by_name to propagate TypeError instead of falling
through to ImportError. Catch TypeError alongside ImportError in all
resolution strategies so the function degrades gracefully.
0cc1ce5 to
904809c
Compare
yinsong1986
left a comment
There was a problem hiding this comment.
Summary
Adds CosmosPredictPolicy as a new Policy provider, mirroring the structure of policies/groot/ and policies/lerobot_local/. The policy supports two execution modes: a local mode that loads cosmos-predict2 directly (CUDA + Python 3.10) and an HTTP server mode for environment isolation. Eighteen unit tests cover init, observation building, action decoding, server-mode mocking and the trust-remote-code gate. The PR also threads cosmos_predict into _HF_REMOTE_CODE_PROVIDERS, registers it in policies.json, and incidentally widens three except ImportError clauses in lerobot_local/resolution.py to also catch TypeError.
What's good
- Mirrors existing provider patterns (Gr00t/Lerobot) for layout and the lazy-load contract.
- Server mode keeps the GPU-only dependency out of the default install path.
- Trust-remote-code opt-in is wired in, with both positive and negative tests.
- No host paths, no emojis in user-facing strings, AGENTS.local.md is not committed.
- Reasonable mock-based unit coverage that keeps CI GPU-free.
Concerns
- Scope creep, lightly. The lerobot_local
except ImportError->except (ImportError, TypeError)widening and the test-side mirror intests/policies/lerobot_local/test_policy.pyare unrelated to Cosmos. The commit message says it fixes aTypeErrorfrom "dataclass field ordering" but no regression test that fails on pre-fix code was added. AGENTS.md > Review Learnings (#85) > 'Pin regression tests for reviewed fixes' applies. Consider splitting that into its own PR with a targeted test, or adding one here. - No
tests_integ/coverage. AGENTS.md > Key Conventions #8 says "each policy needstests_integ/tests with real inference." Even apytest.importorskip("cosmos_predict2")-gated smoke test (load -> singleget_actions()-> shape assertion) would catch upstream-API drift before users hit it. require_optional()not used.strands_robots/utils.py:require_optional()is the documented helper for optional deps (AGENTS.md > Key Conventions #7); the file uses ad-hoctry/except ImportErrorblocks fortorch,cosmos_predict2,requests, andhuggingface_hub. Not blocking, but inconsistent with the rest of the codebase.- Server mode dependency contract.
_verify_serverand_infer_serverimport requestsat function scope, butrequestsis only declared in the[cosmos]extra. A user picking server-only mode still has to install the full[cosmos](torch + transformers + accelerate, ~5GB) just to make HTTP calls. Consider a[cosmos-client]extra that pulls onlyrequests, mirroring thegroot-servicesplit. - Docstring drift around
trust_remote_code._load_local_modeldoes not actually passtrust_remote_code=Trueto anything visible in this diff -- it callscosmos_get_model(cfg). The factory still gates the provider as if it does. Either it does (in which case worth a code comment pointing at where), or the gate is overcautious for the local path (server mode never loads code).
Verification suggestions
# Quick sanity check that the registry collision does what you expect:
STRANDS_TRUST_REMOTE_CODE=1 python -c "\
from strands_robots.policies import create_policy; \
p = create_policy('nvidia/some-other-model'); \
print(type(p).__name__)"
# Expected: CosmosPredictPolicy if you intend nvidia/* to default to Cosmos,
# else Gr00tPolicy. (Currently resolves to Gr00tPolicy due to dict iteration order.)
# Round-trip the action decoder against a 7-DoF arm:
python -c "\
import numpy as np; \
from strands_robots.policies.cosmos_predict.policy import CosmosPredictPolicy as P; \
p = P(server_url='http://x'); \
p.set_robot_state_keys(['j0','j1','j2','j3','j4','j5','j6']); \
print(p._vec_to_action_dict(np.arange(7, dtype=np.float32)))"
# Expected: every joint (j0..j6) present plus 'gripper'.
# Actual: j6 missing.| if j < len(action_vec) - 1: | ||
| action_dict[key] = float(action_vec[j]) | ||
| if len(action_vec) > 0: | ||
| action_dict["gripper"] = float(action_vec[-1]) |
There was a problem hiding this comment.
Last joint silently dropped when len(robot_state_keys) >= len(action_vec).
With robot_state_keys=['j0',...,'j6'] (7 keys) and a 7-DoF action vector, the loop condition j < len(action_vec) - 1 (== j < 6) maps j0..j5 -> action_vec[0..5] and then sets gripper = action_vec[-1]. j6 is never written, so the last joint command is discarded and the gripper is double-counted. The unit test happens to use 6 state keys + a 7-element vector, which masks the off-by-one.
The intent ("last element is gripper, the rest are joints") needs explicit handling of three cases: more keys than action_vec, fewer keys, or equal:
if self._robot_state_keys:
n_joints = len(action_vec) - 1 # last element reserved for gripper
for j, key in enumerate(self._robot_state_keys[:n_joints]):
action_dict[key] = float(action_vec[j])
if len(action_vec) > 0:
action_dict["gripper"] = float(action_vec[-1])Add a regression test with set_robot_state_keys of length == action_dim (the realistic 7-DoF case) and assert every joint key appears in the output. AGENTS.md > Review Learnings (#85) > 'Per-name state copy, not flat index' is the same shape of bug.
| resp = requests.get(f"{self._server_url}/health", timeout=5) | ||
| resp.raise_for_status() | ||
| logger.info("Cosmos server connected: %s", self._server_url) | ||
| except Exception as e: |
There was a problem hiding this comment.
Bare except Exception is forbidden by AGENTS.md > Review Learnings (#86) > 'Exception Clauses Must Be Narrow'. Same pattern repeats at line 283 (_resolve_dataset_stats).
For the _verify_server health check, the realistic failure modes are network-level:
except (requests.exceptions.RequestException, OSError) as e:
logger.warning(...)For _resolve_dataset_stats (snapshot_download + JSON parse), use (OSError, ValueError, huggingface_hub.utils.HfHubHTTPError) or similar -- whatever the narrow set of recoverable failures is. A blanket except Exception will swallow KeyboardInterrupt-adjacent issues (no, KeyboardInterrupt inherits from BaseException, but it will swallow programmer errors like TypeError / AttributeError introduced by future refactors of loader_fn's contract).
|
|
||
| for key, val in observation_dict.items(): | ||
| if isinstance(val, np.ndarray): | ||
| payload[key] = val.tolist() |
There was a problem hiding this comment.
Server-mode camera serialisation is a performance trap.
val.tolist() on a 224x224x3 uint8 image expands ~150 KB of dense bytes into a Python list-of-lists-of-ints, which requests.post(json=...) then JSON-encodes to ~1.5 MB of ASCII. At 16-step chunk + multiple cameras + a control loop, this dominates per-step latency and is allocation-pressure-heavy on both client and server.
Prior art in the repo: Gr00tPolicy uses msgpack over ZMQ for exactly this reason. Even staying on HTTP, base64 + raw bytes (numpy.savez_compressed-into-BytesIO -> base64) is an order of magnitude smaller and faster. At minimum, document the throughput limit so users don't reach for server mode in a hot loop.
No regression test will catch this because the unit test uses np.zeros((224,224,3)) and a mock requests.post.
| "cosmos-predict" | ||
| ], | ||
| "hf_orgs": [ | ||
| "nvidia" |
There was a problem hiding this comment.
hf_orgs: ["nvidia"] collides with the groot provider (line 38).
In strands_robots/registry/policies.py:resolve_policy, the model_id_overrides loop runs first (so the two specific Cosmos checkpoints route correctly), but for any other nvidia/* model the fallback hf_orgs loop iterates dict insertion order and matches groot first. So create_policy("nvidia/Cosmos-Reason2-7B") or any new Cosmos checkpoint NVIDIA publishes silently routes to Gr00tPolicy instead of CosmosPredictPolicy. This is exactly the "silently dropped kwargs / silently mis-routed" anti-pattern from AGENTS.md > Review Learnings (#86) > 'Public API Hygiene'.
Options: (a) tighten hf_orgs to a more specific prefix string (the registry doesn't currently support that), (b) widen model_id_overrides to include "nvidia/Cosmos-", or (c) add an explicit tie-breaker in resolve_policy. (b) is the smallest fix and matches the format of model_id_overrides already in this file.
| "transformers>=4.40.0", | ||
| "huggingface-hub>=0.20.0", | ||
| "accelerate>=0.25.0", | ||
| "requests>=2.28.0,<3.0.0", |
There was a problem hiding this comment.
Missing upper bounds on >=1.0 deps. AGENTS.md > Key Conventions #2: ">=1.0 deps: cap major."
torch>=2.0.0, torchvision>=0.15.0, transformers>=4.40.0, huggingface-hub>=0.20.0, accelerate>=0.25.0 are all unbounded. transformers and torch make breaking changes between minor versions; uncapped, a pip install strands-robots[cosmos] six months from now can silently pull a future torch 4.x that breaks cosmos-predict2's pinned ABI.
Proposed:
"torch>=2.0.0,<3.0.0",
"torchvision>=0.15.0,<1.0.0",
"transformers>=4.40.0,<5.0.0",
"huggingface-hub>=0.20.0,<1.0.0",
"accelerate>=0.25.0,<2.0.0",(requests is correctly capped.)
| ): | ||
| return obj | ||
| except ImportError: | ||
| except (ImportError, TypeError): |
| ): | ||
| return obj | ||
| except ImportError: | ||
| except (ImportError, TypeError): |
| if not inspect.isabstract(PreTrainedPolicy): | ||
| return PreTrainedPolicy | ||
| except ImportError: | ||
| except (ImportError, TypeError): |
|
Round 2: addressed mypy Changes:
Deferred to follow-up:
🤖 This comment was written by an AI agent (Strands) |
Summary
Adds NVIDIA Cosmos-Predict2.5 as a policy provider (
cosmos_predict), implementing thePolicyinterface for action-chunk prediction via latent diffusion denoising.The model predicts 16-step, 7-DoF action chunks from camera observations + proprioception + language instruction using rectified flow. Post-trained on LIBERO (98.5% success) and RoboCasa benchmarks.
Key Design Decisions
cosmos-predict2package from nvidia-cosmos/cosmos-predict2.5 directly, NOT through any wrapper packageGr00tPolicy's ZMQ) for Python version isolationpolicies/groot/andpolicies/lerobot_local/trust_remote_code- added to_HF_REMOTE_CODE_PROVIDERSsince Cosmos models use HF custom codeFiles Changed
strands_robots/policies/cosmos_predict/__init__.pystrands_robots/policies/cosmos_predict/policy.pyCosmosPredictPolicy(Policy)implementationstrands_robots/policies/factory.pycosmos_predictto_HF_REMOTE_CODE_PROVIDERSstrands_robots/registry/policies.jsoncosmos_predictproviderpyproject.toml[cosmos]optional extras, mypy ignore forcosmos_predict2tests/policies/cosmos_predict/test_policy.pySupported Suites
Usage
Testing
ruff checkcleanruff formatcleanmypycleanDependencies
The
[cosmos]extra adds: torch, torchvision, transformers, huggingface-hub, accelerate.The
cosmos-predict2package itself must be installed from source (not on PyPI):Phase 1 of Cosmos Integration
This is the first (lowest-risk) phase of the Cosmos integration plan. Future phases will add:
Reasonerinterface + CosmosReasoner (Cosmos-Reason2 VLM)WorldModelinterface (action-conditioned video generation)References