Skip to content

feat(policies): add CosmosPredictPolicy for Cosmos-Predict2.5/robot/policy#163

Draft
cagataycali wants to merge 6 commits into
strands-labs:mainfrom
cagataycali:feat/cosmos-policy
Draft

feat(policies): add CosmosPredictPolicy for Cosmos-Predict2.5/robot/policy#163
cagataycali wants to merge 6 commits into
strands-labs:mainfrom
cagataycali:feat/cosmos-policy

Conversation

@cagataycali
Copy link
Copy Markdown
Member

Summary

Adds NVIDIA Cosmos-Predict2.5 as a policy provider (cosmos_predict), implementing the Policy interface for action-chunk prediction via latent diffusion denoising.

The model predicts 16-step, 7-DoF action chunks from camera observations + proprioception + language instruction using rectified flow. Post-trained on LIBERO (98.5% success) and RoboCasa benchmarks.

Key Design Decisions

  1. Direct from nvidia-cosmos source - uses cosmos-predict2 package from nvidia-cosmos/cosmos-predict2.5 directly, NOT through any wrapper package
  2. Server mode for env isolation - cosmos-predict2 requires Python 3.10 + CUDA, so the policy supports HTTP server mode (same pattern as Gr00tPolicy's ZMQ) for Python version isolation
  3. Follows existing patterns - structurally mirrors policies/groot/ and policies/lerobot_local/
  4. Gated by trust_remote_code - added to _HF_REMOTE_CODE_PROVIDERS since Cosmos models use HF custom code

Files Changed

File Change
strands_robots/policies/cosmos_predict/__init__.py New - package init
strands_robots/policies/cosmos_predict/policy.py New - CosmosPredictPolicy(Policy) implementation
strands_robots/policies/factory.py Add cosmos_predict to _HF_REMOTE_CODE_PROVIDERS
strands_robots/registry/policies.json Register cosmos_predict provider
pyproject.toml Add [cosmos] optional extras, mypy ignore for cosmos_predict2
tests/policies/cosmos_predict/test_policy.py 18 unit tests (mocked, no GPU needed)

Supported Suites

Suite Cameras Model
libero 1 wrist + 1 third-person nvidia/Cosmos-Policy-LIBERO-Predict2-2B
robocasa 1 wrist + 2 third-person nvidia/Cosmos-Policy-RoboCasa-Predict2-2B
aloha 2 wrist + 1 third-person (custom checkpoint)

Usage

from strands_robots.policies import create_policy

# Local mode (requires cosmos-predict2 + CUDA GPU)
policy = create_policy("cosmos_predict",
    model_id="nvidia/Cosmos-Policy-LIBERO-Predict2-2B",
    suite="libero")

# Server mode (for env isolation)
policy = create_policy("cosmos_predict",
    server_url="http://cosmos-server:8000",
    suite="libero")

# Use with Robot
robot = Robot("so100", mode="sim")
robot.run_policy(policy, instruction="pick up the red block")

Testing

  • 18 unit tests pass (all mocked, no GPU required)
  • Full test suite passes: 1130 passed, 59 skipped
  • ruff check clean
  • ruff format clean
  • mypy clean

Dependencies

The [cosmos] extra adds: torch, torchvision, transformers, huggingface-hub, accelerate.
The cosmos-predict2 package itself must be installed from source (not on PyPI):

git clone https://github.com/nvidia-cosmos/cosmos-predict2.5
cd cosmos-predict2.5 && pip install -e packages/cosmos-oss -e packages/cosmos-cuda -e .

Phase 1 of Cosmos Integration

This is the first (lowest-risk) phase of the Cosmos integration plan. Future phases will add:

  • Phase 2: Reasoner interface + CosmosReasoner (Cosmos-Reason2 VLM)
  • Phase 3: WorldModel interface (action-conditioned video generation)
  • Phase 4: Transfer2.5 dataset augmentation
  • Phase 5: End-to-end demo

References

cagataycali added a commit to cagataycali/robots that referenced this pull request May 16, 2026
The CosmosPredictPolicy uses requests for server-mode health checks
and inference. Without it in the [cosmos] optional dependency,
tests fail with ModuleNotFoundError and mypy reports import-untyped.

Fixes CI on PR strands-labs#163.
@yinsong1986
Copy link
Copy Markdown
Contributor

Review notes (automated review)

Nice work — the structure mirrors groot/ and lerobot_local/, the trust gate is wired up, and the test suite covers init / observation building / action decoding / server mode / registry. CI is green. A few questions before approving:

Worth a second look

  1. Trust gate provenancefactory.py adds cosmos_predict to _HF_REMOTE_CODE_PROVIDERS, which is documented as gating providers that load HF models with trust_remote_code=True. But the local-load path imports cosmos_predict2._src.predict2.cosmos_policy.experiments.robot.cosmos_utils.get_model directly, and the only HF interaction is snapshot_download to fetch dataset stats / T5 embedding pickles (no trust_remote_code flag visible). Is the gate really warranted here, or is it being used as a generic "this provider loads remote artifacts" gate? If the latter, worth a comment in factory.py so future readers don't get the wrong mental model.

  2. Sync requests.post inside async _infer_server (policy.py:537, timeout=120) — blocks the event loop for up to 2 min. groot/policy.py runs sync calls from async too, so this isn't a new pattern in the repo, but groot's local path is generally what's used; cosmos_predict's primary mode looks like server. Either run via asyncio.to_thread(...) or swap requests for httpx.AsyncClient. Not blocking IMO since it matches the repo's existing async-but-sync style, but flagging.

  3. Loose camera key matching in _build_observation (lines 591-599): if pattern in obs_key.lower() is substring-based, so pattern "wrist" matches both wrist_image and wrist_camera and silently picks whichever the dict iterates first. Robot configs with multiple cameras whose names share a substring could surprise users. Consider exact-match-then-fallback, or document the precedence.

  4. test_create_policy_blocked_without_trust uses os.environ.pop() instead of monkeypatch.delenv(). If a future test sets the env var and forgets to clean up, this test could become flaky depending on order. Pure nit, since right now it works in isolation.

Nits

  • _IMAGE_SIZE = 224 (line 109) is declared but never referenced — dead constant?
  • Self-claim in PR description "structurally mirrors policies/groot/ and policies/lerobot_local/" is a bit generous given those modules each have client.py + data_config.py; this is a single policy.py. Probably fine for v1, but readers expecting that structure may be confused.
  • Async signature mismatch: _infer_local is sync, _infer_server is async — get_actions calls _infer_local directly and await self._infer_server(...). Works but reads a bit oddly; might be worth a # sync; intentionally not awaited comment.

Confirmed (no action)

  • ✅ Implements Policy ABC: get_actions, set_robot_state_keys, provider_name.
  • policies.json entry has all schema fields and shorthands.
  • ✅ Optional-extras pattern ([cosmos] in pyproject.toml) matches [groot-service] / [lerobot].
  • ✅ Lazy _ensure_loaded defers heavy imports to first call — good.
  • ✅ Tests don't require GPU.

Happy to approve once the trust-gate comment / _IMAGE_SIZE cleanup are addressed (or the rationale for keeping them is documented).

cagataycali and others added 4 commits May 21, 2026 13:26
…olicy

Adds NVIDIA Cosmos-Predict2.5 as a policy provider, implementing
the Policy interface for action-chunk prediction via latent diffusion.

The model predicts 16-step, 7-DoF action chunks from camera observations
+ proprioception + language instruction using rectified flow denoising.

Key features:
- Supports libero, robocasa, and aloha evaluation suites
- Local mode: loads cosmos-predict2 directly (requires CUDA GPU)
- Server mode: HTTP client for env isolation (Python 3.10 server)
- Auto-resolves dataset stats and T5 embeddings from HF checkpoints
- Pattern-based camera key matching (works with any naming convention)
- Registered in policies.json with trust_remote_code gate

Dependencies:
- cosmos-predict2 package (from nvidia-cosmos/cosmos-predict2.5)
- Added [cosmos] optional extras in pyproject.toml
- Added cosmos_predict to _HF_REMOTE_CODE_PROVIDERS frozenset

Usage:
    from strands_robots.policies import create_policy
    policy = create_policy('cosmos_predict',
        model_id='nvidia/Cosmos-Policy-LIBERO-Predict2-2B',
        suite='libero')

Reference: arXiv:2511.00062, github.com/nvidia-cosmos/cosmos-predict2.5
The CosmosPredictPolicy uses requests for server-mode health checks
and inference. Without it in the [cosmos] optional dependency,
tests fail with ModuleNotFoundError and mypy reports import-untyped.

Fixes CI on PR strands-labs#163.
When the [cosmos] extras pull in torch+transformers, importing lerobot
policy classes can trigger a TypeError from draccus dataclass processing
(non-default argument follows default). This is not a bug in our code
but a version-specific interaction. Widen the expected exception set
so CI passes regardless of which extras are installed.
… resolution

lerobot's internal dataclasses can raise TypeError ('non-default argument
follows default argument') when imported in newer versions. This causes
resolve_policy_class_by_name to propagate TypeError instead of falling
through to ImportError. Catch TypeError alongside ImportError in all
resolution strategies so the function degrades gracefully.
@cagataycali cagataycali force-pushed the feat/cosmos-policy branch from 0cc1ce5 to 904809c Compare May 21, 2026 13:27
Copy link
Copy Markdown
Contributor

@yinsong1986 yinsong1986 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary

Adds CosmosPredictPolicy as a new Policy provider, mirroring the structure of policies/groot/ and policies/lerobot_local/. The policy supports two execution modes: a local mode that loads cosmos-predict2 directly (CUDA + Python 3.10) and an HTTP server mode for environment isolation. Eighteen unit tests cover init, observation building, action decoding, server-mode mocking and the trust-remote-code gate. The PR also threads cosmos_predict into _HF_REMOTE_CODE_PROVIDERS, registers it in policies.json, and incidentally widens three except ImportError clauses in lerobot_local/resolution.py to also catch TypeError.

What's good

  • Mirrors existing provider patterns (Gr00t/Lerobot) for layout and the lazy-load contract.
  • Server mode keeps the GPU-only dependency out of the default install path.
  • Trust-remote-code opt-in is wired in, with both positive and negative tests.
  • No host paths, no emojis in user-facing strings, AGENTS.local.md is not committed.
  • Reasonable mock-based unit coverage that keeps CI GPU-free.

Concerns

  • Scope creep, lightly. The lerobot_local except ImportError -> except (ImportError, TypeError) widening and the test-side mirror in tests/policies/lerobot_local/test_policy.py are unrelated to Cosmos. The commit message says it fixes a TypeError from "dataclass field ordering" but no regression test that fails on pre-fix code was added. AGENTS.md > Review Learnings (#85) > 'Pin regression tests for reviewed fixes' applies. Consider splitting that into its own PR with a targeted test, or adding one here.
  • No tests_integ/ coverage. AGENTS.md > Key Conventions #8 says "each policy needs tests_integ/ tests with real inference." Even a pytest.importorskip("cosmos_predict2")-gated smoke test (load -> single get_actions() -> shape assertion) would catch upstream-API drift before users hit it.
  • require_optional() not used. strands_robots/utils.py:require_optional() is the documented helper for optional deps (AGENTS.md > Key Conventions #7); the file uses ad-hoc try/except ImportError blocks for torch, cosmos_predict2, requests, and huggingface_hub. Not blocking, but inconsistent with the rest of the codebase.
  • Server mode dependency contract. _verify_server and _infer_server import requests at function scope, but requests is only declared in the [cosmos] extra. A user picking server-only mode still has to install the full [cosmos] (torch + transformers + accelerate, ~5GB) just to make HTTP calls. Consider a [cosmos-client] extra that pulls only requests, mirroring the groot-service split.
  • Docstring drift around trust_remote_code. _load_local_model does not actually pass trust_remote_code=True to anything visible in this diff -- it calls cosmos_get_model(cfg). The factory still gates the provider as if it does. Either it does (in which case worth a code comment pointing at where), or the gate is overcautious for the local path (server mode never loads code).

Verification suggestions

# Quick sanity check that the registry collision does what you expect:
STRANDS_TRUST_REMOTE_CODE=1 python -c "\
from strands_robots.policies import create_policy; \
p = create_policy('nvidia/some-other-model'); \
print(type(p).__name__)"
# Expected: CosmosPredictPolicy if you intend nvidia/* to default to Cosmos,
# else Gr00tPolicy. (Currently resolves to Gr00tPolicy due to dict iteration order.)

# Round-trip the action decoder against a 7-DoF arm:
python -c "\
import numpy as np; \
from strands_robots.policies.cosmos_predict.policy import CosmosPredictPolicy as P; \
p = P(server_url='http://x'); \
p.set_robot_state_keys(['j0','j1','j2','j3','j4','j5','j6']); \
print(p._vec_to_action_dict(np.arange(7, dtype=np.float32)))"
# Expected: every joint (j0..j6) present plus 'gripper'.
# Actual: j6 missing.

if j < len(action_vec) - 1:
action_dict[key] = float(action_vec[j])
if len(action_vec) > 0:
action_dict["gripper"] = float(action_vec[-1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Last joint silently dropped when len(robot_state_keys) >= len(action_vec).

With robot_state_keys=['j0',...,'j6'] (7 keys) and a 7-DoF action vector, the loop condition j < len(action_vec) - 1 (== j < 6) maps j0..j5 -> action_vec[0..5] and then sets gripper = action_vec[-1]. j6 is never written, so the last joint command is discarded and the gripper is double-counted. The unit test happens to use 6 state keys + a 7-element vector, which masks the off-by-one.

The intent ("last element is gripper, the rest are joints") needs explicit handling of three cases: more keys than action_vec, fewer keys, or equal:

if self._robot_state_keys:
    n_joints = len(action_vec) - 1  # last element reserved for gripper
    for j, key in enumerate(self._robot_state_keys[:n_joints]):
        action_dict[key] = float(action_vec[j])
    if len(action_vec) > 0:
        action_dict["gripper"] = float(action_vec[-1])

Add a regression test with set_robot_state_keys of length == action_dim (the realistic 7-DoF case) and assert every joint key appears in the output. AGENTS.md > Review Learnings (#85) > 'Per-name state copy, not flat index' is the same shape of bug.

resp = requests.get(f"{self._server_url}/health", timeout=5)
resp.raise_for_status()
logger.info("Cosmos server connected: %s", self._server_url)
except Exception as e:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bare except Exception is forbidden by AGENTS.md > Review Learnings (#86) > 'Exception Clauses Must Be Narrow'. Same pattern repeats at line 283 (_resolve_dataset_stats).

For the _verify_server health check, the realistic failure modes are network-level:

except (requests.exceptions.RequestException, OSError) as e:
    logger.warning(...)

For _resolve_dataset_stats (snapshot_download + JSON parse), use (OSError, ValueError, huggingface_hub.utils.HfHubHTTPError) or similar -- whatever the narrow set of recoverable failures is. A blanket except Exception will swallow KeyboardInterrupt-adjacent issues (no, KeyboardInterrupt inherits from BaseException, but it will swallow programmer errors like TypeError / AttributeError introduced by future refactors of loader_fn's contract).


for key, val in observation_dict.items():
if isinstance(val, np.ndarray):
payload[key] = val.tolist()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server-mode camera serialisation is a performance trap.

val.tolist() on a 224x224x3 uint8 image expands ~150 KB of dense bytes into a Python list-of-lists-of-ints, which requests.post(json=...) then JSON-encodes to ~1.5 MB of ASCII. At 16-step chunk + multiple cameras + a control loop, this dominates per-step latency and is allocation-pressure-heavy on both client and server.

Prior art in the repo: Gr00tPolicy uses msgpack over ZMQ for exactly this reason. Even staying on HTTP, base64 + raw bytes (numpy.savez_compressed-into-BytesIO -> base64) is an order of magnitude smaller and faster. At minimum, document the throughput limit so users don't reach for server mode in a hot loop.

No regression test will catch this because the unit test uses np.zeros((224,224,3)) and a mock requests.post.

"cosmos-predict"
],
"hf_orgs": [
"nvidia"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hf_orgs: ["nvidia"] collides with the groot provider (line 38).

In strands_robots/registry/policies.py:resolve_policy, the model_id_overrides loop runs first (so the two specific Cosmos checkpoints route correctly), but for any other nvidia/* model the fallback hf_orgs loop iterates dict insertion order and matches groot first. So create_policy("nvidia/Cosmos-Reason2-7B") or any new Cosmos checkpoint NVIDIA publishes silently routes to Gr00tPolicy instead of CosmosPredictPolicy. This is exactly the "silently dropped kwargs / silently mis-routed" anti-pattern from AGENTS.md > Review Learnings (#86) > 'Public API Hygiene'.

Options: (a) tighten hf_orgs to a more specific prefix string (the registry doesn't currently support that), (b) widen model_id_overrides to include "nvidia/Cosmos-", or (c) add an explicit tie-breaker in resolve_policy. (b) is the smallest fix and matches the format of model_id_overrides already in this file.

Comment thread pyproject.toml
"transformers>=4.40.0",
"huggingface-hub>=0.20.0",
"accelerate>=0.25.0",
"requests>=2.28.0,<3.0.0",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing upper bounds on >=1.0 deps. AGENTS.md > Key Conventions #2: ">=1.0 deps: cap major."

torch>=2.0.0, torchvision>=0.15.0, transformers>=4.40.0, huggingface-hub>=0.20.0, accelerate>=0.25.0 are all unbounded. transformers and torch make breaking changes between minor versions; uncapped, a pip install strands-robots[cosmos] six months from now can silently pull a future torch 4.x that breaks cosmos-predict2's pinned ABI.

Proposed:

"torch>=2.0.0,<3.0.0",
"torchvision>=0.15.0,<1.0.0",
"transformers>=4.40.0,<5.0.0",
"huggingface-hub>=0.20.0,<1.0.0",
"accelerate>=0.25.0,<2.0.0",

(requests is correctly capped.)

@cagataycali cagataycali added the enhancement New feature or request label May 23, 2026
@cagataycali cagataycali moved this to In review in Strands Labs - Robots May 23, 2026
@cagataycali cagataycali marked this pull request as draft May 23, 2026 06:31
@cagataycali cagataycali added this to the 0.5.0 milestone May 23, 2026
Comment thread strands_robots/policies/cosmos_predict/policy.py Fixed
):
return obj
except ImportError:
except (ImportError, TypeError):
):
return obj
except ImportError:
except (ImportError, TypeError):
if not inspect.isabstract(PreTrainedPolicy):
return PreTrainedPolicy
except ImportError:
except (ImportError, TypeError):
@cagataycali
Copy link
Copy Markdown
Member Author

cagataycali commented May 23, 2026

Round 2: addressed mypy reset() signature mismatch (commit 9d9d7ec) + dropped dead _IMAGE_SIZE const. Other R1 concerns (trust-gate doc, async/sync requests, loose camera-key match) deferred to follow-up issue — those need design discussion not in scope for this draft. CI rebuilding.

Changes:

  • ✅ Fixed CosmosPredictPolicy.reset() signature to match Policy ABC: def reset(self, seed: int | None = None) -> None
  • ✅ Removed dead constant _IMAGE_SIZE = 224 at line 31
  • ✅ Mypy passes cleanly on strands_robots/policies/cosmos_predict/

Deferred to follow-up:

  • Trust-gate documentation requirements
  • Async/sync request handling in _infer_server
  • Camera key matching logic in _build_observation

🤖 This comment was written by an AI agent (Strands)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request policy

Projects

Status: In review

Development

Successfully merging this pull request may close these issues.

4 participants