diff --git a/docker/setup/download_wbc_models.sh b/docker/setup/download_wbc_models.sh new file mode 100755 index 000000000..112befc51 --- /dev/null +++ b/docker/setup/download_wbc_models.sh @@ -0,0 +1,45 @@ +#!/bin/bash +set -euo pipefail + +# Script to download external WBC policy models. +# This script is called from the Dockerfile or can be run manually. + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/../.." && pwd)" +MODELS_DIR="${REPO_ROOT}/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/models" + +# --- WBC-AGILE e2e velocity policy for G1 --- +AGILE_MODEL_DIR="${MODELS_DIR}/agile" +AGILE_MODEL_PATH="${AGILE_MODEL_DIR}/unitree_g1_velocity_e2e.onnx" +AGILE_MODEL_URL="https://github.com/nvidia-isaac/WBC-AGILE/raw/main/agile/data/policy/velocity_g1/unitree_g1_velocity_e2e.onnx" +AGILE_MODEL_SHA256="8995f2462ba2d0d83afe08905148f6373990d50018610663a539225d268ef33b" + +download_model() { + local url="$1" + local dest="$2" + local expected_sha256="$3" + + if [ -f "$dest" ]; then + echo "Model already exists: $dest" + return 0 + fi + + mkdir -p "$(dirname "$dest")" + echo "Downloading $(basename "$dest") from ${url} ..." + curl -L -o "$dest" "$url" + + actual_sha256=$(sha256sum "$dest" | awk '{print $1}') + if [ "$actual_sha256" != "$expected_sha256" ]; then + echo "ERROR: SHA256 mismatch for $dest" + echo " expected: $expected_sha256" + echo " actual: $actual_sha256" + rm -f "$dest" + return 1 + fi + + echo "Downloaded and verified: $dest" +} + +download_model "$AGILE_MODEL_URL" "$AGILE_MODEL_PATH" "$AGILE_MODEL_SHA256" + +echo "All WBC models downloaded successfully." diff --git a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/configs.py b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/configs.py index ccca89bb6..13b5e049f 100644 --- a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/configs.py +++ b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/configs.py @@ -71,3 +71,22 @@ class HomieV2Config(BaseConfig): policy_config_path: str = "config/g1_homie_v2.yaml" """Policy related configuration to specify inputs/outputs dim""" + + +@dataclass +class AgileConfig(BaseConfig): + """Config for the WBC-AGILE end-to-end velocity policy for G1.""" + + # WBC Configuration + wbc_version: Literal["agile"] = "agile" + """Version of the whole body controller.""" + + wbc_model_path: str = "models/agile/unitree_g1_velocity_e2e.onnx" + """Path to WBC model file (relative to wbc_policy directory)""" + + # Robot Configuration + enable_waist: bool = False + """Whether to include waist joints in IK.""" + + policy_config_path: str = "config/g1_agile.yaml" + """Policy related configuration to specify inputs/outputs dim""" diff --git a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/g1_agile.yaml b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/g1_agile.yaml new file mode 100644 index 000000000..cbf62143a --- /dev/null +++ b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/config/g1_agile.yaml @@ -0,0 +1,61 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# Joint ordering for ONNX model inputs (29 body joints in agile order). +# Must match the element_names from unitree_g1_velocity_e2e.yaml. +onnx_input_joint_names: + - left_hip_pitch_joint + - right_hip_pitch_joint + - waist_yaw_joint + - left_hip_roll_joint + - right_hip_roll_joint + - waist_roll_joint + - left_hip_yaw_joint + - right_hip_yaw_joint + - waist_pitch_joint + - left_knee_joint + - right_knee_joint + - left_shoulder_pitch_joint + - right_shoulder_pitch_joint + - left_ankle_pitch_joint + - right_ankle_pitch_joint + - left_shoulder_roll_joint + - right_shoulder_roll_joint + - left_ankle_roll_joint + - right_ankle_roll_joint + - left_shoulder_yaw_joint + - right_shoulder_yaw_joint + - left_elbow_joint + - right_elbow_joint + - left_wrist_roll_joint + - right_wrist_roll_joint + - left_wrist_pitch_joint + - right_wrist_pitch_joint + - left_wrist_yaw_joint + - right_wrist_yaw_joint + +# Joint ordering for ONNX model outputs (14 controlled joints in agile output order). +# Must match the action_joint_pos element_names from unitree_g1_velocity_e2e.yaml. +controlled_joint_names: + - left_hip_pitch_joint + - right_hip_pitch_joint + - left_hip_roll_joint + - right_hip_roll_joint + - waist_roll_joint + - left_hip_yaw_joint + - right_hip_yaw_joint + - waist_pitch_joint + - left_knee_joint + - right_knee_joint + - left_ankle_pitch_joint + - right_ankle_pitch_joint + - left_ankle_roll_joint + - right_ankle_roll_joint + +num_actions: 14 +num_body_joints: 29 + +# Initial commands +cmd_init: [0.0, 0.0, 0.0] diff --git a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/g1_agile_policy.py b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/g1_agile_policy.py new file mode 100644 index 000000000..46f710411 --- /dev/null +++ b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/g1_agile_policy.py @@ -0,0 +1,228 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pathlib +import torch +from typing import Any + +import onnxruntime as ort + +from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.base import WBCPolicy +from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.utils.homie_utils import load_config + + +class G1AgilePolicy(WBCPolicy): + """G1 robot policy using the WBC-AGILE end-to-end neural network. + + This policy uses a single ONNX model that takes raw sensor inputs and + manages observation history internally via feedback connections. The model + outputs target joint positions along with per-joint Kp/Kd gains for 14 + controlled joints (legs + waist_roll + waist_pitch). + """ + + # Names of ONNX state inputs that receive feedback from the previous step's outputs. + _STATE_KEYS = [ + "last_actions", + "base_ang_vel_history", + "projected_gravity_history", + "velocity_commands_history", + "controlled_joint_pos_history", + "controlled_joint_vel_history", + "actions_history", + ] + + def __init__(self, robot_model, config_path: str, model_path: str, num_envs: int = 1): + """Initialize G1AgilePolicy. + + Args: + robot_model: Robot model containing joint ordering info. + config_path: Path to policy YAML configuration file (relative to wbc_policy dir). + model_path: Path to the ONNX model file (relative to wbc_policy dir). + num_envs: Number of environments. + """ + self.parent_dir = pathlib.Path(__file__).parent.parent + self.config = load_config(str(self.parent_dir / config_path)) + self.robot_model = robot_model + self.num_envs = num_envs + + # Load ONNX model (must be downloaded beforehand via docker/setup/download_wbc_models.sh) + model_full_path = self.parent_dir / model_path + if not model_full_path.exists(): + raise FileNotFoundError( + f"AGILE ONNX model not found at {model_full_path}. " + "Run docker/setup/download_wbc_models.sh to download it." + ) + self.session = ort.InferenceSession(str(model_full_path)) + self.output_names = [out.name for out in self.session.get_outputs()] + print(f"Successfully loaded ONNX policy from {model_full_path}") + + # Build joint index mappings between WBC order and agile ONNX order + self._build_joint_mappings() + + # Initialize state + self._init_state() + + def _build_joint_mappings(self): + """Build index mappings between WBC joint order and agile ONNX joint order.""" + wbc_order = self.robot_model.wbc_g1_joints_order # {joint_name: wbc_index} + + # Mapping for ONNX input: indices into the WBC-ordered observation to select + # the 29 body joints in the order the ONNX model expects. + onnx_input_names = self.config["onnx_input_joint_names"] + self.wbc_to_agile_input = [wbc_order[name] for name in onnx_input_names] + + # Mapping for ONNX output: for each of the 14 agile output joints, the + # position in the 15-element lower_body array to write to. + controlled_names = self.config["controlled_joint_names"] + lower_body_indices = self.robot_model.get_joint_group_indices("lower_body") + self.agile_output_to_lower_body = [] + for name in controlled_names: + wbc_idx = wbc_order[name] + lb_pos = lower_body_indices.index(wbc_idx) + self.agile_output_to_lower_body.append(lb_pos) + + self.num_lower_body = len(lower_body_indices) + + def _init_state(self): + """Initialize all per-environment state variables.""" + self.observation = None + self.use_policy_action = True + self.cmd = np.tile(self.config["cmd_init"], (self.num_envs, 1)) + + # Per-environment ONNX feedback state. Each entry is shaped for batch=1 + # as the ONNX model expects, matching the input tensor shapes from the YAML. + self.states = [self._make_zero_state() for _ in range(self.num_envs)] + + def _make_zero_state(self) -> dict[str, np.ndarray]: + """Create a zeroed feedback state dict for one environment.""" + return { + "last_actions": np.zeros((1, 14), dtype=np.float32), + "base_ang_vel_history": np.zeros((1, 5, 3), dtype=np.float32), + "projected_gravity_history": np.zeros((1, 5, 3), dtype=np.float32), + "velocity_commands_history": np.zeros((1, 5, 3), dtype=np.float32), + "controlled_joint_pos_history": np.zeros((1, 5, 14), dtype=np.float32), + "controlled_joint_vel_history": np.zeros((1, 5, 14), dtype=np.float32), + "actions_history": np.zeros((1, 5, 14), dtype=np.float32), + } + + def reset(self, env_ids: torch.Tensor): + """Reset the policy state for the given environment ids. + + Args: + env_ids: The environment ids to reset. + """ + for env_id in env_ids: + idx = int(env_id) + self.states[idx] = self._make_zero_state() + self.cmd = np.tile(self.config["cmd_init"], (self.num_envs, 1)) + + def set_observation(self, observation: dict[str, Any]): + """Store the current observation for the next get_action call. + + Args: + observation: Dictionary containing robot state from prepare_observations(). + """ + self.observation = observation + + def set_goal(self, goal: dict[str, Any]): + """Set the goal for the policy. + + Args: + goal: Dictionary containing goals. Supported keys: + - "navigate_cmd": velocity command array of shape (num_envs, 3) + - "toggle_policy_action": bool to toggle policy action on/off + """ + if "toggle_policy_action" in goal: + if goal["toggle_policy_action"]: + self.use_policy_action = not self.use_policy_action + + if "navigate_cmd" in goal: + self.cmd = goal["navigate_cmd"] + + def get_action(self) -> dict[str, Any]: + """Compute and return the next action based on current observation. + + Returns: + Dictionary with "body_action" key containing joint position targets + of shape (num_envs, num_lower_body_joints). + """ + if self.observation is None: + raise ValueError("No observation set. Call set_observation() first.") + + obs = self.observation + body_action = np.zeros((self.num_envs, self.num_lower_body), dtype=np.float32) + + for env_idx in range(self.num_envs): + # Build ONNX inputs for this environment + ort_inputs = self._build_onnx_inputs(obs, env_idx) + + # Run inference + outputs = self.session.run(self.output_names, ort_inputs) + result = dict(zip(self.output_names, outputs)) + + # Extract action joint positions (shape [1, 14]) + action_joint_pos = result["action_joint_pos"] + + # Update feedback state for next step + state = self.states[env_idx] + state["last_actions"] = result["last_actions_out"] + state["base_ang_vel_history"] = result["base_ang_vel_history_out"] + state["projected_gravity_history"] = result["projected_gravity_history_out"] + state["velocity_commands_history"] = result["velocity_commands_history_out"] + state["controlled_joint_pos_history"] = result["controlled_joint_pos_history_out"] + state["controlled_joint_vel_history"] = result["controlled_joint_vel_history_out"] + state["actions_history"] = result["actions_history_out"] + + # Map 14 agile output joints to the 15-joint lower_body array. + # waist_yaw (not controlled by agile) stays at 0.0. + assert self.use_policy_action + if self.use_policy_action: + for agile_idx, lb_idx in enumerate(self.agile_output_to_lower_body): + body_action[env_idx, lb_idx] = action_joint_pos[0, agile_idx] + else: + body_action[env_idx] = obs["q"][env_idx, : self.num_lower_body] + + return {"body_action": body_action} + + def _build_onnx_inputs(self, obs: dict[str, Any], env_idx: int) -> dict[str, np.ndarray]: + """Build the ONNX input dict for a single environment. + + Args: + obs: Observation dictionary from prepare_observations(). + env_idx: Environment index. + + Returns: + Dictionary mapping ONNX input names to numpy arrays. + """ + # Quaternion (w, x, y, z) from floating base pose + root_link_quat_w = obs["floating_base_pose"][env_idx : env_idx + 1, 3:7].astype(np.float32) + + # Angular velocity in body frame + root_ang_vel_b = obs["floating_base_vel"][env_idx : env_idx + 1, 3:6].astype(np.float32) + + # Velocity commands + velocity_commands = self.cmd[env_idx : env_idx + 1].astype(np.float32) + + # Joint positions and velocities: select 29 body joints and reorder to agile order + joint_pos = obs["q"][env_idx : env_idx + 1, self.wbc_to_agile_input].astype(np.float32) + joint_vel = obs["dq"][env_idx : env_idx + 1, self.wbc_to_agile_input].astype(np.float32) + + state = self.states[env_idx] + + return { + "root_link_quat_w": root_link_quat_w, + "root_ang_vel_b": root_ang_vel_b, + "velocity_commands": velocity_commands, + "joint_pos": joint_pos, + "joint_vel": joint_vel, + "last_actions": state["last_actions"], + "base_ang_vel_history": state["base_ang_vel_history"], + "projected_gravity_history": state["projected_gravity_history"], + "velocity_commands_history": state["velocity_commands_history"], + "controlled_joint_pos_history": state["controlled_joint_pos_history"], + "controlled_joint_vel_history": state["controlled_joint_vel_history"], + "actions_history": state["actions_history"], + } diff --git a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/wbc_policy_factory.py b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/wbc_policy_factory.py index c71c983cb..790a42f9b 100644 --- a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/wbc_policy_factory.py +++ b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/policy/wbc_policy_factory.py @@ -9,6 +9,7 @@ from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.g1_decoupled_whole_body_policy import ( G1DecoupledWholeBodyPolicy, ) +from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.g1_agile_policy import G1AgilePolicy from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.g1_homie_policy import G1HomiePolicyV2 from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.identity_policy import IdentityPolicy @@ -39,9 +40,16 @@ def get_wbc_policy(robot_type: str, robot_model: RobotModel, wbc_config: BaseCon model_path=wbc_config.wbc_model_path, num_envs=num_envs, ) + elif lower_body_policy_type == "agile": + lower_body_policy = G1AgilePolicy( + robot_model=robot_model, + config_path=wbc_config.policy_config_path, + model_path=wbc_config.wbc_model_path, + num_envs=num_envs, + ) else: raise ValueError( - f"Invalid lower body policy type: {lower_body_policy_type}, Supported lower body policy types: homie_v2" + f"Invalid lower body policy type: {lower_body_policy_type}, Supported lower body policy types: homie_v2, agile" ) wbc_policy = G1DecoupledWholeBodyPolicy( diff --git a/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/tests/test_g1_agile_policy.py b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/tests/test_g1_agile_policy.py new file mode 100644 index 000000000..e063cef44 --- /dev/null +++ b/isaaclab_arena_g1/g1_whole_body_controller/wbc_policy/tests/test_g1_agile_policy.py @@ -0,0 +1,299 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the G1AgilePolicy. + +These tests run without the full Isaac Lab simulation environment by mocking +the RobotModel and feeding synthetic observations. +""" + +import numpy as np +import pathlib +import pytest +import yaml + +import onnxruntime as ort + + +# --------------------------------------------------------------------------- +# Paths +# --------------------------------------------------------------------------- +WBC_POLICY_DIR = pathlib.Path(__file__).parent.parent +ONNX_MODEL_PATH = WBC_POLICY_DIR / "models" / "agile" / "unitree_g1_velocity_e2e.onnx" +AGILE_CONFIG_PATH = WBC_POLICY_DIR / "config" / "g1_agile.yaml" +WBC_JOINTS_ORDER_PATH = WBC_POLICY_DIR.parent.parent / "g1_env" / "config" / "loco_manip_g1_joints_order_43dof.yaml" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +class MockRobotModel: + """Minimal mock of RobotModel providing only what G1AgilePolicy needs.""" + + def __init__(self): + with open(WBC_JOINTS_ORDER_PATH) as f: + self.wbc_g1_joints_order = yaml.safe_load(f) + + # Build a reverse mapping: index -> name + self._idx_to_name = {v: k for k, v in self.wbc_g1_joints_order.items()} + + # Lower body = left_leg(0-5) + right_leg(6-11) + waist(12-14) + self._lower_body_indices = list(range(15)) + + def get_joint_group_indices(self, group_name): + if group_name == "lower_body": + return self._lower_body_indices + raise ValueError(f"MockRobotModel: unsupported group '{group_name}'") + + +def make_observation(num_envs: int = 1, num_joints: int = 43) -> dict: + """Create a synthetic observation dict matching prepare_observations() output.""" + return { + "q": np.zeros((num_envs, num_joints), dtype=np.float32), + "dq": np.zeros((num_envs, num_joints), dtype=np.float32), + "floating_base_pose": np.tile( + np.array([0.0, 0.0, 0.75, 1.0, 0.0, 0.0, 0.0], dtype=np.float32), + (num_envs, 1), + ), # pos + quat (w,x,y,z) + "floating_base_vel": np.zeros((num_envs, 6), dtype=np.float32), + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- +class TestOnnxModelDirect: + """Test the ONNX model directly (no policy wrapper).""" + + def test_model_loads(self): + session = ort.InferenceSession(str(ONNX_MODEL_PATH)) + assert len(session.get_inputs()) == 12 + assert len(session.get_outputs()) == 10 + + def test_model_input_names(self): + session = ort.InferenceSession(str(ONNX_MODEL_PATH)) + input_names = {inp.name for inp in session.get_inputs()} + expected = { + "root_link_quat_w", + "root_ang_vel_b", + "velocity_commands", + "joint_pos", + "joint_vel", + "last_actions", + "base_ang_vel_history", + "projected_gravity_history", + "velocity_commands_history", + "controlled_joint_pos_history", + "controlled_joint_vel_history", + "actions_history", + } + assert input_names == expected + + def test_model_output_names(self): + session = ort.InferenceSession(str(ONNX_MODEL_PATH)) + output_names = {out.name for out in session.get_outputs()} + expected = { + "action_joint_pos", + "action_joint_pos_kp_gains", + "action_joint_pos_kd_gains", + "last_actions_out", + "base_ang_vel_history_out", + "projected_gravity_history_out", + "velocity_commands_history_out", + "controlled_joint_pos_history_out", + "controlled_joint_vel_history_out", + "actions_history_out", + } + assert output_names == expected + + def test_model_inference_with_zeros(self): + """Run the model with all-zero inputs and verify output shapes.""" + session = ort.InferenceSession(str(ONNX_MODEL_PATH)) + inputs = { + "root_link_quat_w": np.array([[1.0, 0.0, 0.0, 0.0]], dtype=np.float32), + "root_ang_vel_b": np.zeros((1, 3), dtype=np.float32), + "velocity_commands": np.zeros((1, 3), dtype=np.float32), + "joint_pos": np.zeros((1, 29), dtype=np.float32), + "joint_vel": np.zeros((1, 29), dtype=np.float32), + "last_actions": np.zeros((1, 14), dtype=np.float32), + "base_ang_vel_history": np.zeros((1, 5, 3), dtype=np.float32), + "projected_gravity_history": np.zeros((1, 5, 3), dtype=np.float32), + "velocity_commands_history": np.zeros((1, 5, 3), dtype=np.float32), + "controlled_joint_pos_history": np.zeros((1, 5, 14), dtype=np.float32), + "controlled_joint_vel_history": np.zeros((1, 5, 14), dtype=np.float32), + "actions_history": np.zeros((1, 5, 14), dtype=np.float32), + } + + output_names = [out.name for out in session.get_outputs()] + outputs = session.run(output_names, inputs) + result = dict(zip(output_names, outputs)) + + assert result["action_joint_pos"].shape == (1, 14) + assert result["action_joint_pos_kp_gains"].shape == (1, 14) + assert result["action_joint_pos_kd_gains"].shape == (1, 14) + assert result["last_actions_out"].shape == (1, 14) + assert result["base_ang_vel_history_out"].shape == (1, 5, 3) + assert result["actions_history_out"].shape == (1, 5, 14) + + # Actions should be finite + assert np.all(np.isfinite(result["action_joint_pos"])) + # Gains should be positive + assert np.all(result["action_joint_pos_kp_gains"] > 0) + assert np.all(result["action_joint_pos_kd_gains"] > 0) + + +class TestJointMappings: + """Test that the joint ordering mappings are correct.""" + + def test_agile_config_loads(self): + with open(AGILE_CONFIG_PATH) as f: + config = yaml.safe_load(f) + assert len(config["onnx_input_joint_names"]) == 29 + assert len(config["controlled_joint_names"]) == 14 + + def test_wbc_to_agile_input_mapping(self): + """Verify the input mapping selects correct joints from WBC order.""" + with open(WBC_JOINTS_ORDER_PATH) as f: + wbc_order = yaml.safe_load(f) + with open(AGILE_CONFIG_PATH) as f: + config = yaml.safe_load(f) + + mapping = [wbc_order[name] for name in config["onnx_input_joint_names"]] + # All indices should be valid (0-42) + assert all(0 <= idx <= 42 for idx in mapping) + # Should have 29 unique indices (one per body joint) + assert len(set(mapping)) == 29 + # First entry is left_hip_pitch at WBC index 0 + assert mapping[0] == 0 + # Second entry is right_hip_pitch at WBC index 6 + assert mapping[1] == 6 + + def test_agile_output_to_lower_body_mapping(self): + """Verify the output mapping covers all lower body joints except waist_yaw.""" + with open(WBC_JOINTS_ORDER_PATH) as f: + wbc_order = yaml.safe_load(f) + with open(AGILE_CONFIG_PATH) as f: + config = yaml.safe_load(f) + + lower_body_indices = list(range(15)) + output_mapping = [] + for name in config["controlled_joint_names"]: + wbc_idx = wbc_order[name] + lb_pos = lower_body_indices.index(wbc_idx) + output_mapping.append(lb_pos) + + # Should cover 14 of the 15 lower body positions + assert len(set(output_mapping)) == 14 + # waist_yaw is at lower_body position 12 and should NOT be in the mapping + assert 12 not in output_mapping # waist_yaw_joint = WBC index 12 = lb pos 12 + + +class TestG1AgilePolicy: + """Test the full G1AgilePolicy class.""" + + @pytest.fixture + def policy(self): + robot_model = MockRobotModel() + return _create_policy(robot_model, num_envs=1) + + @pytest.fixture + def policy_multi_env(self): + robot_model = MockRobotModel() + return _create_policy(robot_model, num_envs=3) + + def test_init(self, policy): + assert policy.num_envs == 1 + assert policy.num_lower_body == 15 + assert len(policy.wbc_to_agile_input) == 29 + assert len(policy.agile_output_to_lower_body) == 14 + + def test_get_action_shape(self, policy): + obs = make_observation(num_envs=1) + policy.set_observation(obs) + action = policy.get_action() + assert "body_action" in action + assert action["body_action"].shape == (1, 15) + + def test_get_action_finite(self, policy): + obs = make_observation(num_envs=1) + policy.set_observation(obs) + action = policy.get_action() + assert np.all(np.isfinite(action["body_action"])) + + def test_get_action_waist_yaw_zero(self, policy): + """waist_yaw (lower_body position 12) should always be 0.""" + obs = make_observation(num_envs=1) + policy.set_observation(obs) + action = policy.get_action() + assert action["body_action"][0, 12] == 0.0 + + def test_get_action_multi_env(self, policy_multi_env): + obs = make_observation(num_envs=3) + policy_multi_env.set_observation(obs) + action = policy_multi_env.get_action() + assert action["body_action"].shape == (3, 15) + assert np.all(np.isfinite(action["body_action"])) + + def test_set_goal_navigate_cmd(self, policy): + cmd = np.array([[0.5, 0.0, 0.1]], dtype=np.float32) + policy.set_goal({"navigate_cmd": cmd}) + np.testing.assert_array_equal(policy.cmd, cmd) + + def test_reset(self, policy): + """After reset, state should be zeroed and action still valid.""" + obs = make_observation(num_envs=1) + policy.set_observation(obs) + policy.get_action() # populate state + + import torch + + policy.reset(torch.tensor([0])) + policy.set_observation(obs) + action = policy.get_action() + assert np.all(np.isfinite(action["body_action"])) + + def test_multiple_steps(self, policy): + """Run multiple steps to verify feedback state propagation.""" + obs = make_observation(num_envs=1) + policy.set_goal({"navigate_cmd": np.array([[0.3, 0.0, 0.0]], dtype=np.float32)}) + + for _ in range(5): + policy.set_observation(obs) + action = policy.get_action() + assert np.all(np.isfinite(action["body_action"])) + + # After several steps with non-zero command, state should be non-trivial + state = policy.states[0] + assert not np.allclose(state["last_actions"], 0.0) + + def test_no_observation_raises(self, policy): + with pytest.raises(ValueError, match="No observation set"): + policy.get_action() + + +# --------------------------------------------------------------------------- +# Helper to create policy (avoids importing isaaclab_arena_g1 module) +# --------------------------------------------------------------------------- +def _create_policy(robot_model, num_envs): + """Create a G1AgilePolicy without going through the module import.""" + import sys + + # Add the parent packages to sys.path so we can import directly + arena_root = WBC_POLICY_DIR.parent.parent.parent + if str(arena_root) not in sys.path: + sys.path.insert(0, str(arena_root)) + + from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.g1_agile_policy import G1AgilePolicy + + return G1AgilePolicy( + robot_model=robot_model, + config_path="config/g1_agile.yaml", + model_path="models/agile/unitree_g1_velocity_e2e.onnx", + num_envs=num_envs, + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])