Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,21 @@ mesh-iot = [
"awscrt>=0.20.0,<1.0.0",
"boto3>=1.34.0,<2.0.0",
]
cosmos = [
"torch>=2.0.0",
"torchvision>=0.15.0",
"transformers>=4.40.0",
"huggingface-hub>=0.20.0",
"accelerate>=0.25.0",
"requests>=2.28.0,<3.0.0",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing upper bounds on >=1.0 deps. AGENTS.md > Key Conventions #2: ">=1.0 deps: cap major."

torch>=2.0.0, torchvision>=0.15.0, transformers>=4.40.0, huggingface-hub>=0.20.0, accelerate>=0.25.0 are all unbounded. transformers and torch make breaking changes between minor versions; uncapped, a pip install strands-robots[cosmos] six months from now can silently pull a future torch 4.x that breaks cosmos-predict2's pinned ABI.

Proposed:

"torch>=2.0.0,<3.0.0",
"torchvision>=0.15.0,<1.0.0",
"transformers>=4.40.0,<5.0.0",
"huggingface-hub>=0.20.0,<1.0.0",
"accelerate>=0.25.0,<2.0.0",

(requests is correctly capped.)

]
all = [
"strands-robots[groot-service]",
"strands-robots[lerobot]",
"strands-robots[sim-mujoco]",
"strands-robots[mesh]",
"strands-robots[mesh-iot]",
"strands-robots[cosmos]",
]
dev = [
"pytest>=6.0,<9.0.0",
Expand Down Expand Up @@ -156,7 +165,7 @@ ignore_missing_imports = false

# Third-party libs without type stubs
[[tool.mypy.overrides]]
module = ["lerobot.*", "gr00t.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*", "libero.*", "zenoh.*", "boto3", "boto3.*", "awscrt", "awscrt.*", "awsiot", "awsiot.*", "botocore.*"]
module = ["lerobot.*", "gr00t.*", "cosmos_predict2.*", "requests.*", "draccus.*", "msgpack.*", "zmq.*", "huggingface_hub.*", "serial.*", "psutil.*", "torch.*", "torchvision.*", "transformers.*", "einops.*", "robot_descriptions.*", "mujoco.*", "imageio.*", "libero.*", "zenoh.*", "boto3", "boto3.*", "awscrt", "awscrt.*", "awsiot", "awsiot.*", "botocore.*"]
ignore_missing_imports = true

# @tool decorator injects runtime signatures mypy cannot check
Expand Down
33 changes: 33 additions & 0 deletions strands_robots/policies/cosmos_predict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Cosmos Predict 2.5 policy provider for strands-robots.

Wraps NVIDIA's Cosmos-Predict2.5/robot/policy checkpoint for direct
action prediction via latent-diffusion denoising. Post-trained on
LIBERO (98.5% success) and RoboCasa benchmarks.

Architecture:
[Camera Images + Proprio + Language] -> VAE Encoder -> Latent Sequence
-> Rectified Flow DiT (2B) -> Denoised Latent
-> Extract Action Chunk (16-step, 7-DoF)

Requirements:
- cosmos-predict2 package (from nvidia-cosmos/cosmos-predict2.5)
- CUDA GPU with 16GB+ VRAM

Usage::

from strands_robots.policies import create_policy

policy = create_policy(
"cosmos_predict",
model_id="nvidia/Cosmos-Policy-LIBERO-Predict2-2B",
suite="libero",
)

Reference:
"Cosmos World Foundation Model Platform for Physical AI", arXiv:2511.00062
GitHub: https://github.com/nvidia-cosmos/cosmos-predict2.5
"""

from strands_robots.policies.cosmos_predict.policy import CosmosPredictPolicy

__all__ = ["CosmosPredictPolicy"]
Loading
Loading