From 98d5fc37a0a09a360e27d67b3bd3139ff1f2b3ef Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Mon, 9 Mar 2026 01:03:30 -0700 Subject: [PATCH 1/7] Add warp manager-based env infrastructure Add experimental warp-compatible manager implementations, MDP terms, utilities (buffers, modifiers, noise, warp kernels), ManagerCallSwitch for eager/captured dispatch, and manager-based env orchestration. Includes RL library wrapper updates (rsl_rl, rl_games, sb3, skrl) to accept warp env types, and minor stable fixes (settings_manager RuntimeError handling, observation_manager comment cleanup). --- .../isaaclab/managers/observation_manager.py | 6 +- .../isaaclab_experimental/envs/__init__.py | 4 + .../envs/manager_based_env_warp.py | 723 +++++++++++++++ .../envs/manager_based_rl_env_warp.py | 625 +++++++++++++ .../envs/mdp/__init__.py | 26 + .../envs/mdp/actions/__init__.py | 13 + .../envs/mdp/actions/actions_cfg.py | 42 + .../envs/mdp/actions/joint_actions.py | 282 ++++++ .../isaaclab_experimental/envs/mdp/events.py | 212 +++++ .../envs/mdp/observations.py | 106 +++ .../isaaclab_experimental/envs/mdp/rewards.py | 95 ++ .../envs/mdp/terminations.py | 82 ++ .../envs/utils/io_descriptors.py | 301 ++++++ .../managers/__init__.py | 22 + .../managers/action_manager.py | 506 ++++++++++ .../managers/command_manager.py | 599 ++++++++++++ .../managers/event_manager.py | 499 ++++++++++ .../managers/manager_base.py | 446 +++++++++ .../managers/manager_term_cfg.py | 94 ++ .../managers/observation_manager.py | 862 ++++++++++++++++++ .../managers/reward_manager.py | 419 +++++++++ .../managers/scene_entity_cfg.py | 54 ++ .../managers/termination_manager.py | 355 ++++++++ .../utils/buffers/__init__.py | 12 + .../utils/buffers/circular_buffer.py | 194 ++++ .../utils/manager_call_switch.py | 230 +++++ .../utils/modifiers/__init__.py | 21 + .../utils/modifiers/modifier.py | 80 ++ .../utils/modifiers/modifier_base.py | 63 ++ .../utils/modifiers/modifier_cfg.py | 39 + .../utils/noise/__init__.py | 22 + .../utils/noise/noise_cfg.py | 72 ++ .../utils/noise/noise_model.py | 200 ++++ .../utils/torch_utils.py | 32 + .../utils/warp/__init__.py | 9 + .../utils/warp/kernels.py | 45 + .../isaaclab_experimental/utils/warp/utils.py | 192 ++++ .../isaaclab_rl/rl_games/rl_games.py | 23 +- .../isaaclab_rl/rsl_rl/vecenv_wrapper.py | 17 +- source/isaaclab_rl/isaaclab_rl/sb3.py | 23 +- source/isaaclab_rl/isaaclab_rl/skrl.py | 18 +- 41 files changed, 7652 insertions(+), 13 deletions(-) create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/command_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/managers/termination_manager.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/buffers/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/buffers/circular_buffer.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_base.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/noise/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_cfg.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_model.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/torch_utils.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/warp/kernels.py create mode 100644 source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py diff --git a/source/isaaclab/isaaclab/managers/observation_manager.py b/source/isaaclab/isaaclab/managers/observation_manager.py index a1bde0266f4..8c4c6996873 100644 --- a/source/isaaclab/isaaclab/managers/observation_manager.py +++ b/source/isaaclab/isaaclab/managers/observation_manager.py @@ -512,6 +512,7 @@ def _prepare_terms(self): # read common config for the group self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms + # to account for the batch dimension self._group_obs_concatenate_dim[group_name] = ( group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim ) @@ -550,7 +551,7 @@ def _prepare_terms(self): if group_cfg.history_length is not None: term_cfg.history_length = group_cfg.history_length term_cfg.flatten_history_dim = group_cfg.flatten_history_dim - # add term config to list to list + # add term config to list self._group_obs_term_names[group_name].append(term_name) self._group_obs_term_cfgs[group_name].append(term_cfg) @@ -604,6 +605,9 @@ def _prepare_terms(self): f" Received: {mod_cfg.func}" ) + # TODO(jichuanh): improvement can be made in two ways: + # 1. modifier specific check can be done in the modifier class + # 2. general param vs function matching check can be a common utility # check if term's arguments are matched by params term_params = list(mod_cfg.params.keys()) args = inspect.signature(mod_cfg.func).parameters diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py index 81c59dda7d5..fef4091748a 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/__init__.py @@ -44,8 +44,12 @@ from .direct_rl_env_warp import DirectRLEnvWarp # noqa: F401 from .interactive_scene_warp import InteractiveSceneWarp # noqa: F401 +from .manager_based_env_warp import ManagerBasedEnvWarp # noqa: F401 +from .manager_based_rl_env_warp import ManagerBasedRLEnvWarp # noqa: F401 __all__ = [ "DirectRLEnvWarp", "InteractiveSceneWarp", + "ManagerBasedEnvWarp", + "ManagerBasedRLEnvWarp", ] diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py new file mode 100644 index 00000000000..7f747205e96 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -0,0 +1,723 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager-based base environment. + +This is a local copy of :class:`isaaclab.envs.ManagerBasedEnv` placed under +``isaaclab_experimental`` so we can evolve the manager-based workflow for Warp-first +pipelines without depending on (or subclassing) the stable env implementation. + +Behavior is intended to match the stable environment initially. +""" + +# import builtins +import contextlib +import importlib +import logging +import warnings +from collections.abc import Sequence +from copy import deepcopy +from typing import Any + +import torch +import warp as wp + +from isaaclab.envs.common import VecEnvObs +from isaaclab.envs.manager_based_env_cfg import ManagerBasedEnvCfg +from isaaclab.envs.ui import ViewportCameraController +from isaaclab.envs.utils.io_descriptors import export_articulations_data, export_scene_data +from isaaclab.sim import SimulationContext +from isaaclab.sim.utils import use_stage +from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.seed import configure_seed +from isaaclab.utils.timer import Timer + +from isaaclab_experimental.envs.interactive_scene_warp import InteractiveSceneWarp as InteractiveScene +from isaaclab_experimental.utils.manager_call_switch import ManagerCallMode, ManagerCallSwitch +from isaaclab_experimental.utils.warp import resolve_1d_mask + +# import logger +logger = logging.getLogger(__name__) + + +@wp.kernel +def initialize_rng_state( + # input + seed: wp.int32, + # output + state: wp.array(dtype=wp.uint32), +): + env_id = wp.tid() + state[env_id] = wp.rand_init(seed, wp.int32(env_id)) + + +class ManagerBasedEnvWarp: + """The base environment for the manager-based workflow (experimental fork). + + The implementation mirrors :class:`isaaclab.envs.ManagerBasedEnv` to provide + an isolated base class for experimental Warp-based workflows. + """ + + def __init__(self, cfg: ManagerBasedEnvCfg): + """Initialize the environment. + + Args: + cfg: The configuration object for the environment. + + Raises: + RuntimeError: If a simulation context already exists. The environment must always create one + since it configures the simulation context and controls the simulation. + """ + # check that the config is valid + cfg.validate() + # store inputs to class + self.cfg = cfg + # initialize internal variables + self._is_closed = False + self._manager_call_switch = ManagerCallSwitch() + self._apply_manager_term_cfg_profile() + + # set the seed for the environment + if self.cfg.seed is not None: + self.cfg.seed = self.seed(self.cfg.seed) + else: + logger.warning("Seed not set for the environment. The environment creation may not be deterministic.") + + # create a simulation context to control the simulator + if SimulationContext.instance() is None: + # the type-annotation is required to avoid a type-checking error + # since it gets confused with Isaac Sim's SimulationContext class + self.sim: SimulationContext = SimulationContext(self.cfg.sim) + else: + # simulation context should only be created before the environment + # when in extension mode + # if not builtins.ISAAC_LAUNCHED_FROM_TERMINAL: + # raise RuntimeError("Simulation context already exists. Cannot create a new one.") + self.sim: SimulationContext = SimulationContext.instance() + + # make sure torch is running on the correct device + if "cuda" in self.device: + torch.cuda.set_device(self.device) + + # print useful information + print("[INFO]: Base environment:") + print(f"\tEnvironment device : {self.device}") + print(f"\tEnvironment seed : {self.cfg.seed}") + print(f"\tPhysics step-size : {self.physics_dt}") + print(f"\tRendering step-size : {self.physics_dt * self.cfg.sim.render_interval}") + print(f"\tEnvironment step-size : {self.step_dt}") + + if self.cfg.sim.render_interval < self.cfg.decimation: + msg = ( + f"The render interval ({self.cfg.sim.render_interval}) is smaller than the decimation " + f"({self.cfg.decimation}). Multiple render calls will happen for each environment step. " + "If this is not intended, set the render interval to be equal to the decimation." + ) + logger.warning(msg) + + # counter for simulation steps + self._sim_step_counter = 0 + + # allocate dictionary to store metrics + self.extras = {} + + # generate scene + with Timer("[INFO]: Time taken for scene creation", "scene_creation"): + # set the stage context for scene creation steps which use the stage + with use_stage(self.sim.stage): + self.scene = InteractiveScene(self.cfg.scene) + # attach_stage_to_usd_context() + print("[INFO]: Scene manager: ", self.scene) + + # Shared per-env Warp RNG state (accessible to all managers/terms via `env`). + # This is a single stream per env (no lookup) and is initialized once when `num_envs` is known. + self.rng_state_wp = wp.zeros((self.num_envs,), dtype=wp.uint32, device=self.device) + seed_val = int(self.cfg.seed) if self.cfg.seed is not None else -1 + wp.launch( + kernel=initialize_rng_state, + dim=self.num_envs, + inputs=[seed_val, self.rng_state_wp], + device=self.device, + ) + + # TODO(jichuanh): this is problematic as warp capture requires stable pointers, + # using different masks for different managers/terms will cause problems. + # Pre-allocated env masks (shared across managers/terms via `env`). + self.ALL_ENV_MASK = wp.ones((self.num_envs,), dtype=wp.bool, device=self.device) + self.ENV_MASK = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + + # Persistent scalar buffer for global env step count (stable pointer for capture). + self._global_env_step_count_wp = wp.zeros((1,), dtype=wp.int32, device=self.device) + + # set up camera viewport controller + # viewport is not available in other rendering modes so the function will throw a warning + # FIXME: This needs to be fixed in the future when we unify the UI functionalities even for + # non-rendering modes. + viz_str = self.sim.get_setting("/isaaclab/visualizer") or "" + available_visualizers = [v.strip() for v in viz_str.split(",") if v.strip()] + if "kit" in available_visualizers and bool(viz_str): + self.viewport_camera_controller = ViewportCameraController(self, self.cfg.viewer) + else: + self.viewport_camera_controller = None + + # create event manager + # note: this is needed here (rather than after simulation play) to allow USD-related randomization events + # that must happen before the simulation starts. Example: randomizing mesh scale + self.event_manager = self._manager_call_switch.resolve_manager_class("EventManager")(self.cfg.events, self) + + # apply USD-related randomization events + if "prestartup" in self.event_manager.available_modes: + self.event_manager.apply(mode="prestartup") + + # play the simulator to activate physics handles + # note: this activates the physics simulation view that exposes TensorAPIs + # note: when started in extension mode, first call sim.reset_async() and then initialize the managers + # if builtins.ISAAC_LAUNCHED_FROM_TERMINAL is False: + print("[INFO]: Starting the simulation. This may take a few seconds. Please wait...") + with Timer("[INFO]: Time taken for simulation start", "simulation_start"): + # since the reset can trigger callbacks which use the stage, + # we need to set the stage context here + with use_stage(self.sim.stage): + self.sim.reset() + # update scene to pre populate data buffers for assets and sensors. + # this is needed for the observation manager to get valid tensors for initialization. + # this shouldn't cause an issue since later on, users do a reset over all the environments so the lazy + # buffers would be reset. + self.scene.update(dt=self.physics_dt) + + # TODO(jichuanh): This is a temporary solution for event_manager only, but it should be general for all managers + # Resolve SceneEntityCfg-dependent term params once before any captured event paths. + if (not self.event_manager._is_scene_entities_resolved) and self.sim.is_playing(): + self.event_manager._resolve_terms_callback(None) + + # add timeline event to load managers + self.load_managers() + + # extend UI elements + # we need to do this here after all the managers are initialized + # this is because they dictate the sensors and commands right now + if self.sim.has_gui and self.cfg.ui_window_class_type is not None: + # setup live visualizers + self.setup_manager_visualizers() + self._window = self.cfg.ui_window_class_type(self, window_name="IsaacLab") + else: + # if no window, then we don't need to store the window + self._window = None + + # initialize observation buffers + self.obs_buf = {} + + # export IO descriptors if requested + if self.cfg.export_io_descriptors: + self.export_IO_descriptors() + + # show deprecation message for rerender_on_reset + if self.cfg.rerender_on_reset: + msg = ( + "\033[93m\033[1m[DEPRECATION WARNING] ManagerBasedEnvCfg.rerender_on_reset is deprecated. Use" + " ManagerBasedEnvCfg.num_rerenders_on_reset instead.\033[0m" + ) + warnings.warn( + msg, + FutureWarning, + stacklevel=2, + ) + if self.cfg.num_rerenders_on_reset == 0: + self.cfg.num_rerenders_on_reset = 1 + + def __del__(self): + """Cleanup for the environment.""" + # Suppress errors during Python shutdown to avoid noisy tracebacks + # Note: contextlib may be None during interpreter shutdown + if contextlib is not None: + with contextlib.suppress(ImportError, AttributeError, TypeError): + self.close() + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """The number of instances of the environment that are running.""" + return self.scene.num_envs + + @property + def physics_dt(self) -> float: + """The physics time-step (in s). + + This is the lowest time-decimation at which the simulation is happening. + """ + return self.cfg.sim.dt + + @property + def step_dt(self) -> float: + """The environment stepping time-step (in s). + + This is the time-step at which the environment steps forward. + """ + return self.cfg.sim.dt * self.cfg.decimation + + @property + def device(self): + """The device on which the environment is running.""" + return self.sim.device + + def resolve_env_mask( + self, + *, + env_ids: Sequence[int] | slice | wp.array | torch.Tensor | None = None, + env_mask: wp.array | torch.Tensor | None = None, + ) -> wp.array: + """Resolve environment ids/mask into a Warp boolean mask of shape ``(num_envs,)``.""" + return resolve_1d_mask( + ids=env_ids, + mask=env_mask, + all_mask=self.ALL_ENV_MASK, + scratch_mask=self.ENV_MASK, + device=self.device, + ) + + @property + def get_IO_descriptors(self): + """Get the IO descriptors for the environment. + + Returns: + A dictionary with keys as the group names and values as the IO descriptors. + """ + return { + "observations": self.observation_manager.get_IO_descriptors, + "actions": self.action_manager.get_IO_descriptors, + "articulations": export_articulations_data(self), + "scene": export_scene_data(self), + } + + def export_IO_descriptors(self, output_dir: str | None = None): + """Export the IO descriptors for the environment. + + Args: + output_dir: The directory to export the IO descriptors to. + """ + import os + + import yaml + + IO_descriptors = self.get_IO_descriptors + + if output_dir is None: + if self.cfg.log_dir is not None: + output_dir = os.path.join(self.cfg.log_dir, "io_descriptors") + else: + raise ValueError( + "Output directory is not set. Please set the log directory using the `log_dir`" + " configuration or provide an explicit output_dir parameter." + ) + + if not os.path.exists(output_dir): + os.makedirs(output_dir, exist_ok=True) + + with open(os.path.join(output_dir, "IO_descriptors.yaml"), "w") as f: + print(f"[INFO]: Exporting IO descriptors to {os.path.join(output_dir, 'IO_descriptors.yaml')}") + yaml.safe_dump(IO_descriptors, f) + + """ + Operations - Setup. + """ + + def load_managers(self): + """Load the managers for the environment. + + This function is responsible for creating the various managers (action, observation, + events, etc.) for the environment. Since the managers require access to physics handles, + they can only be created after the simulator is reset (i.e. played for the first time). + + .. note:: + In case of standalone application (when running simulator from Python), the function is called + automatically when the class is initialized. + + However, in case of extension mode, the user must call this function manually after the simulator + is reset. This is because the simulator is only reset when the user calls + :meth:`SimulationContext.reset_async` and it isn't possible to call async functions in the constructor. + + """ + # prepare the managers + # -- event manager (we print it here to make the logging consistent) + print("[INFO] Event Manager: ", self.event_manager) + # -- recorder manager + self.recorder_manager = self._manager_call_switch.resolve_manager_class("RecorderManager")( + self.cfg.recorders, self + ) + print("[INFO] Recorder Manager: ", self.recorder_manager) + # -- action manager + self.action_manager = self._manager_call_switch.resolve_manager_class("ActionManager")(self.cfg.actions, self) + print("[INFO] Action Manager: ", self.action_manager) + # -- observation manager + self.observation_manager = self._manager_call_switch.resolve_manager_class("ObservationManager")( + self.cfg.observations, self + ) + print("[INFO] Observation Manager:", self.observation_manager) + + # perform events at the start of the simulation + # in-case a child implementation creates other managers, the randomization should happen + # when all the other managers are created + if self.__class__ == ManagerBasedEnvWarp and "startup" in self.event_manager.available_modes: + self.event_manager.apply(mode="startup") + + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + } + + """ + Operations - MDP. + """ + + def reset( + self, seed: int | None = None, env_ids: Sequence[int] | None = None, options: dict[str, Any] | None = None + ) -> tuple[VecEnvObs, dict]: + """Resets the specified environments and returns observations. + + This function calls the :meth:`_reset_idx` function to reset the specified environments. + However, certain operations, such as procedural terrain generation, that happened during initialization + are not repeated. + + Args: + seed: The seed to use for randomization. Defaults to None, in which case the seed is not set. + env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset. + options: Additional information to specify how the environment is reset. Defaults to None. + + Note: + This argument is used for compatibility with Gymnasium environment definition. + + Returns: + A tuple containing the observations and extras. + """ + if env_ids is None: + env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device) + + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(env_ids) + + # set the seed + if seed is not None: + used_seed = self.seed(seed) + # keep cfg seed in sync for downstream users + self.cfg.seed = used_seed + # re-initialize per-env Warp RNG state without reallocating (stable pointer for capture) + wp.launch( + kernel=initialize_rng_state, + dim=self.num_envs, + inputs=[int(used_seed), self.rng_state_wp], + device=self.device, + ) + + # reset state of scene + self._reset_idx(env_ids) + + # update articulation kinematics + self.scene.write_data_to_sim() + self.sim.forward() + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.settings.get("/isaaclab/render/rtx_sensors") and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + + # compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + + # return observations + return self.obs_buf, self.extras + + def reset_to( + self, + state: dict[str, dict[str, dict[str, torch.Tensor]]], + env_ids: Sequence[int] | None, + seed: int | None = None, + is_relative: bool = False, + ): + """Resets specified environments to provided states. + + This function resets the environments to the provided states. The state is a dictionary + containing the state of the scene entities. Please refer to :meth:`InteractiveScene.get_state` + for the format. + + The function is different from the :meth:`reset` function as it resets the environments to specific states, + instead of using the randomization events for resetting the environments. + + Args: + state: The state to reset the specified environments to. Please refer to + :meth:`InteractiveScene.get_state` for the format. + env_ids: The environment ids to reset. Defaults to None, in which case all environments are reset. + seed: The seed to use for randomization. Defaults to None, in which case the seed is not set. + is_relative: If set to True, the state is considered relative to the environment origins. + Defaults to False. + """ + # reset all envs in the scene if env_ids is None + if env_ids is None: + env_ids = torch.arange(self.num_envs, dtype=torch.int64, device=self.device) + + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(env_ids) + + # set the seed + if seed is not None: + self.seed(seed) + + self._reset_idx(env_ids) + + # set the state + self.scene.reset_to(state, env_ids, is_relative=is_relative) + + # update articulation kinematics + self.sim.forward() + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.settings.get("/isaaclab/render/rtx_sensors") and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(env_ids) + + # compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + + # return observations + return self.obs_buf, self.extras + + def step(self, action: torch.Tensor) -> tuple[VecEnvObs, dict]: + """Execute one time-step of the environment's dynamics. + + The environment steps forward at a fixed time-step, while the physics simulation is + decimated at a lower time-step. This is to ensure that the simulation is stable. These two + time-steps can be configured independently using the :attr:`ManagerBasedEnvCfg.decimation` (number of + simulation steps per environment step) and the :attr:`ManagerBasedEnvCfg.sim.dt` (physics time-step) + parameters. Based on these parameters, the environment time-step is computed as the product of the two. + + Args: + action: The actions to apply on the environment. Shape is (num_envs, action_dim). + + Returns: + A tuple containing the observations and extras. + """ + # process actions + action_device = action.to(self.device) + if action_device.dtype != torch.float32: + action_device = action_device.float() + if not action_device.is_contiguous(): + action_device = action_device.contiguous() + action_wp = wp.from_torch(action_device, dtype=wp.float32) + self.action_manager.process_action(action_wp) + + self.recorder_manager.record_pre_step() + + # check if we need to do rendering within the physics loop + # note: checked here once to avoid multiple checks within the loop + is_rendering = bool(self.sim.settings.get("/isaaclab/visualizer")) or self.sim.settings.get( + "/isaaclab/render/rtx_sensors" + ) + + # perform physics stepping + for _ in range(self.cfg.decimation): + self._sim_step_counter += 1 + # set actions into buffers + self.action_manager.apply_action() + # set actions into simulator + self.scene.write_data_to_sim() + # simulate + self.sim.step(render=False) + # render between steps only if the GUI or an RTX sensor needs it + # note: we assume the render interval to be the shortest accepted rendering interval. + # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior. + if self._sim_step_counter % self.cfg.sim.render_interval == 0 and is_rendering: + self.sim.render() + # update buffers at sim dt + self.scene.update(dt=self.physics_dt) + + # post-step: step interval event + if "interval" in self.event_manager.available_modes: + self.event_manager.apply(mode="interval", dt=self.step_dt) + + # -- compute observations + self.obs_buf = self.observation_manager.compute(update_history=True) + self.recorder_manager.record_post_step() + + # return observations and extras + return self.obs_buf, self.extras + + @staticmethod + def seed(seed: int = -1) -> int: + """Set the seed for the environment. + + Args: + seed: The seed for random generator. Defaults to -1. + + Returns: + The seed used for random generator. + """ + # set seed for replicator + try: + import omni.replicator.core as rep + + rep.set_global_seed(seed) + except ModuleNotFoundError: + pass + # set seed for torch and other libraries + return configure_seed(seed) + + def close(self): + """Cleanup for the environment.""" + if not self._is_closed: + # destructor is order-sensitive + del self.viewport_camera_controller + del self.action_manager + del self.observation_manager + del self.event_manager + del self.recorder_manager + del self.scene + + # self.sim.clear_all_callbacks() + self.sim.clear_instance() + + # destroy the window + if self._window is not None: + self._window = None + # update closing status + self._is_closed = True + + """ + Helper functions. + """ + + def _resolve_stable_cfg_counterpart(self) -> ManagerBasedEnvCfg | None: + """Resolve a stable task config counterpart for the current experimental task config. + + The lookup follows a module-name mirror convention: + ``isaaclab_tasks_experimental...`` -> ``isaaclab_tasks...`` with the same config class name. + """ + cfg_cls = self.cfg.__class__ + cfg_module_name = cfg_cls.__module__ + if "isaaclab_tasks_experimental" not in cfg_module_name: + return None + + stable_module_name = cfg_module_name.replace("isaaclab_tasks_experimental", "isaaclab_tasks", 1) + try: + stable_module = importlib.import_module(stable_module_name) + except Exception as exc: + logger.warning( + "Failed to import stable task cfg module '%s' for manager_call_config stable mode: %s", + stable_module_name, + exc, + ) + return None + + stable_cfg_cls = getattr(stable_module, cfg_cls.__name__, None) + if stable_cfg_cls is None: + logger.warning( + "Stable task cfg class '%s' not found in module '%s'.", + cfg_cls.__name__, + stable_module_name, + ) + return None + + try: + return stable_cfg_cls() + except Exception as exc: + logger.warning( + "Failed to instantiate stable task cfg '%s.%s': %s", + stable_module_name, + cfg_cls.__name__, + exc, + ) + return None + + def _apply_manager_term_cfg_profile(self) -> None: + """Align term configs with manager modes for stable manager selections. + + When a manager is configured as STABLE (0), swap its corresponding config subtree + from the stable task counterpart to keep manager-term type/signature compatibility. + """ + manager_to_cfg_attr = { + "ActionManager": "actions", + "ObservationManager": "observations", + "EventManager": "events", + "RecorderManager": "recorders", + "CommandManager": "commands", + "TerminationManager": "terminations", + "RewardManager": "rewards", + "CurriculumManager": "curriculum", + } + + stable_manager_names = [ + manager_name + for manager_name in manager_to_cfg_attr + if self._manager_call_switch.get_mode_for_manager(manager_name) == ManagerCallMode.STABLE + ] + if not stable_manager_names: + return + + stable_cfg = self._resolve_stable_cfg_counterpart() + if stable_cfg is None: + logger.warning( + "Stable managers requested (%s), but no stable cfg counterpart could be resolved." + " Keeping experimental term configs.", + ", ".join(stable_manager_names), + ) + return + + replaced_items: list[str] = [] + for manager_name, cfg_attr in manager_to_cfg_attr.items(): + if self._manager_call_switch.get_mode_for_manager(manager_name) != ManagerCallMode.STABLE: + continue + if not hasattr(self.cfg, cfg_attr) or not hasattr(stable_cfg, cfg_attr): + continue + setattr(self.cfg, cfg_attr, deepcopy(getattr(stable_cfg, cfg_attr))) + replaced_items.append(f"{manager_name} -> cfg.{cfg_attr}") + + if replaced_items: + print("[INFO] Applied stable term config profile for managers:") + for item in replaced_items: + print(f" - {item}") + + def _reset_idx(self, env_ids: Sequence[int]): + """Reset environments based on specified indices. + + Args: + env_ids: List of environment ids which must be reset + """ + # reset the internal buffers of the scene elements + self.scene.reset(env_ids) + + # apply events such as randomization for environments that need a reset + if "reset" in self.event_manager.available_modes: + env_step_count = self._sim_step_counter // self.cfg.decimation + self._global_env_step_count_wp.fill_(env_step_count) + self.event_manager.apply( + mode="reset", env_ids=env_ids, global_env_step_count=self._global_env_step_count_wp + ) + + # iterate over all managers and reset them + # this returns a dictionary of information which is stored in the extras + # note: This is order-sensitive! Certain things need be reset before others. + self.extras["log"] = dict() + env_mask = self.resolve_env_mask(env_ids=env_ids) + # -- observation manager + info = self.observation_manager.reset(env_mask=env_mask) + self.extras["log"].update(info) + # -- action manager + info = self.action_manager.reset(env_mask=env_mask) + self.extras["log"].update(info) + # -- event manager + info = self.event_manager.reset(env_mask=env_mask) + self.extras["log"].update(info) + # -- recorder manager + info = self.recorder_manager.reset(env_ids) + self.extras["log"].update(info) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py new file mode 100644 index 00000000000..8589a47dbda --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -0,0 +1,625 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager-based RL environment (Warp entry point). + +This module provides an experimental fork of the stable manager-based RL environment +so it can diverge (Warp-first / graph-friendly) without inheriting from the stable +`isaaclab.envs.ManagerBasedRLEnv` implementation. +""" + +# needed to import for allowing type-hinting: np.ndarray | None +from __future__ import annotations + +import math +import os +from collections.abc import Sequence +from typing import Any, ClassVar + +import gymnasium as gym +import numpy as np +import torch +import warp as wp + +from isaaclab.envs.common import VecEnvStepReturn +from isaaclab.envs.manager_based_rl_env_cfg import ManagerBasedRLEnvCfg +from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.timer import Timer + +from isaaclab_experimental.utils.manager_call_switch import ManagerCallMode +from isaaclab_experimental.utils.torch_utils import clone_obs_buffer + +from .manager_based_env_warp import ManagerBasedEnvWarp + +DEBUG_TIMER_STEP = os.environ.get("DEBUG_TIMER_STEP", "0") == "1" +"""Enable outer step() timer. Set DEBUG_TIMER_STEP=1 env var to enable.""" + +DEBUG_TIMERS = os.environ.get("DEBUG_TIMERS", "0") == "1" +"""Enable all fine-grained inner timers. Set DEBUG_TIMERS=1 env var to enable.""" + +TIMER_ENABLED_STEP = DEBUG_TIMER_STEP or DEBUG_TIMERS +TIMER_ENABLED_RESET_IDX = DEBUG_TIMERS + + +class ManagerBasedRLEnvWarp(ManagerBasedEnvWarp, gym.Env): + """The superclass for the manager-based workflow reinforcement learning-based environments. + + This class inherits from :class:`ManagerBasedEnv` and implements the core functionality for + reinforcement learning-based environments. It is designed to be used with any RL + library. The class is designed to be used with vectorized environments, i.e., the + environment is expected to be run in parallel with multiple sub-environments. The + number of sub-environments is specified using the ``num_envs``. + + Each observation from the environment is a batch of observations for each sub- + environments. The method :meth:`step` is also expected to receive a batch of actions + for each sub-environment. + + While the environment itself is implemented as a vectorized environment, we do not + inherit from :class:`gym.vector.VectorEnv`. This is mainly because the class adds + various methods (for wait and asynchronous updates) which are not required. + Additionally, each RL library typically has its own definition for a vectorized + environment. Thus, to reduce complexity, we directly use the :class:`gym.Env` over + here and leave it up to library-defined wrappers to take care of wrapping this + environment for their agents. + + Note: + For vectorized environments, it is recommended to **only** call the :meth:`reset` + method once before the first call to :meth:`step`, i.e. after the environment is created. + After that, the :meth:`step` function handles the reset of terminated sub-environments. + This is because the simulator does not support resetting individual sub-environments + in a vectorized environment. + + """ + + is_vector_env: ClassVar[bool] = True + """Whether the environment is a vectorized environment.""" + metadata: ClassVar[dict[str, Any]] = { + "render_modes": [None, "human", "rgb_array"], + # "isaac_sim_version": get_version(), + } + """Metadata for the environment.""" + + cfg: ManagerBasedRLEnvCfg + """Configuration for the environment.""" + + def __init__(self, cfg: ManagerBasedRLEnvCfg, render_mode: str | None = None, **kwargs): + """Initialize the environment. + + Args: + cfg: The configuration for the environment. + render_mode: The render mode for the environment. Defaults to None, which + is similar to ``"human"``. + """ + # -- counter for curriculum + self.common_step_counter = 0 + + # initialize the episode length buffer BEFORE loading the managers to use it in mdp functions. + # Warp array is the source of truth; torch view is zero-copy for += and indexed assignment. + self._episode_length_buf_wp = wp.zeros(cfg.scene.num_envs, dtype=wp.int64, device=cfg.sim.device) + self._episode_length_buf = wp.to_torch(self._episode_length_buf_wp) + + # initialize the base class to setup the scene. + super().__init__(cfg=cfg) + # store the render mode + self.render_mode = render_mode + + # The persistent reset mask needed for warp capture + # The intended use is to copy into this mask whenever capture is needed + # TODO: termination manager provides the same mask, investigate whether this can be replaced. + self.reset_mask_wp = wp.zeros(cfg.scene.num_envs, dtype=wp.bool, device=cfg.sim.device) + + # Persistent action input buffer to keep pointer stable for captured graphs. + self._action_in_wp: wp.array = wp.zeros( + (self.num_envs, self.action_manager.total_action_dim), dtype=wp.float32, device=self.device + ) + + # initialize data and constants + # -- set the framerate of the gym video recorder wrapper so that the playback speed + # of the produced video matches the simulation + self.metadata["render_fps"] = 1 / self.step_dt + + print("[INFO]: Completed setting up the environment...") + + """ + Properties. + """ + + @property + def episode_length_buf(self) -> torch.Tensor: + """Episode length buffer (torch view of the underlying warp array).""" + return self._episode_length_buf + + @episode_length_buf.setter + def episode_length_buf(self, value: torch.Tensor): + """Copy into the existing buffer to preserve the warp array linkage.""" + self._episode_length_buf[:] = value + + @property + def max_episode_length_s(self) -> float: + """Maximum episode length in seconds.""" + return self.cfg.episode_length_s + + @property + def max_episode_length(self) -> int: + """Maximum episode length in environment steps.""" + return math.ceil(self.max_episode_length_s / self.step_dt) + + """ + Operations - Setup. + """ + + def load_managers(self): + # note: this order is important since observation manager needs to know the command and action managers + # and the reward manager needs to know the termination manager + # -- command manager + self.command_manager = self._manager_call_switch.resolve_manager_class("CommandManager")( + self.cfg.commands, self + ) + print("[INFO] Command Manager: ", self.command_manager) + + # call the parent class to load the managers for observations and actions. + super().load_managers() + + # prepare the managers + # -- termination manager + self.termination_manager = self._manager_call_switch.resolve_manager_class("TerminationManager")( + self.cfg.terminations, self + ) + print("[INFO] Termination Manager: ", self.termination_manager) + # -- reward manager (experimental fork; Warp-compatible rewards) + self.reward_manager = self._manager_call_switch.resolve_manager_class("RewardManager")(self.cfg.rewards, self) + print("[INFO] Reward Manager: ", self.reward_manager) + # -- curriculum manager + self.curriculum_manager = self._manager_call_switch.resolve_manager_class("CurriculumManager")( + self.cfg.curriculum, self + ) + print("[INFO] Curriculum Manager: ", self.curriculum_manager) + + # setup the action and observation spaces for Gym + self._configure_gym_env_spaces() + + # perform events at the start of the simulation + if "startup" in self.event_manager.available_modes: + self.event_manager.apply(mode="startup") + + def setup_manager_visualizers(self): + """Creates live visualizers for manager terms.""" + + self.manager_visualizers = { + "action_manager": ManagerLiveVisualizer(manager=self.action_manager), + "observation_manager": ManagerLiveVisualizer(manager=self.observation_manager), + "command_manager": ManagerLiveVisualizer(manager=self.command_manager), + "termination_manager": ManagerLiveVisualizer(manager=self.termination_manager), + "reward_manager": ManagerLiveVisualizer(manager=self.reward_manager), + "curriculum_manager": ManagerLiveVisualizer(manager=self.curriculum_manager), + } + + """ + Operations - MDP + """ + + def invalidate_wp_graphs(self) -> None: + """Invalidate all cached Warp graphs. + + Call this if the captured launch topology changes (e.g. different term list, shapes, etc.). + """ + self._manager_call_switch.invalidate_graphs() + + def step_warp_termination_compute(self) -> None: + """Captured stage: compute terminations (env-step frequency).""" + self.reset_buf = self.termination_manager.compute() + self.reset_terminated = self.termination_manager.terminated + self.reset_time_outs = self.termination_manager.time_outs + + @Timer(name="env_step", msg="Step took:", enable=TIMER_ENABLED_STEP, time_unit="us") + def step(self, action: torch.Tensor) -> VecEnvStepReturn: + """Execute one time-step of the environment's dynamics and reset terminated environments. + + Unlike the :class:`ManagerBasedEnv.step` class, the function performs the following operations: + + 1. Process the actions. + 2. Perform physics stepping. + 3. Perform rendering if gui is enabled. + 4. Update the environment counters and compute the rewards and terminations. + 5. Reset the environments that terminated. + 6. Compute the observations. + 7. Return the observations, rewards, resets and extras. + + Args: + action: The actions to apply on the environment. Shape is (num_envs, action_dim). + + Returns: + A tuple containing the observations, rewards, resets (terminated and truncated) and extras. + """ + # process actions + # NOTE: keep a persistent action input buffer for graph pointer stability. + # IMPORTANT: Do NOT re-wrap/replace the `wp.array` used by captured graphs each step. + # Instead, copy the latest actions into the persistent buffer. + with Timer( + name="action_preprocess", msg="Action preprocessing took:", enable=TIMER_ENABLED_STEP, time_unit="us" + ): + assert self._action_in_wp is not None + action_device = action.to(self.device) + wp.copy(self._action_in_wp, wp.from_torch(action_device, dtype=wp.float32)) + + self._manager_call_switch.call_stage( + stage="ActionManager_process_action", + warp_call={"fn": self.action_manager.process_action, "kwargs": {"action": self._action_in_wp}}, + timer=TIMER_ENABLED_STEP, + ) + + self.recorder_manager.record_pre_step() + + # check if we need to do rendering within the physics loop + # note: checked here once to avoid multiple checks within the loop + is_rendering = bool(self.sim.settings.get("/isaaclab/visualizer")) or self.sim.settings.get( + "/isaaclab/render/rtx_sensors" + ) + + # perform physics stepping + for _ in range(self.cfg.decimation): + self._sim_step_counter += 1 + # set actions into buffers + self._manager_call_switch.call_stage( + stage="ActionManager_apply_action", + warp_call={"fn": self.action_manager.apply_action}, + timer=TIMER_ENABLED_STEP, + ) + self._manager_call_switch.call_stage( + stage="Scene_write_data_to_sim", + warp_call={"fn": self.scene.write_data_to_sim}, + timer=TIMER_ENABLED_STEP, + ) + + # simulate + with Timer(name="simulate", msg="Newton simulation step took:", enable=TIMER_ENABLED_STEP, time_unit="us"): + self.sim.step(render=False) + self.recorder_manager.record_post_physics_decimation_step() + # render between steps only if the GUI or an RTX sensor needs it + # note: we assume the render interval to be the shortest accepted rendering interval. + # If a camera needs rendering at a faster frequency, this will lead to unexpected behavior. + if self._sim_step_counter % self.cfg.sim.render_interval == 0 and is_rendering: + self.sim.render() + # update buffers at sim dt + with Timer( + name="scene.update", + msg="Scene.update took:", + enable=TIMER_ENABLED_STEP, + time_unit="us", + ): + self.scene.update(dt=self.physics_dt) + + # post-step: + # -- update env counters (used for curriculum generation) + self.episode_length_buf += 1 # step in current episode (per env) + self.common_step_counter += 1 # total step (common for all envs) + + # -- post-processing (termination + reward) as independently configurable stages + self._manager_call_switch.call_stage( + stage="TerminationManager_compute", + warp_call={"fn": self.step_warp_termination_compute}, + timer=TIMER_ENABLED_STEP, + ) + self.reward_buf = self._manager_call_switch.call_stage( + stage="RewardManager_compute", + warp_call={"fn": self.reward_manager.compute, "kwargs": {"dt": float(self.step_dt)}}, + timer=TIMER_ENABLED_STEP, + ) + + if len(self.recorder_manager.active_terms) > 0: + # update observations for recording if needed + self._manager_call_switch.call_stage( + stage="ObservationManager_compute_no_history", + warp_call={"fn": self.observation_manager.compute, "kwargs": {"return_cloned_output": False}}, + timer=TIMER_ENABLED_STEP, + ) + self.recorder_manager.record_post_step() + + # -- reset envs that terminated/timed-out and log the episode information + # NOTE: Interim path (intentional). + # We still compact `reset_buf` into `env_ids` here because several reset-time managers/recorders + # are still `env_ids`-based. Do NOT remove/replace this until mask-based reset is end-to-end. + with Timer( + name="reset_selection", + msg="Reset selection took:", + enable=TIMER_ENABLED_STEP, + time_unit="us", + ): + # Keep the reset-mask handoff fully in Warp when experimental termination buffers exist. + # Stable termination manager path exposes torch-only dones/reset buffers. + termination_manager_mode = self._manager_call_switch.get_mode_for_manager("TerminationManager") + if termination_manager_mode == ManagerCallMode.STABLE: + # copy still needed as mask will be used if manager is set to mode > 0 + wp.copy(self.reset_mask_wp, wp.from_torch(self.reset_buf, dtype=wp.bool)) + else: + wp.copy(self.reset_mask_wp, self.termination_manager.dones_wp) + reset_env_ids = self.reset_buf.nonzero(as_tuple=False).squeeze(-1) + if len(reset_env_ids) > 0: + # trigger recorder terms for pre-reset calls + self.recorder_manager.record_pre_reset(reset_env_ids) + + with Timer( + name="reset_idx", + msg="Reset idx took:", + enable=TIMER_ENABLED_STEP, + time_unit="us", + ): + self._reset_idx(env_ids=reset_env_ids, env_mask=self.reset_mask_wp) + + # if sensors are added to the scene, make sure we render to reflect changes in reset + if self.sim.settings.get("/isaaclab/render/rtx_sensors") and self.cfg.num_rerenders_on_reset > 0: + for _ in range(self.cfg.num_rerenders_on_reset): + self.sim.render() + + # trigger recorder terms for post-reset calls + self.recorder_manager.record_post_reset(reset_env_ids) + + # -- update command + self._manager_call_switch.call_stage( + stage="CommandManager_compute", + warp_call={"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}, + timer=TIMER_ENABLED_STEP, + ) + + # -- step interval events + if "interval" in self.event_manager.available_modes: + self._manager_call_switch.call_stage( + stage="EventManager_apply_interval", + warp_call={"fn": self.event_manager.apply, "kwargs": {"mode": "interval", "dt": float(self.step_dt)}}, + timer=TIMER_ENABLED_STEP, + ) + + # -- compute observations + # note: done after reset to get the correct observations for reset envs + self.obs_buf = self._manager_call_switch.call_stage( + stage="ObservationManager_compute_update_history", + warp_call={ + "fn": self.observation_manager.compute, + "kwargs": {"update_history": True, "return_cloned_output": False}, + "output": lambda r: clone_obs_buffer(r), + }, + timer=TIMER_ENABLED_STEP, + ) + # return observations, rewards, resets and extras + return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras + + def render(self, recompute: bool = False) -> np.ndarray | None: + """Run rendering without stepping through the physics. + + By convention, if mode is: + + - **human**: Render to the current display and return nothing. Usually for human consumption. + - **rgb_array**: Return a numpy.ndarray with shape (x, y, 3), representing RGB values for an + x-by-y pixel image, suitable for turning into a video. + + Args: + recompute: Whether to force a render even if the simulator has already rendered the scene. + Defaults to False. + + Returns: + The rendered image as a numpy array if mode is "rgb_array". Otherwise, returns None. + + Raises: + RuntimeError: If mode is set to "rgb_data" and simulation render mode does not support it. + In this case, the simulation render mode must be set to ``RenderMode.PARTIAL_RENDERING`` + or ``RenderMode.FULL_RENDERING``. + NotImplementedError: If an unsupported rendering mode is specified. + """ + # run a rendering step of the simulator + # if we have rtx sensors, we do not need to render again sin + if not self.sim.settings.get("/isaaclab/render/rtx_sensors") and not recompute: + self.sim.render() + # decide the rendering mode + if self.render_mode == "human" or self.render_mode is None: + return None + elif self.render_mode == "rgb_array": + # check that if any render could have happened + has_gui = bool(self.sim.get_setting("/isaaclab/has_gui")) + offscreen_render = bool(self.sim.get_setting("/isaaclab/render/offscreen")) + if not (has_gui or offscreen_render): + raise RuntimeError( + f"Cannot render '{self.render_mode}' when the simulation render mode does not support" + " rendering. Please set the simulation render mode to 'PARTIAL_RENDERING' or" + " 'FULL_RENDERING'. If running headless, make sure --enable_cameras is set." + ) + # create the annotator if it does not exist + if not hasattr(self, "_rgb_annotator"): + import omni.replicator.core as rep + + # create render product + self._render_product = rep.create.render_product( + self.cfg.viewer.cam_prim_path, self.cfg.viewer.resolution + ) + # create rgb annotator -- used to read data from the render product + self._rgb_annotator = rep.AnnotatorRegistry.get_annotator("rgb", device="cpu") + self._rgb_annotator.attach([self._render_product]) + # obtain the rgb data + rgb_data = self._rgb_annotator.get_data() + # convert to numpy array + rgb_data = np.frombuffer(rgb_data, dtype=np.uint8).reshape(*rgb_data.shape) + # return the rgb data + # note: initially the renerer is warming up and returns empty data + if rgb_data.size == 0: + return np.zeros((self.cfg.viewer.resolution[1], self.cfg.viewer.resolution[0], 3), dtype=np.uint8) + else: + return rgb_data[:, :, :3] + else: + raise NotImplementedError( + f"Render mode '{self.render_mode}' is not supported. Please use: {self.metadata['render_modes']}." + ) + + def close(self): + if not self._is_closed: + # destructor is order-sensitive + del self.command_manager + del self.reward_manager + del self.termination_manager + del self.curriculum_manager + # call the parent class to close the environment + super().close() + + """ + Helper functions. + """ + + def _configure_gym_env_spaces(self): + """Configure the action and observation spaces for the Gym environment.""" + # observation space (unbounded since we don't impose any limits) + self.single_observation_space = gym.spaces.Dict() + for group_name, group_term_names in self.observation_manager.active_terms.items(): + # extract quantities about the group + has_concatenated_obs = self.observation_manager.group_obs_concatenate[group_name] + group_dim = self.observation_manager.group_obs_dim[group_name] + # check if group is concatenated or not + # if not concatenated, then we need to add each term separately as a dictionary + if has_concatenated_obs: + self.single_observation_space[group_name] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=group_dim) + else: + group_term_cfgs = self.observation_manager._group_obs_term_cfgs[group_name] + term_dict = {} + for term_name, term_dim, term_cfg in zip(group_term_names, group_dim, group_term_cfgs): + low = -np.inf if term_cfg.clip is None else term_cfg.clip[0] + high = np.inf if term_cfg.clip is None else term_cfg.clip[1] + term_dict[term_name] = gym.spaces.Box(low=low, high=high, shape=term_dim) + self.single_observation_space[group_name] = gym.spaces.Dict(term_dict) + # action space (unbounded since we don't impose any limits) + action_dim = sum(self.action_manager.action_term_dim) + self.single_action_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(action_dim,)) + + # batch the spaces for vectorized environments + self.observation_space = gym.vector.utils.batch_space(self.single_observation_space, self.num_envs) + self.action_space = gym.vector.utils.batch_space(self.single_action_space, self.num_envs) + + def _reset_idx( + self, + env_ids: Sequence[int] | slice | torch.Tensor, + *, + env_mask: wp.array | None = None, + ): + """Reset environments based on specified indices. + + IMPORTANT: + This function always uses the **TerminationManager-produced Warp env mask** (`self.reset_buf`) to select + which envs to reset. The ids/mask conversion is performed in `step()` before calling this function. + + In other words: + - If `env_mask` is provided, it **must** be `self.reset_buf` (Warp bool mask) + - If `env_mask` is not provided, this function will populate `self.reset_buf` from `env_ids` + - When `env_mask` is provided, `env_ids` **must** correspond to the same mask + + Args: + env_ids: Environment indices to reset. + env_mask: Warp boolean env mask selecting envs to reset. Must be `self.reset_buf`. + If None, uses and populates `self.reset_buf` from `env_ids`. + """ + if env_mask is None: + # Base `reset()` / `reset_to()` call-path provides only `env_ids`. + # Populate the stable TerminationManager-owned mask (`self.reset_buf`) from ids. + env_mask = self.reset_mask_wp + # Use the centralized env-id/mask resolution from the base Warp env, then copy into the + # stable TerminationManager-owned buffer (`self.reset_buf`) used by captured graphs. + resolved_mask = self.resolve_env_mask(env_ids=env_ids) + wp.copy(env_mask, resolved_mask) + + if not isinstance(env_mask, wp.array): + raise TypeError(f"env_mask must be a wp.array (got {type(env_mask)}).") + + # update the curriculum for environments that need a reset + with Timer( + name="curriculum_manager.compute_reset", + msg="CurriculumManager.compute (reset) took:", + enable=TIMER_ENABLED_RESET_IDX, + time_unit="us", + ): + self.curriculum_manager.compute(env_ids=env_ids) + + # reset the internal buffers of the scene elements + self._manager_call_switch.call_stage( + stage="Scene_reset", + warp_call={"fn": self.scene.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + + if "reset" in self.event_manager.available_modes: + self._global_env_step_count_wp.fill_(self._sim_step_counter // self.cfg.decimation) + self._manager_call_switch.call_stage( + stage="EventManager_apply_reset", + warp_call={ + "fn": self.event_manager.apply, + "kwargs": { + "mode": "reset", + "env_mask_wp": env_mask, + "global_env_step_count": self._global_env_step_count_wp, + }, + }, + timer=TIMER_ENABLED_RESET_IDX, + ) + + # iterate over all managers and reset them + # this returns a dictionary of information which is stored in the extras + # note: This is order-sensitive! Certain things need be reset before others. + # -- observation manager + action + reward managers + obs_info = self._manager_call_switch.call_stage( + stage="ObservationManager_reset", + warp_call={"fn": self.observation_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + action_info = self._manager_call_switch.call_stage( + stage="ActionManager_reset", + warp_call={"fn": self.action_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + reward_info = self._manager_call_switch.call_stage( + stage="RewardManager_reset", + warp_call={"fn": self.reward_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + + # -- curriculum manager + with Timer( + name="curriculum_manager.reset", + msg="CurriculumManager.reset took:", + enable=TIMER_ENABLED_RESET_IDX, + time_unit="us", + ): + curriculum_info = self.curriculum_manager.reset(env_ids=env_ids) + + # -- command + event + termination managers + command_info = self._manager_call_switch.call_stage( + stage="CommandManager_reset", + warp_call={"fn": self.command_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + event_info = self._manager_call_switch.call_stage( + stage="EventManager_reset", + warp_call={"fn": self.event_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + termination_info = self._manager_call_switch.call_stage( + stage="TerminationManager_reset", + warp_call={"fn": self.termination_manager.reset, "kwargs": {"env_mask": env_mask}}, + timer=TIMER_ENABLED_RESET_IDX, + ) + + # -- recorder manager + recorder_info = self.recorder_manager.reset(env_ids=env_ids) + + # reset the episode length buffer + self.episode_length_buf[env_ids] = 0 + + # aggregate logging info + log: dict[str, Any] = {} + for info in ( + obs_info, + action_info, + reward_info, + curriculum_info, + command_info, + event_info, + termination_info, + recorder_info, + ): + log.update(info) + self.extras["log"] = log diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py new file mode 100644 index 00000000000..1476dc82879 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental MDP terms. + +This package forwards all stable MDP terms from :mod:`isaaclab.envs.mdp`, but overrides reward +functions with Warp-first implementations from :mod:`isaaclab_experimental.envs.mdp.rewards`. +""" + +# Forward stable MDP terms (commands/observations/terminations/etc.) but *exclude* rewards and actions. +# Rewards and actions are provided by this experimental package to keep Warp-first execution. +from isaaclab.envs.mdp.commands import * # noqa: F401, F403 +from isaaclab.envs.mdp.curriculums import * # noqa: F401, F403 +from isaaclab.envs.mdp.events import * # noqa: F401, F403 +from isaaclab.envs.mdp.observations import * # noqa: F401, F403 +from isaaclab.envs.mdp.recorders import * # noqa: F401, F403 +from isaaclab.envs.mdp.terminations import * # noqa: F401, F403 + +# Override terms with experimental implementations. +from .actions import * # noqa: F401, F403 +from .events import * # noqa: F401, F403 +from .observations import * # noqa: F401, F403 +from .rewards import * # noqa: F401, F403 +from .terminations import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py new file mode 100644 index 00000000000..283805a279f --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental action terms (minimal). + +Only the action configs/terms currently required by the experimental manager-based Cartpole task +are provided here. +""" + +from .actions_cfg import * # noqa: F401, F403 +from .joint_actions import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py new file mode 100644 index 00000000000..d8826602dbe --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Action term configuration (experimental, minimal). + +This module mirrors the stable :mod:`isaaclab.envs.mdp.actions.actions_cfg` but only keeps what +the experimental Cartpole task needs. +""" + +from dataclasses import MISSING + +from isaaclab.utils import configclass + +from isaaclab_experimental.managers.action_manager import ActionTerm, ActionTermCfg + +from . import joint_actions + + +@configclass +class JointActionCfg(ActionTermCfg): + """Configuration for the base joint action term.""" + + joint_names: list[str] = MISSING + """List of joint names or regex expressions that the action will be mapped to.""" + + scale: float | dict[str, float] = 1.0 + """Scale factor for the action (float or dict of regex expressions). Defaults to 1.0.""" + + offset: float | dict[str, float] = 0.0 + """Offset factor for the action (float or dict of regex expressions). Defaults to 0.0.""" + + preserve_order: bool = False + """Whether to preserve the order of the joint names in the action output. Defaults to False.""" + + +@configclass +class JointEffortActionCfg(JointActionCfg): + """Configuration for the joint effort action term.""" + + class_type: type[ActionTerm] = joint_actions.JointEffortAction diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py new file mode 100644 index 00000000000..183cb2cfe49 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py @@ -0,0 +1,282 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +import warp as wp + +import isaaclab.utils.string as string_utils +from isaaclab.assets.articulation import Articulation + +from isaaclab_experimental.managers.action_manager import ActionTerm +from isaaclab_experimental.utils.warp import resolve_1d_mask + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + from isaaclab.envs.utils.io_descriptors import GenericActionIODescriptor + + from . import actions_cfg + +# import logger +logger = logging.getLogger(__name__) + + +@wp.kernel +def _process_joint_actions_kernel( + # input + actions: wp.array(dtype=wp.float32, ndim=2), + action_offset: int, + # params + scale: wp.array(dtype=wp.float32), + offset: wp.array(dtype=wp.float32), + clip: wp.array(dtype=wp.float32, ndim=2), + # output + raw_out: wp.array(dtype=wp.float32, ndim=2), + processed_out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, j = wp.tid() + col = action_offset + j + + a = actions[env_id, col] + raw_out[env_id, j] = a + + x = a * scale[j] + offset[j] + low = clip[j, 0] + high = clip[j, 1] + if x < low: + x = low + if x > high: + x = high + processed_out[env_id, j] = x + + +@wp.kernel +def _zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): + env_id, j = wp.tid() + if mask[env_id]: + values[env_id, j] = 0.0 + + +class JointAction(ActionTerm): + r"""Base class for joint actions. + + This action term performs pre-processing of the raw actions using affine transformations (scale and offset). + These transformations can be configured to be applied to a subset of the articulation's joints. + + Mathematically, the action term is defined as: + + .. math:: + + \text{action} = \text{offset} + \text{scaling} \times \text{input action} + + where :math:`\text{action}` is the action that is sent to the articulation's actuated joints, :math:`\text{offset}` + is the offset applied to the input action, :math:`\text{scaling}` is the scaling applied to the input + action, and :math:`\text{input action}` is the input action from the user. + + Based on above, this kind of action transformation ensures that the input and output actions are in the same + units and dimensions. The child classes of this action term can then map the output action to a specific + desired command of the articulation's joints (e.g. position, velocity, etc.). + """ + + cfg: actions_cfg.JointActionCfg + """The configuration of the action term.""" + _asset: Articulation + """The articulation asset on which the action term is applied.""" + _scale: wp.array + """The scaling factor applied to the input action.""" + _offset: wp.array + """The offset applied to the input action.""" + _clip: wp.array + """The clip applied to the input action.""" + _joint_mask: wp.array + """A persistent joint mask for capturable action application.""" + + def __init__(self, cfg: actions_cfg.JointActionCfg, env: ManagerBasedEnv) -> None: + # initialize the action term + super().__init__(cfg, env) + + # resolve the joints over which the action term is applied + self._joint_ids, self._joint_names = self._asset.find_joints( + self.cfg.joint_names, preserve_order=self.cfg.preserve_order + ) + self._num_joints = len(self._joint_ids) + # log the resolved joint names for debugging + logger.info( + f"Resolved joint names for the action term {self.__class__.__name__}:" + f" {self._joint_names} [{self._joint_ids}]" + ) + + # Avoid indexing across all joints for efficiency + if self._num_joints == self._asset.num_joints and not self.cfg.preserve_order: + self._joint_ids = slice(None) + + # FIXME: ArticulationData.resolve_joint_mask is not available on this branch. + # Port resolve_*_mask methods from dev/newton when articulation_data is aligned. + _all_joint_mask = wp.ones((self._asset.num_joints,), dtype=wp.bool, device=self.device) + _scratch_joint_mask = wp.zeros((self._asset.num_joints,), dtype=wp.bool, device=self.device) + self._joint_mask = wp.clone( + resolve_1d_mask( + ids=self._joint_ids, + mask=None, + all_mask=_all_joint_mask, + scratch_mask=_scratch_joint_mask, + device=self.device, + ) + ) + + # create tensors for raw and processed actions (Warp) + self._raw_actions = wp.zeros((self.num_envs, self.action_dim), dtype=wp.float32, device=self.device) + self._processed_actions = wp.zeros_like(self.raw_actions) + # FIXME: dev/newton set_joint_effort_target accepts partial data + joint_mask. Our branch + # has separate _index (partial data) and _mask (full data) variants. Pre-compute joint_ids + # as warp array for the _index variant. + if self._joint_ids == slice(None): + self._joint_ids_wp = None # None means all joints + else: + self._joint_ids_wp = wp.array(list(self._joint_ids), dtype=wp.int32, device=self.device) + + # parse scale + if isinstance(cfg.scale, (float, int)): + self._scale = wp.array([float(cfg.scale)] * self.action_dim, dtype=wp.float32, device=self.device) + elif isinstance(cfg.scale, dict): + scale_per_joint = [1.0] * self.action_dim + # resolve the dictionary config + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.scale, self._joint_names) + for idx, value in zip(index_list, value_list): + scale_per_joint[idx] = float(value) + self._scale = wp.array(scale_per_joint, dtype=wp.float32, device=self.device) + else: + raise ValueError(f"Unsupported scale type: {type(cfg.scale)}. Supported types are float and dict.") + + # parse offset + if isinstance(cfg.offset, (float, int)): + self._offset = wp.array([float(cfg.offset)] * self.action_dim, dtype=wp.float32, device=self.device) + elif isinstance(cfg.offset, dict): + offset_per_joint = [0.0] * self.action_dim + # resolve the dictionary config + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.offset, self._joint_names) + for idx, value in zip(index_list, value_list): + offset_per_joint[idx] = float(value) + self._offset = wp.array(offset_per_joint, dtype=wp.float32, device=self.device) + else: + raise ValueError(f"Unsupported offset type: {type(cfg.offset)}. Supported types are float and dict.") + + # parse clip + clip_low = [-float("inf")] * self.action_dim + clip_high = [float("inf")] * self.action_dim + if self.cfg.clip is not None: + if isinstance(cfg.clip, dict): + index_list, _, value_list = string_utils.resolve_matching_names_values(self.cfg.clip, self._joint_names) + for idx, value in zip(index_list, value_list): + clip_low[idx] = float(value[0]) + clip_high[idx] = float(value[1]) + else: + raise ValueError(f"Unsupported clip type: {type(cfg.clip)}. Supported types are dict.") + + clip_np = np.column_stack([clip_low, clip_high]).astype(np.float32) + self._clip = wp.array(clip_np, dtype=wp.float32, device=self.device) + + """ + Properties. + """ + + @property + def action_dim(self) -> int: + return self._num_joints + + @property + def raw_actions(self) -> wp.array: + return self._raw_actions + + @property + def processed_actions(self) -> wp.array: + return self._processed_actions + + @property + def IO_descriptor(self) -> GenericActionIODescriptor: + """The IO descriptor of the action term. + + This descriptor is used to describe the action term of the joint action. + It adds the following information to the base descriptor: + - joint_names: The names of the joints. + - scale: The scale of the action term. + - offset: The offset of the action term. + - clip: The clip of the action term. + + Returns: + The IO descriptor of the action term. + """ + super().IO_descriptor + self._IO_descriptor.shape = (self.action_dim,) + self._IO_descriptor.dtype = str(self.raw_actions.dtype) + self._IO_descriptor.action_type = "JointAction" + self._IO_descriptor.joint_names = self._joint_names + self._IO_descriptor.scale = self._scale + # This seems to be always [4xNum_joints] IDK why. Need to check. + if isinstance(self._offset, wp.array): + self._IO_descriptor.offset = self._offset.numpy().tolist() + else: + self._IO_descriptor.offset = None + # FIXME: This is not correct. Add list support. + if self.cfg.clip is not None: + if isinstance(self._clip, wp.array): + self._IO_descriptor.clip = self._clip.numpy().tolist() + else: + self._IO_descriptor.clip = None + else: + self._IO_descriptor.clip = None + return self._IO_descriptor + + """ + Operations. + """ + + def process_actions(self, actions: wp.array, action_offset: int = 0): + wp.launch( + kernel=_process_joint_actions_kernel, + dim=(self.num_envs, self.action_dim), + inputs=[ + actions, + int(action_offset), + self._scale, + self._offset, + self._clip, + self._raw_actions, + self._processed_actions, + ], + device=self.device, + ) + + def reset(self, env_mask: wp.array | None = None) -> None: + """Resets the action term (mask-based).""" + if env_mask is None: + self._raw_actions.fill_(0.0) + return + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.action_dim), + inputs=[env_mask, self._raw_actions], + device=self.device, + ) + + +class JointEffortAction(JointAction): + """Joint action term that applies the processed actions to the articulation's joints as effort commands.""" + + cfg: actions_cfg.JointEffortActionCfg + """The configuration of the action term.""" + + def __init__(self, cfg: actions_cfg.JointEffortActionCfg, env: ManagerBasedEnv): + super().__init__(cfg, env) + + def apply_actions(self): + # set joint effort targets + # FIXME: dev/newton uses set_joint_effort_target(data, joint_mask=) which accepts + # partial data. Our branch uses the separate _index variant for partial data. + self._asset.set_joint_effort_target_index(target=self.processed_actions, joint_ids=self._joint_ids_wp) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py new file mode 100644 index 00000000000..053da8db722 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py @@ -0,0 +1,212 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first overrides for common event terms. + +These functions are intended to be used with the experimental Warp-first +:class:`isaaclab_experimental.managers.EventManager` (mask-based interval/reset). + +Why this exists: +- Stable event terms (e.g. `isaaclab.envs.mdp.events.reset_joints_by_offset`) often build torch tensors and then + call into Newton articulation writers with partial indices (env_ids/joint_ids). +- On the Newton backend, passing torch tensors triggers expensive torch->warp conversions that currently allocate + full `(num_envs, num_joints)` buffers (see `isaaclab.utils.warp.utils.make_complete_data_from_torch_dual_index`). + +These Warp-first implementations avoid that by writing directly into the sim-bound Warp state buffers +(`asset.data.joint_pos` / `asset.data.joint_vel`) for the selected envs/joints. + +Notes: +- These terms assume the Newton/Warp backend (Warp arrays are available for joint state and defaults). +- For best performance, pass :class:`isaaclab_experimental.managers.SceneEntityCfg` so `joint_ids_wp` is cached. +""" + +from __future__ import annotations + +import warp as wp + +from isaaclab.assets import Articulation + +from isaaclab_experimental.managers import SceneEntityCfg + + +@wp.kernel +def _reset_joints_by_offset_kernel( + env_mask: wp.array(dtype=wp.bool), + joint_ids: wp.array(dtype=wp.int32), + rng_state: wp.array(dtype=wp.uint32), + default_joint_pos: wp.array(dtype=wp.float32, ndim=2), + default_joint_vel: wp.array(dtype=wp.float32, ndim=2), + joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_vel: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), + soft_joint_vel_limits: wp.array(dtype=wp.float32, ndim=2), + pos_lo: float, + pos_hi: float, + vel_lo: float, + vel_hi: float, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + # 1 thread per env so per-env RNG state updates are race-free. + state = rng_state[env_id] + for joint_i in range(joint_ids.shape[0]): + joint_id = joint_ids[joint_i] + + # offset samples in the provided ranges (Warp RNG state pattern) + pos_off = wp.randf(state, pos_lo, pos_hi) + vel_off = wp.randf(state, vel_lo, vel_hi) + + pos = default_joint_pos[env_id, joint_id] + pos_off + vel = default_joint_vel[env_id, joint_id] + vel_off + + # clamp to soft limits + lim = soft_joint_pos_limits[env_id, joint_id] + pos = wp.clamp(pos, lim.x, lim.y) + vmax = soft_joint_vel_limits[env_id, joint_id] + vel = wp.clamp(vel, -vmax, vmax) + + # write into sim-bound state buffers + joint_pos[env_id, joint_id] = pos + joint_vel[env_id, joint_id] = vel + + rng_state[env_id] = state + + +def reset_joints_by_offset( + env, + env_mask: wp.array, + position_range: tuple[float, float], + velocity_range: tuple[float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Warp-first reset of joint state by random offsets around defaults. + + This overrides the stable `isaaclab.envs.mdp.events.reset_joints_by_offset` when importing + via `isaaclab_experimental.envs.mdp`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # Assume cfg params are already resolved by the manager stack (Warp-first workflow). + if asset_cfg.joint_ids_wp is None: + raise ValueError( + f"reset_joints_by_offset requires an experimental SceneEntityCfg with resolved joint_ids_wp, " + f"but got None for asset '{asset_cfg.name}'. " + "Use isaaclab_experimental.managers.SceneEntityCfg and ensure joint_names are set." + ) + if not hasattr(env, "rng_state_wp") or env.rng_state_wp is None: + raise AttributeError( + "reset_joints_by_offset requires env.rng_state_wp to be initialized. " + "Use ManagerBasedEnvWarp or ManagerBasedRLEnvWarp as the base environment." + ) + + wp.launch( + kernel=_reset_joints_by_offset_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + asset_cfg.joint_ids_wp, + env.rng_state_wp, + asset.data.default_joint_pos, + asset.data.default_joint_vel, + asset.data.joint_pos, + asset.data.joint_vel, + asset.data.soft_joint_pos_limits, + asset.data.soft_joint_vel_limits, + float(position_range[0]), + float(position_range[1]), + float(velocity_range[0]), + float(velocity_range[1]), + ], + device=env.device, + ) + + +@wp.kernel +def _reset_joints_by_scale_kernel( + env_mask: wp.array(dtype=wp.bool), + joint_ids: wp.array(dtype=wp.int32), + rng_state: wp.array(dtype=wp.uint32), + default_joint_pos: wp.array(dtype=wp.float32, ndim=2), + default_joint_vel: wp.array(dtype=wp.float32, ndim=2), + joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_vel: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), + soft_joint_vel_limits: wp.array(dtype=wp.float32, ndim=2), + pos_lo: float, + pos_hi: float, + vel_lo: float, + vel_hi: float, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + state = rng_state[env_id] + for joint_i in range(joint_ids.shape[0]): + joint_id = joint_ids[joint_i] + + # scale samples in the provided ranges + pos_scale = wp.randf(state, pos_lo, pos_hi) + vel_scale = wp.randf(state, vel_lo, vel_hi) + + pos = default_joint_pos[env_id, joint_id] * pos_scale + vel = default_joint_vel[env_id, joint_id] * vel_scale + + lim = soft_joint_pos_limits[env_id, joint_id] + pos = wp.clamp(pos, lim.x, lim.y) + vmax = soft_joint_vel_limits[env_id, joint_id] + vel = wp.clamp(vel, -vmax, vmax) + + # write into sim + joint_pos[env_id, joint_id] = pos + joint_vel[env_id, joint_id] = vel + + rng_state[env_id] = state + + +def reset_joints_by_scale( + env, + env_mask: wp.array, + position_range: tuple[float, float], + velocity_range: tuple[float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Warp-first reset of joint state by scaling defaults with random factors.""" + asset: Articulation = env.scene[asset_cfg.name] + + if asset_cfg.joint_ids_wp is None: + raise ValueError( + f"reset_joints_by_scale requires an experimental SceneEntityCfg with resolved joint_ids_wp, " + f"but got None for asset '{asset_cfg.name}'. " + "Use isaaclab_experimental.managers.SceneEntityCfg and ensure joint_names are set." + ) + if not hasattr(env, "rng_state_wp") or env.rng_state_wp is None: + raise AttributeError( + "reset_joints_by_scale requires env.rng_state_wp to be initialized. " + "Use ManagerBasedEnvWarp or ManagerBasedRLEnvWarp as the base environment." + ) + + wp.launch( + kernel=_reset_joints_by_scale_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + asset_cfg.joint_ids_wp, + env.rng_state_wp, + asset.data.default_joint_pos, + asset.data.default_joint_vel, + asset.data.joint_pos, + asset.data.joint_vel, + asset.data.soft_joint_pos_limits, + asset.data.soft_joint_vel_limits, + float(position_range[0]), + float(position_range[1]), + float(velocity_range[0]), + float(velocity_range[1]), + ], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py new file mode 100644 index 00000000000..ae6b1998588 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py @@ -0,0 +1,106 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first observation terms (experimental, Cartpole-focused). + +All functions in this file follow the Warp-compatible observation signature expected by the +experimental Warp-first observation manager: + +- ``func(env, out, **params) -> None`` + +where ``out`` is a pre-allocated Warp array with float32 dtype and shape ``(num_envs, term_dim)``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation + +from isaaclab_experimental.envs.utils.io_descriptors import ( + generic_io_descriptor_warp, + record_joint_names, + record_joint_pos_offsets, + record_joint_shape, + record_joint_vel_offsets, +) +from isaaclab_experimental.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _joint_pos_rel_gather_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + default_joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_ids: wp.array(dtype=wp.int32), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, k = wp.tid() + j = joint_ids[k] + out[env_id, k] = joint_pos[env_id, j] - default_joint_pos[env_id, j] + + +@generic_io_descriptor_warp( + observation_type="JointState", + on_inspect=[record_joint_names, record_joint_shape, record_joint_pos_offsets], + units="rad", +) +def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Joint positions relative to defaults. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + + # Subset selection (requires a pre-resolved Warp joint-id list). + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_pos_rel_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_pos, asset.data.default_joint_pos, joint_ids_wp, out], + device=env.device, + ) + + +@wp.kernel +def _joint_vel_rel_gather_kernel( + joint_vel: wp.array(dtype=wp.float32, ndim=2), + default_joint_vel: wp.array(dtype=wp.float32, ndim=2), + joint_ids: wp.array(dtype=wp.int32), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, k = wp.tid() + j = joint_ids[k] + out[env_id, k] = joint_vel[env_id, j] - default_joint_vel[env_id, j] + + +@generic_io_descriptor_warp( + observation_type="JointState", + on_inspect=[record_joint_names, record_joint_shape, record_joint_vel_offsets], + units="rad/s", +) +def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Joint velocities relative to defaults. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + + # Subset selection (requires a pre-resolved Warp joint-id list). + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_vel_rel_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_vel, asset.data.default_joint_vel, joint_ids_wp, out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py new file mode 100644 index 00000000000..ef34627eb54 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Common functions that can be used to enable reward functions (experimental). + +This module is intentionally minimal: it only contains reward terms that are currently +used by the experimental manager-based Cartpole task. + +All functions in this file follow the Warp-compatible reward signature expected by +`isaaclab_experimental.managers.RewardManager`: + +- ``func(env, out, **params) -> None`` + +where ``out`` is a pre-allocated Warp array of shape ``(num_envs,)`` with ``float32`` dtype. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation + +from isaaclab_experimental.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +""" +General. +""" + + +@wp.kernel +def _is_alive_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = wp.where(terminated[i], 0.0, 1.0) + + +def is_alive(env: ManagerBasedRLEnv, out: wp.array(dtype=wp.float32)) -> None: + """Reward for being alive. Writes into ``out`` (shape: (num_envs,)).""" + wp.launch( + kernel=_is_alive_kernel, + dim=env.num_envs, + inputs=[env.termination_manager.terminated_wp, out], + device=env.device, + ) + + +@wp.kernel +def _is_terminated_kernel(terminated: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = wp.where(terminated[i], 1.0, 0.0) + + +def is_terminated(env: ManagerBasedRLEnv, out) -> None: + """Penalize terminated episodes. Writes into ``out``.""" + wp.launch( + kernel=_is_terminated_kernel, + dim=env.num_envs, + inputs=[env.termination_manager.terminated_wp, out], + device=env.device, + ) + + +""" +Joint penalties. +""" + + +@wp.kernel +def _sum_abs_masked_kernel( + x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) +): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + if joint_mask[j]: + s += wp.abs(x[i, j]) + out[i] = s + + +def joint_vel_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg) -> None: + """Penalize joint velocities on the articulation using an L1-kernel. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_abs_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py new file mode 100644 index 00000000000..f8ca65ba980 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py @@ -0,0 +1,82 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Common functions that can be used to activate terminations (experimental). + +This module is intentionally minimal: it only contains termination terms that are currently +used by the experimental manager-based Cartpole task. + +All functions in this file follow the Warp-compatible termination signature expected by +`isaaclab_experimental.managers.TerminationManager`: + +- ``func(env, out, **params) -> None`` + +where ``out`` is a pre-allocated Warp array of shape ``(num_envs,)`` with boolean dtype. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation + +from isaaclab_experimental.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _time_out_kernel( + episode_length: wp.array(dtype=wp.int64), max_episode_length: wp.int64, out: wp.array(dtype=wp.bool) +): + i = wp.tid() + out[i] = episode_length[i] >= max_episode_length + + +def time_out(env: ManagerBasedRLEnv, out) -> None: + """Terminate the episode when episode length exceeds the maximum episode length.""" + wp.launch( + kernel=_time_out_kernel, + dim=env.num_envs, + inputs=[env._episode_length_buf_wp, env.max_episode_length, out], + device=env.device, + ) + + +@wp.kernel +def _joint_pos_out_of_manual_limit_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + lower: float, + upper: float, + out: wp.array(dtype=wp.bool), +): + i = wp.tid() + violated = bool(False) + for j in range(joint_pos.shape[1]): + if joint_mask[j]: + v = joint_pos[i, j] + if v < lower or v > upper: + violated = True + break + out[i] = violated + + +def joint_pos_out_of_manual_limit( + env: ManagerBasedRLEnv, out, bounds: tuple[float, float], asset_cfg: SceneEntityCfg +) -> None: + """Terminate when joint positions are outside configured bounds. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + assert asset_cfg.joint_mask is not None + assert asset.data.joint_pos.shape[1] == asset_cfg.joint_mask.shape[0] + wp.launch( + kernel=_joint_pos_out_of_manual_limit_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset_cfg.joint_mask, bounds[0], bounds[1], out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py new file mode 100644 index 00000000000..47cbad37063 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py @@ -0,0 +1,301 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first IO descriptor decorator and inspection hooks (experimental). + +This module mirrors the stable :mod:`isaaclab.envs.utils.io_descriptors` but is +designed for Warp-first observation terms whose signature is:: + + func(env, out, **params) -> None + +Key difference from the stable decorator: + During inspection (``inspect=True``), the underlying function is **not called**. + Hooks derive metadata from ``env`` / scene / config objects instead of from a + returned output tensor. ``output`` is passed as ``None`` so that hooks share the + same ``(output, descriptor, **kwargs)`` signature as the stable hooks. + +The :class:`GenericObservationIODescriptor` dataclass is reused from the stable +package so that the resulting descriptor dicts are fully compatible with the +existing export / YAML pipeline. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar + +import warp as wp + +# Reuse the descriptor dataclass from the stable package. +from isaaclab.envs.utils.io_descriptors import GenericObservationIODescriptor + +if TYPE_CHECKING: + from isaaclab.assets.articulation import Articulation + from isaaclab.envs import ManagerBasedEnv + +import dataclasses +import functools +import inspect + +# These are defined to help with type hinting +P = ParamSpec("P") +R = TypeVar("R") + + +# --------------------------------------------------------------------------- +# Decorator +# --------------------------------------------------------------------------- + + +# Automatically builds a descriptor from the kwargs +def _make_descriptor(**kwargs: Any) -> GenericObservationIODescriptor: + """Split *kwargs* into (known dataclass fields) and (extras).""" + field_names = {f.name for f in dataclasses.fields(GenericObservationIODescriptor)} + known = {k: v for k, v in kwargs.items() if k in field_names} + extras = {k: v for k, v in kwargs.items() if k not in field_names} + + desc = GenericObservationIODescriptor(**known) + # User defined extras are stored in the descriptor under the `extras` field + desc.extras = extras + return desc + + +# TODO(jichuanh): The exact usage is unclear and this need revisit +# Decorator factory for Warp-first IO descriptors. +def generic_io_descriptor_warp( + _func: Callable[Concatenate[ManagerBasedEnv, P], R] | None = None, + *, + on_inspect: Callable[..., Any] | list[Callable[..., Any]] | None = None, + **descriptor_kwargs: Any, +) -> Callable[[Callable[Concatenate[ManagerBasedEnv, P], R]], Callable[Concatenate[ManagerBasedEnv, P], R]]: + """IO descriptor decorator for Warp-first observation terms. + + Works like the stable :func:`generic_io_descriptor` but adapted to the + ``func(env, out, **params) -> None`` signature: + + * On **normal calls** the decorator passes through to the wrapped function. + * On **inspection** (``inspect=True`` keyword argument) the wrapped function + is *not* called. Instead, the registered hooks are invoked with the same + ``(output, descriptor, **kwargs)`` contract as the stable hooks, except + ``output`` is always ``None``. + + This decorator can be used in the same ways as the stable decorator: + + 1. With keyword arguments:: + + @generic_io_descriptor_warp(observation_type="JointState", units="rad") + def my_func(env, out, asset_cfg=SceneEntityCfg("robot")) -> None: ... + + 2. With a pre-built descriptor:: + + @generic_io_descriptor_warp(GenericObservationIODescriptor(description="..")) + def my_func(env, out, asset_cfg=SceneEntityCfg("robot")) -> None: ... + + 3. With inspection hooks:: + + @generic_io_descriptor_warp( + observation_type="JointState", + on_inspect=[record_joint_names, record_joint_shape, record_joint_pos_offsets], + units="rad", + ) + def joint_pos_rel(env, out, asset_cfg=SceneEntityCfg("robot")) -> None: ... + + Args: + _func: The function to decorate (or a pre-built descriptor). + on_inspect: Hook(s) called during inspection. + **descriptor_kwargs: Keyword arguments to pass to the descriptor. + + Returns: + A decorator that can be used to decorate a function. + """ + # If the decorator is used with a descriptor, use it as the descriptor. + if _func is not None and isinstance(_func, GenericObservationIODescriptor): + descriptor = _func + _func = None + else: + descriptor = _make_descriptor(**descriptor_kwargs) + + # Ensures the hook is a list + if callable(on_inspect): + inspect_hooks: list[Callable[..., Any]] = [on_inspect] + else: + inspect_hooks: list[Callable[..., Any]] = list(on_inspect or []) # handles None + + def _apply(func: Callable[Concatenate[ManagerBasedEnv, P], R]) -> Callable[Concatenate[ManagerBasedEnv, P], R]: + # Capture the signature of the function + sig = inspect.signature(func) + + @functools.wraps(func) + def wrapper(env: ManagerBasedEnv, *args: P.args, **kwargs: P.kwargs) -> R: + inspect_flag: bool = kwargs.pop("inspect", False) + if inspect_flag: + # Warp-first: do NOT call the function (it requires a pre-allocated + # ``out`` buffer that does not exist at inspection time). + # Use bind_partial (tolerates missing ``out``) and apply_defaults so + # that hooks see resolved default values (e.g. ``asset_cfg``). + bound = sig.bind_partial(env, **kwargs) + bound.apply_defaults() + call_kwargs = { + "output": None, + "descriptor": descriptor, + **bound.arguments, + } + for hook in inspect_hooks: + hook(**call_kwargs) + return # noqa: R502 + return func(env, *args, **kwargs) + + # --- Descriptor bookkeeping --- + descriptor.name = func.__name__ + descriptor.full_path = f"{func.__module__}.{func.__name__}" + # Warp-first terms always operate in float32. + descriptor.dtype = str(descriptor.dtype) if descriptor.dtype is not None else "float32" + # Check if description is set in the descriptor + if descriptor.description is None and func.__doc__: + descriptor.description = " ".join(func.__doc__.split()) + + # Adds the descriptor to the wrapped function as an attribute + wrapper._descriptor = descriptor + wrapper._has_descriptor = True + # Alters the signature of the wrapped function to make it match the original function. + # This allows the wrapped functions to pass the checks in the managers. + wrapper.__signature__ = sig + return wrapper + + # If the decorator is used without parentheses, _func will be the function itself. + if callable(_func): + return _apply(_func) + return _apply + + +# --------------------------------------------------------------------------- +# Inspection hooks +# +# All hooks follow the stable convention: (output, descriptor, **kwargs). +# For Warp-first terms ``output`` is always ``None``; hooks that need shape +# or dtype information must derive it from the scene / config objects in +# **kwargs rather than from the output tensor. +# --------------------------------------------------------------------------- + + +def record_shape(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: + """Record the shape of the output buffer. + + No-op when ``output`` is ``None`` (the typical case during Warp-first + inspection). Use a type-specific hook such as :func:`record_joint_shape` + to derive shape from config instead. + + Args: + output: The pre-allocated output buffer, or ``None`` during inspection. + descriptor: The descriptor to record the shape to. + **kwargs: Additional keyword arguments. + """ + if output is None: + return + descriptor.shape = (output.shape[-1],) + + +def record_dtype(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: + """Record the dtype of the output buffer. + + No-op when ``output`` is ``None`` (the typical case during Warp-first + inspection — dtype is already set to ``"float32"`` by the decorator). + + Args: + output: The pre-allocated output buffer, or ``None`` during inspection. + descriptor: The descriptor to record the dtype to. + **kwargs: Additional keyword arguments. + """ + if output is None: + return + descriptor.dtype = str(output.dtype) + + +def record_joint_shape(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: + """Derive the observation shape from the resolved ``joint_ids`` count. + + This is the Warp-first alternative to :func:`record_shape` for joint-based + observations. It ignores ``output`` and reads the shape from the asset + configuration instead. + + Args: + output: Ignored — kept for hook signature compatibility. + descriptor: The descriptor to update. + **kwargs: Must contain ``env`` and ``asset_cfg``. + """ + asset: Articulation = kwargs["env"].scene[kwargs["asset_cfg"].name] + joint_ids = kwargs["asset_cfg"].joint_ids + if joint_ids == slice(None, None, None): + descriptor.shape = (len(asset.joint_names),) + else: + descriptor.shape = (len(joint_ids),) + + +def record_joint_names(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: + """Record the joint names selected by ``asset_cfg.joint_ids``. + + Expects the ``asset_cfg`` keyword argument to be set. + + Args: + output: Ignored — kept for hook signature compatibility. + descriptor: The descriptor to record the joint names to. + **kwargs: Additional keyword arguments. + """ + asset: Articulation = kwargs["env"].scene[kwargs["asset_cfg"].name] + joint_ids = kwargs["asset_cfg"].joint_ids + if joint_ids == slice(None, None, None): + joint_ids = list(range(len(asset.joint_names))) + descriptor.joint_names = [asset.joint_names[i] for i in joint_ids] + + +def record_body_names(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: + """Record the body names selected by ``asset_cfg.body_ids``. + + Expects the ``asset_cfg`` keyword argument to be set. + + Args: + output: Ignored — kept for hook signature compatibility. + descriptor: The descriptor to record the body names to. + **kwargs: Additional keyword arguments. + """ + asset: Articulation = kwargs["env"].scene[kwargs["asset_cfg"].name] + body_ids = kwargs["asset_cfg"].body_ids + if body_ids == slice(None, None, None): + body_ids = list(range(len(asset.body_names))) + descriptor.body_names = [asset.body_names[i] for i in body_ids] + + +def record_joint_pos_offsets(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs): + """Record the default joint-position offsets (first env instance). + + Expects the ``asset_cfg`` keyword argument to be set. + + Args: + output: Ignored — kept for hook signature compatibility. + descriptor: The descriptor to record the joint position offsets to. + **kwargs: Additional keyword arguments. + """ + asset: Articulation = kwargs["env"].scene[kwargs["asset_cfg"].name] + ids = kwargs["asset_cfg"].joint_ids + # Get the offsets of the joints for the first robot in the scene. + # This assumes that all robots have the same joint offsets. + descriptor.joint_pos_offsets = wp.to_torch(asset.data.default_joint_pos).clone()[:, ids][0] + + +def record_joint_vel_offsets(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs): + """Record the default joint-velocity offsets (first env instance). + + Expects the ``asset_cfg`` keyword argument to be set. + + Args: + output: Ignored — kept for hook signature compatibility. + descriptor: The descriptor to record the joint velocity offsets to. + **kwargs: Additional keyword arguments. + """ + asset: Articulation = kwargs["env"].scene[kwargs["asset_cfg"].name] + ids = kwargs["asset_cfg"].joint_ids + # Get the offsets of the joints for the first robot in the scene. + # This assumes that all robots have the same joint offsets. + descriptor.joint_vel_offsets = wp.to_torch(asset.data.default_joint_vel).clone()[:, ids][0] diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py new file mode 100644 index 00000000000..62d3171d32a --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental manager implementations. + +This package is intended for experimental forks of manager implementations while +keeping stable task configs and the stable `isaaclab.managers` package intact. +""" + +from isaaclab.managers import * # noqa: F401,F403 + +# Override the stable implementation with the experimental fork. +from .action_manager import ActionManager # noqa: F401 +from .command_manager import CommandManager # noqa: F401 +from .event_manager import EventManager # noqa: F401 +from .manager_term_cfg import ObservationTermCfg, RewardTermCfg, TerminationTermCfg # noqa: F401 +from .observation_manager import ObservationManager # noqa: F401 +from .reward_manager import RewardManager # noqa: F401 +from .scene_entity_cfg import SceneEntityCfg # noqa: F401 +from .termination_manager import TerminationManager # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py new file mode 100644 index 00000000000..78ef70a99b1 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py @@ -0,0 +1,506 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Action manager for processing actions sent to the environment.""" + +from __future__ import annotations + +import inspect +import re +import weakref +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import torch +import warp as wp +from prettytable import PrettyTable + +from isaaclab.assets import AssetBase +from isaaclab.envs.utils.io_descriptors import GenericActionIODescriptor + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import ActionTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _zero_masked_2d( + # input + mask: wp.array(dtype=wp.bool), + # input/output + data: wp.array(dtype=wp.float32, ndim=2), +): + """Zero rows of a 2D buffer where ``mask`` is True. + + Launched with dim = (num_envs, data.shape[1]). + """ + + env_id, j = wp.tid() + if mask[env_id]: + data[env_id, j] = 0.0 + + +class ActionTerm(ManagerTermBase): + """Base class for action terms. + + The action term is responsible for processing the raw actions sent to the environment + and applying them to the asset managed by the term. The action term is comprised of two + operations: + + * Processing of actions: This operation is performed once per **environment step** and + is responsible for pre-processing the raw actions sent to the environment. + * Applying actions: This operation is performed once per **simulation step** and is + responsible for applying the processed actions to the asset managed by the term. + """ + + def __init__(self, cfg: ActionTermCfg, env: ManagerBasedEnv): + """Initialize the action term. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + # call the base class constructor + super().__init__(cfg, env) + # parse config to obtain asset to which the term is applied + self._asset: AssetBase = self._env.scene[self.cfg.asset_name] + self._IO_descriptor = GenericActionIODescriptor() + self._export_IO_descriptor = True + + # add handle for debug visualization (this is set to a valid handle inside set_debug_vis) + self._debug_vis_handle = None + # set initial state of debug visualization + self.set_debug_vis(self.cfg.debug_vis) + + def __del__(self): + """Unsubscribe from the callbacks.""" + if self._debug_vis_handle: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + + """ + Properties. + """ + + @property + @abstractmethod + def action_dim(self) -> int: + """Dimension of the action term.""" + raise NotImplementedError + + @property + @abstractmethod + def raw_actions(self) -> wp.array: + """The input/raw actions sent to the term.""" + raise NotImplementedError + + @property + @abstractmethod + def processed_actions(self) -> wp.array: + """The actions computed by the term after applying any processing.""" + raise NotImplementedError + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the action term has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_debug_vis_impl) + return "NotImplementedError" not in source_code + + @property + def IO_descriptor(self) -> GenericActionIODescriptor: + """The IO descriptor for the action term.""" + self._IO_descriptor.name = re.sub(r"([a-z])([A-Z])", r"\1_\2", self.__class__.__name__).lower() + self._IO_descriptor.full_path = f"{self.__class__.__module__}.{self.__class__.__name__}" + self._IO_descriptor.description = " ".join((self.__class__.__doc__ or "").split()) + self._IO_descriptor.export = self.export_IO_descriptor + return self._IO_descriptor + + @property + def export_IO_descriptor(self) -> bool: + """Whether to export the IO descriptor for the action term.""" + return self._export_IO_descriptor + + """ + Operations. + """ + + def set_debug_vis(self, debug_vis: bool) -> bool: + """Sets whether to visualize the action term data. + Args: + debug_vis: Whether to visualize the action term data. + Returns: + Whether the debug visualization was successfully set. False if the action term does + not support debug visualization. + """ + # check if debug visualization is supported + if not self.has_debug_vis_implementation: + return False + + import omni.kit.app + + # toggle debug visualization objects + self._set_debug_vis_impl(debug_vis) + # toggle debug visualization handles + if debug_vis: + # create a subscriber for the post update event if it doesn't exist + if self._debug_vis_handle is None: + app_interface = omni.kit.app.get_app_interface() + self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( + lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event) + ) + else: + # remove the subscriber if it exists + if self._debug_vis_handle is not None: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + # return success + return True + + @abstractmethod + def process_actions(self, actions: wp.array, action_offset: int = 0): + """Processes the actions sent to the environment. + + Note: + This function is called once per environment step by the manager. + + Args: + actions: The full action buffer of shape (num_envs, total_action_dim). + action_offset: Column offset into the action buffer for this term. + """ + raise NotImplementedError + + @abstractmethod + def apply_actions(self): + """Applies the actions to the asset managed by the term. + + Note: + This is called at every simulation step by the manager. + """ + raise NotImplementedError + + def _set_debug_vis_impl(self, debug_vis: bool): + """Set debug visualization into visualization objects. + This function is responsible for creating the visualization objects if they don't exist + and input ``debug_vis`` is True. If the visualization objects exist, the function should + set their visibility into the stage. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + def _debug_vis_callback(self, event): + """Callback for debug visualization. + This function calls the visualization objects and sets the data to visualize into them. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + +class ActionManager(ManagerBase): + """Manager for processing and applying actions for a given world. + + The action manager handles the interpretation and application of user-defined + actions on a given world. It is comprised of different action terms that decide + the dimension of the expected actions. + + The action manager performs operations at two stages: + + * processing of actions: It splits the input actions to each term and performs any + pre-processing needed. This should be called once at every environment step. + * apply actions: This operation typically sets the processed actions into the assets in the + scene (such as robots). It should be called before every simulation step. + """ + + def __init__(self, cfg: object, env: ManagerBasedEnv): + """Initialize the action manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, ActionTermCfg]``). + env: The environment instance. + + Raises: + ValueError: If the configuration is None. + """ + # check if config is None + if cfg is None: + raise ValueError("Action manager configuration is None. Please provide a valid configuration.") + + # call the base class constructor (this prepares the terms) + super().__init__(cfg, env) + # create buffers to store actions (Warp) + self._action = wp.zeros((self.num_envs, self.total_action_dim), dtype=wp.float32, device=self.device) + self._prev_action = wp.zeros((self.num_envs, self.total_action_dim), dtype=wp.float32, device=self.device) + + # torch views + self._action_torch = wp.to_torch(self._action) + self._prev_action_torch = wp.to_torch(self._prev_action) + + # check if any term has debug visualization implemented + self.cfg.debug_vis = False + for term in self._terms.values(): + self.cfg.debug_vis |= term.cfg.debug_vis + + def __str__(self) -> str: + """Returns: A string representation for action manager.""" + msg = f" contains {len(self._term_names)} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = f"Active Action Terms (shape: {self.total_action_dim})" + table.field_names = ["Index", "Name", "Dimension"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Dimension"] = "r" + # add info on each term + for index, (name, term) in enumerate(self._terms.items()): + table.add_row([index, name, term.action_dim]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def total_action_dim(self) -> int: + """Total dimension of actions.""" + return sum(self.action_term_dim) + + @property + def active_terms(self) -> list[str]: + """Name of active action terms.""" + return self._term_names + + @property + def action_term_dim(self) -> list[int]: + """Shape of each action term.""" + return [term.action_dim for term in self._terms.values()] + + @property + def action(self) -> wp.array: + """The actions sent to the environment. Shape is (num_envs, total_action_dim).""" + return self._action + + @property + def prev_action(self) -> wp.array: + """The previous actions sent to the environment. Shape is (num_envs, total_action_dim).""" + return self._prev_action + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the command terms have debug visualization implemented.""" + # check if function raises NotImplementedError + has_debug_vis = False + for term in self._terms.values(): + has_debug_vis |= term.has_debug_vis_implementation + return has_debug_vis + + @property + def get_IO_descriptors(self) -> list[dict[str, Any]]: + """Get the IO descriptors for the action manager. + + Returns: + A dictionary with keys as the term names and values as the IO descriptors. + """ + + data = [] + + for term_name, term in self._terms.items(): + try: + data.append(term.IO_descriptor.__dict__.copy()) + except Exception as e: + print(f"Error getting IO descriptor for term '{term_name}': {e}") + + formatted_data = [] + for item in data: + name = item.pop("name") + formatted_item = {"name": name, "extras": item.pop("extras")} + if not item.pop("export"): + continue + for k, v in item.items(): + # Check if v is a tuple and convert to list + if isinstance(v, tuple): + v = list(v) + if k in ["description", "units"]: + formatted_item["extras"][k] = v + else: + formatted_item[k] = v + formatted_data.append(formatted_item) + + return formatted_data + + """ + Operations. + """ + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + idx = 0 + # Copy to host for debug/inspection (not on hot path). + for name, term in self._terms.items(): + term_actions = self._action_torch[env_idx, idx : idx + term.action_dim] + terms.append((name, term_actions.tolist())) + idx += term.action_dim + return terms + + def set_debug_vis(self, debug_vis: bool): + """Sets whether to visualize the action data. + Args: + debug_vis: Whether to visualize the action data. + Returns: + Whether the debug visualization was successfully set. False if the action + does not support debug visualization. + """ + for term in self._terms.values(): + term.set_debug_vis(debug_vis) + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, Any]: + """Resets the action history. + + Args: + env_ids: The specific environment indices to reset. + If None, all environments are considered. + env_mask: Boolean Warp mask of shape (num_envs,) indicating which envs to reset. + If provided, takes precedence over ``env_ids``. + + Returns: + An empty dictionary. + """ + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + if wp.get_device().is_capturing: + raise RuntimeError( + "ActionManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + # reset the action history + if env_mask is None: + self._prev_action.fill_(0.0) + self._action.fill_(0.0) + else: + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.total_action_dim), + inputs=[env_mask, self._prev_action], + device=self.device, + ) + wp.launch( + kernel=_zero_masked_2d, + dim=(self.num_envs, self.total_action_dim), + inputs=[env_mask, self._action], + device=self.device, + ) + + # reset all action terms + for term in self._terms.values(): + term.reset(env_mask=env_mask) + # nothing to log here + return {} + + def process_action(self, action: wp.array): + """Processes the actions sent to the environment. + + Note: + This function should be called once per environment step. + + Args: + action: The actions to process. Shape is (num_envs, total_action_dim). + """ + # check if action dimension is valid + if self.total_action_dim != action.shape[1]: + raise ValueError(f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}.") + + # store the input actions + wp.copy(self._prev_action, self._action) + wp.copy(self._action, action) + + # split the actions and apply to each term + idx = 0 + for term in self._terms.values(): + term.process_actions(self._action, idx) + idx += term.action_dim + + def apply_action(self) -> None: + """Applies the actions to the environment/simulation. + + Note: + This should be called at every simulation step. + """ + for term in self._terms.values(): + term.apply_actions() + + def get_term(self, name: str) -> ActionTerm: + """Returns the action term with the specified name. + + Args: + name: The name of the action term. + + Returns: + The action term with the specified name. + """ + return self._terms[name] + + def serialize(self) -> dict: + """Serialize the action manager configuration. + + Returns: + A dictionary of serialized action term configurations. + """ + return {term_name: term.serialize() for term_name, term in self._terms.items()} + + """ + Helper functions. + """ + + def _prepare_terms(self): + # create buffers to parse and store terms + self._term_names: list[str] = list() + self._terms: dict[str, ActionTerm] = dict() + + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # parse action terms from the config + for term_name, term_cfg in cfg_items: + # check if term config is None + if term_cfg is None: + continue + # check valid type + if not isinstance(term_cfg, ActionTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ActionTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # create the action term + term = term_cfg.class_type(term_cfg, self._env) + # sanity check if term is valid type + if not isinstance(term, ActionTerm): + raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.") + # add term name and parameters + self._term_names.append(term_name) + self._terms[term_name] = term diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/command_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/command_manager.py new file mode 100644 index 00000000000..8b4e8f83dbe --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/command_manager.py @@ -0,0 +1,599 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Command manager for generating and updating commands.""" + +from __future__ import annotations + +import inspect +import weakref +from abc import abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +import warp as wp +from prettytable import PrettyTable + +from isaaclab_experimental.utils.warp.kernels import compute_reset_scale, count_masked + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import CommandTermCfg + +# import omni.kit.app + + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _sum_and_zero_masked( + mask: wp.array(dtype=wp.bool), + scale: wp.array(dtype=wp.float32), + metric: wp.array(dtype=wp.float32), + out_mean: wp.array(dtype=wp.float32), +): + env_id = wp.tid() + if mask[env_id]: + wp.atomic_add(out_mean, 0, metric[env_id] * scale[0]) + metric[env_id] = 0.0 + + +@wp.kernel +def _zero_counter_masked(mask: wp.array(dtype=wp.bool), counter: wp.array(dtype=wp.int32)): + env_id = wp.tid() + if mask[env_id]: + counter[env_id] = 0 + + +@wp.kernel +def _step_time_left_and_build_resample_mask( + time_left: wp.array(dtype=wp.float32), + dt: wp.float32, + out_mask: wp.array(dtype=wp.bool), +): + env_id = wp.tid() + t = time_left[env_id] - dt + time_left[env_id] = t + out_mask[env_id] = t <= wp.float32(0.0) + + +@wp.kernel +def _resample_time_left_and_increment_counter( + mask: wp.array(dtype=wp.bool), + time_left: wp.array(dtype=wp.float32), + counter: wp.array(dtype=wp.int32), + rng_state: wp.array(dtype=wp.uint32), + lower: wp.float32, + upper: wp.float32, +): + env_id = wp.tid() + if mask[env_id]: + s = rng_state[env_id] + time_left[env_id] = wp.randf(s, lower, upper) + rng_state[env_id] = s + counter[env_id] = counter[env_id] + 1 + + +class CommandTerm(ManagerTermBase): + """The base class for implementing a command term. + + A command term is used to generate commands for goal-conditioned tasks. For example, + in the case of a goal-conditioned navigation task, the command term can be used to + generate a target position for the robot to navigate to. + + It implements a resampling mechanism that allows the command to be resampled at a fixed + frequency. The resampling frequency can be specified in the configuration object. + Additionally, it is possible to assign a visualization function to the command term + that can be used to visualize the command in the simulator. + """ + + def __init__(self, cfg: CommandTermCfg, env: ManagerBasedRLEnv): + """Initialize the command generator class. + + Args: + cfg: The configuration parameters for the command generator. + env: The environment object. + """ + super().__init__(cfg, env) + + # create buffers to store the command + # -- metrics that can be used for logging (metric_name -> wp.array(num_envs,)) + self.metrics = dict() + # -- time left before resampling + self.time_left_wp = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + # -- counter for the number of times the command has been resampled within the current episode + self.command_counter_wp = wp.zeros((self.num_envs,), dtype=wp.int32, device=self.device) + + # reset/compute scratch buffers (Warp) + self._reset_count_wp = wp.zeros((1,), dtype=wp.int32, device=self.device) + self._reset_scale_wp = wp.zeros((1,), dtype=wp.float32, device=self.device) + self._resample_mask_wp = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + + # add handle for debug visualization (this is set to a valid handle inside set_debug_vis) + self._debug_vis_handle = None + # set initial state of debug visualization + self.set_debug_vis(self.cfg.debug_vis) + + # pre-allocated reset logging extras (filled during reset) + self._reset_metric_mean_wp: dict[str, wp.array] = {} + self._reset_extras: dict[str, torch.Tensor] = {} + + def __del__(self): + """Unsubscribe from the callbacks.""" + if self._debug_vis_handle: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + + """ + Properties + """ + + @property + @abstractmethod + def command(self) -> torch.Tensor | wp.array: + """The command tensor. Shape is (num_envs, command_dim).""" + raise NotImplementedError + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the command generator has a debug visualization implemented.""" + # check if function raises NotImplementedError + source_code = inspect.getsource(self._set_debug_vis_impl) + return "NotImplementedError" not in source_code + + @property + def reset_extras(self) -> dict[str, torch.Tensor]: + """Pre-allocated reset logging extras for this command term.""" + return self._reset_extras + + """ + Operations. + """ + + def set_debug_vis(self, debug_vis: bool) -> bool: + """Sets whether to visualize the command data. + + Args: + debug_vis: Whether to visualize the command data. + + Returns: + Whether the debug visualization was successfully set. False if the command + generator does not support debug visualization. + """ + # check if debug visualization is supported + if not self.has_debug_vis_implementation: + return False + # toggle debug visualization objects + self._set_debug_vis_impl(debug_vis) + # toggle debug visualization handles + if debug_vis: + # only enable debug_vis if omniverse is available + from isaaclab.sim.simulation_context import SimulationContext + + sim_context = SimulationContext.instance() + if not sim_context.has_omniverse_visualizer(): + return False + # create a subscriber for the post update event if it doesn't exist + if self._debug_vis_handle is None: + import omni.kit.app + + app_interface = omni.kit.app.get_app_interface() + self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop( + lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event) + ) + else: + # remove the subscriber if it exists + if self._debug_vis_handle is not None: + self._debug_vis_handle.unsubscribe() + self._debug_vis_handle = None + # return success + return True + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, torch.Tensor]: + """Reset the command generator and log metrics. + + This function resets the command counter and resamples the command. It should be called + at the beginning of each episode. + + Args: + env_ids: The specific environment indices to reset. + If None, all environments are considered. + env_mask: Boolean Warp mask of shape (num_envs,) selecting reset environments. + If provided, takes precedence over ``env_ids``. + + Returns: + A dictionary containing the information to log under the "{name}" key. + """ + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + if wp.get_device().is_capturing: + raise RuntimeError( + "CommandTerm.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + # compute selected count and reset scale + self._reset_count_wp.zero_() + self._reset_scale_wp.zero_() + wp.launch(kernel=count_masked, dim=self.num_envs, inputs=[env_mask, self._reset_count_wp], device=self.device) + wp.launch( + kernel=compute_reset_scale, + dim=1, + inputs=[self._reset_count_wp, 1.0, self._reset_scale_wp], + device=self.device, + ) + + # update pre-allocated reset extras and clear selected metric rows + for metric_name, metric_value_wp in self.metrics.items(): + out_mean_wp = self._reset_metric_mean_wp[metric_name] + out_mean_wp.zero_() + wp.launch( + kernel=_sum_and_zero_masked, + dim=self.num_envs, + inputs=[env_mask, self._reset_scale_wp, metric_value_wp, out_mean_wp], + device=self.device, + ) + + # set the command counter to zero + wp.launch( + kernel=_zero_counter_masked, + dim=self.num_envs, + inputs=[env_mask, self.command_counter_wp], + device=self.device, + ) + # resample the command + self._resample(env_mask=env_mask) + + return self._reset_extras + + def _prepare_reset_extras(self): + """Pre-allocate reset logging extras from metric definitions.""" + self._reset_metric_mean_wp = {} + self._reset_extras = {} + for metric_name, metric_value in self.metrics.items(): + if not isinstance(metric_value, wp.array): + raise TypeError( + f"Metric '{metric_name}' must be a wp.array(dtype=wp.float32, shape=(num_envs,)). " + f"Received: {type(metric_value)}" + ) + if metric_value.dtype != wp.float32 or metric_value.ndim != 1: + raise TypeError( + f"Metric '{metric_name}' must be wp.float32 1D. " + f"Received dtype={metric_value.dtype}, ndim={metric_value.ndim}." + ) + if metric_value.shape[0] != self.num_envs: + raise ValueError( + f"Metric '{metric_name}' must have shape ({self.num_envs},), received {metric_value.shape}." + ) + out_mean_wp = wp.zeros((1,), dtype=wp.float32, device=self.device) + self._reset_metric_mean_wp[metric_name] = out_mean_wp + self._reset_extras[metric_name] = wp.to_torch(out_mean_wp)[0] + + def compute(self, dt: float): + """Compute the command. + + Args: + dt: The time step passed since the last call to compute. + """ + # update the metrics based on current state + self._update_metrics() + # reduce the time left before resampling and build resample mask + wp.launch( + kernel=_step_time_left_and_build_resample_mask, + dim=self.num_envs, + inputs=[self.time_left_wp, float(dt), self._resample_mask_wp], + device=self.device, + ) + # resample masked envs + self._resample(env_mask=self._resample_mask_wp) + # update the command + self._update_command() + + """ + Helper functions. + """ + + def _resample(self, env_mask: wp.array): + """Resample the command. + + This function resamples the command and time for which the command is applied for the + specified environment mask. + + Args: + env_mask: The boolean environment mask to resample. + """ + if not isinstance(env_mask, wp.array): + raise TypeError(f"env_mask must be a wp.array (got {type(env_mask)}).") + if env_mask.dtype != wp.bool or env_mask.ndim != 1: + raise TypeError(f"env_mask must be wp.bool 1D (got dtype={env_mask.dtype}, ndim={env_mask.ndim}).") + if self._env.rng_state_wp is None: + raise RuntimeError("Environment rng_state_wp is not initialized.") + + # resample time-left and increment command-counter for masked envs + wp.launch( + kernel=_resample_time_left_and_increment_counter, + dim=self.num_envs, + inputs=[ + env_mask, + self.time_left_wp, + self.command_counter_wp, + self._env.rng_state_wp, + float(self.cfg.resampling_time_range[0]), + float(self.cfg.resampling_time_range[1]), + ], + device=self.device, + ) + # resample command values for masked envs + self._resample_command(env_mask) + + """ + Implementation specific functions. + """ + + @abstractmethod + def _update_metrics(self): + """Update the metrics based on the current state.""" + raise NotImplementedError + + @abstractmethod + def _resample_command(self, env_mask: wp.array): + """Resample the command for the specified masked environments.""" + raise NotImplementedError + + @abstractmethod + def _update_command(self): + """Update the command based on the current state.""" + raise NotImplementedError + + def _set_debug_vis_impl(self, debug_vis: bool): + """Set debug visualization into visualization objects. + + This function is responsible for creating the visualization objects if they don't exist + and input ``debug_vis`` is True. If the visualization objects exist, the function should + set their visibility into the stage. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + def _debug_vis_callback(self, event): + """Callback for debug visualization. + + This function calls the visualization objects and sets the data to visualize into them. + """ + raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.") + + +class CommandManager(ManagerBase): + """Manager for generating commands. + + The command manager is used to generate commands for an agent to execute. It makes it convenient to switch + between different command generation strategies within the same environment. For instance, in an environment + consisting of a quadrupedal robot, the command to it could be a velocity command or position command. + By keeping the command generation logic separate from the environment, it is easy to switch between different + command generation strategies. + + The command terms are implemented as classes that inherit from the :class:`CommandTerm` class. + Each command generator term should also have a corresponding configuration class that inherits from the + :class:`CommandTermCfg` class. + """ + + _env: ManagerBasedRLEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: ManagerBasedRLEnv): + """Initialize the command manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, CommandTermCfg]``). + env: The environment instance. + """ + # create buffers to parse and store terms + self._terms: dict[str, CommandTerm] = dict() + + # call the base class constructor (this prepares the terms) + super().__init__(cfg, env) + # store the commands + self._commands = dict() + if self.cfg: + self.cfg.debug_vis = False + for term in self._terms.values(): + self.cfg.debug_vis |= term.cfg.debug_vis + + # reset logging extras (persistent holder for orchestrator aggregation) + self._reset_extras: dict[str, torch.Tensor] = {} + for term_name, term in self._terms.items(): + for metric_name, metric_value in term.reset_extras.items(): + self._reset_extras[f"Metrics/{term_name}/{metric_name}"] = metric_value + + def __str__(self) -> str: + """Returns: A string representation for the command manager.""" + msg = f" contains {len(self._terms.values())} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = "Active Command Terms" + table.field_names = ["Index", "Name", "Type"] + # set alignment of table columns + table.align["Name"] = "l" + # add info on each term + for index, (name, term) in enumerate(self._terms.items()): + table.add_row([index, name, term.__class__.__name__]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_terms(self) -> list[str]: + """Name of active command terms.""" + return list(self._terms.keys()) + + @property + def has_debug_vis_implementation(self) -> bool: + """Whether the command terms have debug visualization implemented.""" + # check if function raises NotImplementedError + has_debug_vis = False + for term in self._terms.values(): + has_debug_vis |= term.has_debug_vis_implementation + return has_debug_vis + + @property + def reset_extras(self) -> dict[str, torch.Tensor]: + """Persistent reset logging extras for command terms.""" + return self._reset_extras + + """ + Operations. + """ + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + + terms = [] + for name, term in self._terms.items(): + command = term.command + if isinstance(command, wp.array): + command = wp.to_torch(command) + terms.append((name, command[env_idx].cpu().tolist())) + return terms + + def set_debug_vis(self, debug_vis: bool): + """Sets whether to visualize the command data. + + Args: + debug_vis: Whether to visualize the command data. + + Returns: + Whether the debug visualization was successfully set. False if the command + generator does not support debug visualization. + """ + for term in self._terms.values(): + term.set_debug_vis(debug_vis) + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, torch.Tensor]: + """Reset the command terms and log their metrics. + + This function resets the command counter and resamples the command for each term. It should be called + at the beginning of each episode. + + Args: + env_ids: The specific environment indices to reset. + If None, all environments are considered. + env_mask: Boolean Warp mask of shape (num_envs,) selecting reset environments. + If provided, takes precedence over ``env_ids``. + + Returns: + A dictionary containing the information to log under the "Metrics/{term_name}/{metric_name}" key. + """ + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + if wp.get_device().is_capturing: + raise RuntimeError( + "CommandManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + for term in self._terms.values(): + # reset the command term + term.reset(env_mask=env_mask) + + return self._reset_extras + + def compute(self, dt: float): + """Updates the commands. + + This function calls each command term managed by the class. + + Args: + dt: The time-step interval of the environment. + + """ + # iterate over all the command terms + for term in self._terms.values(): + # compute term's value + term.compute(dt) + + def get_command(self, name: str) -> torch.Tensor: + """Returns the command for the specified command term. + + Args: + name: The name of the command term. + + Returns: + The command tensor of the specified command term. + """ + command = self._terms[name].command + if isinstance(command, wp.array): + return wp.to_torch(command) + return command + + def get_term(self, name: str) -> CommandTerm: + """Returns the command term with the specified name. + + Args: + name: The name of the command term. + + Returns: + The command term with the specified name. + """ + return self._terms[name] + + """ + Helper functions. + """ + + def _prepare_terms(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, CommandTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type CommandTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # create the action term + term = term_cfg.class_type(term_cfg, self._env) + # sanity check if term is valid type + if not isinstance(term, CommandTerm): + raise TypeError(f"Returned object for the term '{term_name}' is not of type CommandType.") + # pre-build reset extras once for capture-friendly reset logging + term._prepare_reset_extras() + # add class to dict + self._terms[term_name] = term diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py new file mode 100644 index 00000000000..6adc9f055d5 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py @@ -0,0 +1,499 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Event manager for orchestrating operations based on different simulation events (experimental, Warp-first). + +This module mirrors :mod:`isaaclab.managers.event_manager` but removes torch ops from hot paths to enable +CUDA-graph-friendly execution for modes that can be captured (notably ``interval``). + +Key differences from the stable manager: +- ``interval`` and ``reset`` modes are **mask-based** internally and implemented using Warp kernels. +- No ``torch.rand`` / ``nonzero`` / tensor allocations in the ``interval`` apply path. + +Event term signature for Warp-first interval/reset modes: + ``func(env, env_mask_wp, **params) -> None`` + +Other modes (e.g. ``prestartup``, ``startup``) are called using the stable convention: + ``func(env, env_ids, **params) -> None`` +""" + +from __future__ import annotations + +import inspect +import logging +from collections.abc import Sequence + +import torch +import warp as wp +from prettytable import PrettyTable + +from .manager_base import ManagerBase +from .manager_term_cfg import EventTermCfg + +logger = logging.getLogger(__name__) + + +@wp.kernel +def _interval_init_per_env( + time_left: wp.array(dtype=wp.float32), + rng_state: wp.array(dtype=wp.uint32), + lower: wp.float32, + upper: wp.float32, +): + env_id = wp.tid() + s = rng_state[env_id] + time_left[env_id] = wp.randf(s, lower, upper) + rng_state[env_id] = s + + +@wp.kernel +def _interval_init_global( + time_left: wp.array(dtype=wp.float32), + rng_state: wp.array(dtype=wp.uint32), + lower: wp.float32, + upper: wp.float32, +): + # single element + s = rng_state[0] + time_left[0] = wp.randf(s, lower, upper) + rng_state[0] = s + + +@wp.kernel +def _interval_step_per_env( + time_left: wp.array(dtype=wp.float32), + rng_state: wp.array(dtype=wp.uint32), + trigger_mask: wp.array(dtype=wp.bool), + dt: wp.float32, + lower: wp.float32, + upper: wp.float32, +): + env_id = wp.tid() + t = time_left[env_id] - dt + if t < wp.float32(1.0e-6): + trigger_mask[env_id] = True + s = rng_state[env_id] + time_left[env_id] = wp.randf(s, lower, upper) + rng_state[env_id] = s + else: + trigger_mask[env_id] = False + time_left[env_id] = t + + +@wp.kernel +def _interval_step_global( + time_left: wp.array(dtype=wp.float32), + rng_state: wp.array(dtype=wp.uint32), + trigger_flag: wp.array(dtype=wp.bool), + dt: wp.float32, + lower: wp.float32, + upper: wp.float32, +): + t = time_left[0] - dt + if t < wp.float32(1.0e-6): + trigger_flag[0] = True + s = rng_state[0] + time_left[0] = wp.randf(s, lower, upper) + rng_state[0] = s + else: + trigger_flag[0] = False + time_left[0] = t + + +@wp.kernel +def _interval_reset_selected( + env_mask: wp.array(dtype=wp.bool), + time_left: wp.array(dtype=wp.float32), + rng_state: wp.array(dtype=wp.uint32), + lower: wp.float32, + upper: wp.float32, +): + env_id = wp.tid() + if env_mask[env_id]: + s = rng_state[env_id] + time_left[env_id] = wp.randf(s, lower, upper) + rng_state[env_id] = s + + +@wp.kernel +def _seed_global_rng_from_env_rng( + env_rng_state: wp.array(dtype=wp.uint32), + global_rng_state: wp.array(dtype=wp.uint32), +): + global_rng_state[0] = wp.rand_init(wp.int32(env_rng_state[0]), wp.int32(0)) + + +@wp.kernel +def _reset_compute_valid_mask( + in_mask: wp.array(dtype=wp.bool), + last_triggered_step: wp.array(dtype=wp.int32), + triggered_once: wp.array(dtype=wp.bool), + out_mask: wp.array(dtype=wp.bool), + global_step_count_buf: wp.array(dtype=wp.int32), + min_step_count: wp.int32, +): + env_id = wp.tid() + if not in_mask[env_id]: + out_mask[env_id] = False + return + + global_step_count = global_step_count_buf[0] + if min_step_count == wp.int32(0): + out_mask[env_id] = True + last_triggered_step[env_id] = global_step_count + triggered_once[env_id] = True + return + + last = last_triggered_step[env_id] + once = triggered_once[env_id] + steps_since = global_step_count - last + valid = steps_since >= min_step_count + # Trigger at least once at the start (matching stable behavior). + valid = valid or ((last == wp.int32(0)) and (not once)) + out_mask[env_id] = valid + if valid: + last_triggered_step[env_id] = global_step_count + triggered_once[env_id] = True + + +class EventManager(ManagerBase): + """Manager for orchestrating operations based on different simulation events (Warp-first for interval/reset).""" + + def __init__(self, cfg: object, env): + # create buffers to parse and store terms + self._mode_term_names: dict[str, list[str]] = {} + self._mode_term_cfgs: dict[str, list[EventTermCfg]] = {} + self._mode_class_term_cfgs: dict[str, list[EventTermCfg]] = {} + + # Warp buffers for interval/reset modes (populated in _prepare_terms) + self._interval_term_time_left_wp: list[wp.array] = [] + self._interval_term_ranges: list[tuple[float, float]] = [] + self._interval_term_is_global: list[bool] = [] + # Scalar RNG state for global interval timers (allocated lazily if needed). + self._interval_global_rng_state_wp: wp.array | None = None + + self._reset_term_last_triggered_step_wp: list[wp.array] = [] + self._reset_term_triggered_once_wp: list[wp.array] = [] + + super().__init__(cfg, env) + + # persistent scratch mask for per-term interval/reset triggering (must be stable pointer for capture) + self._scratch_term_mask_wp = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + + # scratch scalar flag & broadcast view for global interval triggering (no per-term masks) + self._scratch_interval_trigger_flag_wp = wp.zeros((1,), dtype=wp.bool, device=self.device) + self._scratch_interval_trigger_mask_view_wp = wp.array( + ptr=self._scratch_interval_trigger_flag_wp.ptr, + dtype=wp.bool, + shape=(self.num_envs,), + strides=(0,), + capacity=self._scratch_interval_trigger_flag_wp.capacity, + device=self._scratch_interval_trigger_flag_wp.device, + copy=False, + ) + + def __str__(self) -> str: + msg = f" contains {len(self._mode_term_names)} active terms.\n" + for mode in self._mode_term_names: + table = PrettyTable() + table.title = f"Active Event Terms in Mode: '{mode}'" + if mode == "interval": + table.field_names = ["Index", "Name", "Interval time range (s)"] + table.align["Name"] = "l" + for index, (name, cfg) in enumerate(zip(self._mode_term_names[mode], self._mode_term_cfgs[mode])): + table.add_row([index, name, cfg.interval_range_s]) + else: + table.field_names = ["Index", "Name"] + table.align["Name"] = "l" + for index, name in enumerate(self._mode_term_names[mode]): + table.add_row([index, name]) + msg += table.get_string() + msg += "\n" + return msg + + @property + def active_terms(self) -> dict[str, list[str]]: + return self._mode_term_names + + @property + def available_modes(self) -> list[str]: + return list(self._mode_term_names.keys()) + + def set_term_cfg(self, term_name: str, cfg: EventTermCfg): + term_found = False + for mode, terms in self._mode_term_names.items(): + if term_name in terms: + self._mode_term_cfgs[mode][terms.index(term_name)] = cfg + term_found = True + break + if not term_found: + raise ValueError(f"Event term '{term_name}' not found.") + + def get_term_cfg(self, term_name: str) -> EventTermCfg: + for mode, terms in self._mode_term_names.items(): + if term_name in terms: + return self._mode_term_cfgs[mode][terms.index(term_name)] + raise ValueError(f"Event term '{term_name}' not found.") + + def reset( + self, + env_ids: Sequence[int] | slice | torch.Tensor | wp.array | None = None, + *, + env_mask: wp.array | torch.Tensor | None = None, + ) -> dict[str, float]: + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + # Keep all id->mask resolution strictly outside capture. + if wp.get_device().is_capturing: + raise RuntimeError( + "EventManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + # reset class terms (mask-based) + for mode_cfg in self._mode_class_term_cfgs.values(): + for term_cfg in mode_cfg: + term_cfg.func.reset(env_mask=env_mask) + + # reset interval timers for non-global interval events + if "interval" in self._mode_term_cfgs: + for i, term_cfg in enumerate(self._mode_term_cfgs["interval"]): + if term_cfg.is_global_time: + continue + lower, upper = self._interval_term_ranges[i] + wp.launch( + kernel=_interval_reset_selected, + dim=self.num_envs, + inputs=[ + env_mask, + self._interval_term_time_left_wp[i], + self._env.rng_state_wp, + float(lower), + float(upper), + ], + device=self.device, + ) + return {} + + def apply( + self, + mode: str, + env_ids: Sequence[int] | slice | torch.Tensor | wp.array | None = None, + dt: float | None = None, + global_env_step_count: wp.array | None = None, + *, + env_mask_wp: wp.array | None = None, + ): + if mode not in self._mode_term_names: + logger.warning(f"Event mode '{mode}' is not defined. Skipping event.") + return + + # SceneEntityCfg-dependent term params should be resolved before entering captured event paths. + if not self._is_scene_entities_resolved: + if wp.get_device().is_capturing: + raise RuntimeError( + "EventManager terms are unresolved during CUDA graph capture. " + "Resolve terms before entering captured event paths." + ) + if self._env.sim.is_playing(): + self._resolve_terms_callback(None) + + if mode == "interval": + if dt is None: + raise ValueError(f"Event mode '{mode}' requires the time-step of the environment.") + if env_ids is not None: + raise ValueError( + f"Event mode '{mode}' does not require environment indices. This is an undefined behavior." + ) + self._apply_interval(float(dt)) + return + + if mode == "reset": + if global_env_step_count is None: + raise ValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.") + if env_mask_wp is None: + if wp.get_device().is_capturing: + raise ValueError( + f"Event mode '{mode}' requires the environment mask to be provided when capturing." + ) + env_mask_wp = self._env.resolve_env_mask(env_ids=env_ids) + self._apply_reset(env_mask_wp, global_env_step_count) + return + + # other modes keep the stable convention (env_ids forwarded) + for term_cfg in self._mode_term_cfgs[mode]: + term_cfg.func(self._env, env_ids, **term_cfg.params) + + def _apply_interval(self, dt: float) -> None: + if self._env.rng_state_wp is None: + raise RuntimeError("EventManager._apply_interval: env.rng_state_wp is not initialized.") + + # iterate over all the interval terms (fixed list; captured graph-friendly) + for i, term_cfg in enumerate(self._mode_term_cfgs["interval"]): + lower, upper = self._interval_term_ranges[i] + if self._interval_term_is_global[i]: + if self._interval_global_rng_state_wp is None: + raise RuntimeError( + "EventManager._apply_interval: _interval_global_rng_state_wp is not initialized." + ) + # update scalar time_left and scalar flag (mask is a broadcast view of the flag) + wp.launch( + kernel=_interval_step_global, + dim=1, + inputs=[ + self._interval_term_time_left_wp[i], + self._interval_global_rng_state_wp, + self._scratch_interval_trigger_flag_wp, + float(dt), + float(lower), + float(upper), + ], + device=self.device, + ) + term_cfg.func(self._env, self._scratch_interval_trigger_mask_view_wp, **term_cfg.params) + else: + wp.launch( + kernel=_interval_step_per_env, + dim=self.num_envs, + inputs=[ + self._interval_term_time_left_wp[i], + self._env.rng_state_wp, + self._scratch_term_mask_wp, + float(dt), + float(lower), + float(upper), + ], + device=self.device, + ) + term_cfg.func(self._env, self._scratch_term_mask_wp, **term_cfg.params) + + def _apply_reset(self, env_mask_wp: wp.array, global_env_step_count_wp: wp.array) -> None: + if self._scratch_term_mask_wp is None: + raise RuntimeError("EventManager._apply_reset: _scratch_term_mask_wp is not initialized.") + + # iterate over all the reset terms + for index, term_cfg in enumerate(self._mode_term_cfgs["reset"]): + min_step_count = int(term_cfg.min_step_count_between_reset) + wp.launch( + kernel=_reset_compute_valid_mask, + dim=self.num_envs, + inputs=[ + env_mask_wp, + self._reset_term_last_triggered_step_wp[index], + self._reset_term_triggered_once_wp[index], + self._scratch_term_mask_wp, + global_env_step_count_wp, + int(min_step_count), + ], + device=self.device, + ) + term_cfg.func(self._env, self._scratch_term_mask_wp, **term_cfg.params) + + def _prepare_terms(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # iterate over all the terms + for term_name, term_cfg in cfg_items: + if term_cfg is None: + continue + if not isinstance(term_cfg, EventTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type EventTermCfg. Received:" + f" '{type(term_cfg)}'." + ) + + if term_cfg.mode != "reset" and term_cfg.min_step_count_between_reset != 0: + logger.warning( + f"Event term '{term_name}' has 'min_step_count_between_reset' set to a non-zero value" + " but the mode is not 'reset'. Ignoring the 'min_step_count_between_reset' value." + ) + + # resolve common parameters + self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) + + # check if mode is pre-startup and scene replication is enabled + if term_cfg.mode == "prestartup" and self._env.scene.cfg.replicate_physics: + raise RuntimeError( + "Scene replication is enabled, which may affect USD-level randomization." + " When assets are replicated, their properties are shared across instances," + " potentially leading to unintended behavior." + " For stable USD-level randomization, please disable scene replication" + " by setting 'replicate_physics' to False in 'InteractiveSceneCfg'." + ) + + # for prestartup callable class terms, initialize early (stable behavior) + if inspect.isclass(term_cfg.func) and term_cfg.mode == "prestartup": + logger.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.") + term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env) + + # ensure mode buckets exist + if term_cfg.mode not in self._mode_term_names: + self._mode_term_names[term_cfg.mode] = [] + self._mode_term_cfgs[term_cfg.mode] = [] + self._mode_class_term_cfgs[term_cfg.mode] = [] + # add term name and parameters + self._mode_term_names[term_cfg.mode].append(term_name) + self._mode_term_cfgs[term_cfg.mode].append(term_cfg) + + if inspect.isclass(term_cfg.func): + self._mode_class_term_cfgs[term_cfg.mode].append(term_cfg) + + # per-mode Warp buffers + if term_cfg.mode == "interval": + if term_cfg.interval_range_s is None: + raise ValueError( + f"Event term '{term_name}' has mode 'interval' but 'interval_range_s' is not specified." + ) + lower, upper = term_cfg.interval_range_s + self._interval_term_ranges.append((float(lower), float(upper))) + + if term_cfg.is_global_time: + # allocate and seed scalar global RNG state if needed (avoid consuming env0 RNG stream) + if self._interval_global_rng_state_wp is None: + if self._env.rng_state_wp is None: + raise RuntimeError("EventManager._prepare_terms: env.rng_state_wp is not initialized.") + self._interval_global_rng_state_wp = wp.zeros((1,), dtype=wp.uint32, device=self.device) + wp.launch( + kernel=_seed_global_rng_from_env_rng, + dim=1, + inputs=[self._env.rng_state_wp, self._interval_global_rng_state_wp], + device=self.device, + ) + time_left = wp.zeros((1,), dtype=wp.float32, device=self.device) + wp.launch( + kernel=_interval_init_global, + dim=1, + inputs=[time_left, self._interval_global_rng_state_wp, float(lower), float(upper)], + device=self.device, + ) + self._interval_term_time_left_wp.append(time_left) + self._interval_term_is_global.append(True) + else: + time_left = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + wp.launch( + kernel=_interval_init_per_env, + dim=self.num_envs, + inputs=[time_left, self._env.rng_state_wp, float(lower), float(upper)], + device=self.device, + ) + self._interval_term_time_left_wp.append(time_left) + self._interval_term_is_global.append(False) + + elif term_cfg.mode == "reset": + if term_cfg.min_step_count_between_reset < 0: + raise ValueError( + f"Event term '{term_name}' has mode 'reset' but 'min_step_count_between_reset' is" + f" negative: {term_cfg.min_step_count_between_reset}. Please provide a non-negative value." + ) + # per-env last-trigger bookkeeping (Warp) + self._reset_term_last_triggered_step_wp.append( + wp.zeros((self.num_envs,), dtype=wp.int32, device=self.device) + ) + self._reset_term_triggered_once_wp.append(wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device)) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py new file mode 100644 index 00000000000..e4abb1fa2c7 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py @@ -0,0 +1,446 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Base classes for managers (experimental). + +This file is a local copy of :mod:`isaaclab.managers.manager_base` placed under +``isaaclab_experimental`` so it can evolve independently for Warp-first / graph-friendly +pipelines. + +Key differences from the stable version: +- :meth:`ManagerTermBase.reset` is **mask-based** (preferred for capture-friendly subset operations). +""" + +from __future__ import annotations + +import contextlib +import copy +import inspect +import logging +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import warp as wp + +import isaaclab.utils.string as string_utils +from isaaclab.utils import class_to_dict, string_to_callable + +from .manager_term_cfg import ManagerTermBaseCfg +from .scene_entity_cfg import SceneEntityCfg + +# import omni.timeline + + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + +# import logger +logger = logging.getLogger(__name__) + + +class ManagerTermBase(ABC): + """Base class for manager terms. + + Manager term implementations can be functions or classes. If the term is a class, it should + inherit from this base class and implement the required methods. + + Each manager is implemented as a class that inherits from the :class:`ManagerBase` class. Each manager + class should also have a corresponding configuration class that defines the configuration terms for the + manager. Each term should the :class:`ManagerTermBaseCfg` class or its subclass. + + Example pseudo-code for creating a manager: + + .. code-block:: python + + from isaaclab.utils import configclass + from isaaclab.utils.mdp import ManagerBase, ManagerTermBaseCfg + + + @configclass + class MyManagerCfg: + my_term_1: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + my_term_2: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + my_term_3: ManagerTermBaseCfg = ManagerTermBaseCfg(...) + + + # define manager instance + my_manager = ManagerBase(cfg=ManagerCfg(), env=env) + + """ + + def __init__(self, cfg: ManagerTermBaseCfg, env: ManagerBasedEnv): + """Initialize the manager term. + + Args: + cfg: The configuration object. + env: The environment instance. + """ + # store the inputs + self.cfg = cfg + self._env = env + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + def __name__(self) -> str: + """Return the name of the class or subclass.""" + return self.__class__.__name__ + + """ + Operations. + """ + + def reset(self, env_mask: wp.array | None = None) -> None: + """Resets the manager term (mask-based). + + Args: + env_mask: Boolean mask of shape (num_envs,) indicating which envs to reset. + If None, all envs are considered. + """ + pass + + def serialize(self) -> dict: + """General serialization call. Includes the configuration dict.""" + return {"cfg": class_to_dict(self.cfg)} + + def __call__(self, *args) -> Any: + """Returns the value of the term required by the manager. + + In case of a class implementation, this function is called by the manager + to get the value of the term. The arguments passed to this function are + the ones specified in the term configuration (see :attr:`ManagerTermBaseCfg.params`). + + .. attention:: + To be consistent with memory-less implementation of terms with functions, it is + recommended to ensure that the returned mutable quantities are cloned before + returning them. For instance, if the term returns a tensor, it is recommended + to ensure that the returned tensor is a clone of the original tensor. This prevents + the manager from storing references to the tensors and altering the original tensors. + + Args: + *args: Variable length argument list. + + Returns: + The value of the term. + """ + raise NotImplementedError("The method '__call__' should be implemented by the subclass.") + + +class ManagerBase(ABC): + """Base class for all managers.""" + + def __init__(self, cfg: object, env: ManagerBasedEnv): + """Initialize the manager. + + This function is responsible for parsing the configuration object and creating the terms. + + If the simulation is not playing, the scene entities are not resolved immediately. + Instead, the resolution is deferred until the simulation starts. This is done to ensure + that the scene entities are resolved even if the manager is created after the simulation + has already started. + + Args: + cfg: The configuration object. If None, the manager is initialized without any terms. + env: The environment instance. + """ + # store the inputs + self.cfg = copy.deepcopy(cfg) + self._env = env + + # flag for whether the scene entities have been resolved + # if sim is playing, we resolve the scene entities directly while preparing the terms + self._is_scene_entities_resolved = self._env.sim.is_playing() + + # if the simulation is not playing, we use callbacks to trigger the resolution of the scene + # entities configuration. this is needed for cases where the manager is created after the simulation + # but before the simulation is playing. + # FIXME: Once Isaac Sim supports storing this information as USD schema, we can remove this + # callback and resolve the scene entities directly inside `_prepare_terms`. + # if not self._env.sim.is_playing(): + # # note: Use weakref on all callbacks to ensure that this object can be deleted when its destructor + # # is called + # # The order is set to 20 to allow asset/sensor initialization to complete before the scene entities + # # are resolved. Those have the order 10. + # timeline_event_stream = omni.timeline.get_timeline_interface().get_timeline_event_stream() + # self._resolve_terms_handle = timeline_event_stream.create_subscription_to_pop_by_type( + # int(omni.timeline.TimelineEventType.PLAY), + # lambda event, obj=weakref.proxy(self): obj._resolve_terms_callback(event), + # order=20, + # ) + # else: + # self._resolve_terms_handle = None + self._resolve_terms_handle = None + + # parse config to create terms information + if self.cfg: + self._prepare_terms() + + def __del__(self): + """Delete the manager.""" + # Suppress errors during Python shutdown + # Note: contextlib may be None during interpreter shutdown + if contextlib is not None: + with contextlib.suppress(ImportError, AttributeError, TypeError): + if getattr(self, "_resolve_terms_handle", None): + self._resolve_terms_handle.unsubscribe() + self._resolve_terms_handle = None + + """ + Properties. + """ + + @property + def num_envs(self) -> int: + """Number of environments.""" + return self._env.num_envs + + @property + def device(self) -> str: + """Device on which to perform computations.""" + return self._env.device + + @property + @abstractmethod + def active_terms(self) -> list[str] | dict[str, list[str]]: + """Name of active terms.""" + raise NotImplementedError + + """ + Operations. + """ + + def reset(self, env_ids: Sequence[int] | None = None, env_mask: wp.array | None = None) -> dict[str, float]: + """Resets the manager and returns logging information for the current time-step. + + Args: + env_ids: The environment ids for which to log data. + Defaults None, which logs data for all environments. + + Returns: + Dictionary containing the logging information. + """ + return {} + + def find_terms(self, name_keys: str | Sequence[str]) -> list[str]: + """Find terms in the manager based on the names. + + This function searches the manager for terms based on the names. The names can be + specified as regular expressions or a list of regular expressions. The search is + performed on the active terms in the manager. + + Please check the :meth:`~isaaclab.utils.string_utils.resolve_matching_names` function for more + information on the name matching. + + Args: + name_keys: A regular expression or a list of regular expressions to match the term names. + + Returns: + A list of term names that match the input keys. + """ + # resolve search keys + if isinstance(self.active_terms, dict): + list_of_strings = [] + for names in self.active_terms.values(): + list_of_strings.extend(names) + else: + list_of_strings = self.active_terms + + # return the matching names + return string_utils.resolve_matching_names(name_keys, list_of_strings)[1] + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Returns: + The active terms. + """ + raise NotImplementedError + + """ + Implementation specific. + """ + + @abstractmethod + def _prepare_terms(self): + """Prepare terms information from the configuration object.""" + raise NotImplementedError + + """ + Internal callbacks. + """ + + def _resolve_terms_callback(self, event): + """Resolve configurations of terms once the simulation starts. + + Please check the :meth:`_process_term_cfg_at_play` method for more information. + """ + # check if scene entities have been resolved + if self._is_scene_entities_resolved: + return + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + self._process_term_cfg_at_play(term_name, term_cfg) + + # set the flag + self._is_scene_entities_resolved = True + + """ + Internal functions. + """ + + def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, min_argc: int = 1): + """Resolve common attributes of the term configuration. + + Usually, called by the :meth:`_prepare_terms` method to resolve common attributes of the term + configuration. These include: + + * Resolving the term function and checking if it is callable. + * Checking if the term function's arguments are matched by the parameters. + * Resolving special attributes of the term configuration like ``asset_cfg``, ``sensor_cfg``, etc. + * Initializing the term if it is a class. + + The last two steps are only possible once the simulation starts playing. + + By default, all term functions are expected to have at least one argument, which is the + environment object. Some other managers may expect functions to take more arguments, for + instance, the environment indices as the second argument. In such cases, the + ``min_argc`` argument can be used to specify the minimum number of arguments + required by the term function to be called correctly by the manager. + + Args: + term_name: The name of the term. + term_cfg: The term configuration. + min_argc: The minimum number of arguments required by the term function to be called correctly + by the manager. + + Raises: + TypeError: If the term configuration is not of type :class:`ManagerTermBaseCfg`. + ValueError: If the scene entity defined in the term configuration does not exist. + AttributeError: If the term function is not callable. + ValueError: If the term function's arguments are not matched by the parameters. + """ + # check if the term is a valid term config + if not isinstance(term_cfg, ManagerTermBaseCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ManagerTermBaseCfg." + f" Received: '{type(term_cfg)}'." + ) + + # get the corresponding function or functional class + if isinstance(term_cfg.func, str): + term_cfg.func = string_to_callable(term_cfg.func) + # check if function is callable + if not callable(term_cfg.func): + raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") + + # check if the term is a class of valid type + if inspect.isclass(term_cfg.func): + if not issubclass(term_cfg.func, ManagerTermBase): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ManagerTermBase." + f" Received: '{type(term_cfg.func)}'." + ) + func_static = term_cfg.func.__call__ + min_argc += 1 # forward by 1 to account for 'self' argument + else: + func_static = term_cfg.func + # check if function is callable + if not callable(func_static): + raise AttributeError(f"The term '{term_name}' is not callable. Received: {term_cfg.func}") + + # Materialize configclass defaults from the function signature into params. + # Without this, defaults live only in the callable signature and never get + # resolved/cached by the manager (e.g. SceneEntityCfg.resolve() is never called). + signature = inspect.signature(func_static) + for param in list(signature.parameters.values())[min_argc:]: + if param.default is inspect.Parameter.empty: + continue + if param.name not in term_cfg.params and hasattr(param.default, "__dataclass_fields__"): + term_cfg.params[param.name] = param.default.copy() + + # check statically if the term's arguments are matched by params + term_params = list(term_cfg.params.keys()) + args = signature.parameters + args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty] + args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty] + args = args_without_defaults + args_with_defaults + # ignore first two arguments for env and env_ids + # Think: Check for cases when kwargs are set inside the function? + if len(args) > min_argc: + if set(args[min_argc:]) != set(term_params + args_with_defaults): + raise ValueError( + f"The term '{term_name}' expects mandatory parameters: {args_without_defaults[min_argc:]}" + f" and optional parameters: {args_with_defaults}, but received: {term_params}." + ) + + # process attributes at runtime + # these properties are only resolvable once the simulation starts playing + if self._env.sim.is_playing(): + self._process_term_cfg_at_play(term_name, term_cfg) + + def _process_term_cfg_at_play(self, term_name: str, term_cfg: ManagerTermBaseCfg): + """Process the term configuration at runtime. + + This function is called when the simulation starts playing. It is used to process the term + configuration at runtime. This includes: + + * Resolving the scene entity configuration for the term. + * Initializing the term if it is a class. + + Since the above steps rely on PhysX to parse over the simulation scene, they are deferred + until the simulation starts playing. + + Args: + term_name: The name of the term. + term_cfg: The term configuration. + """ + for key, value in term_cfg.params.items(): + if isinstance(value, SceneEntityCfg): + # load the entity + try: + value.resolve(self._env.scene) + except ValueError as e: + raise ValueError(f"Error while parsing '{term_name}:{key}'. {e}") + # log the entity for checking later + msg = f"[{term_cfg.__class__.__name__}:{term_name}] Found entity '{value.name}'." + if value.joint_ids is not None: + msg += f"\n\tJoint names: {value.joint_names} [{value.joint_ids}]" + if value.body_ids is not None: + msg += f"\n\tBody names: {value.body_names} [{value.body_ids}]" + # print the information + logger.info(msg) + # store the entity + term_cfg.params[key] = value + + # initialize the term if it is a class + if inspect.isclass(term_cfg.func): + logger.info(f"Initializing term '{term_name}' with class '{term_cfg.func.__name__}'.") + term_cfg.func = term_cfg.func(cfg=term_cfg, env=self._env) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py new file mode 100644 index 00000000000..3fab3bfc5ff --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_term_cfg.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Configuration terms for different managers (experimental, Warp-first). + +This module is a passthrough to :mod:`isaaclab.managers.manager_term_cfg` except for +the following term configs which are overridden for Warp-first execution: + +- :class:`ObservationTermCfg` +- :class:`RewardTermCfg` +- :class:`TerminationTermCfg` +""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import MISSING + +from isaaclab.managers.manager_term_cfg import * # noqa: F401,F403 +from isaaclab.managers.manager_term_cfg import ManagerTermBaseCfg as _ManagerTermBaseCfg +from isaaclab.utils import configclass + + +@configclass +class RewardTermCfg(_ManagerTermBaseCfg): + """Configuration for a reward term. + + The function is expected to write the (unweighted) reward values into a + pre-allocated Warp buffer provided by the manager. + + Expected signature: + + - ``func(env, out, **params) -> None`` + + where ``out`` is a Warp array of shape ``(num_envs,)`` with float32 dtype. + """ + + func: Callable[..., None] = MISSING + """The function to be called to fill the pre-allocated reward buffer.""" + + weight: float = MISSING + """The weight of the reward term.""" + + +@configclass +class TerminationTermCfg(_ManagerTermBaseCfg): + """Configuration for a termination term (experimental, Warp-first). + + The function is expected to write termination flags into a pre-allocated Warp buffer. + + Expected signature: + + - ``func(env, out, **params) -> None`` + + where ``out`` is a Warp array of shape ``(num_envs,)`` with boolean dtype. + """ + + func: Callable[..., None] = MISSING + """The function to be called to fill the pre-allocated termination buffer.""" + + time_out: bool = False + """Whether the termination term contributes towards episodic timeouts. Defaults to False.""" + + +@configclass +class ObservationTermCfg(_ManagerTermBaseCfg): + """Configuration for an observation term (experimental, Warp-first). + + The function is expected to write observation values into a pre-allocated Warp buffer provided + by the observation manager. + + Expected signature: + + - ``func(env, out, **params) -> None`` + + where ``out`` is a Warp array of shape ``(num_envs, obs_term_dim)`` with float32 dtype. + + Notes: + - The stable fields (noise/modifiers/history) are kept for config compatibility, but the + experimental Warp-first observation manager may not support all of them initially. + """ + + func: Callable[..., None] = MISSING + """The function to be called to fill the pre-allocated observation buffer.""" + + # Keep stable configuration fields for compatibility with existing task configs. + modifiers: list[ModifierCfg] | None = None # noqa: F405 + noise: NoiseCfg | NoiseModelCfg | None = None # noqa: F405 + clip: tuple[float, float] | None = None + scale: tuple[float, ...] | float | None = None + history_length: int = 0 + flatten_history_dim: bool = True diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py new file mode 100644 index 00000000000..f73f30e5a9b --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py @@ -0,0 +1,862 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Observation manager for computing observation signals for a given world. + +Observations are organized into groups based on their intended usage. This allows having different observation +groups for different types of learning such as asymmetric actor-critic and student-teacher training. Each +group contains observation terms which contain information about the observation function to call, the noise +corruption model to use, and the sensor to retrieve data from. + +Each observation group should inherit from the :class:`ObservationGroupCfg` class. Within each group, each +observation term should instantiate the :class:`ObservationTermCfg` class. Based on the configuration, the +observations in a group can be concatenated into a single tensor or returned as a dictionary with keys +corresponding to the term's name. + +If the observations in a group are concatenated, the shape of the concatenated tensor is computed based on the +shapes of the individual observation terms. This information is stored in the :attr:`group_obs_dim` dictionary +with keys as the group names and values as the shape of the observation tensor. When the terms in a group are not +concatenated, the attribute stores a list of shapes for each term in the group. + +.. note:: + When the observation terms in a group do not have the same shape, the observation terms cannot be + concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the + group configuration to False. + +Observations can also have history. This means a running history is updated per sim step. History can be controlled +per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and +:attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg` +where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering. + +The observation manager can be used to compute observations for all the groups or for a specific group. The +observations are computed by calling the registered functions for each term in the group. The functions are +called in the order of the terms in the group. The functions are expected to return a tensor with shape +(num_envs, ...). + +If a noise model or custom modifier is registered for a term, the function is called to corrupt +the observation. The corruption function is expected to return a tensor with the same shape as the observation. +The observations are clipped and scaled as per the configuration settings. + +Experimental (Warp-first) note: + Observation term functions follow a Warp-first signature and **write** into pre-allocated Warp buffers: + ``func(env, out, **params) -> None``. Post-processing may be implemented via Warp kernels where possible. +""" + +from __future__ import annotations + +import inspect +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +import warp as wp +from prettytable import PrettyTable + +from isaaclab.utils import class_to_dict + +from isaaclab_experimental.utils import modifiers, noise +from isaaclab_experimental.utils.buffers import CircularBuffer +from isaaclab_experimental.utils.torch_utils import clone_obs_buffer + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import ObservationGroupCfg, ObservationTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _apply_clip(out: wp.array(dtype=wp.float32, ndim=2), clip_lo: wp.float32, clip_hi: wp.float32): + env_id = wp.tid() + for j in range(out.shape[1]): + out[env_id, j] = wp.clamp(out[env_id, j], clip_lo, clip_hi) + + +@wp.kernel +def _apply_scale(out: wp.array(dtype=wp.float32, ndim=2), scale: wp.array(dtype=wp.float32)): + env_id = wp.tid() + for j in range(out.shape[1]): + out[env_id, j] = out[env_id, j] * scale[j] + + +def _resolve_scale_vector(value: Any, dim: int, device: str) -> torch.Tensor: + """Resolve scale into a (dim,) float32 tensor (defaults to ones).""" + if value is None: + return torch.ones((dim,), device=device, dtype=torch.float32) + if isinstance(value, torch.Tensor): + t = value.to(device=device, dtype=torch.float32) + if t.numel() == 1: + return t.reshape(1).repeat(dim) + if t.numel() == dim: + return t.reshape(dim) + raise ValueError(f"Expected scale tensor with numel=1 or numel={dim}, got {t.numel()}.") + if isinstance(value, (float, int)): + return torch.full((dim,), float(value), device=device, dtype=torch.float32) + if isinstance(value, (tuple, list)): + if len(value) != dim: + raise ValueError(f"Expected scale length {dim}, got {len(value)}.") + return torch.tensor(value, device=device, dtype=torch.float32) + raise TypeError(f"Unsupported scale type: {type(value)}") + + +class ObservationManager(ManagerBase): + """Manager for computing observation signals for a given world. + + Observations are organized into groups based on their intended usage. This allows having different observation + groups for different types of learning such as asymmetric actor-critic and student-teacher training. Each + group contains observation terms which contain information about the observation function to call, the noise + corruption model to use, and the sensor to retrieve data from. + + Each observation group should inherit from the :class:`ObservationGroupCfg` class. Within each group, each + observation term should instantiate the :class:`ObservationTermCfg` class. Based on the configuration, the + observations in a group can be concatenated into a single tensor or returned as a dictionary with keys + corresponding to the term's name. + + If the observations in a group are concatenated, the shape of the concatenated tensor is computed based on the + shapes of the individual observation terms. This information is stored in the :attr:`group_obs_dim` dictionary + with keys as the group names and values as the shape of the observation tensor. When the terms in a group are not + concatenated, the attribute stores a list of shapes for each term in the group. + + .. note:: + When the observation terms in a group do not have the same shape, the observation terms cannot be + concatenated. In this case, please set the :attr:`ObservationGroupCfg.concatenate_terms` attribute in the + group configuration to False. + + Observations can also have history. This means a running history is updated per sim step. History can be controlled + per :class:`ObservationTermCfg` (See the :attr:`ObservationTermCfg.history_length` and + :attr:`ObservationTermCfg.flatten_history_dim`). History can also be controlled via :class:`ObservationGroupCfg` + where group configuration overwrites per term configuration if set. History follows an oldest to newest ordering. + + The observation manager can be used to compute observations for all the groups or for a specific group. The + observations are computed by calling the registered functions for each term in the group. The functions are + called in the order of the terms in the group. The functions are expected to return a tensor with shape + (num_envs, ...). + + If a noise model or custom modifier is registered for a term, the function is called to corrupt + the observation. The corruption function is expected to return a tensor with the same shape as the observation. + The observations are clipped and scaled as per the configuration settings. + + Experimental (Warp-first) note: + Observation term functions follow a Warp-first signature and **write** into pre-allocated Warp buffers: + ``func(env, out, **params) -> None``. + """ + + def __init__(self, cfg: object, env: ManagerBasedEnv): + """Initialize observation manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, ObservationGroupCfg]``). + env: The environment instance. + + Raises: + ValueError: If the configuration is None. + RuntimeError: If the shapes of the observation terms in a group are not compatible for concatenation + and the :attr:`~ObservationGroupCfg.concatenate_terms` attribute is set to True. + """ + if cfg is None: + raise ValueError("Observation manager configuration is None. Please provide a valid configuration.") + + # call the base class constructor (this will parse the terms config) + super().__init__(cfg, env) + + # compute combined vector for obs group (matches stable semantics) + self._group_obs_dim: dict[str, tuple[int, ...] | list[tuple[int, ...]]] = dict() + for group_name, group_term_dims in self._group_obs_term_dim.items(): + # if terms are concatenated, compute the combined shape into a single tuple + # otherwise, keep the list of shapes as is + if self._group_obs_concatenate[group_name]: + try: + term_dims = torch.stack([torch.tensor(dims, device="cpu") for dims in group_term_dims], dim=0) + if len(term_dims.shape) > 1: + if self._group_obs_concatenate_dim[group_name] >= 0: + dim = self._group_obs_concatenate_dim[group_name] - 1 # account for the batch offset + else: + dim = self._group_obs_concatenate_dim[group_name] + dim_sum = torch.sum(term_dims[:, dim], dim=0) + term_dims[0, dim] = dim_sum + term_dims = term_dims[0] + else: + term_dims = torch.sum(term_dims, dim=0) + self._group_obs_dim[group_name] = tuple(term_dims.tolist()) + except RuntimeError: + raise RuntimeError( + f"Unable to concatenate observation terms in group '{group_name}'." + f" The shapes of the terms are: {group_term_dims}." + " Please ensure that the shapes are compatible for concatenation." + " Otherwise, set 'concatenate_terms' to False in the group configuration." + ) + else: + self._group_obs_dim[group_name] = group_term_dims + + # Stores the latest observations. + self._obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] | None = None + # Note: Persistent Warp output buffers (`_group_out_wp` / `_group_out_torch`) and per-term post-processing + # buffers are allocated during `_prepare_terms()` since they are per-term/per-group setup. + + def __str__(self) -> str: + """Returns: A string representation for the observation manager.""" + msg = f" contains {len(self._group_obs_term_names)} groups.\n" + + # add info for each group + for group_name, group_dim in self._group_obs_dim.items(): + # create table for term information + table = PrettyTable() + table.title = f"Active Observation Terms in Group: '{group_name}'" + if self._group_obs_concatenate[group_name]: + table.title += f" (shape: {group_dim})" + table.field_names = ["Index", "Name", "Shape"] + # set alignment of table columns + table.align["Name"] = "l" + # add info for each term + obs_terms = zip( + self._group_obs_term_names[group_name], + self._group_obs_term_dim[group_name], + ) + for index, (name, dims) in enumerate(obs_terms): + # resolve inputs to simplify prints + tab_dims = tuple(dims) + # add row + table.add_row([index, name, tab_dims]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + + if self._obs_buffer is None: + self.compute() + obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]] = self._obs_buffer + + for group_name, _ in self._group_obs_dim.items(): + if not self.group_obs_concatenate[group_name]: + for name, term in obs_buffer[group_name].items(): + terms.append((group_name + "-" + name, term[env_idx].cpu().tolist())) + continue + + idx = 0 + # add info for each term + data = obs_buffer[group_name] + for name, shape in zip( + self._group_obs_term_names[group_name], + self._group_obs_term_dim[group_name], + ): + data_length = np.prod(shape) + term = data[env_idx, idx : idx + data_length] + terms.append((group_name + "-" + name, term.cpu().tolist())) + idx += data_length + + return terms + + """ + Properties. + """ + + @property + def active_terms(self) -> dict[str, list[str]]: + """Name of active observation terms in each group. + + The keys are the group names and the values are the list of observation term names in the group. + """ + return self._group_obs_term_names + + @property + def group_obs_dim(self) -> dict[str, tuple[int, ...] | list[tuple[int, ...]]]: + """Shape of computed observations in each group. + + The key is the group name and the value is the shape of the observation tensor. + If the terms in the group are concatenated, the value is a single tuple representing the + shape of the concatenated observation tensor. Otherwise, the value is a list of tuples, + where each tuple represents the shape of the observation tensor for a term in the group. + """ + return self._group_obs_dim + + @property + def group_obs_term_dim(self) -> dict[str, list[tuple[int, ...]]]: + """Shape of individual observation terms in each group. + + The key is the group name and the value is a list of tuples representing the shape of the observation terms + in the group. The order of the tuples corresponds to the order of the terms in the group. + This matches the order of the terms in the :attr:`active_terms`. + """ + return self._group_obs_term_dim + + @property + def group_obs_concatenate(self) -> dict[str, bool]: + """Whether the observation terms are concatenated in each group or not. + + The key is the group name and the value is a boolean specifying whether the observation terms in the group + are concatenated into a single tensor. If True, the observations are concatenated along the last dimension. + + The values are set based on the :attr:`~ObservationGroupCfg.concatenate_terms` attribute in the group + configuration. + """ + return self._group_obs_concatenate + + @property + def get_IO_descriptors(self, group_names_to_export: list[str] = ["policy"]): + """Get the IO descriptors for the observation manager. + + Returns: + A dictionary with keys as the group names and values as the IO descriptors. + """ + group_data: dict[str, list[dict[str, Any]]] = {} + + # Collect raw descriptor dicts (plus overloads). + for group_name in self._group_obs_term_names: + group_data[group_name] = [] + # check if group name is valid + if group_name not in self._group_obs_term_names: + raise ValueError( + f"Unable to find the group '{group_name}' in the observation manager." + f" Available groups are: {list(self._group_obs_term_names.keys())}" + ) + for term_name, term_cfg in zip( + self._group_obs_term_names[group_name], self._group_obs_term_cfgs[group_name] + ): + func = term_cfg.func + if not getattr(func, "_has_descriptor", False): + continue + try: + # Both stable-style and Warp-first decorated terms support + # the ``inspect=True`` keyword. Warp-first terms (decorated + # with ``generic_io_descriptor_warp``) will NOT execute the + # underlying function; their hooks derive metadata from + # env/config objects instead. + func(self._env, **term_cfg.params, inspect=True) + desc = func._descriptor.__dict__.copy() + overloads = {} + for k in ["modifiers", "clip", "scale", "history_length", "flatten_history_dim"]: + if hasattr(term_cfg, k): + overloads[k] = getattr(term_cfg, k) + desc.update(overloads) + group_data[group_name].append(desc) + except Exception as e: + print(f"Error getting IO descriptor for term '{term_name}' in group '{group_name}': {e}") + + formatted_data: dict[str, list[dict[str, Any]]] = {} + for group_name, data in group_data.items(): + if group_name not in group_names_to_export: + continue + formatted_data[group_name] = [] + for item in data: + name = item.pop("name") + extras = item.pop("extras", {}) + formatted_item = {"name": name, "overloads": {}, "extras": extras} + for k, v in item.items(): + if isinstance(v, tuple): + v = list(v) + if isinstance(v, torch.Tensor): + v = v.detach().cpu().numpy().tolist() + if k in ["scale", "clip", "history_length", "flatten_history_dim"]: + formatted_item["overloads"][k] = v + elif k in ["modifiers", "description", "units"]: + formatted_item["extras"][k] = v + else: + formatted_item[k] = v + formatted_data[group_name].append(formatted_item) + return formatted_data + + """ + Operations. + """ + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, float]: + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + # Keep all id->mask resolution strictly outside capture. + if wp.get_device().is_capturing: + raise RuntimeError( + "ObservationManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + # call all terms that are classes + for group_name, group_cfg in self._group_obs_class_term_cfgs.items(): + for term_cfg in group_cfg: + term_cfg.func.reset(env_mask=env_mask) + # reset terms with history + for term_name in self._group_obs_term_names[group_name]: + if term_name in self._group_obs_term_history_buffer[group_name]: + self._group_obs_term_history_buffer[group_name][term_name].reset(env_mask=env_mask) + # call all modifiers/noise models that are classes + for mod in self._group_obs_class_instances: + mod.reset(env_mask=env_mask) + + # nothing to log here + return {} + + def compute( + self, update_history: bool = False, return_cloned_output: bool = True + ) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Compute the observations per group for all groups. + + The method computes the observations for all the groups handled by the observation manager. + Please check the :meth:`compute_group` on the processing of observations per group. + + Args: + update_history: The boolean indicator without return obs should be appended to observation history. + Default to False, in which case calling compute_group does not modify history. This input is no-ops + if the group's history_length == 0. + return_cloned_output: Whether to return a cloned snapshot of the observation buffer. + Set to False to return the persistent internal buffer by reference. + + Returns: + A dictionary with keys as the group names and values as the computed observations. + The observations are either concatenated into a single tensor or returned as a dictionary + with keys corresponding to the term's name. + """ + # Launch kernels for every group (writes into persistent buffers in-place). + for group_name in self._group_obs_term_names: + self.compute_group(group_name, update_history=update_history) + # Build the obs buffer once (persistent refs to in-place-updated tensors/dicts). + if self._obs_buffer is None: + self._obs_buffer = { + group_name: ( + self._group_out_torch[group_name] + if self._group_use_warp_concat[group_name] + else self._group_obs_dict[group_name] + ) + for group_name in self._group_obs_term_names + } + if return_cloned_output: + return clone_obs_buffer(self._obs_buffer) + return self._obs_buffer + + def compute_group(self, group_name: str, update_history: bool = False) -> torch.Tensor | dict[str, torch.Tensor]: + """Computes the observations for a given group. + + The observations for a given group are computed by calling the registered functions for each + term in the group. The functions are called in the order of the terms in the group. The functions + are expected to return a tensor with shape (num_envs, ...). + + The following steps are performed for each observation term: + + 1. Compute observation term by calling the function + 2. Apply custom modifiers in the order specified in :attr:`ObservationTermCfg.modifiers` + 3. Apply corruption/noise model based on :attr:`ObservationTermCfg.noise` + 4. Apply clipping based on :attr:`ObservationTermCfg.clip` + 5. Apply scaling based on :attr:`ObservationTermCfg.scale` + + We apply noise to the computed term first to maintain the integrity of how noise affects the data + as it truly exists in the real world. If the noise is applied after clipping or scaling, the noise + could be artificially constrained or amplified, which might misrepresent how noise naturally occurs + in the data. + + Args: + group_name: The name of the group for which to compute the observations. Defaults to None, + in which case observations for all the groups are computed and returned. + update_history: The boolean indicator without return obs should be appended to observation group's history. + Default to False, in which case calling compute_group does not modify history. This input is no-ops + if the group's history_length == 0. + + Returns: + Depending on the group's configuration, the tensors for individual observation terms are + concatenated along the last dimension into a single tensor. Otherwise, they are returned as + a dictionary with keys corresponding to the term's name. + + Raises: + ValueError: If input ``group_name`` is not a valid group handled by the manager. + """ + # check if group name is valid + if group_name not in self._group_obs_term_names: + raise ValueError( + f"Unable to find the group '{group_name}' in the observation manager." + f" Available groups are: {list(self._group_obs_term_names.keys())}" + ) + # iterate over all the terms in each group + group_term_names = self._group_obs_term_names[group_name] + + # Persistent per-term obs dict (pre-allocated in _prepare_terms). + group_obs = self._group_obs_dict[group_name] + + # evaluate terms: compute, add noise, clip, scale, custom modifiers + for term_name, term_cfg in zip(group_term_names, self._group_obs_term_cfgs[group_name]): + # compute term's value into pre-allocated Warp output + term_cfg.func(self._env, term_cfg.out_wp, **term_cfg.params) + + # apply custom modifiers (in-place on out_wp) + if term_cfg.modifiers is not None: + for modifier in term_cfg.modifiers: + modifier.func(term_cfg.out_wp, **modifier.params) + + # apply noise (Warp in-place on out_wp) + if isinstance(term_cfg.noise, noise.NoiseCfg): + term_cfg.noise.func(term_cfg.out_wp, term_cfg.noise) + elif isinstance(term_cfg.noise, noise.NoiseModelCfg) and term_cfg.noise.func is not None: + term_cfg.noise.func(term_cfg.out_wp) + + # clip then scale (stable semantics); implementation may use Warp kernels + if term_cfg.clip is not None: + wp.launch( + kernel=_apply_clip, + dim=self.num_envs, + inputs=[term_cfg.out_wp, float(term_cfg.clip[0]), float(term_cfg.clip[1])], + device=self.device, + ) + if term_cfg.scale is not None: + wp.launch( + kernel=_apply_scale, + dim=self.num_envs, + inputs=[term_cfg.out_wp, term_cfg.scale_wp], + device=self.device, + ) + + # TODO(jichuanh): This is not migrated yet. Need revisit. + # Update the history buffer if observation term has history enabled + if term_cfg.history_length > 0: + circular_buffer = self._group_obs_term_history_buffer[group_name][term_name] + if update_history: + circular_buffer.append(wp.to_torch(term_cfg.out_wp)) + elif circular_buffer._buffer is None: + # because circular buffer only exits after the simulation steps, + # this guards history buffer from corruption by external calls before simulation start + circular_buffer = CircularBuffer( + max_len=circular_buffer.max_length, + batch_size=circular_buffer.batch_size, + device=circular_buffer.device, + ) + self._group_obs_term_history_buffer[group_name][term_name] = circular_buffer + circular_buffer.append(wp.to_torch(term_cfg.out_wp)) + + if term_cfg.flatten_history_dim: + group_obs[term_name] = circular_buffer.buffer.reshape(self._env.num_envs, -1) + else: + group_obs[term_name] = circular_buffer.buffer + + # return persistent output (updated in-place by kernels above) + if self._group_use_warp_concat[group_name]: + return self._group_out_torch[group_name] + return group_obs + + def serialize(self) -> dict: + """Serialize the observation term configurations for all active groups. + + Returns: + A dictionary where each group name maps to its serialized observation term configurations. + """ + output = { + group_name: { + term_name: ( + term_cfg.func.serialize() + if isinstance(term_cfg.func, ManagerTermBase) + else {"cfg": class_to_dict(term_cfg)} + ) + for term_name, term_cfg in zip( + self._group_obs_term_names[group_name], + self._group_obs_term_cfgs[group_name], + ) + } + for group_name in self.active_terms.keys() + } + + return output + + """ + Helper functions. + """ + + def _prepare_terms(self): # noqa: C901 + """Prepares a list of observation terms functions.""" + # create buffers to store information for each observation group + # TODO: Make this more convenient by using data structures. + self._group_obs_term_names: dict[str, list[str]] = dict() + self._group_obs_term_dim: dict[str, list[tuple[int, ...]]] = dict() + self._group_obs_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() + self._group_obs_class_term_cfgs: dict[str, list[ObservationTermCfg]] = dict() + self._group_obs_concatenate: dict[str, bool] = dict() + self._group_obs_concatenate_dim: dict[str, int] = dict() + + self._group_obs_term_history_buffer: dict[str, dict] = dict() + # create a list to store classes instances, e.g., for modifiers and noise models + # we store it as a separate list to only call reset on them and prevent unnecessary calls + self._group_obs_class_instances: list[modifiers.ModifierBase | noise.NoiseModel] = list() + + # Persistent Warp output buffers for concatenated 2D groups (optional fast-path). + # For other cases (non-concat groups, history outputs, non-2D concat dims), we allocate per-term outputs. + self._group_out_wp: dict[str, wp.array] = {} + self._group_out_torch: dict[str, torch.Tensor] = {} + self._group_use_warp_concat: dict[str, bool] = {} + self._group_obs_dict: dict[str, dict[str, torch.Tensor]] = {} + + # make sure the simulation is playing since we compute obs dims which needs asset quantities + if not self._env.sim.is_playing(): + raise RuntimeError( + "Simulation is not playing. Observation manager requires the simulation to be playing" + " to compute observation dimensions. Please start the simulation before using the" + " observation manager." + ) + + # check if config is dict already + if isinstance(self.cfg, dict): + group_cfg_items = self.cfg.items() + else: + group_cfg_items = self.cfg.__dict__.items() + # iterate over all the groups + for group_name, group_cfg in group_cfg_items: + # check for non config + if group_cfg is None: + continue + # check if the term is a curriculum term + if not isinstance(group_cfg, ObservationGroupCfg): + raise TypeError( + f"Observation group '{group_name}' is not of type 'ObservationGroupCfg'." + f" Received: '{type(group_cfg)}'." + ) + # initialize list for the group settings + # group name to list of group term names + self._group_obs_term_names[group_name] = list() + + self._group_obs_term_dim[group_name] = list() + self._group_obs_term_cfgs[group_name] = list() + self._group_obs_class_term_cfgs[group_name] = list() + group_entry_history_buffer: dict[str, CircularBuffer] = dict() + # read common config for the group + self._group_obs_concatenate[group_name] = group_cfg.concatenate_terms + # to account for the batch dimension + self._group_obs_concatenate_dim[group_name] = ( + group_cfg.concatenate_dim + 1 if group_cfg.concatenate_dim >= 0 else group_cfg.concatenate_dim + ) + # check if config is dict already + if isinstance(group_cfg, dict): + term_cfg_items = group_cfg.items() + else: + term_cfg_items = group_cfg.__dict__.items() + # iterate over all the terms in each group + # (we also track raw term dims for Warp output allocation) + group_term_cfgs: list[ObservationTermCfg] = [] + group_term_raw_dims: list[int] = [] + for term_name, term_cfg in term_cfg_items: + # skip non-obs settings + if term_name in [ + "enable_corruption", + "concatenate_terms", + "history_length", + "flatten_history_dim", + "concatenate_dim", + ]: + continue + # check for non config + if term_cfg is None: + continue + if not isinstance(term_cfg, ObservationTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type ObservationTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # resolve common terms in the config + # Warp-first signature is (env, out, **params) + self._resolve_common_term_cfg(f"{group_name}/{term_name}", term_cfg, min_argc=2) + + # check noise settings + if not group_cfg.enable_corruption: + term_cfg.noise = None + # check group history params and override terms + if group_cfg.history_length is not None: + term_cfg.history_length = group_cfg.history_length + term_cfg.flatten_history_dim = group_cfg.flatten_history_dim + # add term config to list + self._group_obs_term_names[group_name].append(term_name) + self._group_obs_term_cfgs[group_name].append(term_cfg) + + # infer dimensions (Warp-first: terms write to out; we infer dim from resolved scene info) + term_dim = self._infer_term_dim_scalar(term_cfg) + # Cache the "raw" term output dimension (before history reshaping) for Warp buffer allocation. + # This matches the tensor shape produced directly by the term into `out`: (num_envs, term_dim). + term_cfg._term_dim = int(term_dim) + group_term_cfgs.append(term_cfg) + group_term_raw_dims.append(int(term_dim)) + obs_dims = (self._env.num_envs, term_dim) + + # if scale is set, check if single float or tuple + if term_cfg.scale is not None: + if not isinstance(term_cfg.scale, (float, int, tuple)): + raise TypeError( + f"Scale for observation term '{term_name}' in group '{group_name}'" + f" is not of type float, int or tuple. Received: '{type(term_cfg.scale)}'." + ) + if isinstance(term_cfg.scale, tuple) and len(term_cfg.scale) != obs_dims[1]: + raise ValueError( + f"Scale for observation term '{term_name}' in group '{group_name}'" + f" does not match the dimensions of the observation. Expected: {obs_dims[1]}" + f" but received: {len(term_cfg.scale)}." + ) + + # cast the scale into torch tensor + term_cfg.scale = torch.tensor(term_cfg.scale, dtype=torch.float, device=self._env.device) + term_cfg.scale_wp = wp.from_torch(term_cfg.scale, dtype=wp.float32) + + # prepare modifiers for each observation + if term_cfg.modifiers is not None: + # initialize list of modifiers for term + for mod_cfg in term_cfg.modifiers: + # check if class modifier and initialize with observation size when adding + if isinstance(mod_cfg, modifiers.ModifierCfg): + # to list of modifiers + if inspect.isclass(mod_cfg.func): + if not issubclass(mod_cfg.func, modifiers.ModifierBase): + raise TypeError( + f"Modifier function '{mod_cfg.func}' for observation term '{term_name}'" + f" is not a subclass of 'ModifierBase'. Received: '{type(mod_cfg.func)}'." + ) + mod_cfg.func = mod_cfg.func(cfg=mod_cfg, data_dim=obs_dims, device=self._env.device) + + # add to list of class modifiers + self._group_obs_class_instances.append(mod_cfg.func) + else: + raise TypeError( + f"Modifier configuration '{mod_cfg}' of observation term '{term_name}' is not of" + f" required type ModifierCfg, Received: '{type(mod_cfg)}'" + ) + + # check if function is callable + if not callable(mod_cfg.func): + raise AttributeError( + f"Modifier '{mod_cfg}' of observation term '{term_name}' is not callable." + f" Received: {mod_cfg.func}" + ) + + # check if term's arguments are matched by params + term_params = list(mod_cfg.params.keys()) + args = inspect.signature(mod_cfg.func).parameters + args_with_defaults = [arg for arg in args if args[arg].default is not inspect.Parameter.empty] + args_without_defaults = [arg for arg in args if args[arg].default is inspect.Parameter.empty] + args = args_without_defaults + args_with_defaults + # ignore first argument for data + if len(args) > 1: + if set(args[1:]) != set(term_params + args_with_defaults): + raise ValueError( + f"Modifier '{mod_cfg}' of observation term '{term_name}' expects" + f" mandatory parameters: {args_without_defaults[1:]}" + f" and optional parameters: {args_with_defaults}, but received: {term_params}." + ) + + # prepare noise model classes + if term_cfg.noise is not None and isinstance(term_cfg.noise, noise.NoiseModelCfg): + # plumb the shared per-env RNG state so Warp noise kernels can consume it + term_cfg.noise.rng_state_wp = self._env.rng_state_wp + noise_model_cls = term_cfg.noise.class_type + if not issubclass(noise_model_cls, noise.NoiseModel): + raise TypeError( + f"Class type for observation term '{term_name}' NoiseModelCfg" + f" is not a subclass of 'NoiseModel'. Received: '{type(noise_model_cls)}'." + ) + # initialize func to be the noise model class instance + term_cfg.noise.func = noise_model_cls( + term_cfg.noise, num_envs=self._env.num_envs, device=self._env.device + ) + self._group_obs_class_instances.append(term_cfg.noise.func) + + # create history buffers and calculate history term dimensions + if term_cfg.history_length > 0: + group_entry_history_buffer[term_name] = CircularBuffer( + max_len=term_cfg.history_length, batch_size=self._env.num_envs, device=self._env.device + ) + old_dims = list(obs_dims) + old_dims.insert(1, term_cfg.history_length) + obs_dims = tuple(old_dims) + if term_cfg.flatten_history_dim: + obs_dims = (obs_dims[0], np.prod(obs_dims[1:])) + raise NotImplementedError("History reshaping is not implemented yet for warp.") + + self._group_obs_term_dim[group_name].append(obs_dims[1:]) + + # add term in a separate list if term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._group_obs_class_term_cfgs[group_name].append(term_cfg) + # call reset (in-case internal state should be reset at init) + term_cfg.func.reset() + + # Allocate persistent outputs for this group. + # - If group is concatenated into a flat 2D vector (N, D) with no history terms, allocate a single group + # buffer and map term outputs to contiguous slices (fast-path). + # - Otherwise allocate per-term outputs. + can_use_group_buffer = ( + self._group_obs_concatenate[group_name] + and self._group_obs_concatenate_dim[group_name] in (1, -1) + and all(cfg.history_length == 0 for cfg in group_term_cfgs) + ) + + if can_use_group_buffer: + total = int(sum(group_term_raw_dims)) + out_wp = wp.zeros((self.num_envs, total), dtype=wp.float32, device=self.device) + self._group_out_wp[group_name] = out_wp + self._group_out_torch[group_name] = wp.to_torch(out_wp) + + base_ptr = out_wp.ptr + row_stride = out_wp.strides[0] + col_stride = out_wp.strides[1] + start = 0 + for term_cfg, d in zip(group_term_cfgs, group_term_raw_dims): + out_view = wp.array( + ptr=base_ptr + start * col_stride, + dtype=wp.float32, + shape=(self.num_envs, int(d)), + strides=(row_stride, col_stride), + device=self.device, + ) + term_cfg.out_wp = out_view + term_cfg.out_torch = wp.to_torch(term_cfg.out_wp) + start += int(d) + else: + for term_cfg, d in zip(group_term_cfgs, group_term_raw_dims): + term_cfg.out_wp = wp.zeros((self.num_envs, int(d)), dtype=wp.float32, device=self.device) + term_cfg.out_torch = wp.to_torch(term_cfg.out_wp) + + # Guard: concat groups must use the Warp fast-path (standard concat dim, no history). + if self._group_obs_concatenate[group_name] and not can_use_group_buffer: + raise ValueError( + f"Observation group '{group_name}' is concatenated but cannot use the Warp" + " fast-path (requires concatenate_dim 0 or -1, and all terms history_length == 0)." + ) + + # Precompute fast-path flag and persistent per-term obs dict. + self._group_use_warp_concat[group_name] = can_use_group_buffer + self._group_obs_dict[group_name] = { + term_name: cfg.out_torch + for term_name, cfg in zip(self._group_obs_term_names[group_name], group_term_cfgs) + } + + # add history buffers for each group + self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer + + def _infer_term_dim_scalar(self, term_cfg: ObservationTermCfg) -> int: + """Infer (D,) using scalar scene info (no term execution).""" + # allow explicit override + for k in ("term_dim", "out_dim", "obs_dim"): + if k in term_cfg.params: + return int(term_cfg.params[k]) + # try explicit param first + asset_cfg = term_cfg.params.get("asset_cfg") + if asset_cfg is None: + raise ValueError(f"Observation term '{term_cfg.params}' has no asset_cfg parameter.") + # resolve selection + asset = self._env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is not None: + return int(joint_ids_wp.shape[0]) + joint_ids = getattr(asset_cfg, "joint_ids", slice(None)) + if isinstance(joint_ids, slice): + return int(getattr(asset, "num_joints", wp.to_torch(asset.data.joint_pos).shape[1])) + return int(len(joint_ids)) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py new file mode 100644 index 00000000000..67dcbc055c2 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/reward_manager.py @@ -0,0 +1,419 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Reward manager for computing reward signals for a given world. + +This file is a copy of `isaaclab.managers.reward_manager` placed under +`isaaclab_experimental` so it can evolve independently. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +import warp as wp +from prettytable import PrettyTable + +from isaaclab_experimental.utils.warp.kernels import compute_reset_scale, count_masked + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import RewardTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _sum_and_zero_masked( + # input + mask: wp.array(dtype=wp.bool), + scale: wp.array(dtype=wp.float32), + # input/output + episode_sums: wp.array(dtype=wp.float32, ndim=2), + # output + out_avg: wp.array(dtype=wp.float32), +): + term_idx, env_id = wp.tid() + if mask[env_id]: + wp.atomic_add(out_avg, term_idx, episode_sums[term_idx, env_id] * scale[0]) + episode_sums[term_idx, env_id] = 0.0 + + +@wp.kernel +def _reward_finalize( + # input + term_outs: wp.array(dtype=wp.float32, ndim=2), + term_weights: wp.array(dtype=wp.float32), + dt: float, + # input/output + reward_buf: wp.array(dtype=wp.float32), + episode_sums: wp.array(dtype=wp.float32, ndim=2), + step_reward: wp.array(dtype=wp.float32, ndim=2), +): + env_id = wp.tid() + + total = wp.float32(0.0) + for term_idx in range(term_outs.shape[0]): + weight = term_weights[term_idx] + if weight != 0.0: + raw = term_outs[term_idx, env_id] + weighted = raw * weight + # store weighted reward rate (matches old: value/dt) + step_reward[env_id, term_idx] = weighted + val = weighted * dt + total += val + episode_sums[term_idx, env_id] += val + + reward_buf[env_id] = total + + +@wp.kernel +def _reward_pre_compute_reset( + # output + reward_buf: wp.array(dtype=wp.float32), + step_reward: wp.array(dtype=wp.float32, ndim=2), + term_outs: wp.array(dtype=wp.float32, ndim=2), +): + """Reset per-step reward buffers. + + Launched with dim = (num_envs,) to reset `reward_buf` and clear the corresponding row in `step_reward`. + This works even when `step_reward.shape[1] == 0` (no terms). + """ + env_id = wp.tid() + reward_buf[env_id] = 0.0 + for term_idx in range(term_outs.shape[0]): + step_reward[env_id, term_idx] = 0.0 + term_outs[term_idx, env_id] = 0.0 + + +class RewardManager(ManagerBase): + """Manager for computing reward signals for a given world. + + The reward manager computes the total reward as a sum of the weighted reward terms. The reward + terms are parsed from a nested config class containing the reward manger's settings and reward + terms configuration. + + The reward terms are parsed from a config class containing the manager's settings and each term's + parameters. Each reward term should instantiate the :class:`RewardTermCfg` class. + + .. note:: + + The reward manager multiplies the reward term's ``weight`` with the time-step interval ``dt`` + of the environment. This is done to ensure that the computed reward terms are balanced with + respect to the chosen time-step interval in the environment. + + """ + + _env: ManagerBasedRLEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: ManagerBasedRLEnv): + """Initialize the reward manager. + + Args: + cfg: The configuration object or dictionary (``dict[str, RewardTermCfg]``). + env: The environment instance. + """ + + # create buffers to parse and store terms + self._term_names: list[str] = list() + self._term_cfgs: list[RewardTermCfg] = list() + self._class_term_cfgs: list[RewardTermCfg] = list() + + # call the base class constructor (this will parse the terms config) + super().__init__(cfg, env) + self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)} + + num_terms = len(self._term_names) + self._num_terms = num_terms + + # persistent term output buffer (raw, unweighted) laid out as (term, env) for contiguous per-term ops + self._term_outs_wp = wp.zeros((num_terms, self.num_envs), dtype=wp.float32, device=self.device) + # per-term output buffers are views into rows of `_term_outs_wp` (Warp) + self._term_out_views_wp: list[wp.array] = [] + if num_terms > 0: + row_stride = self._term_outs_wp.strides[0] + col_stride = self._term_outs_wp.strides[1] + base_ptr = self._term_outs_wp.ptr + for term_idx, term_cfg in enumerate(self._term_cfgs): + out_view = wp.array( + ptr=base_ptr + term_idx * row_stride, + dtype=wp.float32, + shape=(self.num_envs,), + strides=(col_stride,), + device=self.device, + ) + self._term_out_views_wp.append(out_view) + term_cfg.out = out_view + + # prepare extra info to store individual reward term information (warp buffers) + self._episode_sums_wp = wp.zeros((num_terms, self.num_envs), dtype=wp.float32, device=self.device) + self._episode_sum_views_wp: dict[str, wp.array] = {} + if num_terms > 0: + row_stride = self._episode_sums_wp.strides[0] + col_stride = self._episode_sums_wp.strides[1] + base_ptr = self._episode_sums_wp.ptr + for term_idx, term_name in enumerate(self._term_names): + sum_view = wp.array( + ptr=base_ptr + term_idx * row_stride, + dtype=wp.float32, + shape=(self.num_envs,), + strides=(col_stride,), + device=self.device, + ) + self._episode_sum_views_wp[term_name] = sum_view + # per-env reward buffer (Warp) + self._reward_wp = wp.zeros((self.num_envs,), dtype=wp.float32, device=self.device) + + # buffer which stores the current step reward rate for each term for each environment (warp buffer) + self._step_reward_wp = wp.zeros((self.num_envs, num_terms), dtype=wp.float32, device=self.device) + + # per-term weights stored on-device for single-kernel accumulation + self._term_weights_wp = wp.array( + [float(term_cfg.weight) for term_cfg in self._term_cfgs], dtype=wp.float32, device=self.device + ) + + # persistent reset-time logging buffers (warp buffers) + self._episode_sum_avg_wp = wp.zeros((num_terms,), dtype=wp.float32, device=self.device) + self._reset_count_wp = wp.zeros((1,), dtype=wp.int32, device=self.device) + self._reset_scale_wp = wp.zeros((1,), dtype=wp.float32, device=self.device) + + # persistent torch tensor views (helpful for CUDA graph capture) + self._reward_tensor_view = wp.to_torch(self._reward_wp) + self._step_reward_tensor_view = wp.to_torch(self._step_reward_wp) + self._term_weights_tensor_view = wp.to_torch(self._term_weights_wp) + self._episode_sum_avg_tensor_view = wp.to_torch(self._episode_sum_avg_wp) + self._reset_extras = { + "Episode_Reward/" + term_name: self._episode_sum_avg_tensor_view[term_idx] + for term_idx, term_name in enumerate(self._term_names) + } + + def __str__(self) -> str: + """Returns: A string representation for reward manager.""" + msg = f" contains {len(self._term_names)} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = "Active Reward Terms" + table.field_names = ["Index", "Name", "Weight"] + # set alignment of table columns + table.align["Name"] = "l" + table.align["Weight"] = "r" + # add info on each term + for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): + table.add_row([index, name, term_cfg.weight]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_terms(self) -> list[str]: + """Name of active reward terms.""" + return self._term_names + + """ + Operations. + """ + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, torch.Tensor]: + """Computes/reset episodic reward sums for masked envs (capturable core). + + Args: + env_ids: The specific environment indices to reset. + If None, all environments are considered. + env_mask: Boolean Warp mask of shape (num_envs,) selecting reset environments. + If provided, takes precedence over ``env_ids``. + + Returns: + A dictionary containing the information to log under the "Reward/{term_name}" key. + """ + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + if wp.get_device().is_capturing: + raise RuntimeError( + "RewardManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + + self._episode_sum_avg_wp.zero_() + self._reset_count_wp.zero_() + self._reset_scale_wp.zero_() + + wp.launch(kernel=count_masked, dim=self.num_envs, inputs=[env_mask, self._reset_count_wp], device=self.device) + wp.launch( + kernel=compute_reset_scale, + dim=1, + inputs=[self._reset_count_wp, float(self._env.max_episode_length_s), self._reset_scale_wp], + device=self.device, + ) + + if self._num_terms > 0: + wp.launch( + kernel=_sum_and_zero_masked, + dim=(self._num_terms, self.num_envs), + inputs=[env_mask, self._reset_scale_wp, self._episode_sums_wp, self._episode_sum_avg_wp], + device=self.device, + ) + + # reset all the reward terms + for term_cfg in self._class_term_cfgs: + term_cfg.func.reset(env_mask=env_mask) + + return self._reset_extras + + def compute(self, dt: float) -> torch.Tensor: + """Computes the reward signal as a weighted sum of individual terms. + + This function calls each reward term managed by the class and adds them to compute the net + reward signal. It also updates the episodic sums corresponding to individual reward terms. + + Args: + dt: The time-step interval of the environment. + + Returns: + The net reward signal of shape (num_envs,). + """ + # TODO: Investigate performance diff between two .fill_ and kernel launch + # reset computation (Warp buffers) in a single kernel launch + wp.launch( + kernel=_reward_pre_compute_reset, + dim=self.num_envs, + inputs=[self._reward_wp, self._step_reward_wp, self._term_outs_wp], + device=self.device, + ) + # iterate over all the reward terms (Python loop; per-term math is warp) + for term_cfg in self._term_cfgs: + # skip if weight is zero (kind of a micro-optimization) + if term_cfg.weight == 0.0: + continue + # compute term into the persistent warp buffer (raw, unweighted) + # NOTE: `out` is pre-zeroed every step by `_reward_pre_compute_reset`. + term_cfg.func(self._env, term_cfg.out, **term_cfg.params) + + # update total reward, episodic sums and step rewards in a single kernel launch + wp.launch( + kernel=_reward_finalize, + dim=self.num_envs, + inputs=[ + self._term_outs_wp, + self._term_weights_wp, + float(dt), + self._reward_wp, + self._episode_sums_wp, + self._step_reward_wp, + ], + device=self.device, + ) + + return self._reward_tensor_view + + """ + Operations - Term settings. + """ + + def set_term_cfg(self, term_name: str, cfg: RewardTermCfg): + """Sets the configuration of the specified term into the manager. + + Args: + term_name: The name of the reward term. + cfg: The configuration for the reward term. + + Raises: + ValueError: If the term name is not found. + """ + if term_name not in self._term_names: + raise ValueError(f"Reward term '{term_name}' not found.") + # TODO(jichuanh): it's not guaranteed that the pre-allocated output view is still valid. + # Review this in curriculum manager migration. + # set the configuration (preserve the pre-allocated output view) + term_idx = self._term_names.index(term_name) + cfg.out = self._term_out_views_wp[term_idx] + self._term_cfgs[term_idx] = cfg + # keep on-device weights in sync (call this to update weights used in compute) + self._term_weights_tensor_view[term_idx] = float(cfg.weight) + + def get_term_cfg(self, term_name: str) -> RewardTermCfg: + """Gets the configuration for the specified term. + + Args: + term_name: The name of the reward term. + + Returns: + The configuration of the reward term. + + Raises: + ValueError: If the term name is not found. + """ + if term_name not in self._term_names: + raise ValueError(f"Reward term '{term_name}' not found.") + # return the configuration + return self._term_cfgs[self._term_names.index(term_name)] + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples. + + The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. + + Args: + env_idx: The specific environment to pull the active terms from. + + Returns: + The active terms. + """ + terms = [] + step_reward_torch = self._step_reward_tensor_view + for idx, name in enumerate(self._term_names): + terms.append((name, [step_reward_torch[env_idx, idx].cpu().item()])) + return terms + + """ + Helper functions. + """ + + def _prepare_terms(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, RewardTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type RewardTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # check for valid weight type + if not isinstance(term_cfg.weight, (float, int)): + raise TypeError( + f"Weight for the term '{term_name}' is not of type float or int." + f" Received: '{type(term_cfg.weight)}'." + ) + # resolve common parameters + self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) + # add function to list + self._term_names.append(term_name) + self._term_cfgs.append(term_cfg) + # check if the term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._class_term_cfgs.append(term_cfg) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py new file mode 100644 index 00000000000..9f58cbe8ddf --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -0,0 +1,54 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental fork of :class:`isaaclab.managers.SceneEntityCfg`. + +This adds Warp-only cached selections (e.g. a joint mask) while keeping compatibility +with the stable manager stack (which type-checks against the stable SceneEntityCfg). +""" + +from __future__ import annotations + +import warp as wp + +from isaaclab.assets.articulation.base_articulation import BaseArticulation +from isaaclab.managers.scene_entity_cfg import SceneEntityCfg as _SceneEntityCfg +from isaaclab.scene import InteractiveScene + + +class SceneEntityCfg(_SceneEntityCfg): + """Scene entity configuration with an optional Warp joint mask. + + Notes: + - `joint_mask` is intended for Warp kernels only. + """ + + """Boolean mask over all joints — used by warp kernels for masked writes.""" + joint_mask: wp.array | None = None + + """Integer indices of selected joints — used for subset-sized gathers where a boolean mask + cannot provide the mapping from output index k to joint index.""" + joint_ids_wp: wp.array | None = None + + def resolve(self, scene: InteractiveScene): + # run the stable resolution first (fills joint_ids/body_ids from names/regex) + super().resolve(scene) + + # Build a Warp joint mask for articulations only. + entity = scene[self.name] + if not isinstance(entity, BaseArticulation): + return + + # Pre-allocate a full-length mask (all True for default selection). + if self.joint_ids == slice(None): + joint_ids_list = list(range(entity.num_joints)) + mask_list = [True] * entity.num_joints + else: + joint_ids_list = list(self.joint_ids) + mask_list = [False] * entity.num_joints + for idx in joint_ids_list: + mask_list[idx] = True + self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) + self.joint_ids_wp = wp.array(joint_ids_list, dtype=wp.int32, device=scene.device) diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/termination_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/termination_manager.py new file mode 100644 index 00000000000..8a768651b29 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/termination_manager.py @@ -0,0 +1,355 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Termination manager for computing done signals for a given world (experimental, Warp-first). + +This file mirrors `isaaclab.managers.termination_manager` but switches to a Warp-first, +CUDA-graph-friendly implementation: + +- Term functions write into pre-allocated Warp buffers (no per-step torch returns). +- All per-env termination buffers are persistent Warp arrays with torch views at the boundary. +- No data-dependent indexing (e.g. `nonzero`) inside `compute()`; subset updates use masks/kernels. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +import warp as wp +from prettytable import PrettyTable + +from .manager_base import ManagerBase, ManagerTermBase +from .manager_term_cfg import TerminationTermCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _termination_pre_compute_reset( + # output + term_dones: wp.array(dtype=wp.bool, ndim=2), + truncated: wp.array(dtype=wp.bool), + terminated: wp.array(dtype=wp.bool), + dones: wp.array(dtype=wp.bool), +): + """Reset per-step termination buffers. + + Launched with dim = (num_envs,) to reset per-env flags and clear the corresponding row in `term_dones`. + This works even when `term_dones.shape[1] == 0` (no terms). + """ + env_id = wp.tid() + truncated[env_id] = False + terminated[env_id] = False + dones[env_id] = False + for term_idx in range(term_dones.shape[1]): + term_dones[env_id, term_idx] = False + + +@wp.kernel +def _termination_finalize( + # input + term_dones: wp.array(dtype=wp.bool, ndim=2), + term_is_time_out: wp.array(dtype=wp.bool), + # output + truncated: wp.array(dtype=wp.bool), + terminated: wp.array(dtype=wp.bool), + dones: wp.array(dtype=wp.bool), + last_episode_dones: wp.array(dtype=wp.bool, ndim=2), +): + """Finalize termination flags and update last-episode term flags (single kernel). + + This kernel: + - reduces `term_dones` into `truncated`, `terminated`, and `dones` + - for envs where `dones=True`, copies the current `term_dones` row into `last_episode_dones` + (matching the stable manager's behavior). + """ + env_id = wp.tid() + + trunc = bool(False) + term = bool(False) + for term_idx in range(term_dones.shape[1]): + v = term_dones[env_id, term_idx] + if v: + if term_is_time_out[term_idx]: + trunc = True + else: + term = True + + done = trunc or term + truncated[env_id] = trunc + terminated[env_id] = term + dones[env_id] = done + + if done: + for term_idx in range(term_dones.shape[1]): + last_episode_dones[env_id, term_idx] = term_dones[env_id, term_idx] + + +# TODO(jichuanh): Look into wp.tile for better performance +@wp.kernel +def _termination_reset_mean_all_2d( + last_episode_dones: wp.array(dtype=wp.bool, ndim=2), + term_done_avg: wp.array(dtype=wp.float32), +): + """Compute mean(done) per term with 2D parallel accumulation.""" + env_id, term_idx = wp.tid() + num_envs = last_episode_dones.shape[0] + if num_envs > 0 and last_episode_dones[env_id, term_idx]: + wp.atomic_add(term_done_avg, term_idx, 1.0 / float(num_envs)) + + +class TerminationManager(ManagerBase): + """Manager for computing done signals for a given world (Warp-first). + + The termination manager computes the termination signal (also called dones) as a combination + of termination terms. Each termination term is a function which takes the environment and a + pre-allocated Warp boolean output buffer and fills it with per-env termination flags. + """ + + _env: ManagerBasedRLEnv + """The environment instance.""" + + def __init__(self, cfg: object, env: ManagerBasedRLEnv): + # create buffers to parse and store terms + self._term_names: list[str] = list() + self._term_cfgs: list[TerminationTermCfg] = list() + self._class_term_cfgs: list[TerminationTermCfg] = list() + + # call the base class constructor (this will parse the terms config) + super().__init__(cfg, env) + + self._term_name_to_term_idx = {name: i for i, name in enumerate(self._term_names)} + + # persistent buffers (Warp) + num_terms = len(self._term_names) + self._term_dones_wp = wp.zeros((self.num_envs, num_terms), dtype=wp.bool, device=self.device) + self._term_done_avg_wp = wp.zeros((num_terms,), dtype=wp.float32, device=self.device) + self._last_episode_dones_wp = wp.zeros((self.num_envs, num_terms), dtype=wp.bool, device=self.device) + self._truncated_wp = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + self._terminated_wp = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + self._dones_wp = wp.zeros((self.num_envs,), dtype=wp.bool, device=self.device) + + # per-term flags indicating if a term is a timeout (Warp) + self._term_is_time_out_wp = wp.array( + [bool(term_cfg.time_out) for term_cfg in self._term_cfgs], dtype=wp.bool, device=self.device + ) + + # per-term output buffers are views into the columns of `_term_dones_wp` (Warp). + # This avoids per-term temporary outputs and a per-term "store" kernel. + # TODO: Investigate performance diff whether it should using row as per env or per term + self._term_out_views_wp: list[wp.array] = [] + if num_terms > 0: + row_stride = self._term_dones_wp.strides[0] + col_stride = self._term_dones_wp.strides[1] + base_ptr = self._term_dones_wp.ptr + for term_idx, term_cfg in enumerate(self._term_cfgs): + out_view = wp.array( + ptr=base_ptr + term_idx * col_stride, + dtype=wp.bool, + shape=(self.num_envs,), + strides=(row_stride,), + device=self.device, + ) + self._term_out_views_wp.append(out_view) + term_cfg.out = out_view + + # torch tensor views (persistent) + self._term_dones_tensor_view = wp.to_torch(self._term_dones_wp) + self._last_episode_dones_tensor_view = wp.to_torch(self._last_episode_dones_wp) + self._truncated_tensor_view = wp.to_torch(self._truncated_wp) + self._terminated_tensor_view = wp.to_torch(self._terminated_wp) + self._dones_tensor_view = wp.to_torch(self._dones_wp) + self._term_done_avg_tensor_view = wp.to_torch(self._term_done_avg_wp) + self._reset_extras = { + "Episode_Termination/" + term_name: self._term_done_avg_tensor_view[term_idx] + for term_idx, term_name in enumerate(self._term_names) + } + + def __str__(self) -> str: + """Returns: A string representation for termination manager.""" + msg = f" contains {len(self._term_names)} active terms.\n" + + # create table for term information + table = PrettyTable() + table.title = "Active Termination Terms" + table.field_names = ["Index", "Name", "Time Out"] + # set alignment of table columns + table.align["Name"] = "l" + # add info on each term + for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)): + table.add_row([index, name, term_cfg.time_out]) + # convert table to string + msg += table.get_string() + msg += "\n" + + return msg + + """ + Properties. + """ + + @property + def active_terms(self) -> list[str]: + """Name of active termination terms.""" + return self._term_names + + @property + def dones(self) -> torch.Tensor: + """The net termination signal. Shape is (num_envs,).""" + return self._dones_tensor_view + + @property + def dones_wp(self) -> wp.array: + """The net termination signal. Shape is (num_envs,).""" + return self._dones_wp + + @property + def time_outs(self) -> torch.Tensor: + """The timeout signal (reaching max episode length). Shape is (num_envs,).""" + return self._truncated_tensor_view + + @property + def time_outs_wp(self) -> wp.array: + """The timeout signal (reaching max episode length). Shape is (num_envs,).""" + return self._truncated_wp + + @property + def terminated(self) -> torch.Tensor: + """The terminated signal (reaching a terminal state). Shape is (num_envs,).""" + return self._terminated_tensor_view + + @property + def terminated_wp(self) -> wp.array: + """The terminated signal (reaching a terminal state). Shape is (num_envs,).""" + return self._terminated_wp + + """ + Operations. + """ + + def reset( + self, + env_ids: Sequence[int] | torch.Tensor | None = None, + *, + env_mask: wp.array | None = None, + ) -> dict[str, torch.Tensor]: + """Reset termination stats and class terms; return pre-allocated extras. + + Args: + env_ids: The specific environment indices to reset. + If None, all environments are considered. + env_mask: Boolean Warp mask of shape (num_envs,) selecting reset environments. + If provided, takes precedence over ``env_ids``. + + Returns: + A dictionary containing the information to log under the "Termination/{term_name}" key. + """ + # Mask-first path: captured callers must provide env_mask. + if env_mask is None or not isinstance(env_mask, wp.array): + if wp.get_device().is_capturing: + raise RuntimeError( + "TerminationManager.reset requires env_mask(wp.array[bool]) during capture. " + "Do not pass env_ids on captured paths." + ) + env_mask = self._env.resolve_env_mask(env_ids=env_ids, env_mask=env_mask) + if len(self._term_names) > 0: + self._term_done_avg_wp.zero_() + wp.launch( + kernel=_termination_reset_mean_all_2d, + dim=(self.num_envs, len(self._term_names)), + inputs=[self._last_episode_dones_wp, self._term_done_avg_wp], + device=self.device, + ) + for term_cfg in self._class_term_cfgs: + term_cfg.func.reset(env_mask=env_mask) + return self._reset_extras + + @property + def episode_termination_extras(self) -> dict[str, torch.Tensor]: + """Pre-allocated reset logging extras for termination terms.""" + return self._reset_extras + + def compute(self) -> torch.Tensor: + """Computes the termination signal as union of individual terms. + + Returns: + The combined termination signal of shape (num_envs,). + """ + # reset computation (Warp buffers) in a single kernel launch + wp.launch( + kernel=_termination_pre_compute_reset, + dim=self.num_envs, + inputs=[self._term_dones_wp, self._truncated_wp, self._terminated_wp, self._dones_wp], + device=self.device, + ) + + # iterate over all the termination terms (fixed list; per-term math is Warp) + for term_cfg in self._term_cfgs: + term_cfg.func(self._env, term_cfg.out, **term_cfg.params) + + # finalize dones and update last-episode term flags (single kernel launch) + wp.launch( + kernel=_termination_finalize, + dim=self.num_envs, + inputs=[ + self._term_dones_wp, + self._term_is_time_out_wp, + self._truncated_wp, + self._terminated_wp, + self._dones_wp, + self._last_episode_dones_wp, + ], + device=self.device, + ) + + return self._dones_tensor_view + + def get_term(self, name: str) -> torch.Tensor: + """Returns the termination term value at current step with the specified name. + + Returns: + The corresponding termination term value. Shape is (num_envs,). + """ + return self._term_dones_tensor_view[:, self._term_name_to_term_idx[name]] + + def get_active_iterable_terms(self, env_idx: int) -> Sequence[tuple[str, Sequence[float]]]: + """Returns the active terms as iterable sequence of tuples for debug/inspection.""" + terms = [] + for i, key in enumerate(self._term_names): + terms.append((key, [self._term_dones_tensor_view[env_idx, i].float().cpu().item()])) + return terms + + """ + Helper functions. + """ + + def _prepare_terms(self): + # check if config is dict already + if isinstance(self.cfg, dict): + cfg_items = self.cfg.items() + else: + cfg_items = self.cfg.__dict__.items() + # iterate over all the terms + for term_name, term_cfg in cfg_items: + # check for non config + if term_cfg is None: + continue + # check for valid config type + if not isinstance(term_cfg, TerminationTermCfg): + raise TypeError( + f"Configuration for the term '{term_name}' is not of type TerminationTermCfg." + f" Received: '{type(term_cfg)}'." + ) + # resolve common parameters (env, out) + self._resolve_common_term_cfg(term_name, term_cfg, min_argc=2) + # add function to list + self._term_names.append(term_name) + self._term_cfgs.append(term_cfg) + # check if the term is a class + if isinstance(term_cfg.func, ManagerTermBase): + self._class_term_cfgs.append(term_cfg) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/__init__.py new file mode 100644 index 00000000000..b26627f06d4 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/__init__.py @@ -0,0 +1,12 @@ +# ########## New ########## +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Sub-module containing experimental buffer overrides.""" + +from isaaclab.utils.buffers import * # noqa: F401,F403 + +# Override with experimental implementation +from .circular_buffer import CircularBuffer # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/circular_buffer.py b/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/circular_buffer.py new file mode 100644 index 00000000000..b25f91becbc --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/buffers/circular_buffer.py @@ -0,0 +1,194 @@ +# ########## New ########## +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from collections.abc import Sequence + +import torch +import warp as wp + + +class CircularBuffer: + """Circular buffer for storing a history of batched tensor data. + + This class implements a circular buffer for storing a history of batched tensor data. The buffer is + initialized with a maximum length and a batch size. The data is stored in a circular fashion, and the + data can be retrieved in a LIFO (Last-In-First-Out) fashion. The buffer is designed to be used in + multi-environment settings, where each environment has its own data. + + The shape of the appended data is expected to be (batch_size, ...), where the first dimension is the + batch dimension. Correspondingly, the shape of the ring buffer is (max_len, batch_size, ...). + """ + + def __init__(self, max_len: int, batch_size: int, device: str): + """Initialize the circular buffer. + + Args: + max_len: The maximum length of the circular buffer. The minimum allowed value is 1. + batch_size: The batch dimension of the data. + device: The device used for processing. + + Raises: + ValueError: If the buffer size is less than one. + """ + if max_len < 1: + raise ValueError(f"The buffer size should be greater than zero. However, it is set to {max_len}!") + # set the parameters + self._batch_size = batch_size + self._device = device + self._ALL_INDICES = torch.arange(batch_size, device=device) + + # max length tensor for comparisons + self._max_len = torch.full((batch_size,), max_len, dtype=torch.int, device=device) + # number of data pushes passed since the last call to :meth:`reset` + self._num_pushes = torch.zeros(batch_size, dtype=torch.long, device=device) + # the pointer to the current head of the circular buffer (-1 means not initialized) + self._pointer: int = -1 + # the actual buffer for data storage + # note: this is initialized on the first call to :meth:`append` + self._buffer: torch.Tensor = None # type: ignore + + """ + Properties. + """ + + @property + def batch_size(self) -> int: + """The batch size of the ring buffer.""" + return self._batch_size + + @property + def device(self) -> str: + """The device used for processing.""" + return self._device + + @property + def max_length(self) -> int: + """The maximum length of the ring buffer.""" + return int(self._max_len[0].item()) + + @property + def current_length(self) -> torch.Tensor: + """The current length of the buffer. Shape is (batch_size,). + + Since the buffer is circular, the current length is the minimum of the number of pushes + and the maximum length. + """ + return torch.minimum(self._num_pushes, self._max_len) + + @property + def buffer(self) -> torch.Tensor: + """Complete circular buffer with most recent entry at the end and oldest entry at the beginning. + Returns: + Complete circular buffer with most recent entry at the end and oldest entry at the beginning + of dimension 1. The shape is [batch_size, max_length, data.shape[1:]]. + """ + buf = self._buffer.clone() + buf = torch.roll(buf, shifts=self.max_length - self._pointer - 1, dims=0) + return torch.transpose(buf, dim0=0, dim1=1) + + """ + Operations. + """ + + def reset( + self, + batch_ids: Sequence[int] | torch.Tensor | wp.array | None = None, + *, + env_mask: torch.Tensor | wp.array | None = None, + ): + """Reset the circular buffer at the specified batch indices. + + Args: + batch_ids: Elements to reset in the batch dimension. Default is None, which resets all the batch indices. + env_mask: Boolean mask of shape (batch_size,) selecting elements to reset. + If provided, it takes precedence over ``batch_ids``. + """ + if env_mask is not None: + if isinstance(env_mask, wp.array): + env_mask = wp.to_torch(env_mask) + elif not isinstance(env_mask, torch.Tensor): + raise TypeError(f"Unsupported env_mask type: {type(env_mask)}") + if env_mask.dtype != torch.bool: + env_mask = env_mask.to(dtype=torch.bool) + if str(env_mask.device) != self._device: + env_mask = env_mask.to(self._device) + if env_mask.ndim != 1 or env_mask.shape[0] != self._batch_size: + raise ValueError(f"Expected env_mask shape ({self._batch_size},), received {tuple(env_mask.shape)}.") + batch_ids = env_mask + elif isinstance(batch_ids, wp.array): + batch_ids = wp.to_torch(batch_ids) + + # resolve all indices + if batch_ids is None: + batch_ids = slice(None) + # reset the number of pushes for the specified batch indices + self._num_pushes[batch_ids] = 0 + if self._buffer is not None: + # set buffer at batch_id reset indices to 0.0 so that the buffer() getter + # returns the cleared circular buffer after reset. + self._buffer[:, batch_ids, :] = 0.0 + + def append(self, data: torch.Tensor): + """Append the data to the circular buffer. + + Args: + data: The data to append to the circular buffer. The first dimension should be the batch dimension. + Shape is (batch_size, ...). + + Raises: + ValueError: If the input data has a different batch size than the buffer. + """ + # check the batch size + if data.shape[0] != self.batch_size: + raise ValueError(f"The input data has '{data.shape[0]}' batch size while expecting '{self.batch_size}'") + + # move the data to the device + data = data.to(self._device) + # at the first call, initialize the buffer size + if self._buffer is None: + self._pointer = -1 + self._buffer = torch.empty((self.max_length, *data.shape), dtype=data.dtype, device=self._device) + # move the head to the next slot + self._pointer = (self._pointer + 1) % self.max_length + # add the new data to the last layer + self._buffer[self._pointer] = data + # Check for batches with zero pushes and initialize all values in batch to first append + is_first_push = self._num_pushes == 0 + if torch.any(is_first_push): + self._buffer[:, is_first_push] = data[is_first_push] + # increment number of number of pushes for all batches + self._num_pushes += 1 + + def __getitem__(self, key: torch.Tensor) -> torch.Tensor: + """Retrieve the data from the circular buffer in last-in-first-out (LIFO) fashion. + + If the requested index is larger than the number of pushes since the last call to :meth:`reset`, + the oldest stored data is returned. + + Args: + key: The index to retrieve from the circular buffer. The index should be less than the number of pushes + since the last call to :meth:`reset`. Shape is (batch_size,). + + Returns: + The data from the circular buffer. Shape is (batch_size, ...). + + Raises: + ValueError: If the input key has a different batch size than the buffer. + RuntimeError: If the buffer is empty. + """ + # check the batch size + if len(key) != self.batch_size: + raise ValueError(f"The argument 'key' has length {key.shape[0]}, while expecting {self.batch_size}") + # check if the buffer is empty + if torch.any(self._num_pushes == 0) or self._buffer is None: + raise RuntimeError("Attempting to retrieve data on an empty circular buffer. Please append data first.") + + # admissible lag + valid_keys = torch.minimum(key, self._num_pushes - 1) + # the index in the circular buffer (pointer points to the last+1 index) + index_in_buffer = torch.remainder(self._pointer - valid_keys, self.max_length) + # return output + return self._buffer[index_in_buffer, self._ALL_INDICES] diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py new file mode 100644 index 00000000000..c07cea7982b --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Manager call switch for routing manager stage calls through stable/warp/captured paths.""" + +from __future__ import annotations + +import importlib +import json +import os +from enum import IntEnum +from typing import Any + +import warp as wp + +from isaaclab.utils.timer import Timer + + +class ManagerCallMode(IntEnum): + """Execution mode for manager stage calls. + + * ``STABLE`` (0): Call stable Python manager implementations from :mod:`isaaclab.managers`. + * ``WARP_NOT_CAPTURED`` (1): Call Warp-compatible implementations without CUDA graph capture. + * ``WARP_CAPTURED`` (2): Call Warp implementations with CUDA graph capture/replay. + """ + + STABLE = 0 + WARP_NOT_CAPTURED = 1 + WARP_CAPTURED = 2 + + +class ManagerCallSwitch: + """Per-manager call switch for stable/warp/captured execution. + + Routes each manager stage call through the configured execution path: + stable Python, Warp (eager), or Warp (captured CUDA graph). Optionally + wraps each call in a :class:`Timer` context for profiling. + """ + + DEFAULT_CONFIG: dict[str, int] = {"default": 2} + DEFAULT_KEY = "default" + MANAGER_NAMES: tuple[str, ...] = ( + "ActionManager", + "ObservationManager", + "EventManager", + "RecorderManager", + "CommandManager", + "TerminationManager", + "RewardManager", + "CurriculumManager", + "Scene", + ) + # FIXME: Scene_write_data_to_sim calls articulation._apply_actuator_model which + # uses wp.to_torch + torch indexing -- not capture-safe on this branch. + # Cap Scene stages to WARP_NOT_CAPTURED until the articulation layer is capture-ready. + MAX_MODE_OVERRIDES: dict[str, int] = {"Scene": ManagerCallMode.WARP_NOT_CAPTURED} + + ENV_VAR = "MANAGER_CALL_CONFIG" + """Environment variable name for the JSON config string. + + Example usage:: + + MANAGER_CALL_CONFIG='{"RewardManager": 0, "default": 2}' python train.py ... + """ + + def __init__(self): + self._wp_graphs: dict[str, Any] = {} + self._wp_results: dict[str, Any] = {} + self._cfg = self._load_cfg(os.environ.get(self.ENV_VAR)) + print("[INFO] ManagerCallSwitch configuration:") + print(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}") + for manager_name in self.MANAGER_NAMES: + print(f" - {manager_name}: {int(self.get_mode_for_manager(manager_name))}") + + # ------------------------------------------------------------------ + # Graph management + # ------------------------------------------------------------------ + + def invalidate_graphs(self) -> None: + """Invalidate cached capture graphs and their cached return values.""" + self._wp_graphs.clear() + self._wp_results.clear() + + # ------------------------------------------------------------------ + # Stage dispatch + # ------------------------------------------------------------------ + + def call_stage( + self, + *, + stage: str, + warp_call: dict[str, Any], + stable_call: dict[str, Any] | None = None, + timer: bool = False, + ) -> Any: + """Run the stage according to configured mode, optionally wrapped in a :class:`Timer`. + + A call spec dict supports the following keys: + + * ``fn`` (required): The callable to invoke. + * ``args`` (optional): Positional arguments tuple. + * ``kwargs`` (optional): Keyword arguments dict. + * ``output`` (optional): A ``Callable[[Any], Any]`` that transforms the raw + return value into the final output. For captured stages the raw value is + ``None``. When omitted, the raw return value is used as-is. + + Args: + stage: Stage identifier in the form ``"ManagerName_function_name"``. + warp_call: Call spec for the warp path (eager or captured). + stable_call: Call spec for the stable (torch) path. Defaults to ``None``. + timer: Whether to wrap execution in a :class:`Timer`. Defaults to ``True`` + (controlled by the global :attr:`Timer.enable` class-level toggle). + Pass a module-level flag like ``TIMER_ENABLED_STEP`` to make timing + conditional on that flag. + + Returns: + The (possibly transformed) return value of the stage. + """ + with Timer(name=stage, msg=f"{stage} took:", enable=timer, time_unit="us"): + return self._dispatch(stage, stable_call, warp_call) + + def _dispatch( + self, + stage: str, + stable_call: dict[str, Any] | None, + warp_call: dict[str, Any], + ) -> Any: + """Select call path based on mode, execute, and apply output.""" + mode = self.get_mode_for_manager(self._manager_name_from_stage(stage)) + if mode == ManagerCallMode.STABLE: + if stable_call is None: + raise ValueError(f"Stage '{stage}' is configured as STABLE (mode=0) but no stable_call was provided.") + call, result = stable_call, self._run_call(stable_call) + elif mode == ManagerCallMode.WARP_CAPTURED: + call, result = warp_call, self._wp_capture_or_launch(stage, warp_call) + else: + call, result = warp_call, self._run_call(warp_call) + + output_fn = call.get("output") + return output_fn(result) if output_fn is not None else result + + # ------------------------------------------------------------------ + # Manager resolution + # ------------------------------------------------------------------ + + def _manager_name_from_stage(self, stage: str) -> str: + if "_" not in stage: + raise ValueError(f"Invalid stage '{stage}'. Expected '{{manager_name}}_{{function_name}}'.") + return stage.split("_", 1)[0] + + def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode: + """Return the resolved execution mode for the given manager.""" + default_key = next(iter(self.DEFAULT_CONFIG)) + mode_value = self._cfg.get(manager_name, self._cfg[default_key]) + return ManagerCallMode(mode_value) + + def resolve_manager_class(self, manager_name: str) -> type: + """Import and return the manager class for the configured mode.""" + module_name = ( + "isaaclab.managers" + if self.get_mode_for_manager(manager_name) == ManagerCallMode.STABLE + else "isaaclab_experimental.managers" + ) + module = importlib.import_module(module_name) + if not hasattr(module, manager_name): + raise AttributeError(f"Manager '{manager_name}' not found in module '{module_name}'.") + return getattr(module, manager_name) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _run_call(self, call: dict[str, Any]) -> Any: + """Execute a single call spec eagerly.""" + return call["fn"](*call.get("args", ()), **call.get("kwargs", {})) + + def _wp_capture_or_launch(self, stage: str, call: dict[str, Any]) -> Any: + """Capture Warp CUDA graph on first call, then replay. + + The return value from the first (capture) run is cached and returned + on every subsequent replay. This ensures captured stages return the + same references (e.g. tensor views) as eager stages. + """ + graph = self._wp_graphs.get(stage) + if graph is None: + with wp.ScopedCapture() as capture: + result = call["fn"](*call.get("args", ()), **call.get("kwargs", {})) + self._wp_graphs[stage] = capture.graph + self._wp_results[stage] = result + wp.capture_launch(self._wp_graphs[stage]) + return self._wp_results[stage] + + def _load_cfg(self, cfg_source: str | None) -> dict[str, int]: + if cfg_source is not None and not isinstance(cfg_source, str): + raise TypeError(f"cfg_source must be a string or None, got: {type(cfg_source)}") + if cfg_source is None or cfg_source.strip() == "": + cfg = dict(self.DEFAULT_CONFIG) + else: + parsed = json.loads(cfg_source) + if not isinstance(parsed, dict): + raise TypeError("manager_call_config must decode to a dict.") + + cfg = dict(parsed) + if self.DEFAULT_KEY not in cfg: + cfg[self.DEFAULT_KEY] = self.DEFAULT_CONFIG[self.DEFAULT_KEY] + + # validation + for manager_name, mode_value in cfg.items(): + if not isinstance(mode_value, int): + raise TypeError( + f"manager_call_config value for '{manager_name}' must be int (0/1/2), got: {type(mode_value)}" + ) + try: + ManagerCallMode(mode_value) + except ValueError as exc: + raise ValueError( + f"Invalid manager_call_config value for '{manager_name}': {mode_value}. Expected 0/1/2." + ) from exc + + # Apply MAX_MODE_OVERRIDES: bake caps into the resolved config so + # get_mode_for_manager never needs per-call branching. + default_mode = cfg[self.DEFAULT_KEY] + for name, max_mode in self.MAX_MODE_OVERRIDES.items(): + resolved = cfg.get(name, default_mode) + if resolved > max_mode: + cfg[name] = max_mode + + return cfg diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/__init__.py new file mode 100644 index 00000000000..3563b166642 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native modifier implementations (experimental). + +Re-exports stable configs and base classes, then overrides the function-based +modifiers (``scale``, ``bias``, ``clip``) with Warp-native versions that +operate in-place on ``wp.array``. + +Calling convention (matches Warp MDP terms):: + + modifier.func(data_wp, **params) -> None # in-place on wp.array +""" + +from .modifier import bias, clip, scale # noqa: F401 +from .modifier_base import ModifierBase # noqa: F401 + +# Override with Warp-native implementations +from .modifier_cfg import ModifierCfg # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier.py b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier.py new file mode 100644 index 00000000000..887db5e4ee0 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native function-based modifiers (experimental). + +Each modifier takes a ``wp.array`` as its first argument and operates **in-place** +via ``wp.launch``. The calling convention mirrors Warp MDP terms:: + + modifier.func(data_wp, **params) -> None +""" + +from __future__ import annotations + +import warp as wp + +# -- scale -------------------------------------------------------------------- + + +@wp.kernel +def _scale_kernel(data: wp.array(dtype=wp.float32, ndim=2), multiplier: wp.float32): + i, j = wp.tid() + data[i, j] = data[i, j] * multiplier + + +def scale(data: wp.array, multiplier: float) -> None: + """Scale all elements of *data* by *multiplier* in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.modifiers.scale`. + + Args: + data: The observation buffer to modify. Shape ``(num_envs, D)``. + multiplier: Scalar multiplier. + """ + wp.launch(_scale_kernel, dim=data.shape, inputs=[data, float(multiplier)], device=data.device) + + +# -- bias --------------------------------------------------------------------- + + +@wp.kernel +def _bias_kernel(data: wp.array(dtype=wp.float32, ndim=2), value: wp.float32): + i, j = wp.tid() + data[i, j] = data[i, j] + value + + +def bias(data: wp.array, value: float) -> None: + """Add a uniform *value* to all elements of *data* in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.modifiers.bias`. + + Args: + data: The observation buffer to modify. Shape ``(num_envs, D)``. + value: Scalar bias to add. + """ + wp.launch(_bias_kernel, dim=data.shape, inputs=[data, float(value)], device=data.device) + + +# -- clip --------------------------------------------------------------------- + + +@wp.kernel +def _clip_kernel(data: wp.array(dtype=wp.float32, ndim=2), lo: wp.float32, hi: wp.float32): + i, j = wp.tid() + data[i, j] = wp.clamp(data[i, j], lo, hi) + + +def clip(data: wp.array, bounds: tuple[float | None, float | None]) -> None: + """Clamp all elements of *data* to [lo, hi] in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.modifiers.clip`. + + Args: + data: The observation buffer to modify. Shape ``(num_envs, D)``. + bounds: ``(min, max)`` tuple. ``None`` means no bound on that side. + """ + lo = float(bounds[0]) if bounds[0] is not None else float(-1e38) + hi = float(bounds[1]) if bounds[1] is not None else float(1e38) + wp.launch(_clip_kernel, dim=data.shape, inputs=[data, lo, hi], device=data.device) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_base.py b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_base.py new file mode 100644 index 00000000000..083e1576e89 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_base.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native modifier base class (experimental).""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import warp as wp + +if TYPE_CHECKING: + from .modifier_cfg import ModifierCfg + + +class ModifierBase(ABC): + """Base class for Warp-native class-based modifiers. + + Experimental fork of :class:`isaaclab.utils.modifiers.ModifierBase` adapted for the + Warp-first calling convention. Subclasses operate **in-place** on ``wp.array`` + buffers and return ``None``. + + A class implementation of a modifier can be used to store state information between + calls. This is useful for modifiers that require stateful operations, such as + rolling averages, delays, or decaying filters. + """ + + def __init__(self, cfg: ModifierCfg, data_dim: tuple[int, ...], device: str) -> None: + """Initializes the modifier class. + + Args: + cfg: Configuration parameters. + data_dim: The dimensions of the data to be modified. First element is the + batch size (number of environments). + device: The device to run the modifier on. + """ + self._cfg = cfg + self._data_dim = data_dim + self._device = device + + @abstractmethod + def reset(self, env_mask: wp.array | None = None): + """Resets the modifier. + + Args: + env_mask: Boolean env mask of shape ``(num_envs,)`` selecting environments + to reset. Defaults to None, in which case all environments are + considered. + """ + raise NotImplementedError + + @abstractmethod + def __call__(self, data: wp.array) -> None: + """Apply the modification in-place. + + Args: + data: The ``wp.array`` buffer to modify. Shape should match the + *data_dim* passed during initialization. + """ + raise NotImplementedError diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_cfg.py new file mode 100644 index 00000000000..7cc22eaee6a --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/modifiers/modifier_cfg.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native modifier configuration (experimental).""" + +from collections.abc import Callable +from dataclasses import MISSING +from typing import Any + +from isaaclab.utils import configclass + + +@configclass +class ModifierCfg: + """Configuration parameters for Warp-native modifiers. + + Experimental fork of :class:`isaaclab.utils.modifiers.ModifierCfg` adapted for the + Warp-first calling convention where modifier functions operate **in-place** on a + ``wp.array`` buffer and return ``None``. + """ + + func: Callable[..., None] = MISSING + """Function or callable class used by modifier. + + The function must take a ``wp.array`` as the first argument and operate on it + **in-place** (no return value). The remaining arguments are specified in the + :attr:`params` attribute. + + It also supports callable classes that implement ``__call__()``. In this case the + class should inherit from :class:`ModifierBase` and implement the required methods. + """ + + params: dict[str, Any] = dict() + """The parameters to be passed to the function or callable class as keyword arguments. + + Defaults to an empty dictionary. + """ diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/noise/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/__init__.py new file mode 100644 index 00000000000..16aa5ab0733 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native noise implementations (experimental). + +Re-exports stable configs and classes, then overrides the function-based +noise models (``constant_noise``, ``uniform_noise``, ``gaussian_noise``) +and their configs with Warp-native versions that operate **in-place** on +``wp.array``. + +Calling convention (matches Warp MDP terms):: + + noise_cfg.func(data_wp, noise_cfg) -> None # in-place on wp.array +""" + +from isaaclab.utils.noise import * # noqa: F401,F403 + +# Override with Warp-native implementations +from .noise_cfg import ConstantNoiseCfg, GaussianNoiseCfg, NoiseCfg, UniformNoiseCfg # noqa: F401 +from .noise_model import NoiseModel, constant_noise, gaussian_noise, uniform_noise # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_cfg.py new file mode 100644 index 00000000000..5c7045e65a6 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_cfg.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native noise configuration (experimental).""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import MISSING +from typing import Literal + +import warp as wp + +from isaaclab.utils import configclass + +from . import noise_model + + +@configclass +class NoiseCfg: + """Configuration for a Warp-native noise term. + + Experimental fork of :class:`isaaclab.utils.noise.NoiseCfg` adapted for the + Warp-first calling convention where noise functions operate **in-place** on a + ``wp.array`` buffer and return ``None``. + """ + + func: Callable[[wp.array, NoiseCfg], None] = MISSING + """The function to be called for applying the noise. + + The function must take a ``wp.array`` as the first argument and the noise + configuration as the second argument. It operates **in-place** (no return value). + """ + + operation: Literal["add", "scale", "abs"] = "add" + """The operation to apply the noise on the data. Defaults to ``"add"``.""" + + +@configclass +class ConstantNoiseCfg(NoiseCfg): + """Configuration for a constant noise term (Warp-native).""" + + func = noise_model.constant_noise + + bias: float = 0.0 + """The bias to add. Defaults to 0.0.""" + + +@configclass +class UniformNoiseCfg(NoiseCfg): + """Configuration for a uniform noise term (Warp-native).""" + + func = noise_model.uniform_noise + + n_min: float = -1.0 + """The minimum value of the noise. Defaults to -1.0.""" + n_max: float = 1.0 + """The maximum value of the noise. Defaults to 1.0.""" + + +@configclass +class GaussianNoiseCfg(NoiseCfg): + """Configuration for a gaussian noise term (Warp-native).""" + + func = noise_model.gaussian_noise + + mean: float = 0.0 + """The mean of the noise. Defaults to 0.0.""" + std: float = 1.0 + """The standard deviation of the noise. Defaults to 1.0.""" diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_model.py b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_model.py new file mode 100644 index 00000000000..acba926ea48 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/noise/noise_model.py @@ -0,0 +1,200 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-native noise functions and models (experimental). + +Each noise function takes a ``wp.array`` as its first argument and operates **in-place** +via ``wp.launch``. The calling convention mirrors the stable noise interface:: + + noise_cfg.func(data_wp, noise_cfg) -> None + +Random noise kernels (gaussian, uniform) consume the shared per-env Warp RNG state +(``rng_state_wp``) that is set on the config at manager prep time from +``env.rng_state_wp``. See :func:`initialize_rng_state` in +``isaaclab_experimental.envs.manager_based_env_warp`` for the initialization pattern. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +if TYPE_CHECKING: + from . import noise_cfg + +## +# Operation mode mapping. +## + +_OPERATION_MAP: dict[str, int] = {"add": 0, "scale": 1, "abs": 2} + +## +# Noise as functions. +## + + +# -- constant ----------------------------------------------------------------- + + +@wp.kernel +def _apply_constant_noise( + out: wp.array(dtype=wp.float32, ndim=2), + bias: wp.float32, + operation: wp.int32, +): + env_id = wp.tid() + for j in range(out.shape[1]): + if operation == 0: + out[env_id, j] = out[env_id, j] + bias + elif operation == 1: + out[env_id, j] = out[env_id, j] * bias + else: + out[env_id, j] = bias + + +def constant_noise(data: wp.array, cfg: noise_cfg.ConstantNoiseCfg) -> None: + """Applies a constant noise bias to a given data set in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.noise.constant_noise`. + + Args: + data: The data buffer to modify. Shape ``(num_envs, D)``. + cfg: The configuration parameters for constant noise. + """ + wp.launch( + _apply_constant_noise, + dim=data.shape[0], + inputs=[data, float(cfg.bias), _OPERATION_MAP[cfg.operation]], + device=data.device, + ) + + +# -- uniform ------------------------------------------------------------------ + + +@wp.kernel +def _apply_uniform_noise( + out: wp.array(dtype=wp.float32, ndim=2), + rng_state: wp.array(dtype=wp.uint32), + n_min: wp.float32, + n_max: wp.float32, + operation: wp.int32, +): + env_id = wp.tid() + state = rng_state[env_id] + for j in range(out.shape[1]): + n = wp.randf(state, n_min, n_max) + if operation == 0: + out[env_id, j] = out[env_id, j] + n + elif operation == 1: + out[env_id, j] = out[env_id, j] * n + else: + out[env_id, j] = n + rng_state[env_id] = state + + +def uniform_noise(data: wp.array, cfg: noise_cfg.UniformNoiseCfg) -> None: + """Applies uniform noise to a given data set in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.noise.uniform_noise`. + + Args: + data: The data buffer to modify. Shape ``(num_envs, D)``. + cfg: The configuration parameters for uniform noise. + """ + wp.launch( + _apply_uniform_noise, + dim=data.shape[0], + inputs=[data, cfg.rng_state_wp, float(cfg.n_min), float(cfg.n_max), _OPERATION_MAP[cfg.operation]], + device=data.device, + ) + + +# -- gaussian ----------------------------------------------------------------- + + +@wp.kernel +def _apply_gaussian_noise( + out: wp.array(dtype=wp.float32, ndim=2), + rng_state: wp.array(dtype=wp.uint32), + mean: wp.float32, + std: wp.float32, + operation: wp.int32, +): + env_id = wp.tid() + state = rng_state[env_id] + for j in range(out.shape[1]): + n = mean + std * wp.randn(state) + if operation == 0: + out[env_id, j] = out[env_id, j] + n + elif operation == 1: + out[env_id, j] = out[env_id, j] * n + else: + out[env_id, j] = n + rng_state[env_id] = state + + +def gaussian_noise(data: wp.array, cfg: noise_cfg.GaussianNoiseCfg) -> None: + """Applies gaussian noise to a given data set in-place. + + Warp-native drop-in replacement for :func:`isaaclab.utils.noise.gaussian_noise`. + + Args: + data: The data buffer to modify. Shape ``(num_envs, D)``. + cfg: The configuration parameters for gaussian noise. + """ + wp.launch( + _apply_gaussian_noise, + dim=data.shape[0], + inputs=[data, cfg.rng_state_wp, float(cfg.mean), float(cfg.std), _OPERATION_MAP[cfg.operation]], + device=data.device, + ) + + +## +# Noise models as classes. +## + + +class NoiseModel: + """Warp-native base class for noise models. + + Experimental fork of :class:`isaaclab.utils.noise.NoiseModel` adapted for the + Warp-first calling convention where noise is applied **in-place** on ``wp.array``. + """ + + def __init__(self, noise_model_cfg: noise_cfg.NoiseModelCfg, num_envs: int, device: str): + """Initialize the noise model. + + Args: + noise_model_cfg: The noise configuration to use. + num_envs: The number of environments. + device: The device to use for the noise model. + """ + self._noise_model_cfg = noise_model_cfg + self._num_envs = num_envs + self._device = device + + def reset(self, env_mask: wp.array | None = None): + """Reset the noise model. + + This method can be implemented by derived classes to reset the noise model. + This is useful when implementing temporal noise models such as random walk. + + Args: + env_mask: Boolean env mask of shape ``(num_envs,)`` selecting environments + to reset. Defaults to None, in which case all environments are + considered. + """ + pass + + def __call__(self, data: wp.array) -> None: + """Apply the noise to the data in-place. + + Args: + data: The data to apply the noise to. Shape is ``(num_envs, ...)``. + """ + self._noise_model_cfg.noise_cfg.func(data, self._noise_model_cfg.noise_cfg) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/torch_utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/torch_utils.py new file mode 100644 index 00000000000..1e611ec3fef --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/torch_utils.py @@ -0,0 +1,32 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch + + +def clone_obs_buffer( + obs_buffer: dict[str, torch.Tensor | dict[str, torch.Tensor]], +) -> dict[str, torch.Tensor | dict[str, torch.Tensor]]: + """Clone a nested observation buffer, using :meth:`torch.Tensor.clone` for every leaf tensor. + + This avoids the overhead of :func:`copy.deepcopy` while still producing an independent + snapshot of the buffer (new dict objects + cloned tensor storage). + + Args: + obs_buffer: Observation buffer mapping group names to either a single concatenated + tensor or a dict of per-term tensors. + + Returns: + A new dictionary with the same structure whose tensors are clones of the originals. + """ + result: dict[str, torch.Tensor | dict[str, torch.Tensor]] = {} + for key, value in obs_buffer.items(): + if isinstance(value, torch.Tensor): + result[key] = value.clone() + else: # dict[str, torch.Tensor] + result[key] = {k: v.clone() for k, v in value.items()} + return result diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py new file mode 100644 index 00000000000..30f0fb17d6d --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp utility functions and shared kernels for isaaclab_experimental.""" + +from .kernels import compute_reset_scale, count_masked +from .utils import WarpCapturable, resolve_1d_mask, resolve_asset_cfg, wrap_to_pi diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/kernels.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/kernels.py new file mode 100644 index 00000000000..8d3f9e65d49 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/kernels.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Shared Warp kernels used across multiple managers.""" + +from __future__ import annotations + +import warp as wp + + +@wp.kernel +def count_masked( + mask: wp.array(dtype=wp.bool), + out_count: wp.array(dtype=wp.int32), +): + """Count the number of True entries in a boolean mask. + + ``out_count`` must be zeroed before launch. Result is stored in ``out_count[0]``. + Launched with ``dim = num_envs``. + """ + env_id = wp.tid() + if mask[env_id]: + wp.atomic_add(out_count, 0, 1) + + +@wp.kernel +def compute_reset_scale( + reset_count: wp.array(dtype=wp.int32), + divisor: wp.float32, + out_scale: wp.array(dtype=wp.float32), +): + """Compute ``1 / (count * divisor)`` scaling factor from a reset count. + + Pass ``divisor = 1.0`` for plain ``1 / count`` (e.g. command manager). + Pass ``divisor = max_episode_length_s`` for reward-style normalization. + + Launched with ``dim = 1``. + """ + count = reset_count[0] + if count > 0: + out_scale[0] = 1.0 / (wp.float32(count) * divisor) + else: + out_scale[0] = 0.0 diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py new file mode 100644 index 00000000000..0d3a85f8020 --- /dev/null +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py @@ -0,0 +1,192 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +from isaaclab_experimental.managers.scene_entity_cfg import SceneEntityCfg + +import warp as wp + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _set_mask_from_ids( + mask: wp.array(dtype=wp.bool), + ids: wp.array(dtype=wp.int32), +): + """Set ``mask[ids[i]] = True`` for each thread *i*.""" + i = wp.tid() + mask[ids[i]] = True + + +def resolve_1d_mask( + *, + ids: Sequence[int] | slice | wp.array | torch.Tensor | None, + mask: wp.array | torch.Tensor | None, + all_mask: wp.array, + scratch_mask: wp.array, + device: str, +) -> wp.array: + """Resolve ids/mask into a warp boolean mask. + + Matches the contract of ``ArticulationData._resolve_1d_mask`` on dev/newton. + Callers must provide pre-allocated ``all_mask`` (all-True) and ``scratch_mask`` + (reusable working buffer). No allocations happen inside this function. + + Args: + ids: Indices to set to ``True``. ``None`` or ``slice(None)`` means all. + mask: Explicit boolean mask. If provided, returned directly (after + torch→warp normalization if needed). Takes precedence over *ids*. + all_mask: Pre-allocated all-True mask of shape ``(size,)``, returned + when both *ids* and *mask* are ``None``. + scratch_mask: Pre-allocated scratch mask of shape ``(size,)``, filled + in-place when *ids* are provided. + device: Warp device string. + + Returns: + A ``wp.array(dtype=wp.bool)`` — ``mask``, ``all_mask``, or ``scratch_mask``. + """ + # Fast path: explicit mask provided. + if mask is not None: + if isinstance(mask, torch.Tensor): + if mask.dtype != torch.bool: + mask = mask.to(dtype=torch.bool) + if str(mask.device) != device: + mask = mask.to(device) + return wp.from_torch(mask, dtype=wp.bool) + return mask + + # Fast path: all ids. + if ids is None or (isinstance(ids, slice) and ids == slice(None)): + return all_mask + + # Normalize slice into explicit indices. + if isinstance(ids, slice): + start, stop, step = ids.indices(scratch_mask.shape[0]) + ids = list(range(start, stop, step)) + elif not isinstance(ids, (torch.Tensor, wp.array)): + ids = list(ids) + + # Prepare output mask. + scratch_mask.fill_(False) + + # Normalize ids to wp.int32 array and launch kernel. + if isinstance(ids, torch.Tensor): + if ids.numel() == 0: + return scratch_mask + if str(ids.device) != device: + ids = ids.to(device) + if ids.dtype != torch.int32: + ids = ids.to(dtype=torch.int32) + if not ids.is_contiguous(): + ids = ids.contiguous() + ids_wp = wp.from_torch(ids, dtype=wp.int32) + elif isinstance(ids, wp.array): + if ids.shape[0] == 0: + return scratch_mask + ids_wp = ids + else: + if len(ids) == 0: + return scratch_mask + ids_wp = wp.array(ids, dtype=wp.int32, device=device) + + wp.launch(kernel=_set_mask_from_ids, dim=ids_wp.shape[0], inputs=[scratch_mask, ids_wp], device=device) + return scratch_mask + + +@wp.func +def wrap_to_pi(angle: float) -> float: + """Wrap input angle (in radians) to the range [-pi, pi).""" + two_pi = 2.0 * wp.pi + wrapped_angle = angle + wp.pi + # NOTE: Use floor-based remainder semantics to match torch's `%` for negative inputs. + wrapped_angle = wrapped_angle - wp.floor(wrapped_angle / two_pi) * two_pi + return wp.where((wrapped_angle == 0) and (angle > 0), wp.pi, wrapped_angle - wp.pi) + + +class WarpCapturable: + """CUDA graph capture safety: decorator, annotation checker, and runtime guard. + + Decorator usage:: + + @WarpCapturable(False) + def reset_root_state_uniform(env, env_mask, ...): + ... + + @WarpCapturable(False, reason="calls write_root_pose_to_sim") + def push_by_setting_velocity(env, env_mask, ...): + ... + + - ``@WarpCapturable(True)`` or no decorator: capturable, returned unwrapped. + - ``@WarpCapturable(False)``: sets ``func._warp_capturable = False``, wraps with + runtime guard that raises if ``wp.get_device().is_capturing`` is ``True``. + """ + + def __init__(self, capturable: bool, *, reason: str | None = None): + self._capturable = capturable + self._reason = reason + + def __call__(self, func): + """Decorate *func* with capture safety annotation and optional runtime guard.""" + import functools + + func._warp_capturable = self._capturable + if self._capturable: + return func + + reason = self._reason + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if wp.get_device().is_capturing: + msg = f"'{func.__qualname__}' is marked @WarpCapturable(False) but called during CUDA graph capture." + if reason: + msg = f"{msg} {reason}" + raise RuntimeError(msg) + return func(*args, **kwargs) + + wrapper._warp_capturable = False + return wrapper + + @staticmethod + def is_capturable(func) -> bool: + """Check capturability annotation. Default: ``True``. + + Checks ``__wrapped__`` for decorated functions to handle stacked decorators. + """ + for f in (func, getattr(func, "__wrapped__", None)): + if f is not None: + val = getattr(f, "_warp_capturable", None) + if val is not None: + return val + return True + + +def resolve_asset_cfg(cfg: dict, env: ManagerBasedEnv) -> SceneEntityCfg: + asset_cfg = None + + for value in cfg.values(): + # If it exists, the SceneEntityCfg should have been resolved by the base manager. + if isinstance(value, SceneEntityCfg): + asset_cfg = value + # Check if the joint ids are not set, and if so, resolve them. + if asset_cfg.joint_ids is None or asset_cfg.joint_ids == slice(None): + asset_cfg.resolve_for_warp(env.scene) + if asset_cfg.body_ids is None or asset_cfg.body_ids == slice(None): + asset_cfg.resolve_for_warp(env.scene) + break + + # If it doesn't exist, use the default robot entity. + if asset_cfg is None: + asset_cfg = SceneEntityCfg("robot") + asset_cfg.resolve_for_warp(env.scene) + + return asset_cfg diff --git a/source/isaaclab_rl/isaaclab_rl/rl_games/rl_games.py b/source/isaaclab_rl/isaaclab_rl/rl_games/rl_games.py index 0fcc217977d..ccc16d29aba 100644 --- a/source/isaaclab_rl/isaaclab_rl/rl_games/rl_games.py +++ b/source/isaaclab_rl/isaaclab_rl/rl_games/rl_games.py @@ -34,6 +34,7 @@ # needed to import for allowing type-hinting:gym.spaces.Box | None from __future__ import annotations +import contextlib from collections.abc import Callable from typing import TYPE_CHECKING @@ -51,6 +52,9 @@ ManagerBasedRLEnv, ) + with contextlib.suppress(ImportError): + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + """ Vectorized environment wrapper. """ @@ -118,9 +122,22 @@ def __init__( # NOTE: import here (not at module level) to avoid loading heavy env classes before Isaac Sim is initialized. from isaaclab.envs import DirectMARLEnv, DirectRLEnv, ManagerBasedRLEnv - if not isinstance(env.unwrapped, (ManagerBasedRLEnv, DirectRLEnv, DirectMARLEnv)): + try: + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + except ImportError: + DirectRLEnvWarp = None + ManagerBasedRLEnvWarp = None + + allowed_types = (ManagerBasedRLEnv, DirectRLEnv, DirectMARLEnv) + if DirectRLEnvWarp is not None: + allowed_types += (DirectRLEnvWarp,) + if ManagerBasedRLEnvWarp is not None: + allowed_types += (ManagerBasedRLEnvWarp,) + + if not isinstance(env.unwrapped, allowed_types): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) # initialize the wrapper @@ -220,7 +237,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py b/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py index 7e753c0f7a2..5a3fdcd4716 100644 --- a/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py +++ b/source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause from __future__ import annotations +import contextlib from typing import TYPE_CHECKING import gymnasium as gym @@ -17,6 +18,9 @@ ManagerBasedRLEnv, ) + with contextlib.suppress(ImportError): + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + class RslRlVecEnvWrapper(VecEnv): """Wraps around Isaac Lab environment for the RSL-RL library @@ -48,17 +52,24 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, clip_actions: float | N from isaaclab.envs import DirectRLEnv, ManagerBasedEnv, ManagerBasedRLEnv try: - from isaaclab_experimental.envs import DirectRLEnvWarp + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedEnvWarp, ManagerBasedRLEnvWarp except ImportError: DirectRLEnvWarp = None + ManagerBasedEnvWarp = None + ManagerBasedRLEnvWarp = None allowed_types = (ManagerBasedRLEnv, ManagerBasedEnv, DirectRLEnv) if DirectRLEnvWarp is not None: allowed_types += (DirectRLEnvWarp,) + if ManagerBasedEnvWarp is not None: + allowed_types += (ManagerBasedEnvWarp,) + if ManagerBasedRLEnvWarp is not None: + allowed_types += (ManagerBasedRLEnvWarp,) if not isinstance(env.unwrapped, allowed_types): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) @@ -121,7 +132,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/sb3.py b/source/isaaclab_rl/isaaclab_rl/sb3.py index 1b50f8be9c7..90ca0fd3b85 100644 --- a/source/isaaclab_rl/isaaclab_rl/sb3.py +++ b/source/isaaclab_rl/isaaclab_rl/sb3.py @@ -18,6 +18,7 @@ # needed to import for allowing type-hinting: torch.Tensor | dict[str, torch.Tensor] from __future__ import annotations +import contextlib import warnings from typing import TYPE_CHECKING, Any @@ -32,6 +33,9 @@ if TYPE_CHECKING: from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv + with contextlib.suppress(ImportError): + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + # remove SB3 warnings because PPO with bigger net actually benefits from GPU warnings.filterwarnings("ignore", message="You are trying to run PPO on the GPU") @@ -150,9 +154,22 @@ def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv, fast_variant: bool = Tr # NOTE: import here (not at module level) to avoid loading heavy env classes before Isaac Sim is initialized. from isaaclab.envs import DirectRLEnv, ManagerBasedRLEnv - if not isinstance(env.unwrapped, (ManagerBasedRLEnv, DirectRLEnv)): + try: + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + except ImportError: + DirectRLEnvWarp = None + ManagerBasedRLEnvWarp = None + + allowed_types = (ManagerBasedRLEnv, DirectRLEnv) + if DirectRLEnvWarp is not None: + allowed_types += (DirectRLEnvWarp,) + if ManagerBasedRLEnvWarp is not None: + allowed_types += (ManagerBasedRLEnvWarp,) + + if not isinstance(env.unwrapped, allowed_types): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:" + "The environment must be inherited from ManagerBasedRLEnv / DirectRLEnv / DirectRLEnvWarp /" + " ManagerBasedRLEnvWarp. Environment type:" f" {type(env)}" ) # initialize the wrapper @@ -186,7 +203,7 @@ def class_name(cls) -> str: return cls.__name__ @property - def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv: + def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv | DirectRLEnvWarp | ManagerBasedRLEnvWarp: """Returns the base environment of the wrapper. This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers. diff --git a/source/isaaclab_rl/isaaclab_rl/skrl.py b/source/isaaclab_rl/isaaclab_rl/skrl.py index 6ae5eb28f9b..55802c2a591 100644 --- a/source/isaaclab_rl/isaaclab_rl/skrl.py +++ b/source/isaaclab_rl/isaaclab_rl/skrl.py @@ -70,10 +70,22 @@ def SkrlVecEnvWrapper( # NOTE: import here (not at module level) to avoid loading heavy env classes before Isaac Sim is initialized. from isaaclab.envs import DirectMARLEnv, DirectRLEnv, ManagerBasedRLEnv - if not isinstance(env.unwrapped, (ManagerBasedRLEnv, DirectRLEnv, DirectMARLEnv)): + try: + from isaaclab_experimental.envs import DirectRLEnvWarp, ManagerBasedRLEnvWarp + except ImportError: + DirectRLEnvWarp = None + ManagerBasedRLEnvWarp = None + + allowed_types = (ManagerBasedRLEnv, DirectRLEnv, DirectMARLEnv) + if DirectRLEnvWarp is not None: + allowed_types += (DirectRLEnvWarp,) + if ManagerBasedRLEnvWarp is not None: + allowed_types += (ManagerBasedRLEnvWarp,) + + if not isinstance(env.unwrapped, allowed_types): raise ValueError( - "The environment must be inherited from ManagerBasedRLEnv, DirectRLEnv or DirectMARLEnv. Environment type:" - f" {type(env)}" + "The environment must be inherited from ManagerBasedRLEnv, DirectRLEnv, DirectMARLEnv," + f" DirectRLEnvWarp or ManagerBasedRLEnvWarp. Environment type: {type(env)}" ) # import statements according to the ML framework From 5108ca85feb991c35119e8d03b35a4a23dea027e Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Mon, 9 Mar 2026 01:03:49 -0700 Subject: [PATCH 2/7] Add warp Cartpole task configuration Add an experimental manager-based Cartpole environment using the warp manager infrastructure as a reference task for testing and benchmarking. --- .../envs/manager_based_rl_env_warp.py | 23 +- .../envs/mdp/terminations.py | 12 +- .../utils/manager_call_switch.py | 136 +++++++----- .../utils/warp/__init__.py | 2 +- .../isaaclab_experimental/utils/warp/utils.py | 29 +-- .../manager_based/__init__.py | 10 + .../manager_based/classic/__init__.py | 6 + .../classic/cartpole/__init__.py | 29 +++ .../classic/cartpole/cartpole_env_cfg.py | 199 ++++++++++++++++++ .../classic/cartpole/mdp/__init__.py | 10 + .../classic/cartpole/mdp/rewards.py | 46 ++++ 11 files changed, 403 insertions(+), 99 deletions(-) create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 8589a47dbda..0f9f237610d 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -153,10 +153,10 @@ def max_episode_length(self) -> int: def load_managers(self): # note: this order is important since observation manager needs to know the command and action managers # and the reward manager needs to know the termination manager - # -- command manager - self.command_manager = self._manager_call_switch.resolve_manager_class("CommandManager")( - self.cfg.commands, self - ) + # -- command manager (stable impl — not routed through ManagerCallSwitch) + from isaaclab.managers import CommandManager + + self.command_manager = CommandManager(self.cfg.commands, self) print("[INFO] Command Manager: ", self.command_manager) # call the parent class to load the managers for observations and actions. @@ -240,7 +240,8 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: with Timer( name="action_preprocess", msg="Action preprocessing took:", enable=TIMER_ENABLED_STEP, time_unit="us" ): - assert self._action_in_wp is not None + if self._action_in_wp is None: + raise RuntimeError("Action buffer not initialized. Call reset() before step().") action_device = action.to(self.device) wp.copy(self._action_in_wp, wp.from_torch(action_device, dtype=wp.float32)) @@ -357,11 +358,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self.recorder_manager.record_post_reset(reset_env_ids) # -- update command - self._manager_call_switch.call_stage( - stage="CommandManager_compute", - warp_call={"fn": self.command_manager.compute, "kwargs": {"dt": float(self.step_dt)}}, - timer=TIMER_ENABLED_STEP, - ) + self.command_manager.compute(dt=float(self.step_dt)) # -- step interval events if "interval" in self.event_manager.available_modes: @@ -587,11 +584,7 @@ def _reset_idx( curriculum_info = self.curriculum_manager.reset(env_ids=env_ids) # -- command + event + termination managers - command_info = self._manager_call_switch.call_stage( - stage="CommandManager_reset", - warp_call={"fn": self.command_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, - ) + command_info = self.command_manager.reset(env_ids=env_ids) event_info = self._manager_call_switch.call_stage( stage="EventManager_reset", warp_call={"fn": self.event_manager.reset, "kwargs": {"env_mask": env_mask}}, diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py index f8ca65ba980..a6b0cea4375 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py @@ -72,8 +72,16 @@ def joint_pos_out_of_manual_limit( ) -> None: """Terminate when joint positions are outside configured bounds. Writes into ``out``.""" asset: Articulation = env.scene[asset_cfg.name] - assert asset_cfg.joint_mask is not None - assert asset.data.joint_pos.shape[1] == asset_cfg.joint_mask.shape[0] + if asset_cfg.joint_mask is None: + raise ValueError( + f"joint_pos_out_of_manual_limit requires SceneEntityCfg with resolved joint_mask, " + f"but got None for asset '{asset_cfg.name}'." + ) + if asset.data.joint_pos.shape[1] != asset_cfg.joint_mask.shape[0]: + raise ValueError( + f"joint_mask length ({asset_cfg.joint_mask.shape[0]}) does not match " + f"joint_pos dim ({asset.data.joint_pos.shape[1]}) for asset '{asset_cfg.name}'." + ) wp.launch( kernel=_joint_pos_out_of_manual_limit_kernel, dim=env.num_envs, diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py index c07cea7982b..2f37b1087ca 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py @@ -13,7 +13,7 @@ from enum import IntEnum from typing import Any -import warp as wp +from isaaclab_experimental.utils.warp_graph_cache import WarpGraphCache from isaaclab.utils.timer import Timer @@ -46,7 +46,6 @@ class ManagerCallSwitch: "ObservationManager", "EventManager", "RecorderManager", - "CommandManager", "TerminationManager", "RewardManager", "CurriculumManager", @@ -65,14 +64,28 @@ class ManagerCallSwitch: MANAGER_CALL_CONFIG='{"RewardManager": 0, "default": 2}' python train.py ... """ - def __init__(self): - self._wp_graphs: dict[str, Any] = {} - self._wp_results: dict[str, Any] = {} - self._cfg = self._load_cfg(os.environ.get(self.ENV_VAR)) + def __init__( + self, + cfg_source: dict | str | None = None, + *, + max_modes: dict[str, int] | None = None, + ): + self._graph_cache = WarpGraphCache() + # Merge caller-supplied max_modes with the class-level MAX_MODE_OVERRIDES. + self._max_modes = dict(self.MAX_MODE_OVERRIDES) + if max_modes is not None: + self._max_modes.update(max_modes) + # Resolve config: prefer explicit cfg_source, fall back to env var. + if cfg_source is None: + cfg_source = os.environ.get(self.ENV_VAR) + self._cfg = self._load_cfg(cfg_source) print("[INFO] ManagerCallSwitch configuration:") print(f" - {self.DEFAULT_KEY}: {self._cfg[self.DEFAULT_KEY]}") for manager_name in self.MANAGER_NAMES: - print(f" - {manager_name}: {int(self.get_mode_for_manager(manager_name))}") + mode = int(self.get_mode_for_manager(manager_name)) + cap = self._max_modes.get(manager_name) + cap_str = f" (cap={cap})" if cap is not None else "" + print(f" - {manager_name}: {mode}{cap_str}") # ------------------------------------------------------------------ # Graph management @@ -80,8 +93,7 @@ def __init__(self): def invalidate_graphs(self) -> None: """Invalidate cached capture graphs and their cached return values.""" - self._wp_graphs.clear() - self._wp_results.clear() + self._graph_cache.invalidate() # ------------------------------------------------------------------ # Stage dispatch @@ -151,23 +163,38 @@ def _manager_name_from_stage(self, stage: str) -> str: return stage.split("_", 1)[0] def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode: - """Return the resolved execution mode for the given manager.""" - default_key = next(iter(self.DEFAULT_CONFIG)) - mode_value = self._cfg.get(manager_name, self._cfg[default_key]) + """Return the resolved execution mode for the given manager. + + Looks up the manager in the config dict, falls back to the default, + then caps by :attr:`MAX_MODE_OVERRIDES`. + """ + mode_value = self._cfg.get(manager_name, self._cfg[self.DEFAULT_KEY]) + cap = self._max_modes.get(manager_name) + if cap is not None: + mode_value = min(mode_value, cap) return ManagerCallMode(mode_value) def resolve_manager_class(self, manager_name: str) -> type: """Import and return the manager class for the configured mode.""" - module_name = ( - "isaaclab.managers" - if self.get_mode_for_manager(manager_name) == ManagerCallMode.STABLE - else "isaaclab_experimental.managers" - ) + mode = self.get_mode_for_manager(manager_name) + module_name = "isaaclab.managers" if mode == ManagerCallMode.STABLE else "isaaclab_experimental.managers" module = importlib.import_module(module_name) if not hasattr(module, manager_name): raise AttributeError(f"Manager '{manager_name}' not found in module '{module_name}'.") return getattr(module, manager_name) + def register_manager_capturability(self, manager_name: str, capturable: bool) -> None: + """Register that a manager has non-capturable terms, capping its mode. + + Called by :class:`ManagerBase` during term preparation when a term + is decorated with ``@warp_capturable(False)``. + """ + if not capturable: + self._max_modes[manager_name] = min( + self._max_modes.get(manager_name, ManagerCallMode.WARP_CAPTURED), + ManagerCallMode.WARP_NOT_CAPTURED, + ) + # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @@ -179,50 +206,53 @@ def _run_call(self, call: dict[str, Any]) -> Any: def _wp_capture_or_launch(self, stage: str, call: dict[str, Any]) -> Any: """Capture Warp CUDA graph on first call, then replay. - The return value from the first (capture) run is cached and returned - on every subsequent replay. This ensures captured stages return the - same references (e.g. tensor views) as eager stages. + Delegates to :class:`WarpGraphCache` which caches the return value + and replays immediately after the first capture for validation. """ - graph = self._wp_graphs.get(stage) - if graph is None: - with wp.ScopedCapture() as capture: - result = call["fn"](*call.get("args", ()), **call.get("kwargs", {})) - self._wp_graphs[stage] = capture.graph - self._wp_results[stage] = result - wp.capture_launch(self._wp_graphs[stage]) - return self._wp_results[stage] - - def _load_cfg(self, cfg_source: str | None) -> dict[str, int]: - if cfg_source is not None and not isinstance(cfg_source, str): - raise TypeError(f"cfg_source must be a string or None, got: {type(cfg_source)}") - if cfg_source is None or cfg_source.strip() == "": - cfg = dict(self.DEFAULT_CONFIG) - else: - parsed = json.loads(cfg_source) - if not isinstance(parsed, dict): - raise TypeError("manager_call_config must decode to a dict.") + return self._graph_cache.capture_or_replay( + stage, + call["fn"], + args=call.get("args", ()), + kwargs=call.get("kwargs", {}), + ) - cfg = dict(parsed) + def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]: + if cfg_source is None: + cfg = dict(self.DEFAULT_CONFIG) + elif isinstance(cfg_source, dict): + cfg = dict(cfg_source) if self.DEFAULT_KEY not in cfg: cfg[self.DEFAULT_KEY] = self.DEFAULT_CONFIG[self.DEFAULT_KEY] - - # validation - for manager_name, mode_value in cfg.items(): - if not isinstance(mode_value, int): - raise TypeError( - f"manager_call_config value for '{manager_name}' must be int (0/1/2), got: {type(mode_value)}" - ) - try: - ManagerCallMode(mode_value) - except ValueError as exc: - raise ValueError( - f"Invalid manager_call_config value for '{manager_name}': {mode_value}. Expected 0/1/2." - ) from exc + elif isinstance(cfg_source, str): + if cfg_source.strip() == "": + cfg = dict(self.DEFAULT_CONFIG) + else: + parsed = json.loads(cfg_source) + if not isinstance(parsed, dict): + raise TypeError("manager_call_config must decode to a dict.") + cfg = dict(parsed) + if self.DEFAULT_KEY not in cfg: + cfg[self.DEFAULT_KEY] = self.DEFAULT_CONFIG[self.DEFAULT_KEY] + else: + raise TypeError(f"cfg_source must be a dict, string, or None, got: {type(cfg_source)}") + + # Validation + for manager_name, mode_value in cfg.items(): + if not isinstance(mode_value, int): + raise TypeError( + f"manager_call_config value for '{manager_name}' must be int (0/1/2), got: {type(mode_value)}" + ) + try: + ManagerCallMode(mode_value) + except ValueError as exc: + raise ValueError( + f"Invalid manager_call_config value for '{manager_name}': {mode_value}. Expected 0/1/2." + ) from exc # Apply MAX_MODE_OVERRIDES: bake caps into the resolved config so # get_mode_for_manager never needs per-call branching. default_mode = cfg[self.DEFAULT_KEY] - for name, max_mode in self.MAX_MODE_OVERRIDES.items(): + for name, max_mode in self._max_modes.items(): resolved = cfg.get(name, default_mode) if resolved > max_mode: cfg[name] = max_mode diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py index 30f0fb17d6d..a7e71ac4688 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py @@ -6,4 +6,4 @@ """Warp utility functions and shared kernels for isaaclab_experimental.""" from .kernels import compute_reset_scale, count_masked -from .utils import WarpCapturable, resolve_1d_mask, resolve_asset_cfg, wrap_to_pi +from .utils import WarpCapturable, resolve_1d_mask, wrap_to_pi diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py index 0d3a85f8020..c7a2c63d959 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py @@ -6,16 +6,11 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING import torch -from isaaclab_experimental.managers.scene_entity_cfg import SceneEntityCfg import warp as wp -if TYPE_CHECKING: - from isaaclab.envs import ManagerBasedEnv - @wp.kernel def _set_mask_from_ids( @@ -104,7 +99,7 @@ def resolve_1d_mask( @wp.func def wrap_to_pi(angle: float) -> float: - """Wrap input angle (in radians) to the range [-pi, pi).""" + """Wrap input angle (in radians) to the range [-pi, pi].""" two_pi = 2.0 * wp.pi wrapped_angle = angle + wp.pi # NOTE: Use floor-based remainder semantics to match torch's `%` for negative inputs. @@ -168,25 +163,3 @@ def is_capturable(func) -> bool: if val is not None: return val return True - - -def resolve_asset_cfg(cfg: dict, env: ManagerBasedEnv) -> SceneEntityCfg: - asset_cfg = None - - for value in cfg.values(): - # If it exists, the SceneEntityCfg should have been resolved by the base manager. - if isinstance(value, SceneEntityCfg): - asset_cfg = value - # Check if the joint ids are not set, and if so, resolve them. - if asset_cfg.joint_ids is None or asset_cfg.joint_ids == slice(None): - asset_cfg.resolve_for_warp(env.scene) - if asset_cfg.body_ids is None or asset_cfg.body_ids == slice(None): - asset_cfg.resolve_for_warp(env.scene) - break - - # If it doesn't exist, use the default robot entity. - if asset_cfg is None: - asset_cfg = SceneEntityCfg("robot") - asset_cfg.resolve_for_warp(env.scene) - - return asset_cfg diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py new file mode 100644 index 00000000000..7f23883e633 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Experimental registrations for manager-based tasks. + +We intentionally only register new Gym IDs pointing at experimental entry points. +Task definitions (configs/mdp) remain in `isaaclab_tasks` to avoid duplication. +""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py new file mode 100644 index 00000000000..4781f141af4 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Classic experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py new file mode 100644 index 00000000000..17a4c5c03cd --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Cartpole balancing environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +gym.register( + id="Isaac-Cartpole-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + # Use experimental Cartpole cfg (allows isolated modifications). + "env_cfg_entry_point": ( + "isaaclab_tasks_experimental.manager_based.classic.cartpole.cartpole_env_cfg:CartpoleEnvCfg" + ), + # Point agent configs to the existing task package. + "rl_games_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": ( + "isaaclab_tasks.manager_based.classic.cartpole.agents.rsl_rl_ppo_cfg:CartpolePPORunnerCfg" + ), + "skrl_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py new file mode 100644 index 00000000000..898ac8be4fd --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/cartpole_env_cfg.py @@ -0,0 +1,199 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.classic.cartpole.mdp as mdp + +## +# Pre-defined configs +## +from isaaclab_assets.robots.cartpole import CARTPOLE_CFG # isort:skip + + +## +# Scene definition +## + + +@configclass +class CartpoleSceneCfg(InteractiveSceneCfg): + """Configuration for a cart-pole scene.""" + + # ground plane + # ground = AssetBaseCfg( + # prim_path="/World/ground", + # spawn=sim_utils.GroundPlaneCfg(size=(100.0, 100.0)), + # ) + + # cartpole + robot: ArticulationCfg = CARTPOLE_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + dome_light = AssetBaseCfg( + prim_path="/World/DomeLight", + spawn=sim_utils.DomeLightCfg(color=(0.9, 0.9, 0.9), intensity=500.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg(asset_name="robot", joint_names=["slider_to_cart"], scale=100.0) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + joint_pos_rel = ObsTerm(func=mdp.joint_pos_rel) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel) + + def __post_init__(self) -> None: + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + # reset + reset_cart_position = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"]), + "position_range": (-1.0, 1.0), + "velocity_range": (-0.5, 0.5), + }, + ) + + reset_pole_position = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"]), + "position_range": (-0.25 * math.pi, 0.25 * math.pi), + "velocity_range": (-0.25 * math.pi, 0.25 * math.pi), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Constant running reward + alive = RewTerm(func=mdp.is_alive, weight=1.0) + # (2) Failure penalty + terminating = RewTerm(func=mdp.is_terminated, weight=-2.0) + # (3) Primary task: keep pole upright + pole_pos = RewTerm( + func=mdp.joint_pos_target_l2, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"]), "target": 0.0}, + ) + # (4) Shaping tasks: lower cart velocity + cart_vel = RewTerm( + func=mdp.joint_vel_l1, + weight=-0.01, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"])}, + ) + # (5) Shaping tasks: lower pole angular velocity + pole_vel = RewTerm( + func=mdp.joint_vel_l1, + weight=-0.005, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["cart_to_pole"])}, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Time out + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Cart out of bounds + cart_out_of_bounds = DoneTerm( + func=mdp.joint_pos_out_of_manual_limit, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["slider_to_cart"]), "bounds": (-3.0, 3.0)}, + ) + + +## +# Environment configuration +## + + +@configclass +class CartpoleEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the cartpole environment.""" + + # Scene settings + scene: CartpoleSceneCfg = CartpoleSceneCfg(num_envs=4096, env_spacing=4.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + events: EventCfg = EventCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + # Simulation settings + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=5, + nconmax=3, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + use_cuda_graph=True, + ) + ) + + # Post initialization + def __post_init__(self) -> None: + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 5 + # viewer settings + self.viewer.eye = (8.0, 0.0, 5.0) + # simulation settings + self.sim.dt = 1 / 120 + self.sim.render_interval = self.decimation diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py new file mode 100644 index 00000000000..73b1cf4fb2c --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the cartpole environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py new file mode 100644 index 00000000000..fbb426751fa --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/mdp/rewards.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp.utils import wrap_to_pi + +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _joint_pos_target_l2_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), + target: float, +): + i = wp.tid() + s = float(0.0) + for j in range(joint_pos.shape[1]): + if joint_mask[j]: + a = wrap_to_pi(joint_pos[i, j]) + d = a - target + s += d * d + out[i] = s + + +def joint_pos_target_l2(env: ManagerBasedRLEnv, out, target: float, asset_cfg: SceneEntityCfg) -> None: + """Penalize joint position deviation from a target value. Writes into ``out``.""" + asset: Articulation = env.scene[asset_cfg.name] + assert asset.data.joint_pos.shape[1] == asset_cfg.joint_mask.shape[0] + wp.launch( + kernel=_joint_pos_target_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset_cfg.joint_mask, out, target], + device=env.device, + ) From b281975d2549885f3cb5134ca99f241a49b2f558 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 12 Mar 2026 00:29:17 -0700 Subject: [PATCH 3/7] Add warp MDP terms and manager infrastructure updates Add warp-first MDP terms (observations, rewards, events, terminations, actions) for manager-based envs. Update manager infrastructure with ManagerCallSwitch max_mode cap, Scene capture config, body_ids_wp resolution, and capture safety annotations. Add newton state kernels for body-frame computations used by MDP terms. --- .../reinforcement_learning/rsl_rl/train.py | 18 + .../envs/manager_based_env_warp.py | 16 +- .../envs/manager_based_rl_env_warp.py | 67 ++- .../envs/mdp/actions/__init__.py | 6 +- .../envs/mdp/actions/actions_cfg.py | 41 +- .../envs/mdp/actions/joint_actions.py | 35 +- .../isaaclab_experimental/envs/mdp/events.py | 370 +++++++++++++++- .../envs/mdp/observations.py | 293 ++++++++++++- .../isaaclab_experimental/envs/mdp/rewards.py | 399 +++++++++++++++++- .../envs/mdp/terminations.py | 97 ++++- .../envs/utils/io_descriptors.py | 34 +- .../managers/__init__.py | 4 +- .../managers/action_manager.py | 18 +- .../managers/event_manager.py | 16 +- .../managers/manager_base.py | 8 + .../managers/observation_manager.py | 95 ++++- .../managers/scene_entity_cfg.py | 38 +- .../utils/manager_call_switch.py | 60 ++- .../utils/warp/__init__.py | 9 +- .../isaaclab_experimental/utils/warp/utils.py | 44 +- .../isaaclab_newton/kernels/state_kernels.py | 30 ++ 21 files changed, 1514 insertions(+), 184 deletions(-) create mode 100644 source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py diff --git a/scripts/reinforcement_learning/rsl_rl/train.py b/scripts/reinforcement_learning/rsl_rl/train.py index 9c4c95da230..087cdbde0d7 100644 --- a/scripts/reinforcement_learning/rsl_rl/train.py +++ b/scripts/reinforcement_learning/rsl_rl/train.py @@ -65,6 +65,24 @@ parser.add_argument( "--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None." ) +parser.add_argument( + "--step-timer", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable granular timer sections in environment step().", +) +parser.add_argument( + "--reset-timer", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable granular timer sections in environment reset().", +) +parser.add_argument( + "--manager_call_config", + type=str, + default=None, + help='Manager mode JSON only: \'{"RewardManager": 0, "ActionManager": 2, "default": 2}\'.', +) cli_args.add_rsl_rl_args(parser) add_launcher_args(parser) args_cli, hydra_args = parser.parse_known_args() diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py index 7f747205e96..8c558b92594 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_env_warp.py @@ -76,7 +76,10 @@ def __init__(self, cfg: ManagerBasedEnvCfg): self.cfg = cfg # initialize internal variables self._is_closed = False - self._manager_call_switch = ManagerCallSwitch() + # temporary debug runtime config for manager source/call switching. + cfg_source: dict | str | None = getattr(self.cfg, "manager_call_config", None) + max_modes: dict[str, int] | None = getattr(self.cfg, "manager_call_max_mode", None) + self._manager_call_switch = ManagerCallSwitch(cfg_source, max_modes=max_modes) self._apply_manager_term_cfg_profile() # set the seed for the environment @@ -265,6 +268,17 @@ def device(self): """The device on which the environment is running.""" return self.sim.device + @property + def env_origins_wp(self) -> wp.array: + """Scene env origins as a warp ``vec3f`` array. Cached on first access.""" + if not hasattr(self, "_env_origins_wp"): + origins = self.scene.env_origins + if isinstance(origins, wp.array): + self._env_origins_wp = origins + else: + self._env_origins_wp = wp.from_torch(origins, dtype=wp.vec3f) + return self._env_origins_wp + def resolve_env_mask( self, *, diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py index 0f9f237610d..69569d375dd 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/manager_based_rl_env_warp.py @@ -25,6 +25,7 @@ from isaaclab.envs.common import VecEnvStepReturn from isaaclab.envs.manager_based_rl_env_cfg import ManagerBasedRLEnvCfg +from isaaclab.managers import CommandManager from isaaclab.ui.widgets import ManagerLiveVisualizer from isaaclab.utils.timer import Timer @@ -33,14 +34,14 @@ from .manager_based_env_warp import ManagerBasedEnvWarp -DEBUG_TIMER_STEP = os.environ.get("DEBUG_TIMER_STEP", "0") == "1" -"""Enable outer step() timer. Set DEBUG_TIMER_STEP=1 env var to enable.""" - DEBUG_TIMERS = os.environ.get("DEBUG_TIMERS", "0") == "1" -"""Enable all fine-grained inner timers. Set DEBUG_TIMERS=1 env var to enable.""" +"""Enable outer step() timer. Set DEBUG_TIMERS=1 env var to enable.""" + +DEBUG_TIMER_STEP = os.environ.get("DEBUG_TIMER_STEP", "0") == "1" +"""Enable step sub-phase timers. Set DEBUG_TIMER_STEP=1 env var to enable.""" -TIMER_ENABLED_STEP = DEBUG_TIMER_STEP or DEBUG_TIMERS -TIMER_ENABLED_RESET_IDX = DEBUG_TIMERS +DEBUG_TIMER_RESET = os.environ.get("DEBUG_TIMER_RESET", "0") == "1" +"""Enable reset sub-phase timers. Set DEBUG_TIMER_RESET=1 env var to enable.""" class ManagerBasedRLEnvWarp(ManagerBasedEnvWarp, gym.Env): @@ -154,8 +155,6 @@ def load_managers(self): # note: this order is important since observation manager needs to know the command and action managers # and the reward manager needs to know the termination manager # -- command manager (stable impl — not routed through ManagerCallSwitch) - from isaaclab.managers import CommandManager - self.command_manager = CommandManager(self.cfg.commands, self) print("[INFO] Command Manager: ", self.command_manager) @@ -213,7 +212,7 @@ def step_warp_termination_compute(self) -> None: self.reset_terminated = self.termination_manager.terminated self.reset_time_outs = self.termination_manager.time_outs - @Timer(name="env_step", msg="Step took:", enable=TIMER_ENABLED_STEP, time_unit="us") + @Timer(name="env_step", msg="Step took:", enable=DEBUG_TIMER_STEP, time_unit="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics and reset terminated environments. @@ -237,9 +236,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: # NOTE: keep a persistent action input buffer for graph pointer stability. # IMPORTANT: Do NOT re-wrap/replace the `wp.array` used by captured graphs each step. # Instead, copy the latest actions into the persistent buffer. - with Timer( - name="action_preprocess", msg="Action preprocessing took:", enable=TIMER_ENABLED_STEP, time_unit="us" - ): + with Timer(name="action_preprocess", msg="Action preprocessing took:", enable=DEBUG_TIMER_STEP, time_unit="us"): if self._action_in_wp is None: raise RuntimeError("Action buffer not initialized. Call reset() before step().") action_device = action.to(self.device) @@ -248,7 +245,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self._manager_call_switch.call_stage( stage="ActionManager_process_action", warp_call={"fn": self.action_manager.process_action, "kwargs": {"action": self._action_in_wp}}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) self.recorder_manager.record_pre_step() @@ -266,16 +263,16 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self._manager_call_switch.call_stage( stage="ActionManager_apply_action", warp_call={"fn": self.action_manager.apply_action}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) self._manager_call_switch.call_stage( stage="Scene_write_data_to_sim", warp_call={"fn": self.scene.write_data_to_sim}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) # simulate - with Timer(name="simulate", msg="Newton simulation step took:", enable=TIMER_ENABLED_STEP, time_unit="us"): + with Timer(name="simulate", msg="Newton simulation step took:", enable=DEBUG_TIMER_STEP, time_unit="us"): self.sim.step(render=False) self.recorder_manager.record_post_physics_decimation_step() # render between steps only if the GUI or an RTX sensor needs it @@ -287,7 +284,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: with Timer( name="scene.update", msg="Scene.update took:", - enable=TIMER_ENABLED_STEP, + enable=DEBUG_TIMER_STEP, time_unit="us", ): self.scene.update(dt=self.physics_dt) @@ -301,12 +298,12 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self._manager_call_switch.call_stage( stage="TerminationManager_compute", warp_call={"fn": self.step_warp_termination_compute}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) self.reward_buf = self._manager_call_switch.call_stage( stage="RewardManager_compute", warp_call={"fn": self.reward_manager.compute, "kwargs": {"dt": float(self.step_dt)}}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) if len(self.recorder_manager.active_terms) > 0: @@ -314,7 +311,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self._manager_call_switch.call_stage( stage="ObservationManager_compute_no_history", warp_call={"fn": self.observation_manager.compute, "kwargs": {"return_cloned_output": False}}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) self.recorder_manager.record_post_step() @@ -325,7 +322,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: with Timer( name="reset_selection", msg="Reset selection took:", - enable=TIMER_ENABLED_STEP, + enable=DEBUG_TIMER_STEP, time_unit="us", ): # Keep the reset-mask handoff fully in Warp when experimental termination buffers exist. @@ -344,7 +341,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: with Timer( name="reset_idx", msg="Reset idx took:", - enable=TIMER_ENABLED_STEP, + enable=DEBUG_TIMER_STEP, time_unit="us", ): self._reset_idx(env_ids=reset_env_ids, env_mask=self.reset_mask_wp) @@ -365,7 +362,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: self._manager_call_switch.call_stage( stage="EventManager_apply_interval", warp_call={"fn": self.event_manager.apply, "kwargs": {"mode": "interval", "dt": float(self.step_dt)}}, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) # -- compute observations @@ -377,7 +374,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: "kwargs": {"update_history": True, "return_cloned_output": False}, "output": lambda r: clone_obs_buffer(r), }, - timer=TIMER_ENABLED_STEP, + timer=DEBUG_TIMER_STEP, ) # return observations, rewards, resets and extras return self.obs_buf, self.reward_buf, self.reset_terminated, self.reset_time_outs, self.extras @@ -527,7 +524,7 @@ def _reset_idx( with Timer( name="curriculum_manager.compute_reset", msg="CurriculumManager.compute (reset) took:", - enable=TIMER_ENABLED_RESET_IDX, + enable=DEBUG_TIMER_RESET, time_unit="us", ): self.curriculum_manager.compute(env_ids=env_ids) @@ -535,8 +532,8 @@ def _reset_idx( # reset the internal buffers of the scene elements self._manager_call_switch.call_stage( stage="Scene_reset", - warp_call={"fn": self.scene.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + warp_call={"fn": self.scene.reset, "kwargs": {"env_ids": env_ids, "env_mask": env_mask}}, + timer=DEBUG_TIMER_RESET, ) if "reset" in self.event_manager.available_modes: @@ -551,7 +548,7 @@ def _reset_idx( "global_env_step_count": self._global_env_step_count_wp, }, }, - timer=TIMER_ENABLED_RESET_IDX, + timer=DEBUG_TIMER_RESET, ) # iterate over all managers and reset them @@ -561,24 +558,24 @@ def _reset_idx( obs_info = self._manager_call_switch.call_stage( stage="ObservationManager_reset", warp_call={"fn": self.observation_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + timer=DEBUG_TIMER_RESET, ) action_info = self._manager_call_switch.call_stage( stage="ActionManager_reset", warp_call={"fn": self.action_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + timer=DEBUG_TIMER_RESET, ) reward_info = self._manager_call_switch.call_stage( stage="RewardManager_reset", warp_call={"fn": self.reward_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + timer=DEBUG_TIMER_RESET, ) # -- curriculum manager with Timer( name="curriculum_manager.reset", msg="CurriculumManager.reset took:", - enable=TIMER_ENABLED_RESET_IDX, + enable=DEBUG_TIMER_RESET, time_unit="us", ): curriculum_info = self.curriculum_manager.reset(env_ids=env_ids) @@ -588,12 +585,14 @@ def _reset_idx( event_info = self._manager_call_switch.call_stage( stage="EventManager_reset", warp_call={"fn": self.event_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + stable_call={"fn": self.event_manager.reset, "kwargs": {"env_ids": env_ids}}, + timer=DEBUG_TIMER_RESET, ) termination_info = self._manager_call_switch.call_stage( stage="TerminationManager_reset", warp_call={"fn": self.termination_manager.reset, "kwargs": {"env_mask": env_mask}}, - timer=TIMER_ENABLED_RESET_IDX, + stable_call={"fn": self.termination_manager.reset, "kwargs": {"env_ids": env_ids}}, + timer=DEBUG_TIMER_RESET, ) # -- recorder manager diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py index 283805a279f..d295384149d 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/__init__.py @@ -3,10 +3,10 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Experimental action terms (minimal). +"""Experimental action terms (Warp-first). -Only the action configs/terms currently required by the experimental manager-based Cartpole task -are provided here. +Provides Warp-first action term implementations overriding the stable +:mod:`isaaclab.envs.mdp.actions` module. """ from .actions_cfg import * # noqa: F401, F403 diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py index d8826602dbe..39d5b29c6fd 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/actions_cfg.py @@ -3,12 +3,6 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Action term configuration (experimental, minimal). - -This module mirrors the stable :mod:`isaaclab.envs.mdp.actions.actions_cfg` but only keeps what -the experimental Cartpole task needs. -""" - from dataclasses import MISSING from isaaclab.utils import configclass @@ -17,26 +11,51 @@ from . import joint_actions +## +# Joint actions. +## + @configclass class JointActionCfg(ActionTermCfg): - """Configuration for the base joint action term.""" + """Configuration for the base joint action term. + + See :class:`JointAction` for more details. + """ joint_names: list[str] = MISSING """List of joint names or regex expressions that the action will be mapped to.""" - scale: float | dict[str, float] = 1.0 """Scale factor for the action (float or dict of regex expressions). Defaults to 1.0.""" - offset: float | dict[str, float] = 0.0 """Offset factor for the action (float or dict of regex expressions). Defaults to 0.0.""" - preserve_order: bool = False """Whether to preserve the order of the joint names in the action output. Defaults to False.""" +@configclass +class JointPositionActionCfg(JointActionCfg): + """Configuration for the joint position action term. + + See :class:`JointPositionAction` for more details. + """ + + class_type: type[ActionTerm] = joint_actions.JointPositionAction + + use_default_offset: bool = True + """Whether to use default joint positions configured in the articulation asset as offset. + Defaults to True. + + If True, this flag results in overwriting the values of :attr:`offset` to the default joint positions + from the articulation asset. + """ + + @configclass class JointEffortActionCfg(JointActionCfg): - """Configuration for the joint effort action term.""" + """Configuration for the joint effort action term. + + See :class:`JointEffortAction` for more details. + """ class_type: type[ActionTerm] = joint_actions.JointEffortAction diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py index 183cb2cfe49..441bde86b4d 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/actions/joint_actions.py @@ -15,7 +15,7 @@ from isaaclab.assets.articulation import Articulation from isaaclab_experimental.managers.action_manager import ActionTerm -from isaaclab_experimental.utils.warp import resolve_1d_mask +from isaaclab_experimental.utils.warp import resolve_1d_mask, zero_masked_2d if TYPE_CHECKING: from isaaclab.envs import ManagerBasedEnv @@ -56,13 +56,6 @@ def _process_joint_actions_kernel( processed_out[env_id, j] = x -@wp.kernel -def _zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): - env_id, j = wp.tid() - if mask[env_id]: - values[env_id, j] = 0.0 - - class JointAction(ActionTerm): r"""Base class for joint actions. @@ -259,13 +252,37 @@ def reset(self, env_mask: wp.array | None = None) -> None: self._raw_actions.fill_(0.0) return wp.launch( - kernel=_zero_masked_2d, + kernel=zero_masked_2d, dim=(self.num_envs, self.action_dim), inputs=[env_mask, self._raw_actions], device=self.device, ) +class JointPositionAction(JointAction): + """Joint action term that applies the processed actions to the articulation's joints as position commands. + + Warp-first override of :class:`isaaclab.envs.mdp.actions.JointPositionAction`. + """ + + cfg: actions_cfg.JointPositionActionCfg + """The configuration of the action term.""" + + def __init__(self, cfg: actions_cfg.JointPositionActionCfg, env: ManagerBasedEnv): + super().__init__(cfg, env) + # use default joint positions as offset + if cfg.use_default_offset: + defaults_np = self._asset.data.default_joint_pos.numpy() + if isinstance(self._joint_ids, slice): + offset_vals = defaults_np[0, :].tolist() + else: + offset_vals = [float(defaults_np[0, jid]) for jid in self._joint_ids] + self._offset = wp.array(offset_vals, dtype=wp.float32, device=self.device) + + def apply_actions(self): + self._asset.set_joint_position_target_index(target=self.processed_actions, joint_ids=self._joint_ids_wp) + + class JointEffortAction(JointAction): """Joint action term that applies the processed actions to the articulation's joints as effort commands.""" diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py index 053da8db722..69145c17394 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/events.py @@ -12,7 +12,7 @@ - Stable event terms (e.g. `isaaclab.envs.mdp.events.reset_joints_by_offset`) often build torch tensors and then call into Newton articulation writers with partial indices (env_ids/joint_ids). - On the Newton backend, passing torch tensors triggers expensive torch->warp conversions that currently allocate - full `(num_envs, num_joints)` buffers (see `isaaclab.utils.warp.utils.make_complete_data_from_torch_dual_index`). + full `(num_envs, num_joints)` buffers. These Warp-first implementations avoid that by writing directly into the sim-bound Warp state buffers (`asset.data.joint_pos` / `asset.data.joint_vel`) for the selected envs/joints. @@ -29,6 +29,366 @@ from isaaclab.assets import Articulation from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.utils.warp import WarpCapturable + +# --------------------------------------------------------------------------- +# Randomize rigid body center of mass +# --------------------------------------------------------------------------- + + +@wp.kernel +def _randomize_com_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + body_com_pos_b: wp.array(dtype=wp.vec3f, ndim=2), + body_ids: wp.array(dtype=wp.int32), + com_lo: wp.vec3f, + com_hi: wp.vec3f, +): + """Add random offset to center of mass positions for selected bodies.""" + env_id = wp.tid() + if not env_mask[env_id]: + return + + state = rng_state[env_id] + for k in range(body_ids.shape[0]): + b = body_ids[k] + v = body_com_pos_b[env_id, b] + dx = wp.randf(state, com_lo[0], com_hi[0]) + dy = wp.randf(state, com_lo[1], com_hi[1]) + dz = wp.randf(state, com_lo[2], com_hi[2]) + body_com_pos_b[env_id, b] = wp.vec3f(v[0] + dx, v[1] + dy, v[2] + dz) + rng_state[env_id] = state + + +@WarpCapturable(False, reason="set_coms_mask calls SimulationManager.add_model_change") +def randomize_rigid_body_com( + env, + env_mask: wp.array, + com_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Randomize the center of mass (CoM) of rigid bodies by adding random offsets. + + Warp-first override of :func:`isaaclab.envs.mdp.events.randomize_rigid_body_com`. + Writes directly into the sim-bound ``body_com_pos_b`` buffer, then notifies the solver + via :meth:`set_coms_mask` so it recomputes inertial properties. + """ + asset: Articulation = env.scene[asset_cfg.name] + + fn = randomize_rigid_body_com + if not hasattr(fn, "_com_lo") or fn._asset_name != asset_cfg.name: + fn._asset_name = asset_cfg.name + r = [com_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z"]] + fn._com_lo = wp.vec3f(r[0][0], r[1][0], r[2][0]) + fn._com_hi = wp.vec3f(r[0][1], r[1][1], r[2][1]) + + wp.launch( + kernel=_randomize_com_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.body_com_pos_b, + asset_cfg.body_ids_wp, + fn._com_lo, + fn._com_hi, + ], + device=env.device, + ) + + # Notify the solver that inertial properties changed (COM position affects inertia). + asset.set_coms_mask(coms=asset.data.body_com_pos_b, env_mask=env_mask) + + +# --------------------------------------------------------------------------- +# Apply external force and torque +# --------------------------------------------------------------------------- + + +@wp.kernel +def _apply_external_force_torque_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + force_out: wp.array(dtype=wp.vec3f, ndim=2), + torque_out: wp.array(dtype=wp.vec3f, ndim=2), + force_lo: float, + force_hi: float, + torque_lo: float, + torque_hi: float, +): + env_id = wp.tid() + if not env_mask[env_id]: + # zero out unmasked envs so they don't accumulate stale forces + for b in range(force_out.shape[1]): + force_out[env_id, b] = wp.vec3f(0.0, 0.0, 0.0) + torque_out[env_id, b] = wp.vec3f(0.0, 0.0, 0.0) + return + + state = rng_state[env_id] + for b in range(force_out.shape[1]): + force_out[env_id, b] = wp.vec3f( + wp.randf(state, force_lo, force_hi), + wp.randf(state, force_lo, force_hi), + wp.randf(state, force_lo, force_hi), + ) + torque_out[env_id, b] = wp.vec3f( + wp.randf(state, torque_lo, torque_hi), + wp.randf(state, torque_lo, torque_hi), + wp.randf(state, torque_lo, torque_hi), + ) + rng_state[env_id] = state + + +def apply_external_force_torque( + env, + env_mask: wp.array, + force_range: tuple[float, float], + torque_range: tuple[float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Randomize external forces and torques applied to the asset's bodies. + + Warp-first override of :func:`isaaclab.envs.mdp.events.apply_external_force_torque`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-convert constant arguments. + if not hasattr(apply_external_force_torque, "_scratch_forces"): + apply_external_force_torque._scratch_forces = wp.zeros( + (env.num_envs, asset.num_bodies), dtype=wp.vec3f, device=env.device + ) + apply_external_force_torque._scratch_torques = wp.zeros( + (env.num_envs, asset.num_bodies), dtype=wp.vec3f, device=env.device + ) + + wp.launch( + kernel=_apply_external_force_torque_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + apply_external_force_torque._scratch_forces, + apply_external_force_torque._scratch_torques, + force_range[0], + force_range[1], + torque_range[0], + torque_range[1], + ], + device=env.device, + ) + + asset.permanent_wrench_composer.set_forces_and_torques_mask( + forces=apply_external_force_torque._scratch_forces, + torques=apply_external_force_torque._scratch_torques, + env_mask=env_mask, + ) + + +# --------------------------------------------------------------------------- +# Push by velocity +# --------------------------------------------------------------------------- + + +@wp.kernel +def _push_by_setting_velocity_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + vel_out: wp.array(dtype=wp.spatial_vectorf), + lin_lo: wp.vec3f, + lin_hi: wp.vec3f, + ang_lo: wp.vec3f, + ang_hi: wp.vec3f, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + vel = root_vel_w[env_id] + state = rng_state[env_id] + + vel_out[env_id] = wp.spatial_vectorf( + vel[0] + wp.randf(state, lin_lo[0], lin_hi[0]), + vel[1] + wp.randf(state, lin_lo[1], lin_hi[1]), + vel[2] + wp.randf(state, lin_lo[2], lin_hi[2]), + vel[3] + wp.randf(state, ang_lo[0], ang_hi[0]), + vel[4] + wp.randf(state, ang_lo[1], ang_hi[1]), + vel[5] + wp.randf(state, ang_lo[2], ang_hi[2]), + ) + + rng_state[env_id] = state + + +def push_by_setting_velocity( + env, + env_mask: wp.array, + velocity_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Push the asset by setting the root velocity to a random value within the given ranges. + + Warp-first override of :func:`isaaclab.envs.mdp.events.push_by_setting_velocity`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-parse constant range arguments. + if not hasattr(push_by_setting_velocity, "_scratch_vel"): + push_by_setting_velocity._scratch_vel = wp.zeros((env.num_envs,), dtype=wp.spatial_vectorf, device=env.device) + r = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + push_by_setting_velocity._lin_lo = wp.vec3f(r[0][0], r[1][0], r[2][0]) + push_by_setting_velocity._lin_hi = wp.vec3f(r[0][1], r[1][1], r[2][1]) + push_by_setting_velocity._ang_lo = wp.vec3f(r[3][0], r[4][0], r[5][0]) + push_by_setting_velocity._ang_hi = wp.vec3f(r[3][1], r[4][1], r[5][1]) + + wp.launch( + kernel=_push_by_setting_velocity_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.root_vel_w, + push_by_setting_velocity._scratch_vel, + push_by_setting_velocity._lin_lo, + push_by_setting_velocity._lin_hi, + push_by_setting_velocity._ang_lo, + push_by_setting_velocity._ang_hi, + ], + device=env.device, + ) + + asset.write_root_velocity_to_sim_mask(root_velocity=push_by_setting_velocity._scratch_vel, env_mask=env_mask) + + +# --------------------------------------------------------------------------- +# Reset root state uniform +# --------------------------------------------------------------------------- + + +@wp.kernel +def _reset_root_state_uniform_kernel( + env_mask: wp.array(dtype=wp.bool), + rng_state: wp.array(dtype=wp.uint32), + default_root_pose: wp.array(dtype=wp.transformf), + default_root_vel: wp.array(dtype=wp.spatial_vectorf), + env_origins: wp.array(dtype=wp.vec3f), + pose_out: wp.array(dtype=wp.transformf), + vel_out: wp.array(dtype=wp.spatial_vectorf), + pos_lo: wp.vec3f, + pos_hi: wp.vec3f, + rot_lo: wp.vec3f, + rot_hi: wp.vec3f, + vel_lin_lo: wp.vec3f, + vel_lin_hi: wp.vec3f, + vel_ang_lo: wp.vec3f, + vel_ang_hi: wp.vec3f, +): + env_id = wp.tid() + if not env_mask[env_id]: + return + + state = rng_state[env_id] + + # --- Pose --- + default_pose = default_root_pose[env_id] + default_pos = wp.transform_get_translation(default_pose) + default_q = wp.transform_get_rotation(default_pose) + origin = env_origins[env_id] + + # position = default + env_origin + random offset + pos = wp.vec3f( + default_pos[0] + origin[0] + wp.randf(state, pos_lo[0], pos_hi[0]), + default_pos[1] + origin[1] + wp.randf(state, pos_lo[1], pos_hi[1]), + default_pos[2] + origin[2] + wp.randf(state, pos_lo[2], pos_hi[2]), + ) + + # orientation = default * delta(euler_xyz) + roll = wp.randf(state, rot_lo[0], rot_hi[0]) + pitch = wp.randf(state, rot_lo[1], rot_hi[1]) + yaw = wp.randf(state, rot_lo[2], rot_hi[2]) + qx = wp.quat_from_axis_angle(wp.vec3f(1.0, 0.0, 0.0), roll) + qy = wp.quat_from_axis_angle(wp.vec3f(0.0, 1.0, 0.0), pitch) + qz = wp.quat_from_axis_angle(wp.vec3f(0.0, 0.0, 1.0), yaw) + # ZYX extrinsic = XYZ intrinsic: delta = qz * qy * qx + delta_q = wp.mul(wp.mul(qz, qy), qx) + final_q = wp.mul(default_q, delta_q) + + pose_out[env_id] = wp.transformf(pos, final_q) + + # --- Velocity --- + default_vel = default_root_vel[env_id] + vel_out[env_id] = wp.spatial_vectorf( + default_vel[0] + wp.randf(state, vel_lin_lo[0], vel_lin_hi[0]), + default_vel[1] + wp.randf(state, vel_lin_lo[1], vel_lin_hi[1]), + default_vel[2] + wp.randf(state, vel_lin_lo[2], vel_lin_hi[2]), + default_vel[3] + wp.randf(state, vel_ang_lo[0], vel_ang_hi[0]), + default_vel[4] + wp.randf(state, vel_ang_lo[1], vel_ang_hi[1]), + default_vel[5] + wp.randf(state, vel_ang_lo[2], vel_ang_hi[2]), + ) + + rng_state[env_id] = state + + +def reset_root_state_uniform( + env, + env_mask: wp.array, + pose_range: dict[str, tuple[float, float]], + velocity_range: dict[str, tuple[float, float]], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +): + """Reset the asset root state to a random position and velocity uniformly within the given ranges. + + Warp-first override of :func:`isaaclab.envs.mdp.events.reset_root_state_uniform`. + """ + asset: Articulation = env.scene[asset_cfg.name] + + # First-call: allocate scratch and pre-parse range dicts. + if not hasattr(reset_root_state_uniform, "_scratch_pose"): + reset_root_state_uniform._scratch_pose = wp.zeros((env.num_envs,), dtype=wp.transformf, device=env.device) + reset_root_state_uniform._scratch_vel = wp.zeros((env.num_envs,), dtype=wp.spatial_vectorf, device=env.device) + # Pre-parse pose_range dict + p = [pose_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + reset_root_state_uniform._pos_lo = wp.vec3f(p[0][0], p[1][0], p[2][0]) + reset_root_state_uniform._pos_hi = wp.vec3f(p[0][1], p[1][1], p[2][1]) + reset_root_state_uniform._rot_lo = wp.vec3f(p[3][0], p[4][0], p[5][0]) + reset_root_state_uniform._rot_hi = wp.vec3f(p[3][1], p[4][1], p[5][1]) + # Pre-parse velocity_range dict + v = [velocity_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z", "roll", "pitch", "yaw"]] + reset_root_state_uniform._vel_lin_lo = wp.vec3f(v[0][0], v[1][0], v[2][0]) + reset_root_state_uniform._vel_lin_hi = wp.vec3f(v[0][1], v[1][1], v[2][1]) + reset_root_state_uniform._vel_ang_lo = wp.vec3f(v[3][0], v[4][0], v[5][0]) + reset_root_state_uniform._vel_ang_hi = wp.vec3f(v[3][1], v[4][1], v[5][1]) + + wp.launch( + kernel=_reset_root_state_uniform_kernel, + dim=env.num_envs, + inputs=[ + env_mask, + env.rng_state_wp, + asset.data.default_root_pose, + asset.data.default_root_vel, + env.env_origins_wp, + reset_root_state_uniform._scratch_pose, + reset_root_state_uniform._scratch_vel, + reset_root_state_uniform._pos_lo, + reset_root_state_uniform._pos_hi, + reset_root_state_uniform._rot_lo, + reset_root_state_uniform._rot_hi, + reset_root_state_uniform._vel_lin_lo, + reset_root_state_uniform._vel_lin_hi, + reset_root_state_uniform._vel_ang_lo, + reset_root_state_uniform._vel_ang_hi, + ], + device=env.device, + ) + + asset.write_root_pose_to_sim_mask(root_pose=reset_root_state_uniform._scratch_pose, env_mask=env_mask) + asset.write_root_velocity_to_sim_mask(root_velocity=reset_root_state_uniform._scratch_vel, env_mask=env_mask) + + +# --------------------------------------------------------------------------- +# Reset joints by offset +# --------------------------------------------------------------------------- @wp.kernel @@ -124,6 +484,10 @@ def reset_joints_by_offset( device=env.device, ) + # Sync derived buffers (_previous_joint_vel, joint_acc) for reset envs. + asset.write_joint_position_to_sim_mask(position=asset.data.joint_pos, env_mask=env_mask) + asset.write_joint_velocity_to_sim_mask(velocity=asset.data.joint_vel, env_mask=env_mask) + @wp.kernel def _reset_joints_by_scale_kernel( @@ -210,3 +574,7 @@ def reset_joints_by_scale( ], device=env.device, ) + + # Sync derived buffers (_previous_joint_vel, joint_acc) for reset envs. + asset.write_joint_position_to_sim_mask(position=asset.data.joint_pos, env_mask=env_mask) + asset.write_joint_velocity_to_sim_mask(velocity=asset.data.joint_vel, env_mask=env_mask) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py index ae6b1998588..f23f8d71473 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/observations.py @@ -3,30 +3,39 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Warp-first observation terms (experimental, Cartpole-focused). +"""Warp-first observation terms (experimental). All functions in this file follow the Warp-compatible observation signature expected by the experimental Warp-first observation manager: - ``func(env, out, **params) -> None`` -where ``out`` is a pre-allocated Warp array with float32 dtype and shape ``(num_envs, term_dim)``. +where ``out`` is a pre-allocated Warp array with float32 dtype and shape ``(num_envs, D)``. +Output dimension ``D`` is inferred from decorator metadata: ``axes`` for root-state terms, +``out_dim`` for body/command/action/time terms, or ``joint_ids`` count for joint terms. """ from __future__ import annotations from typing import TYPE_CHECKING +import torch import warp as wp +from isaaclab_newton.kernels.state_kernels import ( + body_ang_vel_from_root, + body_lin_vel_from_root, + rotate_vec_to_body_frame, +) from isaaclab.assets import Articulation from isaaclab_experimental.envs.utils.io_descriptors import ( generic_io_descriptor_warp, + record_dtype, record_joint_names, record_joint_pos_offsets, - record_joint_shape, record_joint_vel_offsets, + record_shape, ) from isaaclab_experimental.managers import SceneEntityCfg @@ -34,21 +43,191 @@ from isaaclab.envs import ManagerBasedEnv +# --------------------------------------------------------------------------- +# Shared kernels +# --------------------------------------------------------------------------- + + @wp.kernel -def _joint_pos_rel_gather_kernel( - joint_pos: wp.array(dtype=wp.float32, ndim=2), - default_joint_pos: wp.array(dtype=wp.float32, ndim=2), +def _vec3_to_out3_kernel( + src: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id = wp.tid() + v = src[env_id] + out[env_id, 0] = v[0] + out[env_id, 1] = v[1] + out[env_id, 2] = v[2] + + +@wp.kernel +def _joint_gather_kernel( + src: wp.array(dtype=wp.float32, ndim=2), + joint_ids: wp.array(dtype=wp.int32), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id, k = wp.tid() + j = joint_ids[k] + out[env_id, k] = src[env_id, j] + + +""" +Root state. +""" + + +@wp.kernel +def _base_pos_z_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + env_id = wp.tid() + out[env_id, 0] = root_pos_w[env_id][2] + + +@generic_io_descriptor_warp( + units="m", axes=["Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_pos_z(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root height in the simulation world frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_pos_z_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, out], + device=env.device, + ) + + +# Inline Tier 1 access: these observations derive body-frame quantities directly from +# root_link_pose_w (transformf) and root_com_vel_w (spatial_vectorf), avoiding the lazy +# TimestampedWarpBuffer properties which are not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe in the future, these can +# revert to reading the pre-computed .data buffers (simpler, avoids redundant rotations). + + +@wp.kernel +def _base_lin_vel_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + v = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i]) + out[i, 0] = v[0] + out[i, 1] = v[1] + out[i, 2] = v[2] + + +@generic_io_descriptor_warp( + units="m/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_lin_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root linear velocity in the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_lin_vel_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], + device=env.device, + ) + + +@wp.kernel +def _base_ang_vel_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + v = body_ang_vel_from_root(root_pose_w[i], root_vel_w[i]) + out[i, 0] = v[0] + out[i, 1] = v[1] + out[i, 2] = v[2] + + +@generic_io_descriptor_warp( + units="rad/s", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def base_ang_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Root angular velocity in the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_ang_vel_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], + device=env.device, + ) + + +@wp.kernel +def _projected_gravity_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + i = wp.tid() + g = rotate_vec_to_body_frame(gravity_w[0], root_pose_w[i]) + out[i, 0] = g[0] + out[i, 1] = g[1] + out[i, 2] = g[2] + + +@generic_io_descriptor_warp( + units="m/s^2", axes=["X", "Y", "Z"], observation_type="RootState", on_inspect=[record_shape, record_dtype] +) +def projected_gravity(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Gravity projection on the asset's root frame.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_projected_gravity_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], + device=env.device, + ) + + +""" +Joint state. +""" + + +@generic_io_descriptor_warp( + observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape], units="rad" +) +def joint_pos(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint positions of the asset.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_pos, joint_ids_wp, out], + device=env.device, + ) + + +@wp.kernel +def _joint_rel_gather_kernel( + values: wp.array(dtype=wp.float32, ndim=2), + defaults: wp.array(dtype=wp.float32, ndim=2), joint_ids: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.float32, ndim=2), ): env_id, k = wp.tid() j = joint_ids[k] - out[env_id, k] = joint_pos[env_id, j] - default_joint_pos[env_id, j] + out[env_id, k] = values[env_id, j] - defaults[env_id, j] @generic_io_descriptor_warp( observation_type="JointState", - on_inspect=[record_joint_names, record_joint_shape, record_joint_pos_offsets], + on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_pos_offsets], units="rad", ) def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: @@ -63,7 +242,7 @@ def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn "Pass `asset_cfg` via term cfg params so it is resolved at manager init." ) wp.launch( - kernel=_joint_pos_rel_gather_kernel, + kernel=_joint_rel_gather_kernel, dim=(env.num_envs, out.shape[1]), inputs=[asset.data.joint_pos, asset.data.default_joint_pos, joint_ids_wp, out], device=env.device, @@ -71,20 +250,62 @@ def joint_pos_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn @wp.kernel -def _joint_vel_rel_gather_kernel( - joint_vel: wp.array(dtype=wp.float32, ndim=2), - default_joint_vel: wp.array(dtype=wp.float32, ndim=2), +def _joint_pos_limit_normalized_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), joint_ids: wp.array(dtype=wp.int32), out: wp.array(dtype=wp.float32, ndim=2), ): env_id, k = wp.tid() j = joint_ids[k] - out[env_id, k] = joint_vel[env_id, j] - default_joint_vel[env_id, j] + pos = joint_pos[env_id, j] + lim = soft_joint_pos_limits[env_id, j] + lower = lim.x + upper = lim.y + out[env_id, k] = 2.0 * (pos - (lower + upper) * 0.5) / (upper - lower) + + +@generic_io_descriptor_warp(observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape]) +def joint_pos_limit_normalized(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint positions of the asset normalized with the asset's joint limits.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_pos_limit_normalized_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_pos, asset.data.soft_joint_pos_limits, joint_ids_wp, out], + device=env.device, + ) + + +@generic_io_descriptor_warp( + observation_type="JointState", on_inspect=[record_joint_names, record_dtype, record_shape], units="rad/s" +) +def joint_vel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """The joint velocities of the asset.""" + asset: Articulation = env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is None: + raise RuntimeError( + "SceneEntityCfg.joint_ids_wp is required for subset joint observations in Warp-first observations. " + "Pass `asset_cfg` via term cfg params so it is resolved at manager init." + ) + wp.launch( + kernel=_joint_gather_kernel, + dim=(env.num_envs, out.shape[1]), + inputs=[asset.data.joint_vel, joint_ids_wp, out], + device=env.device, + ) @generic_io_descriptor_warp( observation_type="JointState", - on_inspect=[record_joint_names, record_joint_shape, record_joint_vel_offsets], + on_inspect=[record_joint_names, record_dtype, record_shape, record_joint_vel_offsets], units="rad/s", ) def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: @@ -99,8 +320,50 @@ def joint_vel_rel(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEn "Pass `asset_cfg` via term cfg params so it is resolved at manager init." ) wp.launch( - kernel=_joint_vel_rel_gather_kernel, + kernel=_joint_rel_gather_kernel, dim=(env.num_envs, out.shape[1]), inputs=[asset.data.joint_vel, asset.data.default_joint_vel, joint_ids_wp, out], device=env.device, ) + + +""" +Actions. +""" + + +@generic_io_descriptor_warp(out_dim="action", dtype=torch.float32, observation_type="Action", on_inspect=[record_shape]) +def last_action(env: ManagerBasedEnv, out, action_name: str | None = None) -> None: + """The last input action to the environment.""" + # TODO(warp-migration): Cross-manager access (observation → action). Currently works + # because experimental ActionManager.action is already a warp array. No from_torch needed. + if action_name is not None: + raise NotImplementedError("Named action support is not yet implemented for Warp-first last_action observation.") + wp.copy(out, env.action_manager.action) + + +""" +Commands. +""" + + +@generic_io_descriptor_warp( + out_dim="command", dtype=torch.float32, observation_type="Command", on_inspect=[record_shape] +) +def generated_commands(env: ManagerBasedEnv, out, command_name: str) -> None: + """The generated command from the command manager. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.observations.generated_commands`. + Uses ``wp.from_torch`` to create a zero-copy warp view of the command tensor on first call. + """ + # TODO(warp-migration): Cross-manager access (observation → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + fn = generated_commands + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + fn._cmd_wp = cmd + else: + fn._cmd_wp = wp.from_torch(cmd) + fn._cmd_name = command_name + wp.copy(out, fn._cmd_wp) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py index ef34627eb54..424e7f28541 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/rewards.py @@ -5,9 +5,6 @@ """Common functions that can be used to enable reward functions (experimental). -This module is intentionally minimal: it only contains reward terms that are currently -used by the experimental manager-based Cartpole task. - All functions in this file follow the Warp-compatible reward signature expected by `isaaclab_experimental.managers.RewardManager`: @@ -21,6 +18,11 @@ from typing import TYPE_CHECKING import warp as wp +from isaaclab_newton.kernels.state_kernels import ( + body_ang_vel_from_root, + body_lin_vel_from_root, + rotate_vec_to_body_frame, +) from isaaclab.assets import Articulation @@ -67,11 +69,116 @@ def is_terminated(env: ManagerBasedRLEnv, out) -> None: ) +""" +Root penalties. +""" + + +# Inline Tier 1 access: these rewards derive body-frame quantities directly from +# root_link_pose_w (transformf) and root_com_vel_w (spatial_vectorf), avoiding the lazy +# TimestampedWarpBuffer properties which are not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe in the future, these can +# revert to reading the pre-computed .data buffers (simpler, avoids redundant rotations). + + +@wp.kernel +def _lin_vel_z_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + vz = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i])[2] + out[i] = vz * vz + + +def lin_vel_z_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize z-axis base linear velocity using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_lin_vel_z_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], + device=env.device, + ) + + +@wp.kernel +def _ang_vel_xy_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + v = body_ang_vel_from_root(root_pose_w[i], root_vel_w[i]) + out[i] = v[0] * v[0] + v[1] * v[1] + + +def ang_vel_xy_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize xy-axis base angular velocity using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_ang_vel_xy_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.root_com_vel_w, out], + device=env.device, + ) + + +@wp.kernel +def _flat_orientation_l2_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + g = rotate_vec_to_body_frame(gravity_w[0], root_pose_w[i]) + out[i] = g[0] * g[0] + g[1] * g[1] + + +def flat_orientation_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize non-flat base orientation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_flat_orientation_l2_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], + device=env.device, + ) + + """ Joint penalties. """ +# TODO(warp-migration): Revisit whether 2D kernel + wp.atomic_add is faster than 1D with inner loop +# for the following masked reduction kernels. Profile with typical joint counts (12-30). +@wp.kernel +def _sum_sq_masked_kernel( + x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) +): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + if joint_mask[j]: + s += x[i, j] * x[i, j] + out[i] = s + + +def joint_torques_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint torques applied on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.applied_torque, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. @wp.kernel def _sum_abs_masked_kernel( x: wp.array(dtype=wp.float32, ndim=2), joint_mask: wp.array(dtype=wp.bool), out: wp.array(dtype=wp.float32) @@ -93,3 +200,289 @@ def joint_vel_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg) -> None inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], device=env.device, ) + + +def joint_vel_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint velocities on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_vel, asset_cfg.joint_mask, out], + device=env.device, + ) + + +def joint_acc_l2(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint accelerations on the articulation using L2 squared kernel.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_sq_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_acc, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_abs_diff_masked_kernel( + a: wp.array(dtype=wp.float32, ndim=2), + b: wp.array(dtype=wp.float32, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(a.shape[1]): + if joint_mask[j]: + s += wp.abs(a[i, j] - b[i, j]) + out[i] = s + + +def joint_deviation_l1(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint positions that deviate from the default one.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_sum_abs_diff_masked_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.default_joint_pos, asset_cfg.joint_mask, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _joint_pos_limits_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_joint_pos_limits: wp.array(dtype=wp.vec2f, ndim=2), + joint_mask: wp.array(dtype=wp.bool), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(joint_pos.shape[1]): + if joint_mask[j]: + pos = joint_pos[i, j] + lim = soft_joint_pos_limits[i, j] + lower = lim.x + upper = lim.y + # penalty for exceeding lower limit + below = lower - pos + if below > 0.0: + s += below + # penalty for exceeding upper limit + above = pos - upper + if above > 0.0: + s += above + out[i] = s + + +def joint_pos_limits(env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize joint positions if they cross the soft limits.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_joint_pos_limits_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.soft_joint_pos_limits, asset_cfg.joint_mask, out], + device=env.device, + ) + + +""" +Action penalties. +""" + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_sq_diff_2d_kernel( + a: wp.array(dtype=wp.float32, ndim=2), + b: wp.array(dtype=wp.float32, ndim=2), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for j in range(a.shape[1]): + d = a[i, j] - b[i, j] + s += d * d + out[i] = s + + +def action_rate_l2(env: ManagerBasedRLEnv, out) -> None: + """Penalize the rate of change of the actions using L2 squared kernel.""" + wp.launch( + kernel=_sum_sq_diff_2d_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, env.action_manager.prev_action, out], + device=env.device, + ) + + +# TODO(warp-migration): Revisit 2D kernel + wp.atomic_add vs 1D inner loop. +@wp.kernel +def _sum_sq_2d_kernel(x: wp.array(dtype=wp.float32, ndim=2), out: wp.array(dtype=wp.float32)): + i = wp.tid() + s = float(0.0) + for j in range(x.shape[1]): + s += x[i, j] * x[i, j] + out[i] = s + + +def action_l2(env: ManagerBasedRLEnv, out) -> None: + """Penalize the actions using L2 squared kernel.""" + wp.launch( + kernel=_sum_sq_2d_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, out], + device=env.device, + ) + + +""" +Contact sensor. +""" + + +@wp.kernel +def _undesired_contacts_kernel( + forces: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + threshold: float, + out: wp.array(dtype=wp.float32), +): + """Count bodies where max-over-history contact force norm exceeds threshold.""" + i = wp.tid() + count = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + max_force = float(0.0) + for h in range(forces.shape[1]): + f = forces[i, h, b] + norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + if norm > max_force: + max_force = norm + if max_force > threshold: + count += 1.0 + out[i] = count + + +def undesired_contacts(env: ManagerBasedRLEnv, out, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Penalize undesired contacts as the number of violations above a threshold. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.undesired_contacts`. + """ + contact_sensor = env.scene.sensors[sensor_cfg.name] + wp.launch( + kernel=_undesired_contacts_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.net_forces_w_history, sensor_cfg.body_ids_wp, threshold, out], + device=env.device, + ) + + +""" +Velocity-tracking rewards. +""" + + +@wp.kernel +def _track_lin_vel_xy_exp_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + command: wp.array(dtype=wp.float32, ndim=2), + std_sq_inv: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + v = body_lin_vel_from_root(root_pose_w[i], root_vel_w[i]) + dx = command[i, 0] - v[0] + dy = command[i, 1] - v[1] + error = dx * dx + dy * dy + out[i] = wp.exp(-error * std_sq_inv) + + +def track_lin_vel_xy_exp( + env: ManagerBasedRLEnv, + out, + std: float, + command_name: str, + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward tracking of linear velocity commands (xy axes) using exponential kernel. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.track_lin_vel_xy_exp`. + """ + asset: Articulation = env.scene[asset_cfg.name] + # cache the warp view of the command tensor on first call (zero-copy) + # TODO(warp-migration): Cross-manager access (reward → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + if not hasattr(track_lin_vel_xy_exp, "_cmd_wp") or track_lin_vel_xy_exp._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + track_lin_vel_xy_exp._cmd_wp = cmd + else: + track_lin_vel_xy_exp._cmd_wp = wp.from_torch(cmd) + track_lin_vel_xy_exp._cmd_name = command_name + wp.launch( + kernel=_track_lin_vel_xy_exp_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_link_pose_w, + asset.data.root_com_vel_w, + track_lin_vel_xy_exp._cmd_wp, + 1.0 / (std * std), + out, + ], + device=env.device, + ) + + +@wp.kernel +def _track_ang_vel_z_exp_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + root_vel_w: wp.array(dtype=wp.spatial_vectorf), + command: wp.array(dtype=wp.float32, ndim=2), + cmd_col: int, + std_sq_inv: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + dz = command[i, cmd_col] - body_ang_vel_from_root(root_pose_w[i], root_vel_w[i])[2] + out[i] = wp.exp(-dz * dz * std_sq_inv) + + +def track_ang_vel_z_exp( + env: ManagerBasedRLEnv, + out, + std: float, + command_name: str, + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward tracking of angular velocity commands (yaw) using exponential kernel. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.rewards.track_ang_vel_z_exp`. + """ + asset: Articulation = env.scene[asset_cfg.name] + # TODO(warp-migration): Cross-manager access (reward → command). Replace with direct + # warp getter once all managers are guaranteed to be warp-native. + if not hasattr(track_ang_vel_z_exp, "_cmd_wp") or track_ang_vel_z_exp._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + if isinstance(cmd, wp.array): + track_ang_vel_z_exp._cmd_wp = cmd + else: + track_ang_vel_z_exp._cmd_wp = wp.from_torch(cmd) + track_ang_vel_z_exp._cmd_name = command_name + wp.launch( + kernel=_track_ang_vel_z_exp_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_link_pose_w, + asset.data.root_com_vel_w, + track_ang_vel_z_exp._cmd_wp, + 2, + 1.0 / (std * std), + out, + ], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py index a6b0cea4375..9ea143e4e8f 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/mdp/terminations.py @@ -5,9 +5,6 @@ """Common functions that can be used to activate terminations (experimental). -This module is intentionally minimal: it only contains termination terms that are currently -used by the experimental manager-based Cartpole task. - All functions in this file follow the Warp-compatible termination signature expected by `isaaclab_experimental.managers.TerminationManager`: @@ -30,6 +27,11 @@ from isaaclab.envs import ManagerBasedRLEnv +""" +MDP terminations. +""" + + @wp.kernel def _time_out_kernel( episode_length: wp.array(dtype=wp.int64), max_episode_length: wp.int64, out: wp.array(dtype=wp.bool) @@ -48,6 +50,39 @@ def time_out(env: ManagerBasedRLEnv, out) -> None: ) +""" +Root terminations. +""" + + +@wp.kernel +def _root_height_below_min_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + minimum_height: float, + out: wp.array(dtype=wp.bool), +): + i = wp.tid() + out[i] = root_pos_w[i][2] < minimum_height + + +def root_height_below_minimum( + env: ManagerBasedRLEnv, out, minimum_height: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Terminate when the asset's root height is below the minimum height.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_root_height_below_min_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, minimum_height, out], + device=env.device, + ) + + +""" +Joint terminations. +""" + + @wp.kernel def _joint_pos_out_of_manual_limit_kernel( joint_pos: wp.array(dtype=wp.float32, ndim=2), @@ -56,15 +91,12 @@ def _joint_pos_out_of_manual_limit_kernel( upper: float, out: wp.array(dtype=wp.bool), ): - i = wp.tid() - violated = bool(False) - for j in range(joint_pos.shape[1]): - if joint_mask[j]: - v = joint_pos[i, j] - if v < lower or v > upper: - violated = True - break - out[i] = violated + """2D kernel (num_envs, num_joints). ``out`` is pre-zeroed; only writes True.""" + i, j = wp.tid() + if joint_mask[j]: + v = joint_pos[i, j] + if v < lower or v > upper: + out[i] = True def joint_pos_out_of_manual_limit( @@ -84,7 +116,46 @@ def joint_pos_out_of_manual_limit( ) wp.launch( kernel=_joint_pos_out_of_manual_limit_kernel, - dim=env.num_envs, + dim=(env.num_envs, asset.data.joint_pos.shape[1]), inputs=[asset.data.joint_pos, asset_cfg.joint_mask, bounds[0], bounds[1], out], device=env.device, ) + + +""" +Contact sensor. +""" + + +@wp.kernel +def _illegal_contact_kernel( + forces: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + threshold: float, + out: wp.array(dtype=wp.bool), +): + """Terminate when any selected body's max-over-history contact force norm exceeds threshold.""" + i = wp.tid() + violated = bool(False) + for k in range(body_ids.shape[0]): + b = body_ids[k] + for h in range(forces.shape[1]): + f = forces[i, h, b] + norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + if norm > threshold: + violated = True + out[i] = violated + + +def illegal_contact(env: ManagerBasedRLEnv, out, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Terminate when the contact force on the sensor exceeds the force threshold. Writes into ``out``. + + Warp-first override of :func:`isaaclab.envs.mdp.terminations.illegal_contact`. + """ + contact_sensor = env.scene.sensors[sensor_cfg.name] + wp.launch( + kernel=_illegal_contact_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.net_forces_w_history, sensor_cfg.body_ids_wp, threshold, out], + device=env.device, + ) diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py index 47cbad37063..19b6530b0b0 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/utils/io_descriptors.py @@ -59,10 +59,12 @@ def _make_descriptor(**kwargs: Any) -> GenericObservationIODescriptor: desc = GenericObservationIODescriptor(**known) # User defined extras are stored in the descriptor under the `extras` field desc.extras = extras + # ``out_dim`` is kept as a top-level attribute (not in extras) so the + # observation manager can read it without inspecting extras. + desc.out_dim = extras.pop("out_dim", None) return desc -# TODO(jichuanh): The exact usage is unclear and this need revisit # Decorator factory for Warp-first IO descriptors. def generic_io_descriptor_warp( _func: Callable[Concatenate[ManagerBasedEnv, P], R] | None = None, @@ -183,18 +185,38 @@ def wrapper(env: ManagerBasedEnv, *args: P.args, **kwargs: P.kwargs) -> R: def record_shape(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: """Record the shape of the output buffer. - No-op when ``output`` is ``None`` (the typical case during Warp-first - inspection). Use a type-specific hook such as :func:`record_joint_shape` - to derive shape from config instead. + When ``output`` is not ``None`` (eager path), shape is read directly. + When ``output`` is ``None`` (Warp-first inspection), shape is derived from: + - ``descriptor.extras["axes"]`` for RootState observations, or + - ``asset_cfg.joint_ids`` for JointState observations. + + BodyState shape cannot be derived without calling the function (the per-body + feature size varies). In that case shape is left unset. Args: output: The pre-allocated output buffer, or ``None`` during inspection. descriptor: The descriptor to record the shape to. **kwargs: Additional keyword arguments. """ - if output is None: + if output is not None: + descriptor.shape = (output.shape[-1],) + return + # --- Warp-first fallback: derive shape without output --- + # 1) From axes metadata (RootState) + axes = descriptor.extras.get("axes") if descriptor.extras else None + if axes: + descriptor.shape = (len(axes),) return - descriptor.shape = (output.shape[-1],) + # 2) From asset_cfg for JointState + if descriptor.observation_type == "JointState": + asset_cfg = kwargs.get("asset_cfg") + if asset_cfg is not None: + asset: Articulation = kwargs["env"].scene[asset_cfg.name] + joint_ids = asset_cfg.joint_ids + if joint_ids == slice(None): + descriptor.shape = (len(asset.joint_names),) + else: + descriptor.shape = (len(joint_ids),) def record_dtype(output: wp.array | None, descriptor: GenericObservationIODescriptor, **kwargs) -> None: diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py index 62d3171d32a..b4521b98434 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/__init__.py @@ -11,10 +11,12 @@ from isaaclab.managers import * # noqa: F401,F403 -# Override the stable implementation with the experimental fork. from .action_manager import ActionManager # noqa: F401 from .command_manager import CommandManager # noqa: F401 from .event_manager import EventManager # noqa: F401 + +# Override the stable implementation with the experimental fork. +from .manager_base import ManagerTermBase # noqa: F401 from .manager_term_cfg import ObservationTermCfg, RewardTermCfg, TerminationTermCfg # noqa: F401 from .observation_manager import ObservationManager # noqa: F401 from .reward_manager import RewardManager # noqa: F401 diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py index 78ef70a99b1..672af82bc8a 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/action_manager.py @@ -27,22 +27,8 @@ if TYPE_CHECKING: from isaaclab.envs import ManagerBasedEnv - -@wp.kernel -def _zero_masked_2d( - # input - mask: wp.array(dtype=wp.bool), - # input/output - data: wp.array(dtype=wp.float32, ndim=2), -): - """Zero rows of a 2D buffer where ``mask`` is True. - - Launched with dim = (num_envs, data.shape[1]). - """ - - env_id, j = wp.tid() - if mask[env_id]: - data[env_id, j] = 0.0 +# Shared kernel – imported from utils to avoid duplication. +from isaaclab_experimental.utils.warp.utils import zero_masked_2d as _zero_masked_2d class ActionTerm(ManagerTermBase): diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py index 6adc9f055d5..3b36613a2ac 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/event_manager.py @@ -311,21 +311,21 @@ def apply( self._apply_interval(float(dt)) return + # resolve the environment mask + if env_mask_wp is None: + if wp.get_device().is_capturing: + raise ValueError(f"Event mode '{mode}' requires the environment mask to be provided when capturing.") + env_mask_wp = self._env.resolve_env_mask(env_ids=env_ids) + if mode == "reset": if global_env_step_count is None: raise ValueError(f"Event mode '{mode}' requires the total number of environment steps to be provided.") - if env_mask_wp is None: - if wp.get_device().is_capturing: - raise ValueError( - f"Event mode '{mode}' requires the environment mask to be provided when capturing." - ) - env_mask_wp = self._env.resolve_env_mask(env_ids=env_ids) self._apply_reset(env_mask_wp, global_env_step_count) return - # other modes keep the stable convention (env_ids forwarded) + # other modes (startup, prestartup, custom) — env_mask forwarded for term_cfg in self._mode_term_cfgs[mode]: - term_cfg.func(self._env, env_ids, **term_cfg.params) + term_cfg.func(self._env, env_mask_wp, **term_cfg.params) def _apply_interval(self, dt: float) -> None: if self._env.rng_state_wp is None: diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py index e4abb1fa2c7..d40920b23fe 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/manager_base.py @@ -28,6 +28,8 @@ import isaaclab.utils.string as string_utils from isaaclab.utils import class_to_dict, string_to_callable +from isaaclab_experimental.utils.warp import is_warp_capturable + from .manager_term_cfg import ManagerTermBaseCfg from .scene_entity_cfg import SceneEntityCfg @@ -401,6 +403,12 @@ def _resolve_common_term_cfg(self, term_name: str, term_cfg: ManagerTermBaseCfg, f" and optional parameters: {args_with_defaults}, but received: {term_params}." ) + # register non-capturable terms with the call switch for mode=2 fallback + if not is_warp_capturable(term_cfg.func): + switch = getattr(self._env, "_manager_call_switch", None) + if switch is not None: + switch.register_manager_capturability(type(self).__name__, False) + # process attributes at runtime # these properties are only resolvable once the simulation starts playing if self._env.sim.is_playing(): diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py index f73f30e5a9b..07287eb03fb 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/observation_manager.py @@ -526,6 +526,13 @@ def compute_group(self, group_name: str, update_history: bool = False) -> torch. # TODO(jichuanh): This is not migrated yet. Need revisit. # Update the history buffer if observation term has history enabled if term_cfg.history_length > 0: + # circular buffer is not capture safe + if wp.get_device().is_capturing: + raise RuntimeError( + "Observation terms with history (circular buffer) are not CUDA-graph-capture-safe yet. " + "Disable history for observation terms used inside a captured graph, or restructure " + "the graph to exclude history-buffered terms." + ) circular_buffer = self._group_obs_term_history_buffer[group_name][term_name] if update_history: circular_buffer.append(wp.to_torch(term_cfg.out_wp)) @@ -703,9 +710,10 @@ def _prepare_terms(self): # noqa: C901 f" but received: {len(term_cfg.scale)}." ) - # cast the scale into torch tensor - term_cfg.scale = torch.tensor(term_cfg.scale, dtype=torch.float, device=self._env.device) - term_cfg.scale_wp = wp.from_torch(term_cfg.scale, dtype=wp.float32) + scale_vals = ( + term_cfg.scale if isinstance(term_cfg.scale, tuple) else [float(term_cfg.scale)] * obs_dims[1] + ) + term_cfg.scale_wp = wp.array(scale_vals, dtype=wp.float32, device=self._env.device) # prepare modifiers for each observation if term_cfg.modifiers is not None: @@ -842,16 +850,43 @@ def _prepare_terms(self): # noqa: C901 self._group_obs_term_history_buffer[group_name] = group_entry_history_buffer def _infer_term_dim_scalar(self, term_cfg: ObservationTermCfg) -> int: - """Infer (D,) using scalar scene info (no term execution).""" - # allow explicit override + """Infer observation output dimension (D,) using decorator metadata, scene info, or manager state. + + Resolution order: + 1. ``out_dim`` on the function's ``@generic_io_descriptor_warp`` decorator. + 2. ``axes`` on the decorator (e.g. ``axes=["X","Y","Z"]`` → dim 3). + 3. Explicit ``term_dim`` / ``out_dim`` / ``obs_dim`` in ``term_cfg.params`` (legacy). + 4. ``asset_cfg.joint_ids`` count (joint-based observations). + """ + # --- 1-2. Decorator metadata (preferred) --- + func = term_cfg.func + # Check for descriptor on the (possibly wrapped) function first, + # then fall back to unwrapping for class-based terms. + descriptor = getattr(func, "_descriptor", None) + if descriptor is None and hasattr(func, "__wrapped__"): + descriptor = getattr(func.__wrapped__, "_descriptor", None) + if descriptor is not None: + # 1. Explicit out_dim on decorator + out_dim = getattr(descriptor, "out_dim", None) + if out_dim is not None: + return self._resolve_out_dim(out_dim, term_cfg) + # 2. Derive from axes metadata + axes = descriptor.extras.get("axes") if descriptor.extras else None + if axes is not None: + return len(axes) + + # --- 3. Legacy explicit override in params --- for k in ("term_dim", "out_dim", "obs_dim"): if k in term_cfg.params: return int(term_cfg.params[k]) - # try explicit param first + + # --- 3. Joint-based fallback via asset_cfg --- asset_cfg = term_cfg.params.get("asset_cfg") if asset_cfg is None: - raise ValueError(f"Observation term '{term_cfg.params}' has no asset_cfg parameter.") - # resolve selection + raise ValueError( + f"Cannot infer output dimension for observation term '{getattr(func, '__name__', func)}'. " + "Add `out_dim=` to its @generic_io_descriptor_warp decorator." + ) asset = self._env.scene[asset_cfg.name] joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) if joint_ids_wp is not None: @@ -860,3 +895,47 @@ def _infer_term_dim_scalar(self, term_cfg: ObservationTermCfg) -> int: if isinstance(joint_ids, slice): return int(getattr(asset, "num_joints", wp.to_torch(asset.data.joint_pos).shape[1])) return int(len(joint_ids)) + + def _resolve_out_dim(self, out_dim: int | str, term_cfg: ObservationTermCfg) -> int: + """Resolve an ``out_dim`` value from a decorator into a concrete integer. + + Supports: + - ``int``: returned as-is (fixed dimension). + - ``"joint"``: number of selected joints from ``asset_cfg``. + - ``"body:N"``: ``N`` components per selected body from ``asset_cfg``. + - ``"command"``: query ``command_manager.get_command(name).shape[-1]``. + - ``"action"``: query ``action_manager.action.shape[-1]``. + """ + if isinstance(out_dim, int): + return out_dim + + if out_dim == "joint": + asset_cfg = term_cfg.params.get("asset_cfg") + asset = self._env.scene[asset_cfg.name] + joint_ids_wp = getattr(asset_cfg, "joint_ids_wp", None) + if joint_ids_wp is not None: + return int(joint_ids_wp.shape[0]) + joint_ids = getattr(asset_cfg, "joint_ids", slice(None)) + if isinstance(joint_ids, slice): + return int(getattr(asset, "num_joints", wp.to_torch(asset.data.joint_pos).shape[1])) + return int(len(joint_ids)) + + if isinstance(out_dim, str) and out_dim.startswith("body:"): + per_body = int(out_dim.split(":")[1]) + asset_cfg = term_cfg.params.get("asset_cfg") + body_ids = getattr(asset_cfg, "body_ids", None) + if body_ids is None or body_ids == slice(None): + asset = self._env.scene[asset_cfg.name] + return per_body * len(asset.body_names) + return per_body * len(body_ids) + + if out_dim == "command": + command_name = term_cfg.params.get("command_name") + cmd = self._env.command_manager.get_command(command_name) + return int(cmd.shape[-1]) + + if out_dim == "action": + action = self._env.action_manager.action + return int(action.shape[-1]) + + raise ValueError(f"Unknown out_dim sentinel: {out_dim!r}") diff --git a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py index 9f58cbe8ddf..1930f6ce1cb 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py +++ b/source/isaaclab_experimental/isaaclab_experimental/managers/scene_entity_cfg.py @@ -25,30 +25,36 @@ class SceneEntityCfg(_SceneEntityCfg): - `joint_mask` is intended for Warp kernels only. """ - """Boolean mask over all joints — used by warp kernels for masked writes.""" joint_mask: wp.array | None = None """Integer indices of selected joints — used for subset-sized gathers where a boolean mask cannot provide the mapping from output index k to joint index.""" joint_ids_wp: wp.array | None = None + """Integer indices of selected bodies — used for subset-sized body gathers.""" + body_ids_wp: wp.array | None = None + def resolve(self, scene: InteractiveScene): # run the stable resolution first (fills joint_ids/body_ids from names/regex) super().resolve(scene) - # Build a Warp joint mask for articulations only. entity = scene[self.name] - if not isinstance(entity, BaseArticulation): - return - - # Pre-allocate a full-length mask (all True for default selection). - if self.joint_ids == slice(None): - joint_ids_list = list(range(entity.num_joints)) - mask_list = [True] * entity.num_joints - else: - joint_ids_list = list(self.joint_ids) - mask_list = [False] * entity.num_joints - for idx in joint_ids_list: - mask_list[idx] = True - self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) - self.joint_ids_wp = wp.array(joint_ids_list, dtype=wp.int32, device=scene.device) + + # -- Warp joint mask / ids for articulations + if isinstance(entity, BaseArticulation): + if self.joint_ids == slice(None): + joint_ids_list = list(range(entity.num_joints)) + mask_list = [True] * entity.num_joints + else: + joint_ids_list = list(self.joint_ids) + mask_list = [False] * entity.num_joints + for idx in joint_ids_list: + mask_list[idx] = True + self.joint_mask = wp.array(mask_list, dtype=wp.bool, device=scene.device) + self.joint_ids_wp = wp.array(joint_ids_list, dtype=wp.int32, device=scene.device) + + # -- Warp body ids + if self.body_ids is not None and self.body_ids != slice(None): + self.body_ids_wp = wp.array(list(self.body_ids), dtype=wp.int32, device=scene.device) + elif hasattr(entity, "num_bodies"): + self.body_ids_wp = wp.array(list(range(entity.num_bodies)), dtype=wp.int32, device=scene.device) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py index 2f37b1087ca..d9236203ee4 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/manager_call_switch.py @@ -13,7 +13,7 @@ from enum import IntEnum from typing import Any -from isaaclab_experimental.utils.warp_graph_cache import WarpGraphCache +import warp as wp from isaaclab.utils.timer import Timer @@ -46,6 +46,7 @@ class ManagerCallSwitch: "ObservationManager", "EventManager", "RecorderManager", + "CommandManager", "TerminationManager", "RewardManager", "CurriculumManager", @@ -70,7 +71,8 @@ def __init__( *, max_modes: dict[str, int] | None = None, ): - self._graph_cache = WarpGraphCache() + self._wp_graphs: dict[str, Any] = {} + self._wp_results: dict[str, Any] = {} # Merge caller-supplied max_modes with the class-level MAX_MODE_OVERRIDES. self._max_modes = dict(self.MAX_MODE_OVERRIDES) if max_modes is not None: @@ -93,7 +95,8 @@ def __init__( def invalidate_graphs(self) -> None: """Invalidate cached capture graphs and their cached return values.""" - self._graph_cache.invalidate() + self._wp_graphs.clear() + self._wp_results.clear() # ------------------------------------------------------------------ # Stage dispatch @@ -163,38 +166,20 @@ def _manager_name_from_stage(self, stage: str) -> str: return stage.split("_", 1)[0] def get_mode_for_manager(self, manager_name: str) -> ManagerCallMode: - """Return the resolved execution mode for the given manager. - - Looks up the manager in the config dict, falls back to the default, - then caps by :attr:`MAX_MODE_OVERRIDES`. - """ - mode_value = self._cfg.get(manager_name, self._cfg[self.DEFAULT_KEY]) - cap = self._max_modes.get(manager_name) - if cap is not None: - mode_value = min(mode_value, cap) + """Return the resolved execution mode for the given manager.""" + default_key = next(iter(self.DEFAULT_CONFIG)) + mode_value = self._cfg.get(manager_name, self._cfg[default_key]) return ManagerCallMode(mode_value) - def resolve_manager_class(self, manager_name: str) -> type: + def resolve_manager_class(self, manager_name: str, mode_override: ManagerCallMode | int | None = None) -> type: """Import and return the manager class for the configured mode.""" - mode = self.get_mode_for_manager(manager_name) + mode = self.get_mode_for_manager(manager_name) if mode_override is None else ManagerCallMode(mode_override) module_name = "isaaclab.managers" if mode == ManagerCallMode.STABLE else "isaaclab_experimental.managers" module = importlib.import_module(module_name) if not hasattr(module, manager_name): raise AttributeError(f"Manager '{manager_name}' not found in module '{module_name}'.") return getattr(module, manager_name) - def register_manager_capturability(self, manager_name: str, capturable: bool) -> None: - """Register that a manager has non-capturable terms, capping its mode. - - Called by :class:`ManagerBase` during term preparation when a term - is decorated with ``@warp_capturable(False)``. - """ - if not capturable: - self._max_modes[manager_name] = min( - self._max_modes.get(manager_name, ManagerCallMode.WARP_CAPTURED), - ManagerCallMode.WARP_NOT_CAPTURED, - ) - # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @@ -206,15 +191,18 @@ def _run_call(self, call: dict[str, Any]) -> Any: def _wp_capture_or_launch(self, stage: str, call: dict[str, Any]) -> Any: """Capture Warp CUDA graph on first call, then replay. - Delegates to :class:`WarpGraphCache` which caches the return value - and replays immediately after the first capture for validation. + The return value from the first (capture) run is cached and returned + on every subsequent replay. This ensures captured stages return the + same references (e.g. tensor views) as eager stages. """ - return self._graph_cache.capture_or_replay( - stage, - call["fn"], - args=call.get("args", ()), - kwargs=call.get("kwargs", {}), - ) + graph = self._wp_graphs.get(stage) + if graph is None: + with wp.ScopedCapture() as capture: + result = call["fn"](*call.get("args", ()), **call.get("kwargs", {})) + self._wp_graphs[stage] = capture.graph + self._wp_results[stage] = result + wp.capture_launch(self._wp_graphs[stage]) + return self._wp_results[stage] def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]: if cfg_source is None: @@ -236,7 +224,7 @@ def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]: else: raise TypeError(f"cfg_source must be a dict, string, or None, got: {type(cfg_source)}") - # Validation + # validation for manager_name, mode_value in cfg.items(): if not isinstance(mode_value, int): raise TypeError( @@ -249,7 +237,7 @@ def _load_cfg(self, cfg_source: dict | str | None) -> dict[str, int]: f"Invalid manager_call_config value for '{manager_name}': {mode_value}. Expected 0/1/2." ) from exc - # Apply MAX_MODE_OVERRIDES: bake caps into the resolved config so + # Apply max mode caps: bake caps into the resolved config so # get_mode_for_manager never needs per-call branching. default_mode = cfg[self.DEFAULT_KEY] for name, max_mode in self._max_modes.items(): diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py index a7e71ac4688..2d071b823a4 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/__init__.py @@ -6,4 +6,11 @@ """Warp utility functions and shared kernels for isaaclab_experimental.""" from .kernels import compute_reset_scale, count_masked -from .utils import WarpCapturable, resolve_1d_mask, wrap_to_pi +from .utils import ( + WarpCapturable, + is_warp_capturable, + resolve_1d_mask, + warp_capturable, + wrap_to_pi, + zero_masked_2d, +) diff --git a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py index c7a2c63d959..b8eb145e21b 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py +++ b/source/isaaclab_experimental/isaaclab_experimental/utils/warp/utils.py @@ -39,7 +39,7 @@ def resolve_1d_mask( Args: ids: Indices to set to ``True``. ``None`` or ``slice(None)`` means all. mask: Explicit boolean mask. If provided, returned directly (after - torch→warp normalization if needed). Takes precedence over *ids*. + torch->warp normalization if needed). Takes precedence over *ids*. all_mask: Pre-allocated all-True mask of shape ``(size,)``, returned when both *ids* and *mask* are ``None``. scratch_mask: Pre-allocated scratch mask of shape ``(size,)``, filled @@ -47,7 +47,7 @@ def resolve_1d_mask( device: Warp device string. Returns: - A ``wp.array(dtype=wp.bool)`` — ``mask``, ``all_mask``, or ``scratch_mask``. + A ``wp.array(dtype=wp.bool)`` -- ``mask``, ``all_mask``, or ``scratch_mask``. """ # Fast path: explicit mask provided. if mask is not None: @@ -97,6 +97,37 @@ def resolve_1d_mask( return scratch_mask +def warp_capturable(capturable: bool): + """Annotate an MDP term's CUDA-graph capturability. + + No-wrapper decorator: sets ``_warp_capturable`` directly on the function + and returns it unchanged. Safe to stack with any other decorator in any order. + + By default all MDP terms are assumed capturable (True). Use + ``@warp_capturable(False)`` on terms that call non-capturable external APIs. + """ + + def decorator(func): + func._warp_capturable = capturable + return func + + return decorator + + +def is_warp_capturable(func) -> bool: + """Check if a term function is CUDA-graph-capturable. + + Checks ``_warp_capturable`` on the function and its ``__wrapped__`` target. + Returns True (capturable) by default if no annotation is found. + """ + for f in (func, getattr(func, "__wrapped__", None)): + if f is not None: + val = getattr(f, "_warp_capturable", None) + if val is not None: + return val + return True + + @wp.func def wrap_to_pi(angle: float) -> float: """Wrap input angle (in radians) to the range [-pi, pi].""" @@ -163,3 +194,12 @@ def is_capturable(func) -> bool: if val is not None: return val return True + + + +@wp.kernel +def zero_masked_2d(mask: wp.array(dtype=wp.bool), values: wp.array(dtype=wp.float32, ndim=2)): + """Zero out rows of a 2D float32 array where mask is True.""" + env_id, j = wp.tid() + if mask[env_id]: + values[env_id, j] = 0.0 diff --git a/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py new file mode 100644 index 00000000000..9bb7a39e3b2 --- /dev/null +++ b/source/isaaclab_newton/isaaclab_newton/kernels/state_kernels.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp helper functions for body-frame state computations. + +These ``@wp.func`` helpers are used by warp-first MDP terms (observations, +rewards) that need to project root-frame quantities into body frames. +""" + +import warp as wp + + +@wp.func +def rotate_vec_to_body_frame(vec_w: wp.vec3f, pose_w: wp.transformf) -> wp.vec3f: + """Rotate a world-frame vector into the body frame defined by pose_w.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), vec_w) + + +@wp.func +def body_lin_vel_from_root(pose_w: wp.transformf, vel_w: wp.spatial_vectorf) -> wp.vec3f: + """Extract body-frame linear velocity from root pose and spatial velocity.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), wp.spatial_top(vel_w)) + + +@wp.func +def body_ang_vel_from_root(pose_w: wp.transformf, vel_w: wp.spatial_vectorf) -> wp.vec3f: + """Extract body-frame angular velocity from root pose and spatial velocity.""" + return wp.quat_rotate_inv(wp.transform_get_rotation(pose_w), wp.spatial_bottom(vel_w)) From dc432494345b4fa6c435460de65762d6027b23d4 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 12 Mar 2026 00:29:48 -0700 Subject: [PATCH 4/7] Add capture safety guards and fix WrenchComposer stale COM pose Fix WrenchComposer to use Tier 1 sim-bind buffers (body_link_pose_w, body_com_pos_b) instead of caching the lazy Tier 2 body_com_pose_w property, which became stale after the first step. Add capture_unsafe decorator for lazy-evaluated derived properties in articulation and rigid object data. Update wrench kernels to compute COM pose inline from link pose and body-frame COM offset. --- .../isaaclab/isaaclab/utils/warp/kernels.py | 204 ++++++++++-------- source/isaaclab/isaaclab/utils/warp/utils.py | 151 +++++++++++++ .../isaaclab/utils/wrench_composer.py | 65 +++--- .../assets/articulation/articulation_data.py | 15 ++ .../assets/rigid_object/rigid_object_data.py | 15 ++ 5 files changed, 317 insertions(+), 133 deletions(-) create mode 100644 source/isaaclab/isaaclab/utils/warp/utils.py diff --git a/source/isaaclab/isaaclab/utils/warp/kernels.py b/source/isaaclab/isaaclab/utils/warp/kernels.py index da2d9123db4..e044c5224ee 100644 --- a/source/isaaclab/isaaclab/utils/warp/kernels.py +++ b/source/isaaclab/isaaclab/utils/warp/kernels.py @@ -307,54 +307,68 @@ def reshape_tiled_image( @wp.func -def cast_to_link_frame(position: wp.vec3f, link_position: wp.vec3f, is_global: bool) -> wp.vec3f: - """Casts a position to the link frame of the body. +def _com_pose_from_link(link_pose_w: wp.transformf, com_pos_b: wp.vec3f) -> wp.transformf: + """Compute the world-frame COM pose from a link pose and a body-frame COM offset. + + The COM frame shares the link's orientation; only the position is offset. + Equivalent to ``combine_frame_transforms(link_pose_w, com_pos_b, identity_quat)`` + but avoids the full transform multiply. + """ + q = wp.transform_get_rotation(link_pose_w) + p = wp.transform_get_translation(link_pose_w) + wp.quat_rotate(q, com_pos_b) + return wp.transformf(p, q) + + +@wp.func +def cast_to_com_frame(position: wp.vec3f, com_pose_w: wp.transformf, is_global: bool) -> wp.vec3f: + """Casts a position to the COM frame of the body. Args: position: The position to cast. - link_position: The link frame position. + com_pose_w: The COM frame pose in world frame. is_global: Whether the position is in the global frame. Returns: - The position in the link frame of the body. + The position in the COM frame of the body. """ if is_global: - return position - link_position + return position - wp.transform_get_translation(com_pose_w) else: return position @wp.func -def cast_force_to_link_frame(force: wp.vec3f, link_quat: wp.quatf, is_global: bool) -> wp.vec3f: - """Casts a force to the link frame of the body. +def cast_force_to_com_frame(force: wp.vec3f, com_pose_w: wp.transformf, is_global: bool) -> wp.vec3f: + """Casts a force to the COM frame of the body. Args: force: The force to cast. - link_quat: The link frame quaternion. + com_pose_w: The COM frame pose in world frame. is_global: Whether the force is applied in the global frame. + Returns: - The force in the link frame of the body. + The force in the COM frame of the body. """ if is_global: - return wp.quat_rotate_inv(link_quat, force) + return wp.quat_rotate_inv(wp.transform_get_rotation(com_pose_w), force) else: return force @wp.func -def cast_torque_to_link_frame(torque: wp.vec3f, link_quat: wp.quatf, is_global: bool) -> wp.vec3f: - """Casts a torque to the link frame of the body. +def cast_torque_to_com_frame(torque: wp.vec3f, com_pose_w: wp.transformf, is_global: bool) -> wp.vec3f: + """Casts a torque to the COM frame of the body. Args: torque: The torque to cast. - link_quat: The link frame quaternion. + com_pose_w: The COM frame pose in world frame. is_global: Whether the torque is applied in the global frame. Returns: - The torque in the link frame of the body. + The torque in the COM frame of the body. """ if is_global: - return wp.quat_rotate_inv(link_quat, torque) + return wp.quat_rotate_inv(wp.transform_get_rotation(com_pose_w), torque) else: return torque @@ -366,7 +380,8 @@ def add_forces_and_torques_at_position_index( forces: wp.array2d(dtype=wp.vec3f), torques: wp.array2d(dtype=wp.vec3f), positions: wp.array2d(dtype=wp.vec3f), - link_poses: wp.array2d(dtype=wp.transformf), + body_link_pose_w: wp.array2d(dtype=wp.transformf), + body_com_pos_b: wp.array2d(dtype=wp.vec3f), is_global: bool, composed_forces_b: wp.array2d(dtype=wp.vec3f), composed_torques_b: wp.array2d(dtype=wp.vec3f), @@ -374,7 +389,7 @@ def add_forces_and_torques_at_position_index( """Add forces and torques to the composed wrench at user-provided positions using index selection. When is_global is False, the user-provided positions offset the force application relative to - the link frame. When is_global is True, positions are in the global frame. Results are + the COM frame. When is_global is True, positions are in the global frame. Results are accumulated (added) into the composed buffers. .. note:: @@ -389,43 +404,40 @@ def add_forces_and_torques_at_position_index( Can be None if not provided. positions: Input array of position offsets for force application. Shape is (num_selected_envs, num_selected_bodies). Can be None if not provided. - link_poses: Input array of link frame poses in world frame. + body_link_pose_w: Body link poses in world frame (Tier 1 sim-bind). + Shape is (num_envs, num_bodies). + body_com_pos_b: Body COM offsets in link frame (static). Shape is (num_envs, num_bodies). is_global: Input flag indicating whether forces/torques/positions are in the global frame. - composed_forces_b: Output array where forces in the link frame are accumulated. + composed_forces_b: Output array where forces in the COM frame are accumulated. Shape is (num_envs, num_bodies). - composed_torques_b: Output array where torques in the link frame are accumulated. + composed_torques_b: Output array where torques in the COM frame are accumulated. Shape is (num_envs, num_bodies). """ # get the thread id tid_env, tid_body = wp.tid() + # compute COM pose from link pose and body-frame COM offset + com_pose = _com_pose_from_link( + body_link_pose_w[env_ids[tid_env], body_ids[tid_body]], + body_com_pos_b[env_ids[tid_env], body_ids[tid_body]], + ) + # add the forces to the composed force, if the positions are provided, also adds a torque to the composed torque. if forces: # add the forces to the composed force - composed_forces_b[env_ids[tid_env], body_ids[tid_body]] += cast_force_to_link_frame( - forces[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_forces_b[env_ids[tid_env], body_ids[tid_body]] += cast_force_to_com_frame( + forces[tid_env, tid_body], com_pose, is_global ) # if there is a position offset, add a torque to the composed torque. if positions: - composed_torques_b[env_ids[tid_env], body_ids[tid_body]] += wp.skew( - cast_to_link_frame( - positions[tid_env, tid_body], - wp.transform_get_translation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, - ) - ) @ cast_force_to_link_frame( - forces[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_torques_b[env_ids[tid_env], body_ids[tid_body]] += wp.cross( + cast_to_com_frame(positions[tid_env, tid_body], com_pose, is_global), + cast_force_to_com_frame(forces[tid_env, tid_body], com_pose, is_global), ) if torques: - composed_torques_b[env_ids[tid_env], body_ids[tid_body]] += cast_torque_to_link_frame( - torques[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_torques_b[env_ids[tid_env], body_ids[tid_body]] += cast_torque_to_com_frame( + torques[tid_env, tid_body], com_pose, is_global ) @@ -436,7 +448,8 @@ def set_forces_and_torques_at_position_index( forces: wp.array2d(dtype=wp.vec3f), torques: wp.array2d(dtype=wp.vec3f), positions: wp.array2d(dtype=wp.vec3f), - link_poses: wp.array2d(dtype=wp.transformf), + body_link_pose_w: wp.array2d(dtype=wp.transformf), + body_com_pos_b: wp.array2d(dtype=wp.vec3f), is_global: bool, composed_forces_b: wp.array2d(dtype=wp.vec3f), composed_torques_b: wp.array2d(dtype=wp.vec3f), @@ -444,7 +457,7 @@ def set_forces_and_torques_at_position_index( """Set forces and torques to the composed wrench at user-provided positions using index selection. When is_global is False, the user-provided positions offset the force application relative to - the link frame. When is_global is True, positions are in the global frame. Results are + the COM frame. When is_global is True, positions are in the global frame. Results are overwritten (set) in the composed buffers. .. note:: @@ -459,45 +472,45 @@ def set_forces_and_torques_at_position_index( Can be None if not provided. positions: Input array of position offsets for force application. Shape is (num_selected_envs, num_selected_bodies). Can be None if not provided. - link_poses: Input array of link frame poses in world frame. + body_link_pose_w: Body link poses in world frame (Tier 1 sim-bind). + Shape is (num_envs, num_bodies). + body_com_pos_b: Body COM offsets in link frame (static). Shape is (num_envs, num_bodies). is_global: Input flag indicating whether forces/torques/positions are in the global frame. - composed_forces_b: Output array where forces in the link frame are written. + composed_forces_b: Output array where forces in the COM frame are written. Shape is (num_envs, num_bodies). - composed_torques_b: Output array where torques in the link frame are written. + composed_torques_b: Output array where torques in the COM frame are written. Shape is (num_envs, num_bodies). """ # get the thread id tid_env, tid_body = wp.tid() + # compute COM pose from link pose and body-frame COM offset + com_pose = _com_pose_from_link( + body_link_pose_w[env_ids[tid_env], body_ids[tid_body]], + body_com_pos_b[env_ids[tid_env], body_ids[tid_body]], + ) + # set the torques to the composed torque if torques: - composed_torques_b[env_ids[tid_env], body_ids[tid_body]] = cast_torque_to_link_frame( - torques[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_torques_b[env_ids[tid_env], body_ids[tid_body]] = cast_torque_to_com_frame( + torques[tid_env, tid_body], com_pose, is_global ) # set the forces to the composed force, if the positions are provided, adds a torque to the composed torque # from the force at that position. if forces: # set the forces to the composed force - composed_forces_b[env_ids[tid_env], body_ids[tid_body]] = cast_force_to_link_frame( - forces[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_forces_b[env_ids[tid_env], body_ids[tid_body]] = cast_force_to_com_frame( + forces[tid_env, tid_body], com_pose, is_global ) # if there is a position offset, set the torque from the force at that position. + # NOTE: this overwrites any explicit torque set above. If both torques and + # forces+positions are provided, the correct result should be τ_explicit + r × F + # (i.e. += instead of =). Pre-existing behavior — no caller currently passes both. if positions: - composed_torques_b[env_ids[tid_env], body_ids[tid_body]] = wp.skew( - cast_to_link_frame( - positions[tid_env, tid_body], - wp.transform_get_translation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, - ) - ) @ cast_force_to_link_frame( - forces[tid_env, tid_body], - wp.transform_get_rotation(link_poses[env_ids[tid_env], body_ids[tid_body]]), - is_global, + composed_torques_b[env_ids[tid_env], body_ids[tid_body]] = wp.cross( + cast_to_com_frame(positions[tid_env, tid_body], com_pose, is_global), + cast_force_to_com_frame(forces[tid_env, tid_body], com_pose, is_global), ) @@ -508,7 +521,8 @@ def add_forces_and_torques_at_position_mask( forces: wp.array2d(dtype=wp.vec3f), torques: wp.array2d(dtype=wp.vec3f), positions: wp.array2d(dtype=wp.vec3f), - link_poses: wp.array2d(dtype=wp.transformf), + body_link_pose_w: wp.array2d(dtype=wp.transformf), + body_com_pos_b: wp.array2d(dtype=wp.vec3f), is_global: bool, composed_forces_b: wp.array2d(dtype=wp.vec3f), composed_torques_b: wp.array2d(dtype=wp.vec3f), @@ -516,7 +530,7 @@ def add_forces_and_torques_at_position_mask( """Add forces and torques to the composed wrench at user-provided positions using mask selection. When is_global is False, the user-provided positions offset the force application relative to - the link frame. When is_global is True, positions are in the global frame. Results are + the COM frame. When is_global is True, positions are in the global frame. Results are accumulated (added) into the composed buffers. Only entries where both env_mask and body_mask are True are processed. @@ -532,39 +546,38 @@ def add_forces_and_torques_at_position_mask( Can be None if not provided. positions: Input array of position offsets for force application. Shape is (num_envs, num_bodies). Can be None if not provided. - link_poses: Input array of link frame poses in world frame. + body_link_pose_w: Body link poses in world frame (Tier 1 sim-bind). + Shape is (num_envs, num_bodies). + body_com_pos_b: Body COM offsets in link frame (static). Shape is (num_envs, num_bodies). is_global: Input flag indicating whether forces/torques/positions are in the global frame. - composed_forces_b: Output array where forces in the link frame are accumulated. + composed_forces_b: Output array where forces in the COM frame are accumulated. Shape is (num_envs, num_bodies). - composed_torques_b: Output array where torques in the link frame are accumulated. + composed_torques_b: Output array where torques in the COM frame are accumulated. Shape is (num_envs, num_bodies). """ # get the thread id tid_env, tid_body = wp.tid() if env_mask[tid_env] and body_mask[tid_body]: + # compute COM pose from link pose and body-frame COM offset + com_pose = _com_pose_from_link(body_link_pose_w[tid_env, tid_body], body_com_pos_b[tid_env, tid_body]) # add the forces to the composed force, if the positions are provided, also adds a torque to the composed # torque. if forces: # add the forces to the composed force - composed_forces_b[tid_env, tid_body] += cast_force_to_link_frame( - forces[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_forces_b[tid_env, tid_body] += cast_force_to_com_frame( + forces[tid_env, tid_body], com_pose, is_global ) # if there is a position offset, add a torque to the composed torque. if positions: - composed_torques_b[tid_env, tid_body] += wp.skew( - cast_to_link_frame( - positions[tid_env, tid_body], - wp.transform_get_translation(link_poses[tid_env, tid_body]), - is_global, - ) - ) @ cast_force_to_link_frame( - forces[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_torques_b[tid_env, tid_body] += wp.cross( + cast_to_com_frame(positions[tid_env, tid_body], com_pose, is_global), + cast_force_to_com_frame(forces[tid_env, tid_body], com_pose, is_global), ) if torques: - composed_torques_b[tid_env, tid_body] += cast_torque_to_link_frame( - torques[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_torques_b[tid_env, tid_body] += cast_torque_to_com_frame( + torques[tid_env, tid_body], com_pose, is_global ) @@ -575,7 +588,8 @@ def set_forces_and_torques_at_position_mask( forces: wp.array2d(dtype=wp.vec3f), torques: wp.array2d(dtype=wp.vec3f), positions: wp.array2d(dtype=wp.vec3f), - link_poses: wp.array2d(dtype=wp.transformf), + body_link_pose_w: wp.array2d(dtype=wp.transformf), + body_com_pos_b: wp.array2d(dtype=wp.vec3f), is_global: bool, composed_forces_b: wp.array2d(dtype=wp.vec3f), composed_torques_b: wp.array2d(dtype=wp.vec3f), @@ -583,7 +597,7 @@ def set_forces_and_torques_at_position_mask( """Set forces and torques to the composed wrench at user-provided positions using mask selection. When is_global is False, the user-provided positions offset the force application relative to - the link frame. When is_global is True, positions are in the global frame. Results are + the COM frame. When is_global is True, positions are in the global frame. Results are overwritten (set) in the composed buffers. Only entries where both env_mask and body_mask are True are processed. @@ -599,12 +613,14 @@ def set_forces_and_torques_at_position_mask( Can be None if not provided. positions: Input array of position offsets for force application. Shape is (num_envs, num_bodies). Can be None if not provided. - link_poses: Input array of link frame poses in world frame. + body_link_pose_w: Body link poses in world frame (Tier 1 sim-bind). + Shape is (num_envs, num_bodies). + body_com_pos_b: Body COM offsets in link frame (static). Shape is (num_envs, num_bodies). is_global: Input flag indicating whether forces/torques/positions are in the global frame. - composed_forces_b: Output array where forces in the link frame are written. + composed_forces_b: Output array where forces in the COM frame are written. Shape is (num_envs, num_bodies). - composed_torques_b: Output array where torques in the link frame are written. + composed_torques_b: Output array where torques in the COM frame are written. Shape is (num_envs, num_bodies). """ # get the thread id @@ -612,27 +628,25 @@ def set_forces_and_torques_at_position_mask( # set the torques to the composed torque if env_mask[tid_env] and body_mask[tid_body]: + # compute COM pose from link pose and body-frame COM offset + com_pose = _com_pose_from_link(body_link_pose_w[tid_env, tid_body], body_com_pos_b[tid_env, tid_body]) if torques: - composed_torques_b[tid_env, tid_body] = cast_torque_to_link_frame( - torques[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_torques_b[tid_env, tid_body] = cast_torque_to_com_frame( + torques[tid_env, tid_body], com_pose, is_global ) # set the forces to the composed force, if the positions are provided, adds a torque to the composed torque # from the force at that position. if forces: # set the forces to the composed force - composed_forces_b[tid_env, tid_body] = cast_force_to_link_frame( - forces[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_forces_b[tid_env, tid_body] = cast_force_to_com_frame( + forces[tid_env, tid_body], com_pose, is_global ) # if there is a position offset, set the torque from the force at that position. + # NOTE: same overwrite caveat as the _index variant — see comment there. if positions: - composed_torques_b[tid_env, tid_body] = wp.skew( - cast_to_link_frame( - positions[tid_env, tid_body], - wp.transform_get_translation(link_poses[tid_env, tid_body]), - is_global, - ) - ) @ cast_force_to_link_frame( - forces[tid_env, tid_body], wp.transform_get_rotation(link_poses[tid_env, tid_body]), is_global + composed_torques_b[tid_env, tid_body] = wp.cross( + cast_to_com_frame(positions[tid_env, tid_body], com_pose, is_global), + cast_force_to_com_frame(forces[tid_env, tid_body], com_pose, is_global), ) diff --git a/source/isaaclab/isaaclab/utils/warp/utils.py b/source/isaaclab/isaaclab/utils/warp/utils.py new file mode 100644 index 00000000000..2df288dfba2 --- /dev/null +++ b/source/isaaclab/isaaclab/utils/warp/utils.py @@ -0,0 +1,151 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import functools +from collections.abc import Sequence + +import torch +import warp as wp + +## +# Mask resolution - ids/mask to warp boolean mask. +## + + +@wp.kernel +def _populate_mask_from_ids( + mask: wp.array(dtype=wp.bool), + ids: wp.array(dtype=wp.int32), +): + i = wp.tid() + mask[ids[i]] = True + + +def resolve_1d_mask( + *, + ids: Sequence[int] | slice | torch.Tensor | wp.array | None = None, + mask: wp.array | torch.Tensor | None = None, + all_mask: wp.array, + scratch_mask: wp.array, + device: str, +) -> wp.array: + """Resolve ids/mask into a warp boolean mask. + + Callers provide pre-allocated ``all_mask`` (all-True) and ``scratch_mask`` (reusable + work buffer) so this function never allocates. + + Args: + ids: Index ids. Accepts ``Sequence[int]``, ``slice``, ``torch.Tensor``, + ``wp.array(dtype=wp.int32)``, or ``None`` (all elements). + mask: Direct boolean mask. ``wp.array`` is returned as-is; + ``torch.Tensor`` is converted. + all_mask: Pre-allocated all-True mask returned when both *ids* and *mask* + are ``None``. + scratch_mask: Pre-allocated scratch buffer populated in-place when *ids* + are provided. Not re-entrant (shared buffer). + device: Warp device string (e.g. ``"cuda:0"``). + + Returns: + A ``wp.array(dtype=wp.bool)`` mask. + """ + # Normalize slice(None) to None so the capture guard treats it identically to ids=None. + if isinstance(ids, slice) and ids == slice(None): + ids = None + + if wp.get_device().is_capturing: + if ids is not None or (mask is not None and not isinstance(mask, wp.array)): + raise RuntimeError( + "resolve_1d_mask is only capturable when mask is a wp.array or both ids and mask are None." + ) + + # --- Direct mask input --- + if mask is not None: + if isinstance(mask, wp.array): + return mask + if isinstance(mask, torch.Tensor): + if mask.dtype != torch.bool: + mask = mask.to(dtype=torch.bool) + if str(mask.device) != device: + mask = mask.to(device) + return wp.from_torch(mask, dtype=wp.bool) + raise TypeError(f"Unsupported mask type: {type(mask)}") + + # --- Fast path: all elements --- + if ids is None: + return all_mask + + # --- Normalize slice to list --- + if isinstance(ids, slice): + start, stop, step = ids.indices(scratch_mask.shape[0]) + ids = list(range(start, stop, step)) + + # --- Normalize to concrete type --- + if not isinstance(ids, (torch.Tensor, wp.array)): + ids = list(ids) + + # --- Populate scratch mask --- + scratch_mask.fill_(False) + + if isinstance(ids, torch.Tensor): + if ids.numel() == 0: + return scratch_mask + if str(ids.device) != device: + ids = ids.to(device) + if ids.dtype != torch.int32: + ids = ids.to(dtype=torch.int32) + if not ids.is_contiguous(): + ids = ids.contiguous() + ids_wp = wp.from_torch(ids, dtype=wp.int32) + elif isinstance(ids, wp.array): + if ids.shape[0] == 0: + return scratch_mask + if ids.dtype != wp.int32: + raise TypeError(f"Unsupported wp.array dtype for ids: {ids.dtype}. Expected wp.int32 index array.") + ids_wp = ids + else: + if len(ids) == 0: + return scratch_mask + ids_wp = wp.array(ids, dtype=wp.int32, device=device) + + wp.launch(_populate_mask_from_ids, dim=ids_wp.shape[0], inputs=[scratch_mask, ids_wp], device=device) + return scratch_mask + + +## +# Capture safety — property guard. +## + + +def capture_unsafe(reason: str | None = None): + """Mark a callable as not CUDA-graph-capture-safe. + + Raises ``RuntimeError`` if the decorated callable is invoked while + ``wp.get_device().is_capturing`` is ``True``. + + Args: + reason: Optional explanation appended to the error message. + + Usage:: + + @property + @capture_unsafe("Relies on a Python timestamp guard.") + def projected_gravity_b(self) -> wp.array: ... + """ + + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if wp.get_device().is_capturing: + msg = f"'{func.__qualname__}' cannot be called during CUDA graph capture." + if reason: + msg = f"{msg} {reason}" + raise RuntimeError(msg) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/source/isaaclab/isaaclab/utils/wrench_composer.py b/source/isaaclab/isaaclab/utils/wrench_composer.py index 5ad966a6e4e..cecfbf933f1 100644 --- a/source/isaaclab/isaaclab/utils/wrench_composer.py +++ b/source/isaaclab/isaaclab/utils/wrench_composer.py @@ -45,9 +45,18 @@ def __init__(self, asset: BaseArticulation | BaseRigidObject | BaseRigidObjectCo self._asset = asset self._active = False - # Avoid isinstance here due to potential circular import issues; check by attribute presence instead. - if hasattr(self._asset.data, "body_link_pose_w"): - self._get_link_pose_fn = lambda a=self._asset: a.data.body_link_pose_w + # Store references to Tier 1 (sim-bind) buffers for COM pose computation. + # We intentionally avoid caching body_com_pose_w (a Tier 2 derived property) because + # it is lazily computed via a Python timestamp guard. Saving the .data pointer at init + # time would freeze it at the initial value — subsequent steps would read stale COM + # world poses since nothing triggers the lazy recomputation. Instead, we keep the two + # Tier 1 inputs (body_link_pose_w and body_com_pos_b) and let the wrench kernels + # compute the COM pose inline. This is both correct in eager mode and CUDA-graph- + # capture safe (Tier 1 buffers are stable sim-bind pointers updated by the solver). + data = self._asset.data + if hasattr(data, "body_link_pose_w") and hasattr(data, "body_com_pos_b"): + self._body_link_pose_w = data.body_link_pose_w + self._body_com_pos_b = data.body_com_pos_b else: raise ValueError(f"Unsupported asset type: {self._asset.__class__.__name__}") @@ -68,9 +77,6 @@ def __init__(self, asset: BaseArticulation | BaseRigidObject | BaseRigidObjectCo self._temp_forces_wp = wp.zeros((self.num_envs, self.num_bodies), dtype=wp.vec3f, device=self.device) self._temp_torques_wp = wp.zeros((self.num_envs, self.num_bodies), dtype=wp.vec3f, device=self.device) - # Flag to check if the link poses have been updated. - self._link_poses_updated = False - @property def active(self) -> bool: """Whether the wrench composer is active.""" @@ -148,11 +154,6 @@ def add_forces_and_torques_index( stacklevel=2, ) return - # Get the link poses - if not self._link_poses_updated: - self._link_poses = self._get_link_pose_fn() - self._link_poses_updated = True - # Set the active flag to true self._active = True @@ -165,7 +166,8 @@ def add_forces_and_torques_index( forces, torques, positions, - self._link_poses, + self._body_link_pose_w, + self._body_com_pos_b, is_global, ], outputs=[ @@ -219,11 +221,6 @@ def set_forces_and_torques_index( stacklevel=2, ) return - # Get the link poses - if not self._link_poses_updated: - self._link_poses = self._get_link_pose_fn() - self._link_poses_updated = True - # Set the active flag to true self._active = True @@ -236,7 +233,8 @@ def set_forces_and_torques_index( forces, torques, positions, - self._link_poses, + self._body_link_pose_w, + self._body_com_pos_b, is_global, ], outputs=[ @@ -290,10 +288,6 @@ def add_forces_and_torques_mask( stacklevel=2, ) return - # Get the link poses - if not self._link_poses_updated: - self._link_poses = self._get_link_pose_fn() - self._link_poses_updated = True # Set the active flag to true self._active = True @@ -307,7 +301,8 @@ def add_forces_and_torques_mask( forces, torques, positions, - self._link_poses, + self._body_link_pose_w, + self._body_com_pos_b, is_global, ], outputs=[ @@ -357,10 +352,6 @@ def set_forces_and_torques_mask( stacklevel=2, ) return - # Get the link poses - if not self._link_poses_updated: - self._link_poses = self._get_link_pose_fn() - self._link_poses_updated = True # Set the active flag to true self._active = True @@ -374,7 +365,8 @@ def set_forces_and_torques_mask( forces, torques, positions, - self._link_poses, + self._body_link_pose_w, + self._body_com_pos_b, is_global, ], outputs=[ @@ -388,8 +380,6 @@ def reset(self, env_ids: wp.array | torch.Tensor | None = None, env_mask: wp.arr """Reset the composed force and torque. This function will reset the composed force and torque to zero. - It will also make sure the link positions and quaternions are updated in the next call of the - `add_forces_and_torques` or `set_forces_and_torques` functions. .. note:: This function should be called after every simulation step / reset to ensure no force is carried over to the next step. @@ -401,11 +391,7 @@ def reset(self, env_ids: wp.array | torch.Tensor | None = None, env_mask: wp.arr env_ids: Environment indices. Defaults to None (all environments). env_mask: Environment mask. Defaults to None (all environments). """ - if env_ids is None and env_mask is None: - self._composed_force_b.zero_() - self._composed_torque_b.zero_() - self._active = False - elif env_mask is not None: + if env_mask is not None: wp.launch( reset_wrench_composer_mask, dim=(self.num_envs, self.num_bodies), @@ -418,8 +404,8 @@ def reset(self, env_ids: wp.array | torch.Tensor | None = None, env_mask: wp.arr ], device=self.device, ) - else: - if env_ids is None or env_ids == slice(None): + elif env_ids is not None: + if env_ids == slice(None): env_ids = self._ALL_ENV_INDICES elif isinstance(env_ids, list): env_ids = wp.array(env_ids, dtype=wp.int32, device=self.device) @@ -437,7 +423,10 @@ def reset(self, env_ids: wp.array | torch.Tensor | None = None, env_mask: wp.arr ], device=self.device, ) - self._link_poses_updated = False + else: + self._composed_force_b.zero_() + self._composed_torque_b.zero_() + self._active = False """ Deprecated functions. diff --git a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py index 4dc9507dbe5..8d90bf81928 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/articulation/articulation_data.py @@ -15,6 +15,7 @@ from isaaclab.assets.articulation.base_articulation_data import BaseArticulationData from isaaclab.utils.buffers import TimestampedBufferWarp as TimestampedBuffer from isaaclab.utils.math import normalize +from isaaclab.utils.warp.utils import capture_unsafe from isaaclab_newton.assets import kernels as shared_kernels from isaaclab_newton.assets.articulation import kernels as articulation_kernels @@ -26,6 +27,13 @@ # import logger logger = logging.getLogger(__name__) +_LAZY_CAPTURE_REASON = ( + "This is a lazily-computed derived property guarded by a Python timestamp check " + "that is invisible during graph replay. Use Tier 1 base data (root_link_pose_w, " + "root_com_vel_w, body_link_pose_w, body_com_vel_w, joint_pos, joint_vel) and " + "inline the computation in your warp kernel. See GRAPH_CAPTURE_MIGRATION.md." +) + class ArticulationData(BaseArticulationData): """Data container for an articulation. @@ -553,6 +561,7 @@ def root_link_pose_w(self) -> wp.array: return self._sim_bind_root_link_pose_w @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_link_vel_w(self) -> wp.array: """Root link velocity ``[lin_vel, ang_vel]`` in simulation world frame. @@ -580,6 +589,7 @@ def root_link_vel_w(self) -> wp.array: return self._root_link_vel_w.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_com_pose_w(self) -> wp.array: """Root center of mass pose ``[pos, quat]`` in simulation world frame. @@ -654,6 +664,7 @@ def body_link_pose_w(self) -> wp.array: return self._sim_bind_body_link_pose_w @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def body_link_vel_w(self) -> wp.array: """Body link velocity ``[lin_vel, ang_vel]`` in simulation world frame. @@ -682,6 +693,7 @@ def body_link_vel_w(self) -> wp.array: return self._body_link_vel_w.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def body_com_pose_w(self) -> wp.array: """Body center of mass pose ``[pos, quat]`` in simulation world frame. @@ -860,6 +872,7 @@ def joint_acc(self) -> wp.array: """ @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def projected_gravity_b(self): """Projection of the gravity direction on base frame. @@ -877,6 +890,7 @@ def projected_gravity_b(self): return self._projected_gravity_b.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def heading_w(self): """Yaw heading of the base frame (in radians). @@ -898,6 +912,7 @@ def heading_w(self): return self._heading_w.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_link_lin_vel_b(self) -> wp.array: """Root link linear velocity in base frame. diff --git a/source/isaaclab_newton/isaaclab_newton/assets/rigid_object/rigid_object_data.py b/source/isaaclab_newton/isaaclab_newton/assets/rigid_object/rigid_object_data.py index 7a8c85a5775..87c916974c2 100644 --- a/source/isaaclab_newton/isaaclab_newton/assets/rigid_object/rigid_object_data.py +++ b/source/isaaclab_newton/isaaclab_newton/assets/rigid_object/rigid_object_data.py @@ -15,6 +15,7 @@ from isaaclab.assets.rigid_object.base_rigid_object_data import BaseRigidObjectData from isaaclab.utils.buffers import TimestampedBufferWarp as TimestampedBuffer from isaaclab.utils.math import normalize +from isaaclab.utils.warp.utils import capture_unsafe from isaaclab_newton.assets import kernels as shared_kernels from isaaclab_newton.physics import NewtonManager as SimulationManager @@ -26,6 +27,13 @@ # import logger logger = logging.getLogger(__name__) +_LAZY_CAPTURE_REASON = ( + "This is a lazily-computed derived property guarded by a Python timestamp check " + "that is invisible during graph replay. Use Tier 1 base data (root_link_pose_w, " + "root_com_vel_w, body_link_pose_w, body_com_vel_w) and inline the computation " + "in your warp kernel. See GRAPH_CAPTURE_MIGRATION.md." +) + class RigidObjectData(BaseRigidObjectData): """Data container for a rigid object. @@ -188,6 +196,7 @@ def root_link_pose_w(self) -> wp.array: return self._sim_bind_root_link_pose_w @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_link_vel_w(self) -> wp.array: """Root link velocity ``[lin_vel, ang_vel]`` in simulation world frame. @@ -215,6 +224,7 @@ def root_link_vel_w(self) -> wp.array: return self._root_link_vel_w.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_com_pose_w(self) -> wp.array: """Root center of mass pose ``[pos, quat]`` in simulation world frame. @@ -283,6 +293,7 @@ def body_link_pose_w(self) -> wp.array: return self._sim_bind_body_link_pose_w @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def body_link_vel_w(self) -> wp.array: """Body link velocity ``[lin_vel, ang_vel]`` in simulation world frame. @@ -293,6 +304,7 @@ def body_link_vel_w(self) -> wp.array: return self.root_link_vel_w.reshape((self._num_instances, 1)) @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def body_com_pose_w(self) -> wp.array: """Body center of mass pose ``[pos, quat]`` in simulation world frame. @@ -382,6 +394,7 @@ def body_com_pose_b(self) -> wp.array: """ @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def projected_gravity_b(self) -> wp.array: """Projection of the gravity direction on base frame. @@ -399,6 +412,7 @@ def projected_gravity_b(self) -> wp.array: return self._projected_gravity_b.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def heading_w(self) -> wp.array: """Yaw heading of the base frame (in radians). @@ -420,6 +434,7 @@ def heading_w(self) -> wp.array: return self._heading_w.data @property + @capture_unsafe(_LAZY_CAPTURE_REASON) def root_link_lin_vel_b(self) -> wp.array: """Root link linear velocity in base frame. From 2e3f6ade165a2ee1cebcf98d68ebc153f67a3b1c Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 12 Mar 2026 00:30:39 -0700 Subject: [PATCH 5/7] Add warp env configs and task-specific MDP terms Add manager-based warp env configs for classic (Ant, Humanoid), locomotion velocity (A1, AnymalB/C/D, Cassie, G1, Go1/2, H1), and manipulation reach (Franka, UR10). Include task-specific MDP terms for humanoid observations/rewards and locomotion rewards/terminations/ curriculums. --- .../manager_based/classic/__init__.py | 8 +- .../manager_based/classic/ant/__init__.py | 30 ++ .../manager_based/classic/ant/ant_env_cfg.py | 198 +++++++++++ .../classic/cartpole/__init__.py | 23 +- .../classic/humanoid/__init__.py | 30 ++ .../classic/humanoid/humanoid_env_cfg.py | 233 +++++++++++++ .../classic/humanoid/mdp/__init__.py | 11 + .../classic/humanoid/mdp/observations.py | 179 ++++++++++ .../classic/humanoid/mdp/rewards.py | 314 ++++++++++++++++++ .../manager_based/locomotion/__init__.py | 6 + .../locomotion/velocity/__init__.py | 6 + .../locomotion/velocity/config/__init__.py | 9 + .../locomotion/velocity/config/a1/__init__.py | 37 +++ .../velocity/config/a1/flat_env_cfg.py | 60 ++++ .../velocity/config/a1/rough_env_cfg.py | 91 +++++ .../velocity/config/anymal_b/__init__.py | 37 +++ .../velocity/config/anymal_b/flat_env_cfg.py | 60 ++++ .../velocity/config/anymal_b/rough_env_cfg.py | 34 ++ .../velocity/config/anymal_c/__init__.py | 39 +++ .../velocity/config/anymal_c/flat_env_cfg.py | 52 +++ .../velocity/config/anymal_c/rough_env_cfg.py | 40 +++ .../velocity/config/anymal_d/__init__.py | 60 ++++ .../velocity/config/anymal_d/flat_env_cfg.py | 46 +++ .../velocity/config/anymal_d/rough_env_cfg.py | 37 +++ .../velocity/config/cassie/__init__.py | 35 ++ .../velocity/config/cassie/flat_env_cfg.py | 45 +++ .../velocity/config/cassie/rough_env_cfg.py | 96 ++++++ .../locomotion/velocity/config/g1/__init__.py | 58 ++++ .../velocity/config/g1/flat_env_cfg.py | 58 ++++ .../velocity/config/g1/rough_env_cfg.py | 178 ++++++++++ .../velocity/config/g1_29_dofs/__init__.py | 35 ++ .../config/g1_29_dofs/flat_env_cfg.py | 58 ++++ .../config/g1_29_dofs/rough_env_cfg.py | 131 ++++++++ .../velocity/config/go1/__init__.py | 35 ++ .../velocity/config/go1/flat_env_cfg.py | 46 +++ .../velocity/config/go1/rough_env_cfg.py | 61 ++++ .../velocity/config/go2/__init__.py | 35 ++ .../velocity/config/go2/flat_env_cfg.py | 46 +++ .../velocity/config/go2/rough_env_cfg.py | 60 ++++ .../locomotion/velocity/config/h1/__init__.py | 57 ++++ .../velocity/config/h1/flat_env_cfg.py | 46 +++ .../velocity/config/h1/rough_env_cfg.py | 131 ++++++++ .../locomotion/velocity/mdp/__init__.py | 12 + .../locomotion/velocity/mdp/curriculums.py | 40 +++ .../locomotion/velocity/mdp/rewards.py | 307 +++++++++++++++++ .../locomotion/velocity/mdp/terminations.py | 66 ++++ .../locomotion/velocity/velocity_env_cfg.py | 296 +++++++++++++++++ .../manager_based/manipulation/__init__.py | 6 + .../manipulation/reach/__init__.py | 6 + .../manipulation/reach/config/__init__.py | 4 + .../reach/config/franka/__init__.py | 41 +++ .../reach/config/franka/joint_pos_env_cfg.py | 74 +++++ .../reach/config/ur_10/__init__.py | 36 ++ .../reach/config/ur_10/joint_pos_env_cfg.py | 74 +++++ .../manipulation/reach/mdp/__init__.py | 10 + .../manipulation/reach/mdp/rewards.py | 166 +++++++++ .../manipulation/reach/reach_env_cfg.py | 206 ++++++++++++ 57 files changed, 4183 insertions(+), 12 deletions(-) create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py create mode 100644 source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py index 4781f141af4..79c13e2aa8f 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/__init__.py @@ -3,4 +3,10 @@ # # SPDX-License-Identifier: BSD-3-Clause -"""Classic experimental task registrations (manager-based).""" +"""Classic environments for control. + +These environments are based on the MuJoCo environments provided by OpenAI. + +Reference: + https://github.com/openai/gym/tree/master/gym/envs/mujoco +""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py new file mode 100644 index 00000000000..5f123abaa75 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Ant locomotion environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.ant import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Ant-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.ant_env_cfg:AntEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AntPPORunnerCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py new file mode 100644 index 00000000000..106d1a78ba0 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/ant/ant_env_cfg.py @@ -0,0 +1,198 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# Ant reuses humanoid's experimental MDP (mirrors stable pattern). +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.classic.humanoid.mdp as mdp + +## +# Pre-defined configs +## +from isaaclab_assets.robots.ant import ANT_CFG # isort: skip + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with an ant robot.""" + + # terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="plane", + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="average", + restitution_combine_mode="average", + static_friction=1.0, + dynamic_friction=1.0, + restitution=0.0, + ), + debug_vis=False, + ) + + # robot + robot = ANT_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg(asset_name="robot", joint_names=[".*"], scale=7.5) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for the policy.""" + + base_height = ObsTerm(func=mdp.base_pos_z) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel) + base_yaw_roll = ObsTerm(func=mdp.base_yaw_roll) + base_angle_to_target = ObsTerm(func=mdp.base_angle_to_target, params={"target_pos": (1000.0, 0.0, 0.0)}) + base_up_proj = ObsTerm(func=mdp.base_up_proj) + base_heading_proj = ObsTerm(func=mdp.base_heading_proj, params={"target_pos": (1000.0, 0.0, 0.0)}) + joint_pos_norm = ObsTerm(func=mdp.joint_pos_limit_normalized) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel, scale=0.2) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "position_range": (-0.2, 0.2), + "velocity_range": (-0.1, 0.1), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Reward for moving forward + progress = RewTerm(func=mdp.progress_reward, weight=1.0, params={"target_pos": (1000.0, 0.0, 0.0)}) + # (2) Stay alive bonus + alive = RewTerm(func=mdp.is_alive, weight=0.5) + # (3) Reward for non-upright posture + upright = RewTerm(func=mdp.upright_posture_bonus, weight=0.1, params={"threshold": 0.93}) + # (4) Reward for moving in the right direction + move_to_target = RewTerm( + func=mdp.move_to_target_bonus, weight=0.5, params={"threshold": 0.8, "target_pos": (1000.0, 0.0, 0.0)} + ) + # (5) Penalty for large action commands + action_l2 = RewTerm(func=mdp.action_l2, weight=-0.005) + # (6) Penalty for energy consumption + energy = RewTerm(func=mdp.power_consumption, weight=-0.05, params={"gear_ratio": {".*": 15.0}}) + # (7) Penalty for reaching close to joint limits + joint_pos_limits = RewTerm( + func=mdp.joint_pos_limits_penalty_ratio, weight=-0.1, params={"threshold": 0.99, "gear_ratio": {".*": 15.0}} + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Terminate if the episode length is exceeded + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Terminate if the robot falls + torso_height = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.31}) + + +## +# Environment configuration +## + + +@configclass +class AntEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the MuJoCo-style Ant walking environment.""" + + # Simulation settings + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=38, + nconmax=15, + ls_iterations=10, + cone="pyramidal", + ls_parallel=True, + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=5.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 16.0 + # simulation settings + self.sim.dt = 1 / 120.0 + self.sim.render_interval = self.decimation + # default friction material + self.sim.physics_material.static_friction = 1.0 + self.sim.physics_material.dynamic_friction = 1.0 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py index 17a4c5c03cd..4e332426494 100644 --- a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/cartpole/__init__.py @@ -9,21 +9,22 @@ import gymnasium as gym +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.cartpole import agents + +## +# Register Gym environments. +## + gym.register( id="Isaac-Cartpole-Warp-v0", entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", disable_env_checker=True, kwargs={ - # Use experimental Cartpole cfg (allows isolated modifications). - "env_cfg_entry_point": ( - "isaaclab_tasks_experimental.manager_based.classic.cartpole.cartpole_env_cfg:CartpoleEnvCfg" - ), - # Point agent configs to the existing task package. - "rl_games_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:rl_games_ppo_cfg.yaml", - "rsl_rl_cfg_entry_point": ( - "isaaclab_tasks.manager_based.classic.cartpole.agents.rsl_rl_ppo_cfg:CartpolePPORunnerCfg" - ), - "skrl_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:skrl_ppo_cfg.yaml", - "sb3_cfg_entry_point": "isaaclab_tasks.manager_based.classic.cartpole.agents:sb3_ppo_cfg.yaml", + "env_cfg_entry_point": f"{__name__}.cartpole_env_cfg:CartpoleEnvCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartpolePPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", }, ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py new file mode 100644 index 00000000000..c08e5156b92 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +""" +Humanoid locomotion environment (experimental manager-based entry point). +""" + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.classic.humanoid import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Humanoid-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.humanoid_env_cfg:HumanoidEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:HumanoidPPORunnerCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py new file mode 100644 index 00000000000..781541a495e --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/humanoid_env_cfg.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +import isaaclab.sim as sim_utils +from isaaclab.assets import AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sim import SimulationCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.classic.humanoid.mdp as mdp + +from isaaclab_assets.robots.humanoid import HUMANOID_CFG # isort:skip + + +## +# Scene definition +## + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with a humanoid robot.""" + + # terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="plane", + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg(static_friction=1.0, dynamic_friction=1.0, restitution=0.0), + debug_vis=False, + ) + + # robot + robot = HUMANOID_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DistantLightCfg(color=(0.75, 0.75, 0.75), intensity=3000.0), + ) + + +## +# MDP settings +## + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_effort = mdp.JointEffortActionCfg( + asset_name="robot", + joint_names=[".*"], + scale={ + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + }, + ) + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for the policy.""" + + base_height = ObsTerm(func=mdp.base_pos_z) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel, scale=0.25) + base_yaw_roll = ObsTerm(func=mdp.base_yaw_roll) + base_angle_to_target = ObsTerm(func=mdp.base_angle_to_target, params={"target_pos": (1000.0, 0.0, 0.0)}) + base_up_proj = ObsTerm(func=mdp.base_up_proj) + base_heading_proj = ObsTerm(func=mdp.base_heading_proj, params={"target_pos": (1000.0, 0.0, 0.0)}) + joint_pos_norm = ObsTerm(func=mdp.joint_pos_limit_normalized) + joint_vel_rel = ObsTerm(func=mdp.joint_vel_rel, scale=0.1) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = False + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_offset, + mode="reset", + params={ + "position_range": (-0.2, 0.2), + "velocity_range": (-0.1, 0.1), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # (1) Reward for moving forward + progress = RewTerm(func=mdp.progress_reward, weight=1.0, params={"target_pos": (1000.0, 0.0, 0.0)}) + # (2) Stay alive bonus + alive = RewTerm(func=mdp.is_alive, weight=2.0) + # (3) Reward for non-upright posture + upright = RewTerm(func=mdp.upright_posture_bonus, weight=0.1, params={"threshold": 0.93}) + # (4) Reward for moving in the right direction + move_to_target = RewTerm( + func=mdp.move_to_target_bonus, weight=0.5, params={"threshold": 0.8, "target_pos": (1000.0, 0.0, 0.0)} + ) + # (5) Penalty for large action commands + action_l2 = RewTerm(func=mdp.action_l2, weight=-0.01) + # (6) Penalty for energy consumption + energy = RewTerm( + func=mdp.power_consumption, + weight=-0.005, + params={ + "gear_ratio": { + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + } + }, + ) + # (7) Penalty for reaching close to joint limits + joint_pos_limits = RewTerm( + func=mdp.joint_pos_limits_penalty_ratio, + weight=-0.25, + params={ + "threshold": 0.98, + "gear_ratio": { + ".*_waist.*": 67.5, + ".*_upper_arm.*": 67.5, + "pelvis": 67.5, + ".*_lower_arm": 45.0, + ".*_thigh:0": 45.0, + ".*_thigh:1": 135.0, + ".*_thigh:2": 45.0, + ".*_shin": 90.0, + ".*_foot.*": 22.5, + }, + }, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + # (1) Terminate if the episode length is exceeded + time_out = DoneTerm(func=mdp.time_out, time_out=True) + # (2) Terminate if the robot falls + torso_height = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.8}) + + +@configclass +class HumanoidEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the MuJoCo-style Humanoid walking environment.""" + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=5.0, clone_in_fabric=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.episode_length_s = 16.0 + # simulation settings + self.sim: SimulationCfg = SimulationCfg( + dt=1 / 120.0, + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=80, + nconmax=25, + ls_iterations=15, + cone="pyramidal", + ls_parallel=True, + update_data_interval=2, + impratio=1, + integrator="implicitfast", + ), + num_substeps=2, + debug_mode=False, + ), + ) + # self.sim.dt = 1 / 120.0 + self.sim.render_interval = self.decimation + # default friction material + self.sim.physics_material.static_friction = 1.0 + self.sim.physics_material.dynamic_friction = 1.0 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py new file mode 100644 index 00000000000..df0802edf05 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the humanoid environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .observations import * # noqa: F401, F403 +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py new file mode 100644 index 00000000000..0210c407dc5 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/observations.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first observation terms for the humanoid task. + +All observation functions follow the ``func(env, out, **params) -> None`` signature. +Dimensions are declared via ``out_dim`` on the ``@generic_io_descriptor_warp`` decorator. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp +from isaaclab_experimental.envs.utils.io_descriptors import generic_io_descriptor_warp +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_newton.kernels.state_kernels import rotate_vec_to_body_frame + +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedEnv + + +@wp.kernel +def _base_yaw_roll_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + out: wp.array(dtype=wp.float32, ndim=2), +): + """Extract yaw and roll angles from root quaternion (x, y, z, w layout).""" + i = wp.tid() + q = root_quat_w[i] + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + # roll = atan2(2*(qw*qx + qy*qz), 1 - 2*(qx^2 + qy^2)) + sin_roll = 2.0 * (qw * qx + qy * qz) + cos_roll = 1.0 - 2.0 * (qx * qx + qy * qy) + roll = wp.atan2(sin_roll, cos_roll) + # yaw = atan2(2*(qw*qz + qx*qy), 1 - 2*(qy^2 + qz^2)) + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw = wp.atan2(sin_yaw, cos_yaw) + out[i, 0] = yaw + out[i, 1] = roll + + +@generic_io_descriptor_warp(out_dim=2, observation_type="RootState") +def base_yaw_roll(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Yaw and roll of the base in the simulation world frame. Shape: (num_envs, 2).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_yaw_roll_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, out], + device=env.device, + ) + + +# Inline Tier 1 access: derives projected gravity directly from root_link_pose_w, +# avoiding the lazy TimestampedWarpBuffer which is not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. + + +@wp.kernel +def _base_up_proj_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.array(dtype=wp.vec3f), + out: wp.array(dtype=wp.float32, ndim=2), +): + """Project base up vector onto world up: -gravity_b[2].""" + i = wp.tid() + out[i, 0] = -rotate_vec_to_body_frame(gravity_w[0], root_pose_w[i])[2] + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_up_proj(env: ManagerBasedEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Projection of the base up vector onto the world up vector. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_up_proj_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, out], + device=env.device, + ) + + +@wp.kernel +def _base_heading_proj_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + target_z: float, + out: wp.array(dtype=wp.float32, ndim=2), +): + """Dot product between robot forward and direction to target.""" + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # compute direction to target (zeroed z) + dx = target_x - pos[0] + dy = target_y - pos[1] + dist = wp.sqrt(dx * dx + dy * dy) + # avoid division by zero + inv_dist = wp.where(dist > 1.0e-6, 1.0 / dist, 0.0) + to_target_x = dx * inv_dist + to_target_y = dy * inv_dist + # compute forward vector via quaternion rotation of (1,0,0) + fwd = wp.quat_rotate(q, wp.vec3f(1.0, 0.0, 0.0)) + # dot product (xy only) + heading_proj = fwd[0] * to_target_x + fwd[1] * to_target_y + out[i, 0] = heading_proj + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_heading_proj( + env: ManagerBasedEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Dot product between the base forward direction and direction to target. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_heading_proj_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], target_pos[2], out], + device=env.device, + ) + + +@wp.kernel +def _base_angle_to_target_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + out: wp.array(dtype=wp.float32, ndim=2), +): + """Angle between base forward and vector to target, normalized to [-pi, pi].""" + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # angle to target in world frame + dx = target_x - pos[0] + dy = target_y - pos[1] + walk_target_angle = wp.atan2(dy, dx) + # extract yaw from quaternion + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw = wp.atan2(sin_yaw, cos_yaw) + # normalize to [-pi, pi] + angle = walk_target_angle - yaw + out[i, 0] = wp.atan2(wp.sin(angle), wp.cos(angle)) + + +@generic_io_descriptor_warp(out_dim=1, observation_type="RootState") +def base_angle_to_target( + env: ManagerBasedEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Angle between the base forward vector and the vector to the target. Shape: (num_envs, 1).""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_base_angle_to_target_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py new file mode 100644 index 00000000000..e45ca6cac4d --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/classic/humanoid/mdp/rewards.py @@ -0,0 +1,314 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward terms for the humanoid task. + +All reward functions follow the ``func(env, out, **params) -> None`` signature +where ``out`` is a pre-allocated Warp array of shape ``(num_envs,)`` with float32 dtype. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import warp as wp +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers.manager_base import ManagerTermBase +from isaaclab_newton.kernels.state_kernels import rotate_vec_to_body_frame + +import isaaclab.utils.string as string_utils +from isaaclab.assets import Articulation + +if TYPE_CHECKING: + from isaaclab_experimental.managers.manager_term_cfg import RewardTermCfg + + from isaaclab.envs import ManagerBasedRLEnv + + +# --------------------------------------------------------------------------- +# Function-based reward terms +# --------------------------------------------------------------------------- + + +# Inline Tier 1 access: derives projected gravity directly from root_link_pose_w, +# avoiding the lazy TimestampedWarpBuffer which is not CUDA-graph-capturable. +# See GRAPH_CAPTURE_MIGRATION.md in isaaclab_newton for background. +# If ArticulationData Tier 2 lazy update is made graph-safe, this can revert to +# reading asset.data.projected_gravity_b directly. + + +@wp.kernel +def _upright_posture_bonus_kernel( + root_pose_w: wp.array(dtype=wp.transformf), + gravity_w: wp.array(dtype=wp.vec3f), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + up_proj = -rotate_vec_to_body_frame(gravity_w[0], root_pose_w[i])[2] + out[i] = wp.where(up_proj > threshold, 1.0, 0.0) + + +def upright_posture_bonus( + env: ManagerBasedRLEnv, out, threshold: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward for maintaining an upright posture. Writes 1.0 if up_proj > threshold, else 0.0.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_upright_posture_bonus_kernel, + dim=env.num_envs, + inputs=[asset.data.root_link_pose_w, asset.data.GRAVITY_VEC_W, threshold, out], + device=env.device, + ) + + +@wp.kernel +def _move_to_target_bonus_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + target_x: float, + target_y: float, + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + pos = root_pos_w[i] + q = root_quat_w[i] + # direction to target + dx = target_x - pos[0] + dy = target_y - pos[1] + dist = wp.sqrt(dx * dx + dy * dy) + inv_dist = wp.where(dist > 1.0e-6, 1.0 / dist, 0.0) + to_target_x = dx * inv_dist + to_target_y = dy * inv_dist + # forward vector + fwd = wp.quat_rotate(q, wp.vec3f(1.0, 0.0, 0.0)) + heading_proj = fwd[0] * to_target_x + fwd[1] * to_target_y + out[i] = wp.where(heading_proj > threshold, 1.0, heading_proj / threshold) + + +def move_to_target_bonus( + env: ManagerBasedRLEnv, + out, + threshold: float, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), +) -> None: + """Reward for heading towards the target.""" + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_move_to_target_bonus_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, asset.data.root_quat_w, target_pos[0], target_pos[1], threshold, out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# Class-based reward terms +# --------------------------------------------------------------------------- + + +@wp.kernel +def _progress_reward_reset_kernel( + env_mask: wp.array(dtype=wp.bool), + root_pos_w: wp.array(dtype=wp.vec3f), + target_x: float, + target_y: float, + target_z: float, + inv_step_dt: float, + potentials: wp.array(dtype=wp.float32), +): + i = wp.tid() + if env_mask[i]: + pos = root_pos_w[i] + dx = target_x - pos[0] + dy = target_y - pos[1] + dz = target_z - pos[2] + dist = wp.sqrt(dx * dx + dy * dy + dz * dz) + potentials[i] = -dist * inv_step_dt + + +@wp.kernel +def _progress_reward_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + target_x: float, + target_y: float, + inv_step_dt: float, + potentials: wp.array(dtype=wp.float32), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + pos = root_pos_w[i] + dx = target_x - pos[0] + dy = target_y - pos[1] + # z component is zeroed (xy distance only, matching stable) + dist = wp.sqrt(dx * dx + dy * dy) + prev = potentials[i] + pot = -dist * inv_step_dt + potentials[i] = pot + out[i] = pot - prev + + +class progress_reward(ManagerTermBase): + """Reward for making progress towards the target (potential-based).""" + + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): + super().__init__(cfg, env) + self.potentials = wp.zeros(env.num_envs, dtype=wp.float32, device=env.device) + self._target_pos = cfg.params["target_pos"] + + def reset(self, env_mask: wp.array | None = None) -> None: + if env_mask is None: + self.potentials.zero_() + return + asset: Articulation = self._env.scene["robot"] + inv_dt = 1.0 / self._env.step_dt + wp.launch( + kernel=_progress_reward_reset_kernel, + dim=self.num_envs, + inputs=[ + env_mask, + asset.data.root_pos_w, + self._target_pos[0], + self._target_pos[1], + self._target_pos[2], + inv_dt, + self.potentials, + ], + device=self.device, + ) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + target_pos: tuple[float, float, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + inv_dt = 1.0 / env.step_dt + wp.launch( + kernel=_progress_reward_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, target_pos[0], target_pos[1], inv_dt, self.potentials, out], + device=env.device, + ) + + +@wp.kernel +def _joint_pos_limits_penalty_ratio_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + soft_limits: wp.array(dtype=wp.vec2f, ndim=2), + gear_ratio_scaled: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + inv_range: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = joint_pos.shape[1] + s = float(0.0) + for j in range(n_joints): + lim = soft_limits[i, j] + lower = lim.x + upper = lim.y + mid = (lower + upper) * 0.5 + half_range = (upper - lower) * 0.5 + scaled = float(0.0) + if half_range > 0.0: + scaled = (joint_pos[i, j] - mid) / half_range + abs_scaled = wp.abs(scaled) + if abs_scaled > threshold: + violation = (abs_scaled - threshold) * inv_range + s += violation * gear_ratio_scaled[i, j] + out[i] = s + + +class joint_pos_limits_penalty_ratio(ManagerTermBase): + """Penalty for violating joint position limits weighted by the gear ratio.""" + + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): + asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot")) + asset: Articulation = env.scene[asset_cfg.name] + + # resolve the gear ratio for each joint (torch in __init__ is fine) + gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) + index_list, _, value_list = string_utils.resolve_matching_names_values( + cfg.params["gear_ratio"], asset.joint_names + ) + gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device) + gear_ratio_scaled = gear_ratio / torch.max(gear_ratio) + self._gear_ratio_scaled_wp = wp.from_torch(gear_ratio_scaled) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + threshold: float, + gear_ratio: dict[str, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_joint_pos_limits_penalty_ratio_kernel, + dim=env.num_envs, + inputs=[ + asset.data.joint_pos, + asset.data.soft_joint_pos_limits, + self._gear_ratio_scaled_wp, + threshold, + 1.0 / (1.0 - threshold), + out, + ], + device=env.device, + ) + + +@wp.kernel +def _power_consumption_kernel( + action: wp.array(dtype=wp.float32, ndim=2), + joint_vel: wp.array(dtype=wp.float32, ndim=2), + gear_ratio_scaled: wp.array(dtype=wp.float32, ndim=2), + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = action.shape[1] + s = float(0.0) + for j in range(n_joints): + s += wp.abs(action[i, j] * joint_vel[i, j] * gear_ratio_scaled[i, j]) + out[i] = s + + +class power_consumption(ManagerTermBase): + """Penalty for the power consumed by the actions to the environment.""" + + def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv): + asset_cfg = cfg.params.get("asset_cfg", SceneEntityCfg("robot")) + asset: Articulation = env.scene[asset_cfg.name] + + # resolve the gear ratio for each joint (torch in __init__ is fine) + gear_ratio = torch.ones(env.num_envs, asset.num_joints, device=env.device) + index_list, _, value_list = string_utils.resolve_matching_names_values( + cfg.params["gear_ratio"], asset.joint_names + ) + gear_ratio[:, index_list] = torch.tensor(value_list, device=env.device) + gear_ratio_scaled = gear_ratio / torch.max(gear_ratio) + self._gear_ratio_scaled_wp = wp.from_torch(gear_ratio_scaled) + + def __call__( + self, + env: ManagerBasedRLEnv, + out, + gear_ratio: dict[str, float], + asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), + ) -> None: + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_power_consumption_kernel, + dim=env.num_envs, + inputs=[env.action_manager.action, asset.data.joint_vel, self._gear_ratio_scaled_wp, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py new file mode 100644 index 00000000000..0660d38f065 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Locomotion experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py new file mode 100644 index 00000000000..0857176a3fc --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Velocity locomotion experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py new file mode 100644 index 00000000000..26f3257daef --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Configurations for velocity-based locomotion environments.""" + +# We leave this file empty since we don't want to expose any configs in this package directly. +# We still need this file to import the "config" module in the parent package. diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py new file mode 100644 index 00000000000..6c79524e853 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.a1 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-A1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-A1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeA1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeA1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py new file mode 100644 index 00000000000..b27f8098d62 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/flat_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeA1RoughEnvCfg + + +@configclass +class UnitreeA1FlatEnvCfg(UnitreeA1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=60, + nconmax=30, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no height scan + # self.scene.height_scanner = None + # self.observations.policy.height_scan = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class UnitreeA1FlatEnvCfg_PLAY(UnitreeA1FlatEnvCfg): + def __post_init__(self) -> None: + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py new file mode 100644 index 00000000000..03aec88a401 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/a1/rough_env_cfg.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm + +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + TerminationsCfg, +) + +from isaaclab_assets.robots.unitree import UNITREE_A1_CFG # isort: skip + + +class TerminationsCfg_A1(TerminationsCfg): + base_too_low = DoneTerm(func=mdp.root_height_below_minimum, params={"minimum_height": 0.2}) + + +@configclass +class UnitreeA1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + terminations: TerminationsCfg_A1 = TerminationsCfg_A1() + + def __post_init__(self): + # post init of parent + super().__post_init__() + + self.scene.robot = UNITREE_A1_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + + # reduce action scale + self.actions.joint_pos.scale = 0.25 + + # event + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass.params["mass_distribution_params"] = (-1.0, 3.0) + # self.events.add_base_mass.params["asset_cfg"].body_names = "trunk" + self.events.base_external_force_torque.params["asset_cfg"].body_names = "trunk" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + + # rewards + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts.params["sensor_cfg"].body_names = ".*thigh" + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "trunk" + + +@configclass +class UnitreeA1RoughEnvCfg_PLAY(UnitreeA1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py new file mode 100644 index 00000000000..cbbf5290e82 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_b import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Anymal-B-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalBFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerWithSymmetryCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-B-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalBFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalBFlatPPORunnerWithSymmetryCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py new file mode 100644 index 00000000000..28c0dc5c26b --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/flat_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalBRoughEnvCfg + + +@configclass +class AnymalBFlatEnvCfg(AnymalBRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=75, + nconmax=15, + cone="elliptic", + impratio=100, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no height scan + self.scene.height_scanner = None + self.observations.policy.height_scan = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class AnymalBFlatEnvCfg_PLAY(AnymalBFlatEnvCfg): + def __post_init__(self) -> None: + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing event + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py new file mode 100644 index 00000000000..9811356ef22 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_b/rough_env_cfg.py @@ -0,0 +1,34 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab_assets import ANYMAL_B_CFG # isort: skip + + +@configclass +class AnymalBRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = ANYMAL_B_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalBRoughEnvCfg_PLAY(AnymalBRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py new file mode 100644 index 00000000000..318b13cc470 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_c import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Anymal-C-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalCFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerWithSymmetryCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_flat_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-C-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalCFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerCfg", + "rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalCFlatPPORunnerWithSymmetryCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_flat_ppo_cfg.yaml", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py new file mode 100644 index 00000000000..e82cf9559d6 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/flat_env_cfg.py @@ -0,0 +1,52 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalCRoughEnvCfg + + +@configclass +class AnymalCFlatEnvCfg(AnymalCRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=120, + nconmax=15, + cone="elliptic", + impratio=100, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # override rewards + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + # change terrain to flat + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + # no terrain curriculum + self.curriculum.terrain_levels = None + + +class AnymalCFlatEnvCfg_PLAY(AnymalCFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py new file mode 100644 index 00000000000..36af75613c0 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_c/rough_env_cfg.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +## +# Pre-defined configs +## +from isaaclab_assets.robots.anymal import ANYMAL_C_CFG # isort: skip + + +@configclass +class AnymalCRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # switch robot to anymal-c + self.scene.robot = ANYMAL_C_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.scene.robot.actuators["legs"].armature = 0.01 + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalCRoughEnvCfg_PLAY(AnymalCRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py new file mode 100644 index 00000000000..e5e75d19dc2 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/__init__.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.anymal_d import agents + +## +# Register Gym environments. +## + +# Rough env disabled: requires isaaclab_physx which is not yet available on dev/newton. +# The package exists on upstream/develop (commit 308400f1d35) but has not been merged. +# Re-enable once dev/newton picks up isaaclab_physx. +# gym.register( +# id="Isaac-Velocity-Rough-Anymal-D-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:AnymalDRoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDRoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Velocity-Rough-Anymal-D-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:AnymalDRoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDRoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-D-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Anymal-D-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:AnymalDFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:AnymalDFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py new file mode 100644 index 00000000000..1f98c1b5612 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/flat_env_cfg.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import AnymalDRoughEnvCfg + + +@configclass +class AnymalDFlatEnvCfg(AnymalDRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=60, + nconmax=25, + cone="elliptic", + impratio=100.0, + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -5.0 + self.rewards.dof_torques_l2.weight = -2.5e-5 + self.rewards.feet_air_time.weight = 0.5 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class AnymalDFlatEnvCfg_PLAY(AnymalDFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py new file mode 100644 index 00000000000..1cafa006ee6 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/anymal_d/rough_env_cfg.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +## +# Pre-defined configs +## +from isaaclab_assets.robots.anymal import ANYMAL_D_CFG # isort: skip + + +@configclass +class AnymalDRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = ANYMAL_D_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + + +@configclass +class AnymalDRoughEnvCfg_PLAY(AnymalDRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py new file mode 100644 index 00000000000..4d9d4a77883 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.cassie import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Cassie-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:CassieFlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CassieFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Cassie-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:CassieFlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CassieFlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py new file mode 100644 index 00000000000..bcc670cf378 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/flat_env_cfg.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import CassieRoughEnvCfg + + +@configclass +class CassieFlatEnvCfg(CassieRoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=52, + nconmax=15, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 5.0 + self.rewards.joint_deviation_hip.params["asset_cfg"].joint_names = ["hip_rotation_.*"] + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class CassieFlatEnvCfg_PLAY(CassieFlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py new file mode 100644 index 00000000000..43341babeb3 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/cassie/rough_env_cfg.py @@ -0,0 +1,96 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +from isaaclab_assets.robots.cassie import CASSIE_CFG # isort: skip + + +@configclass +class CassieRewardsCfg(RewardsCfg): + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=2.5, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*toe"), + "command_name": "base_velocity", + "threshold": 0.3, + }, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["hip_abduction_.*", "hip_rotation_.*"])}, + ) + joint_deviation_toes = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=["toe_joint_.*"])}, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="toe_joint_.*")}, + ) + + +@configclass +class CassieRoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: CassieRewardsCfg = CassieRewardsCfg() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = CASSIE_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.actions.joint_pos.scale = 0.5 + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = [".*pelvis"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.terminations.base_contact.params["sensor_cfg"].body_names = [".*pelvis"] + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -5.0e-6 + self.rewards.track_lin_vel_xy_exp.weight = 2.0 + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.action_rate_l2.weight *= 1.5 + self.rewards.dof_acc_l2.weight *= 1.5 + + +@configclass +class CassieRoughEnvCfg_PLAY(CassieRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py new file mode 100644 index 00000000000..83bdf047a48 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.g1 import agents + +## +# Register Gym environments. +## + +# gym.register( +# id="Isaac-Velocity-Rough-G1-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:G1RoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + + +# gym.register( +# id="Isaac-Velocity-Rough-G1-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:G1RoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py new file mode 100644 index 00000000000..3d9047a9850 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/flat_env_cfg.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import G1RoughEnvCfg + + +@configclass +class G1FlatEnvCfg(G1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=95, + nconmax=10, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.lin_vel_z_l2.weight = -0.2 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.0e-7 + self.rewards.feet_air_time.weight = 0.75 + self.rewards.feet_air_time.params["threshold"] = 0.4 + self.rewards.dof_torques_l2.weight = -2.0e-6 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.5, 0.5) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + +class G1FlatEnvCfg_PLAY(G1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py new file mode 100644 index 00000000000..db4cc159e4c --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1/rough_env_cfg.py @@ -0,0 +1,178 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +## +# Pre-defined configs +## +from isaaclab_assets import G1_MINIMAL_CFG # isort: skip + + +@configclass +class G1Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.1, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])}, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_shoulder_pitch_joint", + ".*_shoulder_roll_joint", + ".*_shoulder_yaw_joint", + ".*_elbow_pitch_joint", + ".*_elbow_roll_joint", + ], + ) + }, + ) + joint_deviation_fingers = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.05, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_five_joint", + ".*_three_joint", + ".*_six_joint", + ".*_four_joint", + ".*_zero_joint", + ".*_one_joint", + ".*_two_joint", + ], + ) + }, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso_joint")}, + ) + + +@configclass +class G1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: G1Rewards = G1Rewards() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = G1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + # TODO: TEMPORARILY DISABLED - adding this causes NaNs in the simulation + # self.events.add_base_mass = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + + # Rewards + self.rewards.lin_vel_z_l2.weight = 0.0 + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.rewards.dof_torques_l2.weight = -1.5e-7 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"] + ) + + # Commands + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + # terminations + self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link" + + +@configclass +class G1RoughEnvCfg_PLAY(G1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.episode_length_s = 40.0 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + self.commands.base_velocity.ranges.lin_vel_x = (1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.commands.base_velocity.ranges.heading = (0.0, 0.0) + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py new file mode 100644 index 00000000000..a0ae516387a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.g1_29_dofs import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-v1", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1_29_DOFs_FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1_29_DOFs_FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-G1-Warp-Play-v1", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:G1_29_DOFs_FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:G1_29_DOFs_FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py new file mode 100644 index 00000000000..d02b2a35e75 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/flat_env_cfg.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import G1_29_DOFs_RoughEnvCfg + + +@configclass +class G1_29_DOFs_FlatEnvCfg(G1_29_DOFs_RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=210, + nconmax=35, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.track_ang_vel_z_exp.weight = 1.0 + self.rewards.lin_vel_z_l2.weight = -0.2 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.0e-7 + self.rewards.feet_air_time.weight = 0.75 + self.rewards.feet_air_time.params["threshold"] = 0.4 + self.rewards.dof_torques_l2.weight = -2.0e-6 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (-1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.5, 0.5) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + + +class G1_29_DOFs_FlatEnvCfg_PLAY(G1_29_DOFs_FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py new file mode 100644 index 00000000000..e1ca70d7e05 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/g1_29_dofs/rough_env_cfg.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +## +# Pre-defined configs +## +from isaaclab_assets import G1_29DOF_CFG # isort: skip + + +@configclass +class G1_29_DOFs_Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=2.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.1, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*_ankle_roll_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*_ankle_roll_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, + weight=-1.0, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_ankle_pitch_joint", ".*_ankle_roll_joint"])}, + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw_joint", ".*_hip_roll_joint"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.1, + params={ + "asset_cfg": SceneEntityCfg( + "robot", + joint_names=[ + ".*_shoulder_pitch_joint", + ".*_shoulder_roll_joint", + ".*_shoulder_yaw_joint", + ".*.*_elbow_joint", + ".*_wrist_.*_joint", + ], + ) + }, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names="waist_.*_joint")}, + ) + + +@configclass +class G1_29_DOFs_RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: G1_29_DOFs_Rewards = G1_29_DOFs_Rewards() + observed_joint_names: list[str] = ["waist.*", ".*_hip.*", ".*_knee.*", ".*_ankle.*"] + + def __post_init__(self): + super().__post_init__() + self.scene.robot = G1_29DOF_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = ["torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.lin_vel_z_l2.weight = 0.0 + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.rewards.dof_acc_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint"] + ) + self.rewards.dof_torques_l2.weight = -1.5e-7 + self.rewards.dof_torques_l2.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=[".*_hip_.*", ".*_knee_joint", ".*_ankle_.*"] + ) + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (-0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.terminations.base_contact.params["sensor_cfg"].body_names = "torso_link" + self.actions.joint_pos.joint_names = self.observed_joint_names + self.observations.policy.joint_pos.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=self.observed_joint_names + ) + self.observations.policy.joint_vel.params["asset_cfg"] = SceneEntityCfg( + "robot", joint_names=self.observed_joint_names + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py new file mode 100644 index 00000000000..038f5574072 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.go1 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py new file mode 100644 index 00000000000..e4fbc73e1d0 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/flat_env_cfg.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeGo1RoughEnvCfg + + +@configclass +class UnitreeGo1FlatEnvCfg(UnitreeGo1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=60, + nconmax=25, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class UnitreeGo1FlatEnvCfg_PLAY(UnitreeGo1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py new file mode 100644 index 00000000000..2864bbba313 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go1/rough_env_cfg.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab_assets.robots.unitree import UNITREE_GO1_CFG # isort: skip + + +@configclass +class UnitreeGo1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = UNITREE_GO1_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.manager_call_max_mode = {"Scene": 1} + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + self.actions.joint_pos.scale = 0.25 + self.events.push_robot = None + self.events.base_external_force_torque.params["asset_cfg"].body_names = "trunk" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "trunk" + + +@configclass +class UnitreeGo1RoughEnvCfg_PLAY(UnitreeGo1RoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py new file mode 100644 index 00000000000..7e124029c68 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/__init__.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.go2 import agents + +## +# Register Gym environments. +## + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go2-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo2FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo2FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-Unitree-Go2-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:UnitreeGo2FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UnitreeGo2FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py new file mode 100644 index 00000000000..ad8a8aa862f --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/flat_env_cfg.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import UnitreeGo2RoughEnvCfg + + +@configclass +class UnitreeGo2FlatEnvCfg(UnitreeGo2RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=65, + nconmax=35, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.rewards.flat_orientation_l2.weight = -2.5 + self.rewards.feet_air_time.weight = 0.25 + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + + +class UnitreeGo2FlatEnvCfg_PLAY(UnitreeGo2FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py new file mode 100644 index 00000000000..ff13b7e8617 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/go2/rough_env_cfg.py @@ -0,0 +1,60 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab.utils import configclass + +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import LocomotionVelocityRoughEnvCfg + +from isaaclab_assets.robots.unitree import UNITREE_GO2_CFG # isort: skip + + +@configclass +class UnitreeGo2RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.robot = UNITREE_GO2_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.scene.terrain.terrain_generator.sub_terrains["boxes"].grid_height_range = (0.025, 0.1) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_range = (0.01, 0.06) + self.scene.terrain.terrain_generator.sub_terrains["random_rough"].noise_step = 0.01 + self.actions.joint_pos.scale = 0.25 + self.events.push_robot = None + self.events.base_external_force_torque.params["asset_cfg"].body_names = "base" + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.feet_air_time.params["sensor_cfg"].body_names = ".*_foot" + self.rewards.feet_air_time.weight = 0.01 + self.rewards.undesired_contacts = None + self.rewards.dof_torques_l2.weight = -0.0002 + self.rewards.track_lin_vel_xy_exp.weight = 1.5 + self.rewards.track_ang_vel_z_exp.weight = 0.75 + self.rewards.dof_acc_l2.weight = -2.5e-7 + self.terminations.base_contact.params["sensor_cfg"].body_names = "base" + + +@configclass +class UnitreeGo2RoughEnvCfg_PLAY(UnitreeGo2RoughEnvCfg): + def __post_init__(self): + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.scene.terrain.max_init_terrain_level = None + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py new file mode 100644 index 00000000000..95a1e8f29e3 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.locomotion.velocity.config.h1 import agents + +## +# Register Gym environments. +## + +# gym.register( +# id="Isaac-Velocity-Rough-H1-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:H1RoughEnvCfg", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Velocity-Rough-H1-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.rough_env_cfg:H1RoughEnvCfg_PLAY", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1RoughPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_rough_ppo_cfg.yaml", +# }, +# ) + +gym.register( + id="Isaac-Velocity-Flat-H1-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:H1FlatEnvCfg", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Velocity-Flat-H1-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.flat_env_cfg:H1FlatEnvCfg_PLAY", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:H1FlatPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_flat_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py new file mode 100644 index 00000000000..22648c27a2c --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/flat_env_cfg.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +from .rough_env_cfg import H1RoughEnvCfg + + +@configclass +class H1FlatEnvCfg(H1RoughEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=65, + nconmax=15, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + super().__post_init__() + self.scene.terrain.terrain_type = "plane" + self.scene.terrain.terrain_generator = None + self.curriculum.terrain_levels = None + self.rewards.feet_air_time.weight = 1.0 + self.rewards.feet_air_time.params["threshold"] = 0.6 + + +class H1FlatEnvCfg_PLAY(H1FlatEnvCfg): + def __post_init__(self) -> None: + super().__post_init__() + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.observations.policy.enable_corruption = False + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py new file mode 100644 index 00000000000..edb71956a3f --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/config/h1/rough_env_cfg.py @@ -0,0 +1,131 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg + +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp +from isaaclab_tasks_experimental.manager_based.locomotion.velocity.velocity_env_cfg import ( + LocomotionVelocityRoughEnvCfg, + RewardsCfg, +) + +## +# Pre-defined configs +## +from isaaclab_assets import H1_MINIMAL_CFG # isort: skip + + +@configclass +class H1Rewards(RewardsCfg): + """Reward terms for the MDP.""" + + termination_penalty = RewTerm(func=mdp.is_terminated, weight=-200.0) + lin_vel_z_l2 = None + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_yaw_frame_exp, + weight=1.0, + params={"command_name": "base_velocity", "std": 0.5}, + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_world_exp, weight=1.0, params={"command_name": "base_velocity", "std": 0.5} + ) + feet_air_time = RewTerm( + func=mdp.feet_air_time_positive_biped, + weight=0.25, + params={ + "command_name": "base_velocity", + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*ankle_link"), + "threshold": 0.4, + }, + ) + feet_slide = RewTerm( + func=mdp.feet_slide, + weight=-0.25, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*ankle_link"), + "asset_cfg": SceneEntityCfg("robot", body_names=".*ankle_link"), + }, + ) + dof_pos_limits = RewTerm( + func=mdp.joint_pos_limits, weight=-1.0, params={"asset_cfg": SceneEntityCfg("robot", joint_names=".*_ankle")} + ) + joint_deviation_hip = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_hip_yaw", ".*_hip_roll"])}, + ) + joint_deviation_arms = RewTerm( + func=mdp.joint_deviation_l1, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", joint_names=[".*_shoulder_.*", ".*_elbow"])}, + ) + joint_deviation_torso = RewTerm( + func=mdp.joint_deviation_l1, weight=-0.1, params={"asset_cfg": SceneEntityCfg("robot", joint_names="torso")} + ) + + +@configclass +class H1RoughEnvCfg(LocomotionVelocityRoughEnvCfg): + rewards: H1Rewards = H1Rewards() + + def __post_init__(self): + super().__post_init__() + self.scene.robot = H1_MINIMAL_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + self.events.push_robot = None + self.events.reset_robot_joints.params["position_range"] = (1.0, 1.0) + self.events.base_external_force_torque.params["asset_cfg"].body_names = [".*torso_link"] + self.events.reset_base.params = { + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + }, + } + self.events.base_com = None + self.rewards.undesired_contacts = None + self.rewards.flat_orientation_l2.weight = -1.0 + self.rewards.dof_torques_l2.weight = 0.0 + self.rewards.action_rate_l2.weight = -0.005 + self.rewards.dof_acc_l2.weight = -1.25e-7 + self.commands.base_velocity.ranges.lin_vel_x = (0.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.terminations.base_contact.params["sensor_cfg"].body_names = ".*torso_link" + + +@configclass +class H1RoughEnvCfg_PLAY(H1RoughEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + self.episode_length_s = 40.0 + # spawn the robot randomly in the grid (instead of their terrain levels) + self.scene.terrain.max_init_terrain_level = None + # reduce the number of terrains to save memory + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.num_rows = 5 + self.scene.terrain.terrain_generator.num_cols = 5 + self.scene.terrain.terrain_generator.curriculum = False + + self.commands.base_velocity.ranges.lin_vel_x = (1.0, 1.0) + self.commands.base_velocity.ranges.lin_vel_y = (0.0, 0.0) + self.commands.base_velocity.ranges.ang_vel_z = (-1.0, 1.0) + self.commands.base_velocity.ranges.heading = (0.0, 0.0) + # disable randomization for play + self.observations.policy.enable_corruption = False + # remove random pushing + self.events.base_external_force_torque = None + self.events.push_robot = None diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py new file mode 100644 index 00000000000..cdc532db425 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the locomotion environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .curriculums import * # noqa: F401, F403 +from .rewards import * # noqa: F401, F403 +from .terminations import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py new file mode 100644 index 00000000000..1ed0a4c6f33 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/curriculums.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Curriculum functions for the velocity locomotion environment. + +Curriculum terms are not warp-managed (they run at reset time, not per-step), +so they remain torch-based. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +import torch +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg +from isaaclab.terrains import TerrainImporter + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +def terrain_levels_vel( + env: ManagerBasedRLEnv, env_ids: Sequence[int], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> torch.Tensor: + """Curriculum based on the distance the robot walked when commanded to move at a desired velocity.""" + asset: Articulation = env.scene[asset_cfg.name] + terrain: TerrainImporter = env.scene.terrain + command = env.command_manager.get_command("base_velocity") + distance = torch.norm(wp.to_torch(asset.data.root_pos_w)[env_ids, :2] - env.scene.env_origins[env_ids, :2], dim=1) + move_up = distance > terrain.cfg.terrain_generator.size[0] / 2 + move_down = distance < torch.norm(command[env_ids, :2], dim=1) * env.max_episode_length_s * 0.5 + move_down *= ~move_up + terrain.update_env_origins(env_ids, move_up, move_down) + return torch.mean(terrain.terrain_levels.float()) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py new file mode 100644 index 00000000000..31baccee1e7 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/rewards.py @@ -0,0 +1,307 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward functions for the velocity locomotion environment. + +All functions follow the ``func(env, out, **params) -> None`` signature. +Cross-manager torch tensors (contact sensor, commands) are cached as zero-copy +warp views on first call via ``wp.from_torch``. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.managers import SceneEntityCfg +from isaaclab.sensors import ContactSensor + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + +# --------------------------------------------------------------------------- +# feet_air_time +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_air_time_kernel( + last_air_time: wp.array(dtype=wp.float32, ndim=2), + first_contact: wp.array(dtype=wp.float32, ndim=2), + body_ids: wp.array(dtype=wp.int32), + cmd_xy: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + s += (last_air_time[i, b] - threshold) * first_contact[i, b] + # gate by command magnitude + cx = cmd_xy[i, 0] + cy = cmd_xy[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm > 0.1, s, 0.0) + + +def feet_air_time(env: ManagerBasedRLEnv, out, command_name: str, sensor_cfg: SceneEntityCfg, threshold: float) -> None: + """Reward long steps taken by the feet using L2-kernel.""" + fn = feet_air_time + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + # Cache command bridge (persistent pointer) + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + # Newton contact sensor returns persistent wp.arrays — use directly, no wp.from_torch needed + first_contact = contact_sensor.compute_first_contact(env.step_dt) + wp.launch( + kernel=_feet_air_time_kernel, + dim=env.num_envs, + inputs=[contact_sensor.data.last_air_time, first_contact, sensor_cfg.body_ids_wp, fn._cmd_wp, threshold, out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# feet_air_time_positive_biped +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_air_time_positive_biped_kernel( + air_time: wp.array(dtype=wp.float32, ndim=2), + contact_time: wp.array(dtype=wp.float32, ndim=2), + body_ids: wp.array(dtype=wp.int32), + cmd_xy: wp.array(dtype=wp.float32, ndim=2), + threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_feet = body_ids.shape[0] + # count feet in contact and find single-stance min mode time + n_contact = int(0) + for k in range(n_feet): + b = body_ids[k] + if contact_time[i, b] > 0.0: + n_contact += 1 + single_stance = n_contact == 1 + min_val = threshold # clamp upper bound + for k in range(n_feet): + b = body_ids[k] + in_contact = contact_time[i, b] > 0.0 + mode_time = wp.where(in_contact, contact_time[i, b], air_time[i, b]) + val = wp.where(single_stance, mode_time, 0.0) + min_val = wp.min(min_val, val) + # gate by command magnitude + cx = cmd_xy[i, 0] + cy = cmd_xy[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm > 0.1, min_val, 0.0) + + +def feet_air_time_positive_biped(env, out, command_name: str, threshold: float, sensor_cfg: SceneEntityCfg) -> None: + """Reward long steps taken by the feet for bipeds.""" + fn = feet_air_time_positive_biped + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_feet_air_time_positive_biped_kernel, + dim=env.num_envs, + inputs=[ + contact_sensor.data.current_air_time, + contact_sensor.data.current_contact_time, + sensor_cfg.body_ids_wp, + fn._cmd_wp, + threshold, + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# feet_slide +# --------------------------------------------------------------------------- + + +@wp.kernel +def _feet_slide_kernel( + body_lin_vel_w: wp.array(dtype=wp.vec3f, ndim=2), + net_forces_w: wp.array(dtype=wp.vec3f, ndim=3), + body_ids: wp.array(dtype=wp.int32), + n_history: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + s = float(0.0) + for k in range(body_ids.shape[0]): + b = body_ids[k] + # check if in contact: max force norm over history > 1.0 + max_force = float(0.0) + for h in range(n_history): + f = net_forces_w[i, h, b] + f_norm = wp.sqrt(f[0] * f[0] + f[1] * f[1] + f[2] * f[2]) + max_force = wp.max(max_force, f_norm) + in_contact = wp.where(max_force > 1.0, 1.0, 0.0) + # planar velocity norm + vx = body_lin_vel_w[i, b][0] + vy = body_lin_vel_w[i, b][1] + vel_norm = wp.sqrt(vx * vx + vy * vy) + s += vel_norm * in_contact + out[i] = s + + +def feet_slide(env, out, sensor_cfg: SceneEntityCfg, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> None: + """Penalize feet sliding.""" + contact_sensor: ContactSensor = env.scene.sensors[sensor_cfg.name] + asset = env.scene[asset_cfg.name] + wp.launch( + kernel=_feet_slide_kernel, + dim=env.num_envs, + inputs=[ + asset.data.body_lin_vel_w, + contact_sensor.data.net_forces_w_history, + sensor_cfg.body_ids_wp, + contact_sensor.data.net_forces_w_history.shape[1], + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# track_lin_vel_xy_yaw_frame_exp +# --------------------------------------------------------------------------- + + +@wp.kernel +def _track_lin_vel_xy_yaw_frame_exp_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + root_lin_vel_w: wp.array(dtype=wp.vec3f), + cmd: wp.array(dtype=wp.float32, ndim=2), + inv_std_sq: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + q = root_quat_w[i] + # extract yaw-only quaternion + qx = q[0] + qy = q[1] + qz = q[2] + qw = q[3] + sin_yaw = 2.0 * (qw * qz + qx * qy) + cos_yaw = 1.0 - 2.0 * (qy * qy + qz * qz) + yaw_half = wp.atan2(sin_yaw, cos_yaw) * 0.5 + yaw_q = wp.quatf(0.0, 0.0, wp.sin(yaw_half), wp.cos(yaw_half)) + # rotate world velocity into yaw frame (inverse = conjugate for unit quat) + vel_w = root_lin_vel_w[i] + vel_yaw = wp.quat_rotate(wp.quat_inverse(yaw_q), vel_w) + # error + ex = cmd[i, 0] - vel_yaw[0] + ey = cmd[i, 1] - vel_yaw[1] + err_sq = ex * ex + ey * ey + out[i] = wp.exp(-err_sq * inv_std_sq) + + +def track_lin_vel_xy_yaw_frame_exp( + env, out, std: float, command_name: str, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward tracking of linear velocity commands (xy axes) in the gravity aligned robot frame.""" + fn = track_lin_vel_xy_yaw_frame_exp + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_track_lin_vel_xy_yaw_frame_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, asset.data.root_lin_vel_w, fn._cmd_wp, 1.0 / (std * std), out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# track_ang_vel_z_world_exp +# --------------------------------------------------------------------------- + + +@wp.kernel +def _track_ang_vel_z_world_exp_kernel( + root_ang_vel_w: wp.array(dtype=wp.vec3f), + cmd: wp.array(dtype=wp.float32, ndim=2), + inv_std_sq: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + err = cmd[i, 2] - root_ang_vel_w[i][2] + out[i] = wp.exp(-(err * err) * inv_std_sq) + + +def track_ang_vel_z_world_exp( + env, out, command_name: str, std: float, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Reward tracking of angular velocity commands (yaw) in world frame.""" + fn = track_ang_vel_z_world_exp + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_track_ang_vel_z_world_exp_kernel, + dim=env.num_envs, + inputs=[asset.data.root_ang_vel_w, fn._cmd_wp, 1.0 / (std * std), out], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# stand_still_joint_deviation_l1 +# --------------------------------------------------------------------------- + + +@wp.kernel +def _stand_still_joint_deviation_l1_kernel( + joint_pos: wp.array(dtype=wp.float32, ndim=2), + default_joint_pos: wp.array(dtype=wp.float32, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + command_threshold: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + n_joints = joint_pos.shape[1] + dev = float(0.0) + for j in range(n_joints): + dev += wp.abs(joint_pos[i, j] - default_joint_pos[i, j]) + # gate: only penalize when command is small + cx = cmd[i, 0] + cy = cmd[i, 1] + cmd_norm = wp.sqrt(cx * cx + cy * cy) + out[i] = wp.where(cmd_norm < command_threshold, dev, 0.0) + + +def stand_still_joint_deviation_l1( + env, out, command_name: str, command_threshold: float = 0.06, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot") +) -> None: + """Penalize offsets from the default joint positions when the command is very small.""" + fn = stand_still_joint_deviation_l1 + asset = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_stand_still_joint_deviation_l1_kernel, + dim=env.num_envs, + inputs=[asset.data.joint_pos, asset.data.default_joint_pos, fn._cmd_wp, command_threshold, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py new file mode 100644 index 00000000000..6dd29b49a46 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/mdp/terminations.py @@ -0,0 +1,66 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first termination functions for the velocity locomotion environment.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +@wp.kernel +def _terrain_out_of_bounds_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + half_width: float, + half_height: float, + distance_buffer: float, + out: wp.array(dtype=wp.bool), +): + i = wp.tid() + px = wp.abs(root_pos_w[i][0]) + py = wp.abs(root_pos_w[i][1]) + out[i] = px > half_width - distance_buffer or py > half_height - distance_buffer + + +def terrain_out_of_bounds( + env: ManagerBasedRLEnv, out, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"), distance_buffer: float = 3.0 +) -> None: + """Terminate when the actor moves too close to the edge of the terrain.""" + fn = terrain_out_of_bounds + if not hasattr(fn, "_terrain_resolved"): + fn._terrain_resolved = True + terrain_type = env.scene.cfg.terrain.terrain_type + if terrain_type == "plane": + fn._is_plane = True + elif terrain_type == "generator": + fn._is_plane = False + terrain_gen_cfg = env.scene.terrain.cfg.terrain_generator + grid_width, grid_length = terrain_gen_cfg.size + n_rows, n_cols = terrain_gen_cfg.num_rows, terrain_gen_cfg.num_cols + border_width = terrain_gen_cfg.border_width + fn._half_width = 0.5 * (n_rows * grid_width + 2 * border_width) + fn._half_height = 0.5 * (n_cols * grid_length + 2 * border_width) + else: + raise ValueError("Received unsupported terrain type, must be either 'plane' or 'generator'.") + + if fn._is_plane: + out.zero_() + return + + asset: Articulation = env.scene[asset_cfg.name] + wp.launch( + kernel=_terrain_out_of_bounds_kernel, + dim=env.num_envs, + inputs=[asset.data.root_pos_w, fn._half_width, fn._half_height, distance_buffer, out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py new file mode 100644 index 00000000000..7442d3654ba --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/locomotion/velocity/velocity_env_cfg.py @@ -0,0 +1,296 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math +from dataclasses import MISSING + +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import CurriculumTermCfg as CurrTerm +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.sensors import ContactSensorCfg +from isaaclab.terrains import TerrainImporterCfg +from isaaclab.utils import configclass +from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR, ISAACLAB_NUCLEUS_DIR +from isaaclab.utils.noise import UniformNoiseCfg as Unoise + +import isaaclab_tasks_experimental.manager_based.locomotion.velocity.mdp as mdp + +## +# Pre-defined configs +## +from isaaclab.terrains.config.rough import ROUGH_TERRAINS_CFG # isort: skip + + +## +# Scene definition +## + + +@configclass +class MySceneCfg(InteractiveSceneCfg): + """Configuration for the terrain scene with a legged robot.""" + + # ground terrain + terrain = TerrainImporterCfg( + prim_path="/World/ground", + terrain_type="generator", + terrain_generator=ROUGH_TERRAINS_CFG, + max_init_terrain_level=5, + collision_group=-1, + physics_material=sim_utils.RigidBodyMaterialCfg( + friction_combine_mode="multiply", + restitution_combine_mode="multiply", + static_friction=1.0, + dynamic_friction=1.0, + ), + visual_material=sim_utils.MdlFileCfg( + mdl_path=f"{ISAACLAB_NUCLEUS_DIR}/Materials/TilesMarbleSpiderWhiteBrickBondHoned/TilesMarbleSpiderWhiteBrickBondHoned.mdl", + project_uvw=True, + texture_scale=(0.25, 0.25), + ), + debug_vis=False, + ) + # robots + robot: ArticulationCfg = MISSING + # sensors + contact_forces = ContactSensorCfg( + prim_path="{ENV_REGEX_NS}/Robot/.*", + filter_prim_paths_expr=[], + history_length=3, + track_air_time=True, + ) + # lights + sky_light = AssetBaseCfg( + prim_path="/World/skyLight", + spawn=sim_utils.DomeLightCfg( + intensity=750.0, + texture_file=f"{ISAAC_NUCLEUS_DIR}/Materials/Textures/Skies/PolyHaven/kloofendal_43d_clear_puresky_4k.hdr", + ), + ) + + +## +# MDP settings +## + + +@configclass +class CommandsCfg: + """Command specifications for the MDP.""" + + base_velocity = mdp.UniformVelocityCommandCfg( + asset_name="robot", + resampling_time_range=(10.0, 10.0), + rel_standing_envs=0.02, + rel_heading_envs=1.0, + heading_command=True, + heading_control_stiffness=0.5, + debug_vis=True, + ranges=mdp.UniformVelocityCommandCfg.Ranges( + lin_vel_x=(-1.0, 1.0), lin_vel_y=(-1.0, 1.0), ang_vel_z=(-1.0, 1.0), heading=(-math.pi, math.pi) + ), + ) + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + joint_pos = mdp.JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True) + + +@configclass +class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + base_lin_vel = ObsTerm(func=mdp.base_lin_vel, noise=Unoise(n_min=-0.1, n_max=0.1)) + base_ang_vel = ObsTerm(func=mdp.base_ang_vel, noise=Unoise(n_min=-0.2, n_max=0.2)) + projected_gravity = ObsTerm( + func=mdp.projected_gravity, + noise=Unoise(n_min=-0.05, n_max=0.05), + ) + velocity_commands = ObsTerm(func=mdp.generated_commands, params={"command_name": "base_velocity"}) + joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-1.5, n_max=1.5)) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = True + self.concatenate_terms = True + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + # FIXME(warp-migration): COM randomization in exp manager-based locomotion currently causes + # NaNs and is temporarily disabled. + # base_com = EventTerm( + # func=mdp.randomize_rigid_body_com, + # mode="startup", + # params={ + # "asset_cfg": SceneEntityCfg("robot", body_names="base"), + # "com_range": {"x": (-0.05, 0.05), "y": (-0.05, 0.05), "z": (-0.01, 0.01)}, + # }, + # ) + base_com = None + + # reset + base_external_force_torque = EventTerm( + func=mdp.apply_external_force_torque, + mode="reset", + params={ + "asset_cfg": SceneEntityCfg("robot", body_names="base"), + "force_range": (0.0, 0.0), + "torque_range": (-0.0, 0.0), + }, + ) + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={ + "pose_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5), "yaw": (-3.14, 3.14)}, + "velocity_range": { + "x": (-0.5, 0.5), + "y": (-0.5, 0.5), + "z": (-0.5, 0.5), + "roll": (-0.5, 0.5), + "pitch": (-0.5, 0.5), + "yaw": (-0.5, 0.5), + }, + }, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_scale, + mode="reset", + params={ + "position_range": (0.5, 1.5), + "velocity_range": (0.0, 0.0), + }, + ) + + # interval + push_robot = EventTerm( + func=mdp.push_by_setting_velocity, + mode="interval", + interval_range_s=(10.0, 15.0), + params={"velocity_range": {"x": (-0.5, 0.5), "y": (-0.5, 0.5)}}, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # -- task + track_lin_vel_xy_exp = RewTerm( + func=mdp.track_lin_vel_xy_exp, weight=1.0, params={"command_name": "base_velocity", "std": math.sqrt(0.25)} + ) + track_ang_vel_z_exp = RewTerm( + func=mdp.track_ang_vel_z_exp, weight=0.5, params={"command_name": "base_velocity", "std": math.sqrt(0.25)} + ) + # -- penalties + lin_vel_z_l2 = RewTerm(func=mdp.lin_vel_z_l2, weight=-2.0) + ang_vel_xy_l2 = RewTerm(func=mdp.ang_vel_xy_l2, weight=-0.05) + dof_torques_l2 = RewTerm(func=mdp.joint_torques_l2, weight=-1.0e-5) + dof_acc_l2 = RewTerm(func=mdp.joint_acc_l2, weight=-2.5e-7) + action_rate_l2 = RewTerm(func=mdp.action_rate_l2, weight=-0.01) + feet_air_time = RewTerm( + func=mdp.feet_air_time, + weight=0.125, + params={ + "sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*FOOT"), + "command_name": "base_velocity", + "threshold": 0.5, + }, + ) + undesired_contacts = RewTerm( + func=mdp.undesired_contacts, + weight=-1.0, + params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names=".*THIGH"), "threshold": 1.0}, + ) + # -- optional penalties + flat_orientation_l2 = RewTerm(func=mdp.flat_orientation_l2, weight=0.0) + dof_pos_limits = RewTerm(func=mdp.joint_pos_limits, weight=0.0) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + time_out = DoneTerm(func=mdp.time_out, time_out=True) + base_contact = DoneTerm( + func=mdp.illegal_contact, + params={"sensor_cfg": SceneEntityCfg("contact_forces", body_names="base"), "threshold": 1.0}, + ) + + +@configclass +class CurriculumCfg: + """Curriculum terms for the MDP.""" + + terrain_levels = CurrTerm(func=mdp.terrain_levels_vel) + + +## +# Environment configuration +## + + +@configclass +class LocomotionVelocityRoughEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the locomotion velocity-tracking environment.""" + + # Scene settings + scene: MySceneCfg = MySceneCfg(num_envs=4096, env_spacing=2.5, replicate_physics=True) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + commands: CommandsCfg = CommandsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + curriculum: CurriculumCfg = CurriculumCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 4 + self.episode_length_s = 20.0 + # simulation settings + self.sim.dt = 1.0 / 200.0 + self.sim.render_interval = self.decimation + self.sim.physics_material = self.scene.terrain.physics_material + # update sensor update periods + if self.scene.contact_forces is not None: + self.scene.contact_forces.update_period = self.sim.dt + # check if terrain levels curriculum is enabled + if getattr(self.curriculum, "terrain_levels", None) is not None: + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.curriculum = True + else: + if self.scene.terrain.terrain_generator is not None: + self.scene.terrain.terrain_generator.curriculum = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py new file mode 100644 index 00000000000..6cd56351b6e --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from .reach import * diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py new file mode 100644 index 00000000000..fe34199f232 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Reach experimental task registrations (manager-based).""" diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py new file mode 100644 index 00000000000..460a3056908 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py new file mode 100644 index 00000000000..b08612ccc74 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import gymnasium as gym + +# Reuse agent configs from the stable task package. +from isaaclab_tasks.manager_based.manipulation.reach.config.franka import agents + +## +# Register Gym environments. +## + +## +# Joint Position Control +## + +gym.register( + id="Isaac-Reach-Franka-Warp-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:FrankaReachEnvCfg", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:FrankaReachPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + }, +) + +gym.register( + id="Isaac-Reach-Franka-Warp-Play-v0", + entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", + disable_env_checker=True, + kwargs={ + "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:FrankaReachEnvCfg_PLAY", + "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", + "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:FrankaReachPPORunnerCfg", + "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", + }, +) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py new file mode 100644 index 00000000000..5b3c7ec0c2f --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/franka/joint_pos_env_cfg.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp +from isaaclab_tasks_experimental.manager_based.manipulation.reach.reach_env_cfg import ReachEnvCfg + +## +# Pre-defined configs +## +from isaaclab_assets import FRANKA_PANDA_CFG # isort: skip + + +## +# Environment configuration +## + + +@configclass +class FrankaReachEnvCfg(ReachEnvCfg): + sim: SimulationCfg = SimulationCfg( + dt=1 / 60, + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=50, + nconmax=20, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ), + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # switch robot to franka + self.scene.robot = FRANKA_PANDA_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + # override rewards + self.rewards.end_effector_position_tracking.params["asset_cfg"].body_names = ["panda_hand"] + self.rewards.end_effector_position_tracking_fine_grained.params["asset_cfg"].body_names = ["panda_hand"] + self.rewards.end_effector_orientation_tracking.params["asset_cfg"].body_names = ["panda_hand"] + + # override actions + self.actions.arm_action = mdp.JointPositionActionCfg( + asset_name="robot", joint_names=["panda_joint.*"], scale=0.5, use_default_offset=True + ) + # override command generator body + # end-effector is along z-direction + self.commands.ee_pose.body_name = "panda_hand" + self.commands.ee_pose.ranges.pitch = (math.pi, math.pi) + + +@configclass +class FrankaReachEnvCfg_PLAY(FrankaReachEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py new file mode 100644 index 00000000000..85908c15805 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +# UR10 env disabled: USD asset has composition errors (broken asset file). +# Fails on both torch baseline and warp with: +# RuntimeError: USD stage has composition errors while loading provided stage +# Re-enable once the UR10 USD asset is fixed. + +# import gymnasium as gym +# from isaaclab_tasks.manager_based.manipulation.reach.config.ur_10 import agents + +# gym.register( +# id="Isaac-Reach-UR10-Warp-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:UR10ReachEnvCfg", +# "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UR10ReachPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", +# }, +# ) + +# gym.register( +# id="Isaac-Reach-UR10-Warp-Play-v0", +# entry_point="isaaclab_experimental.envs:ManagerBasedRLEnvWarp", +# disable_env_checker=True, +# kwargs={ +# "env_cfg_entry_point": f"{__name__}.joint_pos_env_cfg:UR10ReachEnvCfg_PLAY", +# "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml", +# "rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:UR10ReachPPORunnerCfg", +# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml", +# }, +# ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py new file mode 100644 index 00000000000..7eddda91f5a --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/config/ur_10/joint_pos_env_cfg.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +import math + +from isaaclab_newton.physics import MJWarpSolverCfg, NewtonCfg + +from isaaclab.sim import SimulationCfg +from isaaclab.utils import configclass + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp +from isaaclab_tasks_experimental.manager_based.manipulation.reach.reach_env_cfg import ReachEnvCfg + +## +# Pre-defined configs +## +from isaaclab_assets import UR10_CFG # isort: skip + + +## +# Environment configuration +## + + +@configclass +class UR10ReachEnvCfg(ReachEnvCfg): + sim: SimulationCfg = SimulationCfg( + physics=NewtonCfg( + solver_cfg=MJWarpSolverCfg( + njmax=50, + nconmax=20, + cone="pyramidal", + impratio=1, + integrator="implicitfast", + ), + num_substeps=1, + debug_mode=False, + ) + ) + + def __post_init__(self): + # post init of parent + super().__post_init__() + + # switch robot to ur10 + self.scene.robot = UR10_CFG.replace(prim_path="{ENV_REGEX_NS}/Robot") + # override events + self.events.reset_robot_joints.params["position_range"] = (0.75, 1.25) + # override rewards + self.rewards.end_effector_position_tracking.params["asset_cfg"].body_names = ["ee_link"] + self.rewards.end_effector_position_tracking_fine_grained.params["asset_cfg"].body_names = ["ee_link"] + self.rewards.end_effector_orientation_tracking.params["asset_cfg"].body_names = ["ee_link"] + # override actions + self.actions.arm_action = mdp.JointPositionActionCfg( + asset_name="robot", joint_names=[".*"], scale=0.5, use_default_offset=True + ) + # override command generator body + # end-effector is along x-direction + self.commands.ee_pose.body_name = "ee_link" + self.commands.ee_pose.ranges.pitch = (math.pi / 2, math.pi / 2) + + +@configclass +class UR10ReachEnvCfg_PLAY(UR10ReachEnvCfg): + def __post_init__(self): + # post init of parent + super().__post_init__() + # make a smaller scene for play + self.scene.num_envs = 50 + self.scene.env_spacing = 2.5 + # disable randomization for play + self.observations.policy.enable_corruption = False diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py new file mode 100644 index 00000000000..b0845d6735b --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""This sub-module contains the functions that are specific to the reach environments.""" + +from isaaclab_experimental.envs.mdp import * # noqa: F401, F403 + +from .rewards import * # noqa: F401, F403 diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py new file mode 100644 index 00000000000..811163ec973 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/mdp/rewards.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp-first reward terms for the reach task. + +All functions follow the ``func(env, out, **params) -> None`` signature. +Command tensors are cached as zero-copy warp views on first call. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import warp as wp + +from isaaclab.assets import Articulation +from isaaclab.managers import SceneEntityCfg + +if TYPE_CHECKING: + from isaaclab.envs import ManagerBasedRLEnv + + +# --------------------------------------------------------------------------- +# position_command_error +# --------------------------------------------------------------------------- + + +@wp.kernel +def _position_command_error_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + body_pos_w: wp.array(dtype=wp.vec3f, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + # desired position in body frame -> world frame + des_b = wp.vec3f(cmd[i, 0], cmd[i, 1], cmd[i, 2]) + des_w = root_pos_w[i] + wp.quat_rotate(root_quat_w[i], des_b) + # current end-effector position + cur_w = body_pos_w[i, body_idx] + dx = cur_w[0] - des_w[0] + dy = cur_w[1] - des_w[1] + dz = cur_w[2] - des_w[2] + out[i] = wp.sqrt(dx * dx + dy * dy + dz * dz) + + +def position_command_error(env: ManagerBasedRLEnv, out, command_name: str, asset_cfg: SceneEntityCfg) -> None: + """Penalize tracking of the position error using L2-norm.""" + fn = position_command_error + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_position_command_error_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_pos_w, + asset.data.root_quat_w, + asset.data.body_pos_w, + fn._cmd_wp, + asset_cfg.body_ids[0], + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# position_command_error_tanh +# --------------------------------------------------------------------------- + + +@wp.kernel +def _position_command_error_tanh_kernel( + root_pos_w: wp.array(dtype=wp.vec3f), + root_quat_w: wp.array(dtype=wp.quatf), + body_pos_w: wp.array(dtype=wp.vec3f, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + inv_std: float, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + des_b = wp.vec3f(cmd[i, 0], cmd[i, 1], cmd[i, 2]) + des_w = root_pos_w[i] + wp.quat_rotate(root_quat_w[i], des_b) + cur_w = body_pos_w[i, body_idx] + dx = cur_w[0] - des_w[0] + dy = cur_w[1] - des_w[1] + dz = cur_w[2] - des_w[2] + dist = wp.sqrt(dx * dx + dy * dy + dz * dz) + out[i] = 1.0 - wp.tanh(dist * inv_std) + + +def position_command_error_tanh( + env: ManagerBasedRLEnv, out, std: float, command_name: str, asset_cfg: SceneEntityCfg +) -> None: + """Reward tracking of the position using the tanh kernel.""" + fn = position_command_error_tanh + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_position_command_error_tanh_kernel, + dim=env.num_envs, + inputs=[ + asset.data.root_pos_w, + asset.data.root_quat_w, + asset.data.body_pos_w, + fn._cmd_wp, + asset_cfg.body_ids[0], + 1.0 / std, + out, + ], + device=env.device, + ) + + +# --------------------------------------------------------------------------- +# orientation_command_error +# --------------------------------------------------------------------------- + + +@wp.kernel +def _orientation_command_error_kernel( + root_quat_w: wp.array(dtype=wp.quatf), + body_quat_w: wp.array(dtype=wp.quatf, ndim=2), + cmd: wp.array(dtype=wp.float32, ndim=2), + body_idx: int, + out: wp.array(dtype=wp.float32), +): + i = wp.tid() + # desired quat in body frame -> world frame: q_des_w = q_root * q_des_b + des_b = wp.quatf(cmd[i, 3], cmd[i, 4], cmd[i, 5], cmd[i, 6]) + des_w = root_quat_w[i] * des_b + # current ee orientation + cur_w = body_quat_w[i, body_idx] + # shortest-path error: angle of q_err = cur^-1 * des + q_err = wp.quat_inverse(cur_w) * des_w + # error magnitude = 2 * acos(|w|) (w component of the error quaternion) + qw = wp.abs(q_err[3]) + qw = wp.clamp(qw, 0.0, 1.0) + out[i] = 2.0 * wp.acos(qw) + + +def orientation_command_error(env: ManagerBasedRLEnv, out, command_name: str, asset_cfg: SceneEntityCfg) -> None: + """Penalize tracking orientation error using shortest path.""" + fn = orientation_command_error + asset: Articulation = env.scene[asset_cfg.name] + if not hasattr(fn, "_cmd_wp") or fn._cmd_name != command_name: + cmd = env.command_manager.get_command(command_name) + fn._cmd_wp = cmd if isinstance(cmd, wp.array) else wp.from_torch(cmd) + fn._cmd_name = command_name + wp.launch( + kernel=_orientation_command_error_kernel, + dim=env.num_envs, + inputs=[asset.data.root_quat_w, asset.data.body_quat_w, fn._cmd_wp, asset_cfg.body_ids[0], out], + device=env.device, + ) diff --git a/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py new file mode 100644 index 00000000000..d019b053140 --- /dev/null +++ b/source/isaaclab_tasks_experimental/isaaclab_tasks_experimental/manager_based/manipulation/reach/reach_env_cfg.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from dataclasses import MISSING + +from isaaclab_experimental.managers import ObservationTermCfg as ObsTerm +from isaaclab_experimental.managers import RewardTermCfg as RewTerm +from isaaclab_experimental.managers import SceneEntityCfg +from isaaclab_experimental.managers import TerminationTermCfg as DoneTerm + +import isaaclab.sim as sim_utils +from isaaclab.assets import ArticulationCfg, AssetBaseCfg +from isaaclab.envs import ManagerBasedRLEnvCfg +from isaaclab.managers import ActionTermCfg as ActionTerm +from isaaclab.managers import CurriculumTermCfg as CurrTerm +from isaaclab.managers import EventTermCfg as EventTerm +from isaaclab.managers import ObservationGroupCfg as ObsGroup +from isaaclab.scene import InteractiveSceneCfg +from isaaclab.utils import configclass +from isaaclab.utils.noise import UniformNoiseCfg as Unoise + +import isaaclab_tasks_experimental.manager_based.manipulation.reach.mdp as mdp + +## +# Scene definition +## + + +@configclass +class ReachSceneCfg(InteractiveSceneCfg): + """Configuration for the scene with a robotic arm.""" + + # world + ground = AssetBaseCfg( + prim_path="/World/ground", + spawn=sim_utils.GroundPlaneCfg(), + init_state=AssetBaseCfg.InitialStateCfg(pos=(0.0, 0.0, -1.05)), + ) + + # robots + robot: ArticulationCfg = MISSING + + # lights + light = AssetBaseCfg( + prim_path="/World/light", + spawn=sim_utils.DomeLightCfg(color=(0.75, 0.75, 0.75), intensity=2500.0), + ) + + +## +# MDP settings +## + + +@configclass +class CommandsCfg: + """Command terms for the MDP.""" + + ee_pose = mdp.UniformPoseCommandCfg( + asset_name="robot", + body_name=MISSING, + resampling_time_range=(4.0, 4.0), + debug_vis=True, + ranges=mdp.UniformPoseCommandCfg.Ranges( + pos_x=(0.35, 0.65), + pos_y=(-0.2, 0.2), + pos_z=(0.15, 0.5), + roll=(0.0, 0.0), + pitch=MISSING, # depends on end-effector axis + yaw=(-3.14, 3.14), + ), + ) + + +@configclass +class ActionsCfg: + """Action specifications for the MDP.""" + + arm_action: ActionTerm = MISSING + gripper_action: ActionTerm | None = None + + +@configclass +class ObservationsCfg: + """Observation specifications for the MDP.""" + + @configclass + class PolicyCfg(ObsGroup): + """Observations for policy group.""" + + # observation terms (order preserved) + joint_pos = ObsTerm(func=mdp.joint_pos_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + joint_vel = ObsTerm(func=mdp.joint_vel_rel, noise=Unoise(n_min=-0.01, n_max=0.01)) + pose_command = ObsTerm(func=mdp.generated_commands, params={"command_name": "ee_pose"}) + actions = ObsTerm(func=mdp.last_action) + + def __post_init__(self): + self.enable_corruption = True + self.concatenate_terms = True + + # observation groups + policy: PolicyCfg = PolicyCfg() + + +@configclass +class EventCfg: + """Configuration for events.""" + + reset_base = EventTerm( + func=mdp.reset_root_state_uniform, + mode="reset", + params={"pose_range": {}, "velocity_range": {}}, + ) + + reset_robot_joints = EventTerm( + func=mdp.reset_joints_by_scale, + mode="reset", + params={ + "position_range": (0.5, 1.5), + "velocity_range": (0.0, 0.0), + }, + ) + + +@configclass +class RewardsCfg: + """Reward terms for the MDP.""" + + # task terms + end_effector_position_tracking = RewTerm( + func=mdp.position_command_error, + weight=-0.2, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"}, + ) + end_effector_position_tracking_fine_grained = RewTerm( + func=mdp.position_command_error_tanh, + weight=0.1, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "std": 0.1, "command_name": "ee_pose"}, + ) + end_effector_orientation_tracking = RewTerm( + func=mdp.orientation_command_error, + weight=-0.1, + params={"asset_cfg": SceneEntityCfg("robot", body_names=MISSING), "command_name": "ee_pose"}, + ) + + # action penalty + action_rate = RewTerm(func=mdp.action_rate_l2, weight=-0.0001) + joint_vel = RewTerm( + func=mdp.joint_vel_l2, + weight=-0.0001, + params={"asset_cfg": SceneEntityCfg("robot")}, + ) + + +@configclass +class TerminationsCfg: + """Termination terms for the MDP.""" + + time_out = DoneTerm(func=mdp.time_out, time_out=True) + + +@configclass +class CurriculumCfg: + """Curriculum terms for the MDP.""" + + action_rate = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "action_rate", "weight": -0.005, "num_steps": 4500} + ) + + joint_vel = CurrTerm( + func=mdp.modify_reward_weight, params={"term_name": "joint_vel", "weight": -0.001, "num_steps": 4500} + ) + + +## +# Environment configuration +## + + +@configclass +class ReachEnvCfg(ManagerBasedRLEnvCfg): + """Configuration for the reach end-effector pose tracking environment.""" + + # Scene settings + scene: ReachSceneCfg = ReachSceneCfg(num_envs=4096, env_spacing=2.5) + # Basic settings + observations: ObservationsCfg = ObservationsCfg() + actions: ActionsCfg = ActionsCfg() + commands: CommandsCfg = CommandsCfg() + # MDP settings + rewards: RewardsCfg = RewardsCfg() + terminations: TerminationsCfg = TerminationsCfg() + events: EventCfg = EventCfg() + curriculum: CurriculumCfg = CurriculumCfg() + + def __post_init__(self): + """Post initialization.""" + # general settings + self.decimation = 2 + self.sim.render_interval = self.decimation + self.episode_length_s = 12.0 + self.viewer.eye = (3.5, 3.5, 3.5) + # simulation settings + self.sim.dt = 1.0 / 60.0 From a9ece43e1ed230c7d7ff5f1b6b3fa7693c3626a9 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Thu, 12 Mar 2026 00:31:39 -0700 Subject: [PATCH 6/7] Add warp MDP parity tests Add parity test infrastructure and tests organized by MDP category: observations, rewards, terminations, events, and actions. Tests verify warp-first implementations match stable torch baselines across three execution modes (stable, warp uncaptured, warp CUDA-graph captured). Extract shared fixtures and mock objects to parity_helpers.py. --- .../test/envs/mdp/parity_helpers.py | 590 ++++++++++++++++++ .../test/envs/mdp/test_actions_warp_parity.py | 227 +++++++ .../test/envs/mdp/test_events_warp_parity.py | 340 ++++++++++ .../envs/mdp/test_observations_warp_parity.py | 458 ++++++++++++++ .../test/envs/mdp/test_rewards_warp_parity.py | 567 +++++++++++++++++ .../envs/mdp/test_terminations_warp_parity.py | 349 +++++++++++ 6 files changed, 2531 insertions(+) create mode 100644 source/isaaclab_experimental/test/envs/mdp/parity_helpers.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_actions_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_events_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_observations_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_rewards_warp_parity.py create mode 100644 source/isaaclab_experimental/test/envs/mdp/test_terminations_warp_parity.py diff --git a/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py b/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py new file mode 100644 index 00000000000..607039bbd85 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/parity_helpers.py @@ -0,0 +1,590 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Shared test utilities for MDP warp-vs-stable parity tests. + +Contains constants, assertion helpers, warp kernel runners, mock classes, +numpy math utilities, and mutation helpers used by the observation, reward, +termination, event, and action parity test files. +""" + +from __future__ import annotations + +import numpy as np +import torch +import warp as wp + +# --------------------------------------------------------------------------- +# Constants (shared across all MDP parity test files) +# --------------------------------------------------------------------------- +NUM_ENVS = 64 +NUM_JOINTS = 12 +NUM_ACTIONS = 6 +DEVICE = "cuda:0" +ATOL = 1e-5 +RTOL = 1e-5 + +# Body/sensor-level defaults shared by observation, reward, and termination tests +NUM_BODIES = 4 +NUM_HISTORY = 3 +CMD_DIM = 3 +BODY_IDS = [0, 2] + +# Gravity direction constant (normalized, same as ArticulationData.GRAVITY_VEC_W) +GRAVITY_DIR_NP = np.array([[0.0, 0.0, -1.0]], dtype=np.float32) + + +# --------------------------------------------------------------------------- +# Numpy math utilities +# --------------------------------------------------------------------------- + + +def quat_rotate_inv_np(q_xyzw: np.ndarray, v: np.ndarray) -> np.ndarray: + """Apply inverse quaternion rotation to vectors (numpy, batch). + + Equivalent to ``wp.quat_rotate_inv`` — rotates *v* by the conjugate of *q*. + + Args: + q_xyzw: (N, 4) quaternion array in [x, y, z, w] order (warp convention). + v: (N, 3) vector array. + + Returns: + (N, 3) rotated vectors in float32. + """ + qv = -q_xyzw[..., :3] # conjugate xyz + qw = q_xyzw[..., 3:4] + t = 2.0 * np.cross(qv, v) + return (v + qw * t + np.cross(qv, t)).astype(np.float32) + + +# --------------------------------------------------------------------------- +# Warp / numpy utilities +# --------------------------------------------------------------------------- + + +def copy_np_to_wp(dest: wp.array, src_np: np.ndarray): + """In-place overwrite of a warp array's contents from numpy (preserves pointer).""" + tmp = wp.array(src_np, dtype=dest.dtype, device=str(dest.device)) + wp.copy(dest, tmp) + + +# --------------------------------------------------------------------------- +# Test runner helpers +# --------------------------------------------------------------------------- + + +def run_warp_obs(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function and return the result as a torch tensor.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_obs_captured(func, env, shape, device=DEVICE, **kwargs): + """Run a warp observation function under CUDA graph capture and return the result.""" + out = wp.zeros(shape, dtype=wp.float32, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +def run_warp_rew(func, env, device=DEVICE, **kwargs): + """Run a warp reward function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_rew_captured(func, env, device=DEVICE, **kwargs): + """Run a warp reward function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +def run_warp_term(func, env, device=DEVICE, **kwargs): + """Run a warp termination function and return the result as a torch tensor.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) + return wp.to_torch(out).clone() + + +def run_warp_term_captured(func, env, device=DEVICE, **kwargs): + """Run a warp termination function under CUDA graph capture.""" + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=device) + func(env, out, **kwargs) # warm-up + with wp.ScopedCapture() as capture: + func(env, out, **kwargs) + wp.capture_launch(capture.graph) + return wp.to_torch(out).clone() + + +# --------------------------------------------------------------------------- +# Assertion helpers +# --------------------------------------------------------------------------- + + +def assert_close(actual: torch.Tensor, expected: torch.Tensor, atol: float = ATOL, rtol: float = RTOL): + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + + +def assert_equal(actual: torch.Tensor, expected: torch.Tensor): + assert torch.equal(actual, expected), f"Mismatch:\n actual: {actual}\n expected: {expected}" + + +# --------------------------------------------------------------------------- +# Mock classes (shared across parity test files) +# --------------------------------------------------------------------------- + + +class MockArticulationData: + """Mock articulation data backed by Warp arrays (same storage Newton uses). + + Args: + num_envs: Number of environments. + num_joints: Number of joints. + device: Warp device string. + seed: Random seed for reproducibility. + num_bodies: Number of bodies. When > 0, generates body-level arrays + (body_pose_w, body_lin_acc_w, body_com_pos_b) and multi-body + projected_gravity_b. When 0, projected_gravity_b is root-level + (derived from root quaternion). + """ + + def __init__(self, num_envs=NUM_ENVS, num_joints=NUM_JOINTS, device=DEVICE, seed=42, num_bodies=0): + rng = np.random.RandomState(seed) + + # --- Joint state (float32 2D) --- + self.joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32), device=device) + self.joint_vel = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 2.0, device=device) + self.joint_acc = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.5, device=device) + self.default_joint_pos = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 0.01, device=device) + self.default_joint_vel = wp.array(np.zeros((num_envs, num_joints), dtype=np.float32), device=device) + self.applied_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + self.computed_torque = wp.array(rng.randn(num_envs, num_joints).astype(np.float32) * 10.0, device=device) + + # --- Soft joint limits --- + limits_np = np.zeros((num_envs, num_joints, 2), dtype=np.float32) + limits_np[:, :, 0] = -3.14 + limits_np[:, :, 1] = 3.14 + self.soft_joint_pos_limits = wp.array(limits_np, dtype=wp.vec2f, device=device) + self.soft_joint_vel_limits = wp.array(np.full((num_envs, num_joints), 10.0, dtype=np.float32), device=device) + + # --- Root state --- + root_pos_np = rng.randn(num_envs, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.1 # positive heights + self.root_pos_w = wp.array(root_pos_np, dtype=wp.vec3f, device=device) + + # Unit quaternions + quat_np = rng.randn(num_envs, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + self.root_quat_w = wp.array(quat_np, dtype=wp.quatf, device=device) + + # Tier 1 compound: root_link_pose_w (transformf = pos + quat) + pose_np = np.zeros((num_envs, 7), dtype=np.float32) + pose_np[:, :3] = root_pos_np + pose_np[:, 3:] = quat_np + self.root_link_pose_w = wp.array(pose_np, dtype=wp.transformf, device=device) + + # World-frame velocities + lin_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + ang_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + self.root_lin_vel_w = wp.array(lin_vel_w_np, dtype=wp.vec3f, device=device) + self.root_ang_vel_w = wp.array(ang_vel_w_np, dtype=wp.vec3f, device=device) + + # Tier 1 compound: root_com_vel_w (spatial_vectorf: top=linear, bottom=angular) + vel_np = np.zeros((num_envs, 6), dtype=np.float32) + vel_np[:, :3] = lin_vel_w_np + vel_np[:, 3:] = ang_vel_w_np + self.root_com_vel_w = wp.array(vel_np, dtype=wp.spatial_vectorf, device=device) + + # Gravity direction constant + self.GRAVITY_VEC_W = wp.vec3f(0.0, 0.0, -1.0) + + # Derived body-frame quantities (consistent with Tier 1 compounds) + self.root_lin_vel_b = wp.array(quat_rotate_inv_np(quat_np, lin_vel_w_np), dtype=wp.vec3f, device=device) + self.root_ang_vel_b = wp.array(quat_rotate_inv_np(quat_np, ang_vel_w_np), dtype=wp.vec3f, device=device) + + # --- projected_gravity_b and body-level data --- + if num_bodies > 0: + # Multi-body projected_gravity_b: (num_envs, num_bodies) vec3f + grav_np = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + grav_np[:, :, 2] = -1.0 + grav_np /= np.linalg.norm(grav_np, axis=2, keepdims=True) + self.projected_gravity_b = wp.array(grav_np, dtype=wp.vec3f, device=device) + + # body_pose_w: (num_envs, num_bodies) transformf + bpose_np = np.zeros((num_envs, num_bodies, 7), dtype=np.float32) + bpose_np[:, :, :3] = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + bpose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.body_pose_w = wp.array(bpose_np, dtype=wp.transformf, device=device) + + # body_lin_acc_w: (num_envs, num_bodies) vec3f + self.body_lin_acc_w = wp.array( + rng.randn(num_envs, num_bodies, 3).astype(np.float32), dtype=wp.vec3f, device=device + ) + + # body_com_pos_b: (num_envs, num_bodies) vec3f + self.body_com_pos_b = wp.array( + rng.randn(num_envs, num_bodies, 3).astype(np.float32) * 0.01, dtype=wp.vec3f, device=device + ) + else: + # Root-level projected_gravity_b: (num_envs,) vec3f — derived from root quat + self.projected_gravity_b = wp.array( + quat_rotate_inv_np(quat_np, np.tile(GRAVITY_DIR_NP, (num_envs, 1))), + dtype=wp.vec3f, + device=device, + ) + + # --- Event-specific data --- + self.root_vel_w = wp.array(rng.randn(num_envs, 6).astype(np.float32), dtype=wp.spatial_vectorf, device=device) + + default_pose_np = np.zeros((num_envs, 7), dtype=np.float32) + default_pose_np[:, 0:3] = rng.randn(num_envs, 3).astype(np.float32) * 0.1 + default_pose_np[:, 3:7] = [0.0, 0.0, 0.0, 1.0] + self.default_root_pose = wp.array(default_pose_np, dtype=wp.transformf, device=device) + + self.default_root_vel = wp.array( + np.zeros((num_envs, 6), dtype=np.float32), dtype=wp.spatial_vectorf, device=device + ) + + def resolve_joint_mask(self, joint_ids=None): + n = self.joint_pos.shape[1] + mask = [False] * n + if joint_ids is None or isinstance(joint_ids, slice): + mask = [True] * n + else: + for j in joint_ids: + mask[j] = True + return wp.array(mask, dtype=wp.bool, device=str(self.joint_pos.device)) + + +class MockArticulation: + """Mock articulation asset with simulation write stubs. + + Provides both no-op write stubs (for event tests) and tracking write stubs + (for action tests). The ``last_*_target`` attributes record the most recent + values passed to ``set_joint_*_target``, enabling verification in action tests. + """ + + def __init__(self, data: MockArticulationData, num_bodies: int = 1, num_joints: int = NUM_JOINTS): + self.data = data + self.num_bodies = num_bodies + self.num_joints = num_joints + self.device = DEVICE + self._joint_names = [f"joint_{i}" for i in range(num_joints)] + # Tracking attributes for action tests + self.last_pos_target = None + self.last_vel_target = None + self.last_effort_target = None + self.last_joint_mask = None + + # -- Simulation write stubs (no-op, for event tests) -------------------- + + def write_root_velocity_to_sim(self, *a, **kw): + pass + + def write_root_pose_to_sim(self, *a, **kw): + pass + + def write_joint_state_to_sim(self, *a, **kw): + pass + + def set_external_force_and_torque(self, *a, **kw): + pass + + # -- Action write stubs (tracking, for action tests) -------------------- + + def set_joint_position_target(self, target, joint_ids=None, joint_mask=None): + self.last_pos_target = target + self.last_joint_mask = joint_mask + + def set_joint_velocity_target(self, target, joint_ids=None, joint_mask=None): + self.last_vel_target = target + self.last_joint_mask = joint_mask + + def set_joint_effort_target(self, target, joint_ids=None, joint_mask=None): + self.last_effort_target = target + self.last_joint_mask = joint_mask + + def set_joint_effort_target_index(self, target, joint_ids=None): + self.last_effort_target = target + + # -- Query stubs -------------------------------------------------------- + + def find_joints(self, names, preserve_order=False): + if isinstance(names, list) and names == [".*"]: + return list(range(self.num_joints)), list(self._joint_names) + ids = [] + resolved = [] + for name in names if isinstance(names, list) else [names]: + for i, jn in enumerate(self._joint_names): + if (name in jn or name == jn or name == ".*") and i not in ids: + ids.append(i) + resolved.append(jn) + if not ids: + ids = list(range(self.num_joints)) + resolved = list(self._joint_names) + return ids, resolved + + def find_bodies(self, name): + return [0], [name] + + +class MockScene: + """Mock scene with asset lookup, env origins, and optional sensors.""" + + def __init__(self, assets: dict, env_origins, sensors=None): + self._assets = assets + self.env_origins = env_origins + self.sensors = sensors or {} + self.articulations = dict(assets) + self.rigid_objects = {} + self.num_envs = NUM_ENVS + + def __getitem__(self, name: str): + return self._assets[name] + + +# --------------------------------------------------------------------------- +# Root-state mutation helper +# --------------------------------------------------------------------------- + + +def mutate_root_state(rng: np.random.RandomState, art_data: MockArticulationData, num_envs: int = NUM_ENVS): + """Mutate root-level state arrays in-place (preserves buffer pointers). + + Updates root_pos_w, root_quat_w, root_link_pose_w, root_com_vel_w, + root_lin_vel_w, root_ang_vel_w, root_lin_vel_b, root_ang_vel_b, and + (when 1D) projected_gravity_b — all consistently derived from a fresh + random quaternion and world-frame velocities. + """ + root_pos_np = rng.randn(num_envs, 3).astype(np.float32) + root_pos_np[:, 2] = np.abs(root_pos_np[:, 2]) + 0.05 + copy_np_to_wp(art_data.root_pos_w, root_pos_np) + + quat_np = rng.randn(num_envs, 4).astype(np.float32) + quat_np /= np.linalg.norm(quat_np, axis=1, keepdims=True) + copy_np_to_wp(art_data.root_quat_w, quat_np) + + pose_np = np.zeros((num_envs, 7), dtype=np.float32) + pose_np[:, :3] = root_pos_np + pose_np[:, 3:] = quat_np + copy_np_to_wp(art_data.root_link_pose_w, pose_np) + + lin_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + ang_vel_w_np = rng.randn(num_envs, 3).astype(np.float32) + copy_np_to_wp(art_data.root_lin_vel_w, lin_vel_w_np) + copy_np_to_wp(art_data.root_ang_vel_w, ang_vel_w_np) + + vel_np = np.zeros((num_envs, 6), dtype=np.float32) + vel_np[:, :3] = lin_vel_w_np + vel_np[:, 3:] = ang_vel_w_np + copy_np_to_wp(art_data.root_com_vel_w, vel_np) + + copy_np_to_wp(art_data.root_lin_vel_b, quat_rotate_inv_np(quat_np, lin_vel_w_np)) + copy_np_to_wp(art_data.root_ang_vel_b, quat_rotate_inv_np(quat_np, ang_vel_w_np)) + + # Root-level projected_gravity_b (1D) is derived from quat. + # Multi-body (2D) is mutated separately by callers. + if art_data.projected_gravity_b.ndim == 1: + copy_np_to_wp( + art_data.projected_gravity_b, + quat_rotate_inv_np(quat_np, np.tile(GRAVITY_DIR_NP, (num_envs, 1))), + ) + + +class MockActionManagerWarp: + """Returns warp arrays (for experimental functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = action_wp + self._prev_action = prev_action_wp + + @property + def action(self) -> wp.array: + return self._action + + @property + def prev_action(self) -> wp.array: + return self._prev_action + + +class MockActionManagerTorch: + """Returns torch tensors (for stable functions).""" + + def __init__(self, action_wp: wp.array, prev_action_wp: wp.array): + self._action = wp.to_torch(action_wp) + self._prev_action = wp.to_torch(prev_action_wp) + + @property + def action(self) -> torch.Tensor: + return self._action + + @property + def prev_action(self) -> torch.Tensor: + return self._prev_action + + +# --------------------------------------------------------------------------- +# Shared mock classes (previously duplicated across test files) +# --------------------------------------------------------------------------- + + +class MockSceneEntityCfg: + """Unified cfg that works for both stable (joint_ids) and experimental (joint_mask / joint_ids_wp).""" + + def __init__(self, name: str, joint_ids: list[int], num_joints: int, device: str): + self.name = name + self.joint_ids = joint_ids + + # Experimental extras + mask = [False] * num_joints + for idx in joint_ids: + mask[idx] = True + self.joint_mask = wp.array(mask, dtype=wp.bool, device=device) + self.joint_ids_wp = wp.array(joint_ids, dtype=wp.int32, device=device) + + +class MockContactSensorData: + """Mock contact sensor data with random force history. + + Stores ``net_forces_w_history`` as a warp ``vec3f`` 3D array of shape + ``(num_envs, num_history, num_bodies)``. Both warp kernels (which read + the warp array directly) and stable functions (which call + ``wp.to_torch``) work with this representation. + """ + + def __init__(self, num_envs=NUM_ENVS, num_history=NUM_HISTORY, num_bodies=NUM_BODIES, device=DEVICE, seed=77): + rng = np.random.RandomState(seed) + self.net_forces_w_history = wp.array( + rng.randn(num_envs, num_history, num_bodies, 3).astype(np.float32), + dtype=wp.vec3f, + device=device, + ) + + +class MockContactSensor: + """Mock contact sensor wrapping :class:`MockContactSensorData`.""" + + def __init__(self, data: MockContactSensorData, num_bodies: int = NUM_BODIES): + self.data = data + self.num_bodies = num_bodies + + +class MockCommandTerm: + """Mock command term with time_left and command_counter.""" + + def __init__(self, num_envs=NUM_ENVS, device=DEVICE, seed=88): + rng = np.random.RandomState(seed) + self.time_left = torch.tensor(rng.rand(num_envs).astype(np.float32) * 0.05, device=device) + self.command_counter = torch.tensor(rng.randint(0, 3, (num_envs,)), dtype=torch.float32, device=device) + + +class MockCommandManager: + """Mock command manager returning a fixed command tensor and term.""" + + def __init__(self, command_tensor: torch.Tensor, cmd_term: MockCommandTerm): + self._cmd = command_tensor + self._term = cmd_term + + def get_command(self, name: str) -> torch.Tensor: + return self._cmd + + def get_term(self, name: str): + return self._term + + +class MockBodyCfg: + """SceneEntityCfg-like object for body-level reward/termination terms.""" + + def __init__(self, name="robot", body_ids=None): + self.name = name + self.body_ids = body_ids if body_ids is not None else list(BODY_IDS) + + +class MockSensorCfg: + """SceneEntityCfg-like object for contact sensor terms. + + Provides both ``body_ids`` (for stable functions) and ``body_ids_wp`` + (for experimental warp functions). + """ + + def __init__(self, name="contact_sensor", body_ids=None, device=DEVICE): + self.name = name + self.body_ids = body_ids if body_ids is not None else list(BODY_IDS) + self.body_ids_wp = wp.array(self.body_ids, dtype=wp.int32, device=device) + + +class MockTerminationManager: + """Mock termination manager providing both torch and warp terminated buffers.""" + + def __init__(self, num_envs=NUM_ENVS, device=DEVICE): + self.terminated = torch.zeros(num_envs, dtype=torch.bool, device=device) + self.terminated_wp = wp.from_torch(self.terminated) + + +# --------------------------------------------------------------------------- +# Art-data mutation helpers (previously duplicated in obs/rew/term test files) +# --------------------------------------------------------------------------- + + +def mutate_art_data( + art_data: MockArticulationData, + warp_env, + num_envs: int = NUM_ENVS, + num_joints: int = NUM_JOINTS, + num_actions: int = NUM_ACTIONS, + rng_seed: int = 200, +): + """Mutate every data array in-place so captured graphs see fresh values.""" + rng = np.random.RandomState(rng_seed) + + copy_np_to_wp(art_data.joint_pos, rng.randn(num_envs, num_joints).astype(np.float32) * 1.5) + copy_np_to_wp(art_data.joint_vel, rng.randn(num_envs, num_joints).astype(np.float32) * 3.0) + copy_np_to_wp(art_data.joint_acc, rng.randn(num_envs, num_joints).astype(np.float32) * 0.8) + copy_np_to_wp(art_data.default_joint_pos, rng.randn(num_envs, num_joints).astype(np.float32) * 0.02) + copy_np_to_wp(art_data.applied_torque, rng.randn(num_envs, num_joints).astype(np.float32) * 12.0) + copy_np_to_wp(art_data.computed_torque, rng.randn(num_envs, num_joints).astype(np.float32) * 12.0) + + mutate_root_state(rng, art_data, num_envs) + + copy_np_to_wp(warp_env.action_manager._action, rng.randn(num_envs, num_actions).astype(np.float32)) + copy_np_to_wp(warp_env.action_manager._prev_action, rng.randn(num_envs, num_actions).astype(np.float32)) + + warp_env.episode_length_buf[:] = torch.randint(0, 500, (num_envs,), dtype=torch.int64, device=DEVICE) + + wp.synchronize() + + +def mutate_body_data( + art_data: MockArticulationData, + num_envs: int = NUM_ENVS, + num_bodies: int = NUM_BODIES, + rng_seed: int = 200, +): + """Mutate body-level and root-level data in-place so captured graphs see fresh values.""" + rng = np.random.RandomState(rng_seed) + + mutate_root_state(rng, art_data, num_envs) + + grav_np = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + grav_np[:, :, 2] = -1.0 + grav_np /= np.linalg.norm(grav_np, axis=2, keepdims=True) + copy_np_to_wp(art_data.projected_gravity_b, grav_np) + + copy_np_to_wp(art_data.body_lin_acc_w, rng.randn(num_envs, num_bodies, 3).astype(np.float32)) + + pose_np = np.zeros((num_envs, num_bodies, 7), dtype=np.float32) + pose_np[:, :, :3] = rng.randn(num_envs, num_bodies, 3).astype(np.float32) + pose_np[:, :, 3:7] = [0.0, 0.0, 0.0, 1.0] + copy_np_to_wp(art_data.body_pose_w, pose_np) + + wp.synchronize() diff --git a/source/isaaclab_experimental/test/envs/mdp/test_actions_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_actions_warp_parity.py new file mode 100644 index 00000000000..237632efa84 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_actions_warp_parity.py @@ -0,0 +1,227 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for warp-first action MDP terms.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch +import warp as wp + +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +from isaaclab_experimental.envs.mdp.actions import ( + JointEffortAction, + JointEffortActionCfg, + JointPositionAction, + JointPositionActionCfg, +) +from parity_helpers import MockArticulation, MockArticulationData, MockScene, copy_np_to_wp + +NUM_ENVS = 32 +NUM_JOINTS = 6 +NUM_BODIES = 3 +DEVICE = "cuda:0" +ATOL = 1e-5 +RTOL = 1e-5 +JOINT_NAMES = [f"joint_{i}" for i in range(NUM_JOINTS)] + + +# ============================================================================ +# Mock infrastructure +# ============================================================================ + + +class MockEnv: + def __init__(self, asset): + self.scene = MockScene({"robot": asset}, env_origins=None) + self.num_envs = NUM_ENVS + self.device = DEVICE + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture() +def art_data(): + data = MockArticulationData(num_envs=NUM_ENVS, num_joints=NUM_JOINTS, num_bodies=NUM_BODIES) + # Override defaults with specific per-joint values for action tests + copy_np_to_wp( + data.default_joint_pos, + np.tile([0.1, 0.2, 0.3, 0.4, 0.5, 0.6], (NUM_ENVS, 1)).astype(np.float32), + ) + # Body quaternion for NonHolonomicAction (identity = [0,0,0,1] in xyzw) + quat_np = np.zeros((NUM_ENVS, NUM_BODIES, 4), dtype=np.float32) + quat_np[:, :, 3] = 1.0 + data.body_quat_w = wp.array(quat_np, dtype=wp.quatf, device=DEVICE) + data._num_joints = NUM_JOINTS + return data + + +@pytest.fixture() +def asset(art_data): + return MockArticulation(art_data, num_bodies=NUM_BODIES, num_joints=NUM_JOINTS) + + +@pytest.fixture() +def env(asset): + return MockEnv(asset) + + +@pytest.fixture() +def actions_wp(): + rng = np.random.RandomState(99) + return wp.array(rng.randn(NUM_ENVS, NUM_JOINTS).astype(np.float32), device=DEVICE) + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def assert_close(actual, expected, atol=ATOL, rtol=RTOL): + if isinstance(actual, wp.array): + actual = wp.to_torch(actual) + if isinstance(expected, wp.array): + expected = wp.to_torch(expected) + torch.testing.assert_close(actual.float(), expected.float(), atol=atol, rtol=rtol) + + +# ============================================================================ +# Joint action tests (JointPosition, JointEffort) +# ============================================================================ + + +class TestJointActions: + """Test JointAction subclasses: process, apply, reset.""" + + def test_joint_effort_process_apply(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Processed = raw * scale(1.0) + offset(0.0) = raw + assert_close(term.processed_actions, actions_wp) + assert asset.last_effort_target is not None + + def test_joint_position_default_offset(self, env, asset, art_data, actions_wp): + cfg = JointPositionActionCfg(asset_name="robot", joint_names=[".*"], use_default_offset=True) + term = JointPositionAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + term.apply_actions() + + # Processed = raw * 1.0 + default_joint_pos[0] + defaults = wp.to_torch(art_data.default_joint_pos)[0] + raw = wp.to_torch(actions_wp) + expected = raw + defaults.unsqueeze(0) + assert_close(term.processed_actions, expected) + assert asset.last_pos_target is not None + + def test_joint_action_reset(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + # Process some actions + term.process_actions(actions_wp, action_offset=0) + assert wp.to_torch(term.raw_actions).abs().sum() > 0 + + # Reset all + term.reset(env_mask=None) + assert_close(term.raw_actions, wp.zeros_like(term.raw_actions)) + + def test_joint_action_reset_masked(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + raw_before = wp.to_torch(term.raw_actions).clone() + + # Reset only first half + mask_np = [i < NUM_ENVS // 2 for i in range(NUM_ENVS)] + mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) + term.reset(env_mask=mask) + + raw_after = wp.to_torch(term.raw_actions) + # First half zeroed + assert_close(raw_after[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) + # Second half unchanged + assert_close(raw_after[NUM_ENVS // 2 :], raw_before[NUM_ENVS // 2 :]) + + def test_joint_action_with_scale(self, env, asset, actions_wp): + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"], scale=2.5) + term = JointEffortAction(cfg, env) + + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + expected = raw * 2.5 + assert_close(term.processed_actions, expected) + + +# ============================================================================ +# Mathematical parity tests: warp processed_actions == raw * scale + offset +# (This is the same formula used by the stable JointAction.process_actions.) +# ============================================================================ + + +class TestJointActionMathParity: + """Verify warp processed_actions match the affine formula raw * scale + offset. + + The stable ``JointAction.process_actions`` computes + ``processed = raw * scale + offset``. These tests verify the warp + implementation produces identical results for various scale/offset + configurations, confirming mathematical parity without needing to + instantiate the stable classes (which require a full env). + """ + + def test_effort_identity(self, env, actions_wp): + """scale=1, offset=0 -> processed == raw.""" + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"]) + term = JointEffortAction(cfg, env) + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + expected = raw * 1.0 + 0.0 + assert_close(term.processed_actions, expected) + + def test_effort_with_scale(self, env, actions_wp): + """scale=3.0, offset=0 -> processed == raw * 3.""" + cfg = JointEffortActionCfg(asset_name="robot", joint_names=[".*"], scale=3.0) + term = JointEffortAction(cfg, env) + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + expected = raw * 3.0 + assert_close(term.processed_actions, expected) + + def test_position_with_default_offset(self, env, art_data, actions_wp): + """use_default_offset=True -> processed == raw + defaults[0].""" + cfg = JointPositionActionCfg(asset_name="robot", joint_names=[".*"], use_default_offset=True) + term = JointPositionAction(cfg, env) + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + defaults = wp.to_torch(art_data.default_joint_pos)[0] + expected = raw * 1.0 + defaults.unsqueeze(0) + assert_close(term.processed_actions, expected) + + def test_position_scale_and_offset(self, env, art_data, actions_wp): + """scale=2, use_default_offset=True -> processed == raw * 2 + defaults[0].""" + cfg = JointPositionActionCfg(asset_name="robot", joint_names=[".*"], scale=2.0, use_default_offset=True) + term = JointPositionAction(cfg, env) + term.process_actions(actions_wp, action_offset=0) + + raw = wp.to_torch(actions_wp) + defaults = wp.to_torch(art_data.default_joint_pos)[0] + expected = raw * 2.0 + defaults.unsqueeze(0) + assert_close(term.processed_actions, expected) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_events_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_events_warp_parity.py new file mode 100644 index 00000000000..d29f15f52d8 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_events_warp_parity.py @@ -0,0 +1,340 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for warp-first event MDP terms.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch +import warp as wp + +# Skip entire module if no CUDA device available +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +import isaaclab_experimental.envs.mdp.events as warp_evt +from parity_helpers import ( + DEVICE, + NUM_ACTIONS, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerTorch, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockScene, + MockSceneEntityCfg, + assert_close, + copy_np_to_wp, +) + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_function_caches(): + """Clear first-call caches on warp MDP functions so each test starts fresh. + + Functions that cache warp views via the ``hasattr`` pattern need clearing + between tests to avoid stale references from prior fixtures. + """ + yield + for fn in ( + warp_evt.push_by_setting_velocity, + warp_evt.apply_external_force_torque, + warp_evt.reset_root_state_uniform, + warp_evt.randomize_rigid_body_com, + ): + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockArticulationData(NUM_ENVS, NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def env_origins(): + rng = np.random.RandomState(77) + origins_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def scene(art_data, env_origins): + return MockScene({"robot": MockArticulation(art_data)}, env_origins) + + +@pytest.fixture() +def action_wp(): + rng = np.random.RandomState(99) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf): + """Env with warp action manager (for experimental functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + # RNG state for events (seeded deterministically) + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf): + """Env with torch action manager (for stable functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerTorch(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def all_joints_cfg(): + return MockSceneEntityCfg("robot", list(range(NUM_JOINTS)), NUM_JOINTS, DEVICE) + + +# ============================================================================ +# Event parity tests: deterministic (zero-width range) warp vs stable +# ============================================================================ + + +class TestEventParity: + """Verify warp event functions produce the same result as stable torch equivalents. + + Since warp and stable use different RNG implementations, parity is tested using + deterministic (zero-width) ranges where randomness has no effect. Both must + produce ``default + 0`` (offset) or ``default * 1`` (scale), clamped to limits. + """ + + def test_reset_joints_by_offset_parity(self, warp_env, stable_env, art_data, all_joints_cfg): + """Zero-offset: both warp and stable should produce clamped defaults.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Set known defaults + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.5, dtype=np.float32) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + # Run warp version + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + wp.synchronize() + warp_pos = wp.to_torch(art_data.joint_pos).clone() + warp_vel = wp.to_torch(art_data.joint_vel).clone() + + # Run stable version (writes via write_joint_position_to_sim_index — which our mock + # does not implement, so we compute the expected result directly) + defaults_t = wp.to_torch(art_data.default_joint_pos).clone() + limits_t = wp.to_torch(art_data.soft_joint_pos_limits) + vel_limits_t = wp.to_torch(art_data.soft_joint_vel_limits) + expected_pos = defaults_t.clamp(limits_t[..., 0], limits_t[..., 1]) + expected_vel = wp.to_torch(art_data.default_joint_vel).clone().clamp(-vel_limits_t, vel_limits_t) + + assert_close(warp_pos, expected_pos) + assert_close(warp_vel, expected_vel) + + def test_reset_joints_by_scale_parity(self, warp_env, stable_env, art_data, all_joints_cfg): + """Scale=1.0: both warp and stable should produce clamped defaults.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Set known defaults + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.25, dtype=np.float32) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + # Run warp version + warp_evt.reset_joints_by_scale( + warp_env, mask, position_range=(1.0, 1.0), velocity_range=(1.0, 1.0), asset_cfg=cfg + ) + wp.synchronize() + warp_pos = wp.to_torch(art_data.joint_pos).clone() + warp_vel = wp.to_torch(art_data.joint_vel).clone() + + # Expected: default * 1.0, clamped to limits + defaults_t = wp.to_torch(art_data.default_joint_pos).clone() + limits_t = wp.to_torch(art_data.soft_joint_pos_limits) + vel_limits_t = wp.to_torch(art_data.soft_joint_vel_limits) + expected_pos = defaults_t.clamp(limits_t[..., 0], limits_t[..., 1]) + expected_vel = wp.to_torch(art_data.default_joint_vel).clone().clamp(-vel_limits_t, vel_limits_t) + + assert_close(warp_pos, expected_pos) + assert_close(warp_vel, expected_vel) + + +# ============================================================================ +# Event capture-mutate-replay tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +class TestEventCapturedDataMutation: + """Verify event functions are capture-safe and react to mutated input data.""" + + # -- reset_joints_by_offset ------------------------------------------------- + + def test_reset_joints_by_offset(self, warp_env, art_data, all_joints_cfg): + """With zero-width offset, result == defaults. Mutate defaults -> result tracks.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Warm-up + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + + # Capture + with wp.ScopedCapture() as cap: + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + + # Mutate defaults in-place + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.5, dtype=np.float32) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + # Replay + wp.capture_launch(cap.graph) + wp.synchronize() + + # With zero offset, joint_pos should equal new defaults (clamped to limits [-3.14, 3.14]) + result = wp.to_torch(art_data.joint_pos) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.5, device=DEVICE) + assert_close(result, expected) + + # -- reset_joints_by_scale -------------------------------------------------- + + def test_reset_joints_by_scale(self, warp_env, art_data, all_joints_cfg): + """With scale=1.0, result == defaults. Mutate defaults -> result tracks.""" + cfg = all_joints_cfg + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + warp_evt.reset_joints_by_scale( + warp_env, mask, position_range=(1.0, 1.0), velocity_range=(1.0, 1.0), asset_cfg=cfg + ) + with wp.ScopedCapture() as cap: + warp_evt.reset_joints_by_scale( + warp_env, mask, position_range=(1.0, 1.0), velocity_range=(1.0, 1.0), asset_cfg=cfg + ) + + new_defaults = np.full((NUM_ENVS, NUM_JOINTS), 0.25, dtype=np.float32) + copy_np_to_wp(art_data.default_joint_pos, new_defaults) + + wp.capture_launch(cap.graph) + wp.synchronize() + + result = wp.to_torch(art_data.joint_pos) + expected = torch.full((NUM_ENVS, NUM_JOINTS), 0.25, device=DEVICE) + assert_close(result, expected) + + # -- push_by_setting_velocity ----------------------------------------------- + + def test_push_by_setting_velocity(self, warp_env, art_data, all_joints_cfg): + """With zero-width velocity range, scratch == root_vel_w. Mutate root_vel_w -> scratch tracks.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + zero_range = { + "x": (0.0, 0.0), + "y": (0.0, 0.0), + "z": (0.0, 0.0), + "roll": (0.0, 0.0), + "pitch": (0.0, 0.0), + "yaw": (0.0, 0.0), + } + + warp_evt.push_by_setting_velocity(warp_env, mask, velocity_range=zero_range) + with wp.ScopedCapture() as cap: + warp_evt.push_by_setting_velocity(warp_env, mask, velocity_range=zero_range) + + # Mutate root_vel_w + new_vel = np.tile([1.0, 2.0, 3.0, 0.1, 0.2, 0.3], (NUM_ENVS, 1)).astype(np.float32) + copy_np_to_wp(art_data.root_vel_w, new_vel) + + wp.capture_launch(cap.graph) + wp.synchronize() + + scratch = wp.to_torch(warp_evt.push_by_setting_velocity._scratch_vel) + expected = torch.tensor([1.0, 2.0, 3.0, 0.1, 0.2, 0.3], device=DEVICE).expand(NUM_ENVS, -1) + assert_close(scratch, expected) + + # -- apply_external_force_torque -------------------------------------------- + + def test_apply_external_force_torque(self, warp_env, art_data, all_joints_cfg): + """With zero-width ranges, forces/torques are zero. Non-zero ranges produce non-zero output.""" + mask = wp.array([True] * NUM_ENVS, dtype=wp.bool, device=DEVICE) + + # Zero-range: forces and torques should be zero + warp_evt.apply_external_force_torque(warp_env, mask, force_range=(0.0, 0.0), torque_range=(0.0, 0.0)) + with wp.ScopedCapture() as cap: + warp_evt.apply_external_force_torque(warp_env, mask, force_range=(0.0, 0.0), torque_range=(0.0, 0.0)) + wp.capture_launch(cap.graph) + wp.synchronize() + + forces = wp.to_torch(warp_evt.apply_external_force_torque._scratch_forces) + torques = wp.to_torch(warp_evt.apply_external_force_torque._scratch_torques) + assert_close(forces, torch.zeros_like(forces)) + assert_close(torques, torch.zeros_like(torques)) + + # -- reset_root_state_uniform ----------------------------------------------- + + # -- env_mask selectivity --------------------------------------------------- + + def test_reset_joints_mask_selectivity(self, warp_env, art_data, all_joints_cfg): + """Only masked envs are modified; unmasked envs retain their state.""" + cfg = all_joints_cfg + # Mask: only first half of envs + mask_np = np.array([i < NUM_ENVS // 2 for i in range(NUM_ENVS)]) + mask = wp.array(mask_np, dtype=wp.bool, device=DEVICE) + + # Set joint_pos to a known value + sentinel = np.full((NUM_ENVS, NUM_JOINTS), 999.0, dtype=np.float32) + copy_np_to_wp(art_data.joint_pos, sentinel) + + # Set defaults to 0 + copy_np_to_wp(art_data.default_joint_pos, np.zeros((NUM_ENVS, NUM_JOINTS), dtype=np.float32)) + + warp_evt.reset_joints_by_offset( + warp_env, mask, position_range=(0.0, 0.0), velocity_range=(0.0, 0.0), asset_cfg=cfg + ) + wp.synchronize() + + result = wp.to_torch(art_data.joint_pos) + # Masked envs: reset to 0 (defaults + 0 offset) + assert_close(result[: NUM_ENVS // 2], torch.zeros(NUM_ENVS // 2, NUM_JOINTS, device=DEVICE)) + # Unmasked envs: still 999.0 + assert_close(result[NUM_ENVS // 2 :], torch.full((NUM_ENVS // 2, NUM_JOINTS), 999.0, device=DEVICE)) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_observations_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_observations_warp_parity.py new file mode 100644 index 00000000000..9f1eea23bf8 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_observations_warp_parity.py @@ -0,0 +1,458 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for warp-first observation MDP terms.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch +import warp as wp + +# Skip entire module if no CUDA device available +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +import isaaclab_experimental.envs.mdp.observations as warp_obs +from parity_helpers import ( + CMD_DIM, + DEVICE, + NUM_ACTIONS, + NUM_BODIES, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerTorch, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockCommandManager, + MockCommandTerm, + MockContactSensor, + MockContactSensorData, + MockScene, + MockSceneEntityCfg, + assert_close, + mutate_art_data, + run_warp_obs, + run_warp_obs_captured, +) + +import isaaclab.envs.mdp.observations as stable_obs + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_caches(): + yield + for fn in [warp_obs.generated_commands]: + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockArticulationData(NUM_ENVS, NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def art_data_bodies(): + return MockArticulationData(num_bodies=NUM_BODIES) + + +@pytest.fixture() +def env_origins(): + rng = np.random.RandomState(77) + origins_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def contact_data(): + return MockContactSensorData() + + +@pytest.fixture() +def cmd_tensor(): + rng = np.random.RandomState(99) + return torch.tensor(rng.randn(NUM_ENVS, CMD_DIM).astype(np.float32), device=DEVICE) + + +@pytest.fixture() +def cmd_term(): + return MockCommandTerm() + + +@pytest.fixture() +def scene(art_data, env_origins): + return MockScene({"robot": MockArticulation(art_data)}, env_origins) + + +@pytest.fixture() +def scene_bodies(art_data_bodies, env_origins, contact_data): + art = MockArticulation(art_data_bodies, num_bodies=NUM_BODIES) + sensor = MockContactSensor(contact_data) + return MockScene({"robot": art}, env_origins, sensors={"contact_sensor": sensor}) + + +@pytest.fixture() +def action_wp(): + rng = np.random.RandomState(99) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b # (action, prev_action) + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf): + """Env with warp action manager (for experimental functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf): + """Env with torch action manager (for stable functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerTorch(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def warp_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for new-terms observation tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for stable new-terms observation tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def all_joints_cfg(): + return MockSceneEntityCfg("robot", list(range(NUM_JOINTS)), NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def subset_cfg(): + return MockSceneEntityCfg("robot", [0, 2, 5, 8], NUM_JOINTS, DEVICE) + + +# ============================================================================ +# Observation parity tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +class TestObservationParity: + """Verify experimental observation Warp kernels match stable torch implementations.""" + + # -- Root state observations ------------------------------------------------ + + def test_base_pos_z(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_pos_z(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_pos_z, warp_env, (NUM_ENVS, 1), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_base_lin_vel(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_lin_vel(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_lin_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_base_ang_vel(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.base_ang_vel(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.base_ang_vel, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_projected_gravity(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.projected_gravity(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.projected_gravity, warp_env, (NUM_ENVS, 3), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint observations (all joints) ---------------------------------------- + + def test_joint_pos_all(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_vel_all(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint observations (subset) ------------------------------------------- + + def test_joint_pos_subset(self, warp_env, stable_env, subset_cfg): + cfg = subset_cfg + n_selected = len(cfg.joint_ids) + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_pos, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_vel_subset(self, warp_env, stable_env, subset_cfg): + cfg = subset_cfg + n_selected = len(cfg.joint_ids) + expected = stable_obs.joint_vel(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + actual_cap = run_warp_obs_captured(warp_obs.joint_vel, warp_env, (NUM_ENVS, n_selected), asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Normalized joint position ---------------------------------------------- + + def test_joint_pos_limit_normalized(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_obs.joint_pos_limit_normalized(stable_env, asset_cfg=cfg) + actual = run_warp_obs(warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg) + actual_cap = run_warp_obs_captured( + warp_obs.joint_pos_limit_normalized, warp_env, (NUM_ENVS, NUM_JOINTS), asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Action observation ----------------------------------------------------- + + def test_last_action(self, warp_env, stable_env, action_wp): + # Stable last_action returns env.action_manager.action (torch tensor) + expected = stable_obs.last_action(stable_env) + actual = run_warp_obs(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) + actual_cap = run_warp_obs_captured(warp_obs.last_action, warp_env, (NUM_ENVS, NUM_ACTIONS)) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# Observation parity tests (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestObservationParityNewTerms: + """Verify observation Warp kernels for newly migrated terms match stable torch implementations.""" + + def test_generated_commands(self, warp_env_bodies, stable_env_bodies): + expected = stable_obs.generated_commands(stable_env_bodies, command_name="vel") + actual = run_warp_obs(warp_obs.generated_commands, warp_env_bodies, (NUM_ENVS, CMD_DIM), command_name="vel") + actual_cap = run_warp_obs_captured( + warp_obs.generated_commands, warp_env_bodies, (NUM_ENVS, CMD_DIM), command_name="vel" + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# Capture-then-mutate-then-replay observation tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +def _mutate_art_data(art_data: MockArticulationData, warp_env, rng_seed: int = 200): + """Mutate every data array in-place so captured graphs see fresh values.""" + mutate_art_data(art_data, warp_env, rng_seed=rng_seed) + + +class TestCapturedDataMutationObservations: + """Capture a graph, mutate buffer data in-place, replay -- results must match stable on the *new* data. + + This verifies observation MDP functions are truly capture-safe. + """ + + def _capture_mutate_check_obs(self, warp_fn, stable_fn, warp_env, stable_env, art_data, shape, **kwargs): + out = wp.zeros(shape, dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) # warm-up + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_close(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + def test_base_pos_z(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_pos_z, + stable_obs.base_pos_z, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 1), + asset_cfg=all_joints_cfg, + ) + + def test_base_lin_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_lin_vel, + stable_obs.base_lin_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_base_ang_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.base_ang_vel, + stable_obs.base_ang_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_projected_gravity(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.projected_gravity, + stable_obs.projected_gravity, + warp_env, + stable_env, + art_data, + (NUM_ENVS, 3), + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_pos, + stable_obs.joint_pos, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_vel, + stable_obs.joint_vel, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos_limit_normalized(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_obs( + warp_obs.joint_pos_limit_normalized, + stable_obs.joint_pos_limit_normalized, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_JOINTS), + asset_cfg=all_joints_cfg, + ) + + def test_last_action(self, warp_env, stable_env, art_data): + self._capture_mutate_check_obs( + warp_obs.last_action, + stable_obs.last_action, + warp_env, + stable_env, + art_data, + (NUM_ENVS, NUM_ACTIONS), + ) + + +# ============================================================================ +# Capture-mutate-replay observation tests (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestCapturedDataMutationObservationsNewTerms: + """Capture graph, mutate buffer data, replay -- verify new-terms observation results match stable.""" + + def test_generated_commands(self, warp_env_bodies, stable_env_bodies, art_data_bodies, cmd_tensor): + """Mutate command tensor, replay captured graph, verify new commands are read.""" + out = wp.zeros((NUM_ENVS, CMD_DIM), dtype=wp.float32, device=DEVICE) + warp_obs.generated_commands(warp_env_bodies, out, command_name="vel") + with wp.ScopedCapture() as cap: + warp_obs.generated_commands(warp_env_bodies, out, command_name="vel") + # Mutate the command tensor in-place (zero-copy view picks it up) + cmd_tensor[:] = torch.randn_like(cmd_tensor) + wp.capture_launch(cap.graph) + expected = stable_obs.generated_commands(stable_env_bodies, command_name="vel") + assert_close(wp.to_torch(out).clone(), expected) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_rewards_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_rewards_warp_parity.py new file mode 100644 index 00000000000..69cabcd14a3 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_rewards_warp_parity.py @@ -0,0 +1,567 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for warp-first reward MDP terms.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch +import warp as wp + +# Skip entire module if no CUDA device available +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +import isaaclab_experimental.envs.mdp.rewards as warp_rew +from parity_helpers import ( + BODY_IDS, + CMD_DIM, + DEVICE, + NUM_ACTIONS, + NUM_BODIES, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerTorch, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockBodyCfg, + MockCommandManager, + MockCommandTerm, + MockContactSensor, + MockContactSensorData, + MockScene, + MockSceneEntityCfg, + MockSensorCfg, + MockTerminationManager, + assert_close, + mutate_art_data, + mutate_body_data, + run_warp_rew, + run_warp_rew_captured, +) + +import isaaclab.envs.mdp.rewards as stable_rew + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_caches(): + yield + for fn in [warp_rew.track_lin_vel_xy_exp, warp_rew.track_ang_vel_z_exp, warp_rew.undesired_contacts]: + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockArticulationData(NUM_ENVS, NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def art_data_bodies(): + return MockArticulationData(num_bodies=NUM_BODIES) + + +@pytest.fixture() +def env_origins(): + rng = np.random.RandomState(77) + origins_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def contact_data(): + return MockContactSensorData() + + +@pytest.fixture() +def cmd_tensor(): + rng = np.random.RandomState(99) + return torch.tensor(rng.randn(NUM_ENVS, CMD_DIM).astype(np.float32), device=DEVICE) + + +@pytest.fixture() +def cmd_term(): + return MockCommandTerm() + + +@pytest.fixture() +def scene(art_data, env_origins): + return MockScene({"robot": MockArticulation(art_data)}, env_origins) + + +@pytest.fixture() +def scene_bodies(art_data_bodies, env_origins, contact_data): + art = MockArticulation(art_data_bodies, num_bodies=NUM_BODIES) + sensor = MockContactSensor(contact_data) + return MockScene({"robot": art}, env_origins, sensors={"contact_sensor": sensor}) + + +@pytest.fixture() +def action_wp(): + rng = np.random.RandomState(99) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def term_mgr(): + return MockTerminationManager() + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf, term_mgr): + """Env with warp action manager (for experimental functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.termination_manager = term_mgr + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf, term_mgr): + """Env with torch action manager (for stable functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerTorch(action_wp[0], action_wp[1]) + env.termination_manager = term_mgr + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def warp_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for new-terms reward tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for stable new-terms reward tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def all_joints_cfg(): + return MockSceneEntityCfg("robot", list(range(NUM_JOINTS)), NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def body_cfg(): + return MockBodyCfg("robot", BODY_IDS) + + +@pytest.fixture() +def sensor_cfg(): + return MockSensorCfg("contact_sensor", BODY_IDS) + + +# ============================================================================ +# Reward parity tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +class TestRewardParity: + """Verify experimental reward Warp kernels match stable torch implementations.""" + + # -- General rewards -------------------------------------------------------- + + def test_is_alive(self, warp_env, stable_env, term_mgr): + # Set some envs as terminated so the reward is non-trivial + term_mgr.terminated[::2] = True + expected = stable_rew.is_alive(stable_env) + actual = run_warp_rew(warp_rew.is_alive, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.is_alive, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_is_terminated(self, warp_env, stable_env, term_mgr): + term_mgr.terminated[::3] = True + expected = stable_rew.is_terminated(stable_env) + actual = run_warp_rew(warp_rew.is_terminated, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.is_terminated, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Root penalties --------------------------------------------------------- + + def test_lin_vel_z_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.lin_vel_z_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.lin_vel_z_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_ang_vel_xy_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.ang_vel_xy_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.ang_vel_xy_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_flat_orientation_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.flat_orientation_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.flat_orientation_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint L2 penalties (masked) -------------------------------------------- + + def test_joint_vel_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_vel_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_vel_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_acc_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_acc_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_acc_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_joint_torques_l2(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_torques_l2(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_torques_l2, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Joint L1 penalties (masked) -------------------------------------------- + + def test_joint_vel_l1(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_vel_l1(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_vel_l1, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_vel_l1, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Action penalties ------------------------------------------------------- + + def test_action_l2(self, warp_env, stable_env): + expected = stable_rew.action_l2(stable_env) + actual = run_warp_rew(warp_rew.action_l2, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.action_l2, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_action_rate_l2(self, warp_env, stable_env): + expected = stable_rew.action_rate_l2(stable_env) + actual = run_warp_rew(warp_rew.action_rate_l2, warp_env) + actual_cap = run_warp_rew_captured(warp_rew.action_rate_l2, warp_env) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Limit penalties -------------------------------------------------------- + + def test_joint_pos_limits(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_pos_limits(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_pos_limits, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + # -- Additional penalties --------------------------------------------------- + + def test_joint_deviation_l1(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + expected = stable_rew.joint_deviation_l1(stable_env, asset_cfg=cfg) + actual = run_warp_rew(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) + actual_cap = run_warp_rew_captured(warp_rew.joint_deviation_l1, warp_env, asset_cfg=cfg) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + +# ============================================================================ +# New reward parity tests (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestNewRewardParity: + """Verify newly migrated reward Warp kernels match stable torch implementations.""" + + def test_track_lin_vel_xy_exp(self, warp_env_bodies, stable_env_bodies, body_cfg): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) # needed for stable + std = 0.25 + expected = stable_rew.track_lin_vel_xy_exp(stable_env_bodies, std=std, command_name="vel", asset_cfg=cfg) + actual = run_warp_rew( + warp_rew.track_lin_vel_xy_exp, warp_env_bodies, std=std, command_name="vel", asset_cfg=cfg + ) + actual_cap = run_warp_rew_captured( + warp_rew.track_lin_vel_xy_exp, warp_env_bodies, std=std, command_name="vel", asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_track_ang_vel_z_exp(self, warp_env_bodies, stable_env_bodies, body_cfg): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + std = 0.25 + expected = stable_rew.track_ang_vel_z_exp(stable_env_bodies, std=std, command_name="vel", asset_cfg=cfg) + actual = run_warp_rew(warp_rew.track_ang_vel_z_exp, warp_env_bodies, std=std, command_name="vel", asset_cfg=cfg) + actual_cap = run_warp_rew_captured( + warp_rew.track_ang_vel_z_exp, warp_env_bodies, std=std, command_name="vel", asset_cfg=cfg + ) + assert_close(actual, expected) + assert_close(actual_cap, expected) + + def test_undesired_contacts(self, warp_env_bodies, stable_env_bodies, sensor_cfg): + threshold = 1.0 + expected = stable_rew.undesired_contacts(stable_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg) + actual = run_warp_rew(warp_rew.undesired_contacts, warp_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg) + actual_cap = run_warp_rew_captured( + warp_rew.undesired_contacts, warp_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg + ) + assert_close(actual, expected.float()) + assert_close(actual_cap, expected.float()) + + +# ============================================================================ +# Capture-then-mutate-then-replay reward tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +def _mutate_art_data(art_data: MockArticulationData, warp_env, rng_seed: int = 200): + """Mutate every data array in-place so captured graphs see fresh values.""" + mutate_art_data(art_data, warp_env, rng_seed=rng_seed) + + +def _mutate_body_data(art_data: MockArticulationData, rng_seed=200): + """Mutate body-level and root-level data in-place so captured graphs see fresh values.""" + mutate_body_data(art_data, rng_seed=rng_seed) + + +class TestCapturedDataMutationRewards: + """Capture a graph, mutate buffer data in-place, replay -- results must match stable on the *new* data. + + This verifies reward MDP functions are truly capture-safe. + """ + + def _capture_mutate_check_rew(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_close(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + def test_lin_vel_z_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.lin_vel_z_l2, + stable_rew.lin_vel_z_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_ang_vel_xy_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.ang_vel_xy_l2, + stable_rew.ang_vel_xy_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_flat_orientation_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.flat_orientation_l2, + stable_rew.flat_orientation_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_vel_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_vel_l2, + stable_rew.joint_vel_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_acc_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_acc_l2, + stable_rew.joint_acc_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_torques_l2(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_torques_l2, + stable_rew.joint_torques_l2, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_action_l2(self, warp_env, stable_env, art_data): + self._capture_mutate_check_rew( + warp_rew.action_l2, + stable_rew.action_l2, + warp_env, + stable_env, + art_data, + ) + + def test_action_rate_l2(self, warp_env, stable_env, art_data): + self._capture_mutate_check_rew( + warp_rew.action_rate_l2, + stable_rew.action_rate_l2, + warp_env, + stable_env, + art_data, + ) + + def test_joint_pos_limits(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_pos_limits, + stable_rew.joint_pos_limits, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + def test_joint_deviation_l1(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_rew( + warp_rew.joint_deviation_l1, + stable_rew.joint_deviation_l1, + warp_env, + stable_env, + art_data, + asset_cfg=all_joints_cfg, + ) + + +# ============================================================================ +# Capture-mutate-replay reward tests for new terms (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestCapturedDataMutationRewardsNewTerms: + """Capture graph, mutate buffer data, replay -- verify new-terms reward results match stable.""" + + def _capture_mutate_check_rew(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.float32, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_body_data(art_data) + wp.capture_launch(cap.graph) + expected = stable_fn(stable_env, **kwargs) + assert_close(wp.to_torch(out).clone(), expected) + + def test_track_lin_vel_xy_exp(self, warp_env_bodies, stable_env_bodies, art_data_bodies): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + self._capture_mutate_check_rew( + warp_rew.track_lin_vel_xy_exp, + stable_rew.track_lin_vel_xy_exp, + warp_env_bodies, + stable_env_bodies, + art_data_bodies, + std=0.25, + command_name="vel", + asset_cfg=cfg, + ) + + def test_track_ang_vel_z_exp(self, warp_env_bodies, stable_env_bodies, art_data_bodies): + cfg = MockBodyCfg("robot") + cfg.joint_ids = list(range(NUM_JOINTS)) + self._capture_mutate_check_rew( + warp_rew.track_ang_vel_z_exp, + stable_rew.track_ang_vel_z_exp, + warp_env_bodies, + stable_env_bodies, + art_data_bodies, + std=0.25, + command_name="vel", + asset_cfg=cfg, + ) diff --git a/source/isaaclab_experimental/test/envs/mdp/test_terminations_warp_parity.py b/source/isaaclab_experimental/test/envs/mdp/test_terminations_warp_parity.py new file mode 100644 index 00000000000..d82339e0122 --- /dev/null +++ b/source/isaaclab_experimental/test/envs/mdp/test_terminations_warp_parity.py @@ -0,0 +1,349 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Parity tests for warp-first termination MDP terms.""" + +from __future__ import annotations + +import numpy as np +import pytest +import torch +import warp as wp + +# Skip entire module if no CUDA device available +wp.init() +pytestmark = pytest.mark.skipif(not wp.is_cuda_available(), reason="CUDA device required") + +import isaaclab_experimental.envs.mdp.terminations as warp_term +from parity_helpers import ( + BODY_IDS, + CMD_DIM, + DEVICE, + NUM_ACTIONS, + NUM_BODIES, + NUM_ENVS, + NUM_JOINTS, + MockActionManagerTorch, + MockActionManagerWarp, + MockArticulation, + MockArticulationData, + MockCommandManager, + MockCommandTerm, + MockContactSensor, + MockContactSensorData, + MockScene, + MockSceneEntityCfg, + MockSensorCfg, + MockTerminationManager, + assert_equal, + mutate_art_data, + mutate_body_data, + run_warp_term, + run_warp_term_captured, +) + +import isaaclab.envs.mdp.terminations as stable_term + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture(autouse=True) +def _clear_caches(): + yield + for fn in [warp_term.illegal_contact]: + for attr in list(vars(fn)): + if attr.startswith("_"): + delattr(fn, attr) + + +@pytest.fixture() +def art_data(): + return MockArticulationData(NUM_ENVS, NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def art_data_bodies(): + return MockArticulationData(num_bodies=NUM_BODIES) + + +@pytest.fixture() +def env_origins(): + rng = np.random.RandomState(77) + origins_np = rng.randn(NUM_ENVS, 3).astype(np.float32) + return wp.array(origins_np, dtype=wp.vec3f, device=DEVICE) + + +@pytest.fixture() +def contact_data(): + return MockContactSensorData() + + +@pytest.fixture() +def cmd_tensor(): + rng = np.random.RandomState(99) + return torch.tensor(rng.randn(NUM_ENVS, CMD_DIM).astype(np.float32), device=DEVICE) + + +@pytest.fixture() +def cmd_term(): + return MockCommandTerm() + + +@pytest.fixture() +def scene(art_data, env_origins): + return MockScene({"robot": MockArticulation(art_data)}, env_origins) + + +@pytest.fixture() +def scene_bodies(art_data_bodies, env_origins, contact_data): + art = MockArticulation(art_data_bodies, num_bodies=NUM_BODIES) + sensor = MockContactSensor(contact_data) + return MockScene({"robot": art}, env_origins, sensors={"contact_sensor": sensor}) + + +@pytest.fixture() +def action_wp(): + rng = np.random.RandomState(99) + a = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + b = wp.array(rng.randn(NUM_ENVS, NUM_ACTIONS).astype(np.float32), device=DEVICE) + return a, b + + +@pytest.fixture() +def episode_length_buf(): + torch.manual_seed(55) + return torch.randint(0, 500, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + + +@pytest.fixture() +def warp_env(scene, action_wp, episode_length_buf): + """Env with warp action manager (for experimental functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env(scene, action_wp, episode_length_buf): + """Env with torch action manager (for stable functions).""" + + class _Env: + pass + + env = _Env() + env.scene = scene + env.action_manager = MockActionManagerTorch(action_wp[0], action_wp[1]) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env.step_dt = 0.02 + env.max_episode_length_s = 10.0 + return env + + +@pytest.fixture() +def warp_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for new-terms termination tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env._episode_length_buf_wp = wp.from_torch(episode_length_buf) + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + env.rng_state_wp = wp.array(np.arange(NUM_ENVS, dtype=np.uint32) + 42, device=DEVICE) + return env + + +@pytest.fixture() +def stable_env_bodies(scene_bodies, action_wp, episode_length_buf, cmd_tensor, cmd_term): + """Env with body-level data and command manager (for stable new-terms termination tests).""" + + class _Env: + pass + + env = _Env() + env.scene = scene_bodies + env.action_manager = MockActionManagerWarp(action_wp[0], action_wp[1]) + env.command_manager = MockCommandManager(cmd_tensor, cmd_term) + env.num_envs = NUM_ENVS + env.device = DEVICE + env.episode_length_buf = episode_length_buf + env._episode_length_buf_wp = wp.from_torch(episode_length_buf) + env.step_dt = 0.02 + env.max_episode_length = 500 + env.max_episode_length_s = 10.0 + # stable termination_manager needed for time_out + env.termination_manager = MockTerminationManager() + return env + + +@pytest.fixture() +def all_joints_cfg(): + return MockSceneEntityCfg("robot", list(range(NUM_JOINTS)), NUM_JOINTS, DEVICE) + + +@pytest.fixture() +def sensor_cfg(): + return MockSensorCfg("contact_sensor", BODY_IDS) + + +# ============================================================================ +# Termination parity tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +class TestTerminationParity: + """Verify experimental termination Warp kernels match stable torch implementations.""" + + def test_root_height_below_minimum(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + min_h = 0.5 + expected = stable_term.root_height_below_minimum(stable_env, minimum_height=min_h, asset_cfg=cfg) + actual = run_warp_term(warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg) + actual_cap = run_warp_term_captured( + warp_term.root_height_below_minimum, warp_env, minimum_height=min_h, asset_cfg=cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_joint_pos_out_of_manual_limit(self, warp_env, stable_env, all_joints_cfg): + cfg = all_joints_cfg + bounds = (-1.0, 1.0) + expected = stable_term.joint_pos_out_of_manual_limit(stable_env, bounds=bounds, asset_cfg=cfg) + actual = run_warp_term(warp_term.joint_pos_out_of_manual_limit, warp_env, bounds=bounds, asset_cfg=cfg) + actual_cap = run_warp_term_captured( + warp_term.joint_pos_out_of_manual_limit, warp_env, bounds=bounds, asset_cfg=cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + +# ============================================================================ +# Termination parity tests (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestTerminationParityNewTerms: + """Verify termination Warp kernels for newly migrated terms match stable torch implementations.""" + + def test_time_out(self, warp_env_bodies, stable_env_bodies): + expected = stable_term.time_out(stable_env_bodies) + actual = run_warp_term(warp_term.time_out, warp_env_bodies) + actual_cap = run_warp_term_captured(warp_term.time_out, warp_env_bodies) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + def test_illegal_contact(self, warp_env_bodies, stable_env_bodies, sensor_cfg): + threshold = 1.0 + expected = stable_term.illegal_contact(stable_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg) + actual = run_warp_term(warp_term.illegal_contact, warp_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg) + actual_cap = run_warp_term_captured( + warp_term.illegal_contact, warp_env_bodies, threshold=threshold, sensor_cfg=sensor_cfg + ) + assert_equal(actual, expected) + assert_equal(actual_cap, expected) + + +# ============================================================================ +# Capture-then-mutate-then-replay termination tests (from test_mdp_warp_parity.py) +# ============================================================================ + + +def _mutate_art_data(art_data: MockArticulationData, warp_env, rng_seed: int = 200): + """Mutate every data array in-place so captured graphs see fresh values.""" + mutate_art_data(art_data, warp_env, rng_seed=rng_seed) + + +def _mutate_body_data(art_data: MockArticulationData, rng_seed=200): + """Mutate body-level and root-level data in-place so captured graphs see fresh values.""" + mutate_body_data(art_data, rng_seed=rng_seed) + + +class TestCapturedDataMutationTerminations: + """Capture a graph, mutate buffer data in-place, replay -- results must match stable on the *new* data. + + This verifies termination MDP functions are truly capture-safe. + """ + + def _capture_mutate_check_term(self, warp_fn, stable_fn, warp_env, stable_env, art_data, **kwargs): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_fn(warp_env, out, **kwargs) + with wp.ScopedCapture() as cap: + warp_fn(warp_env, out, **kwargs) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + assert_equal(wp.to_torch(out).clone(), stable_fn(stable_env, **kwargs)) + + def test_root_height_below_minimum(self, warp_env, stable_env, art_data, all_joints_cfg): + self._capture_mutate_check_term( + warp_term.root_height_below_minimum, + stable_term.root_height_below_minimum, + warp_env, + stable_env, + art_data, + minimum_height=0.5, + asset_cfg=all_joints_cfg, + ) + + def test_joint_pos_out_of_manual_limit(self, warp_env, stable_env, art_data, all_joints_cfg): + # joint_pos_out_of_manual_limit uses a 2D kernel that only writes True + # (never clears to False), so the output must be zeroed before each call. + # We include the zeroing inside the captured graph. + bounds = (-1.0, 1.0) + cfg = all_joints_cfg + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + # warm-up + out.zero_() + warp_term.joint_pos_out_of_manual_limit(warp_env, out, bounds=bounds, asset_cfg=cfg) + # capture (including the zero) + with wp.ScopedCapture() as cap: + out.zero_() + warp_term.joint_pos_out_of_manual_limit(warp_env, out, bounds=bounds, asset_cfg=cfg) + _mutate_art_data(art_data, warp_env) + wp.capture_launch(cap.graph) + expected = stable_term.joint_pos_out_of_manual_limit(stable_env, bounds=bounds, asset_cfg=cfg) + assert_equal(wp.to_torch(out).clone(), expected) + + +# ============================================================================ +# Capture-mutate-replay termination tests for new terms (from test_mdp_warp_parity_new_terms.py) +# ============================================================================ + + +class TestCapturedDataMutationTerminationsNewTerms: + """Capture graph, mutate buffer data, replay -- verify new-terms termination results match stable.""" + + def test_time_out(self, warp_env_bodies, stable_env_bodies, art_data_bodies): + out = wp.zeros((NUM_ENVS,), dtype=wp.bool, device=DEVICE) + warp_term.time_out(warp_env_bodies, out) + with wp.ScopedCapture() as cap: + warp_term.time_out(warp_env_bodies, out) + # Mutate episode length in-place + warp_env_bodies.episode_length_buf[:] = torch.randint(0, 600, (NUM_ENVS,), dtype=torch.int64, device=DEVICE) + wp.capture_launch(cap.graph) + expected = stable_term.time_out(stable_env_bodies) + assert_equal(wp.to_torch(out).clone(), expected) From 7ee1f431a307d484a1751b51df94b3f458aa7bc9 Mon Sep 17 00:00:00 2001 From: Jichuan Hu Date: Fri, 13 Mar 2026 02:25:40 -0700 Subject: [PATCH 7/7] Add warp environment docs and align step timer across all env variants --- .../newton-physics-integration/index.rst | 4 + .../warp-env-migration.rst | 283 ++++++++++++++++++ .../warp-environments.rst | 204 +++++++++++++ source/isaaclab/isaaclab/cli/__init__.py | 6 +- .../isaaclab/isaaclab/envs/direct_rl_env.py | 5 + .../isaaclab/envs/manager_based_rl_env.py | 6 + .../envs/direct_rl_env_warp.py | 24 +- 7 files changed, 519 insertions(+), 13 deletions(-) create mode 100644 docs/source/experimental-features/newton-physics-integration/warp-env-migration.rst create mode 100644 docs/source/experimental-features/newton-physics-integration/warp-environments.rst diff --git a/docs/source/experimental-features/newton-physics-integration/index.rst b/docs/source/experimental-features/newton-physics-integration/index.rst index 731e74e4a6b..e4def0fca37 100644 --- a/docs/source/experimental-features/newton-physics-integration/index.rst +++ b/docs/source/experimental-features/newton-physics-integration/index.rst @@ -38,6 +38,10 @@ For an overview of how the multi-backend architecture works, including how to ad :titlesonly: installation + warp-environments + warp-env-migration + training-environments + visualization limitations-and-known-bugs solver-transitioning sim-to-sim diff --git a/docs/source/experimental-features/newton-physics-integration/warp-env-migration.rst b/docs/source/experimental-features/newton-physics-integration/warp-env-migration.rst new file mode 100644 index 00000000000..24b631b9ffe --- /dev/null +++ b/docs/source/experimental-features/newton-physics-integration/warp-env-migration.rst @@ -0,0 +1,283 @@ +.. _warp-env-migration: + +Warp Environment Guide +====================== + +This guide covers the key conventions and patterns used by the warp-first environment +infrastructure, useful for migrating existing stable environments or creating new ones natively. + +.. note:: + + The warp environment infrastructure lives in ``isaaclab_experimental`` and + ``isaaclab_tasks_experimental``. It's an experimental feature. + + +Design Rationale +~~~~~~~~~~~~~~~~ + +The warp environment path is built around `CUDA graph capture +`_. +A CUDA graph records a sequence of GPU operations (kernel launches, memory copies) during a +capture phase, then replays the entire sequence with a single launch. This eliminates per-kernel +CPU overhead — the parameter validation, kernel selection, and buffer setup that normally costs +20–200 μs per operation is performed once during graph instantiation and reused on every replay +(~10 μs total). All CPU-side code (Python logic, torch dispatching) executed during capture is +completely bypassed during replay. See the `Warp concurrency documentation +`_ for Warp's graph capture API +(``wp.ScopedCapture``). + +All design decisions in the warp infrastructure follow from this constraint: every operation in the +step loop must be a GPU kernel launch with stable memory pointers so that the captured graph can +be replayed without modification. + +Key consequences: + +- All buffers are **pre-allocated** — no dynamic allocation inside the step loop +- Data flows through **persistent ``wp.array`` pointers** — never replaced, only overwritten +- MDP terms are **pure ``@wp.kernel`` functions** — no Python branching on GPU data +- Reset uses **boolean masks** (``env_mask``) instead of index lists (``env_ids``) to avoid + variable-length indexing that changes graph topology + + +Project Structure +~~~~~~~~~~~~~~~~~ + +Warp-specific implementations that deviate from stable live in the ``_experimental`` packages: + +- ``isaaclab_experimental`` — warp managers, base env classes, warp MDP terms +- ``isaaclab_tasks_experimental`` — warp task configs and task-specific MDP terms + +Any new warp implementation that differs from the stable API belongs in these packages. +Warp task configs reference Newton physics directly (no ``PresetCfg``) since the warp path +is Newton-only. + + +Writing Warp MDP Terms +~~~~~~~~~~~~~~~~~~~~~~ + +Imports +^^^^^^^ + +Warp task configs import from the experimental packages: + +.. code-block:: python + + # Warp + from isaaclab_experimental.managers import ObservationTermCfg, RewardTermCfg, SceneEntityCfg + import isaaclab_experimental.envs.mdp as mdp + +The term config classes have the same interface — only the import path changes. + + +Common Pattern +^^^^^^^^^^^^^^ + +All warp MDP terms (observations, rewards, terminations, events, actions) follow the same +**kernel + launch** pattern. Stable terms use torch tensors and return results; warp terms +write into pre-allocated ``wp.array`` output buffers via ``@wp.kernel`` functions: + +.. code-block:: python + + # Stable — returns a tensor + def lin_vel_z_l2(env, asset_cfg) -> torch.Tensor: + return torch.square(asset.data.root_lin_vel_b[:, 2]) + + # Warp — writes into pre-allocated output + @wp.kernel + def _lin_vel_z_l2_kernel(vel: wp.array(...), out: wp.array(dtype=wp.float32)): + i = wp.tid() + out[i] = vel[i][2] * vel[i][2] + + def lin_vel_z_l2(env, out, asset_cfg) -> None: + wp.launch(_lin_vel_z_l2_kernel, dim=env.num_envs, inputs=[..., out]) + +The output buffer shapes differ by term type: + +- **Observations**: ``(num_envs, D)`` where D is the observation dimension +- **Rewards**: ``(num_envs,)`` +- **Terminations**: ``(num_envs,)`` with dtype ``bool`` +- **Events**: ``(num_envs,)`` mask — events don't produce output, they modify sim state + + +Observation Terms +^^^^^^^^^^^^^^^^^ + +Since warp terms write into pre-allocated buffers, the observation manager must know each +term's output dimension at initialization to allocate the correct ``(num_envs, D)`` output +array. This is resolved via a fallback chain (see +``ObservationManager._infer_term_dim_scalar`` in +``isaaclab_experimental/managers/observation_manager.py``): + +1. **Explicit ``out_dim`` in decorator** (preferred): + + .. code-block:: python + + @generic_io_descriptor_warp(out_dim=3, observation_type="RootState") + def base_lin_vel(env, out, asset_cfg) -> None: ... + + ``out_dim`` can be an integer, or a string that resolves at initialization: + + - ``"joint"`` — number of selected joints from ``asset_cfg`` + - ``"body:N"`` — N components per selected body from ``asset_cfg`` + - ``"command"`` — dimension from command manager + - ``"action"`` — dimension from action manager + +2. **``axes`` metadata**: Dimension equals the number of axes listed: + + .. code-block:: python + + @generic_io_descriptor_warp(axes=["X", "Y", "Z"], observation_type="RootState") + def projected_gravity(env, out, asset_cfg) -> None: ... + # → dimension = 3 + +3. **Legacy params**: ``term_dim``, ``out_dim``, or ``obs_dim`` keys in ``term_cfg.params``. + +4. **Asset config fallback**: Count of ``asset_cfg.joint_ids`` (or ``joint_ids_wp``) for + joint-level terms. + + +Event Terms +^^^^^^^^^^^ + +Events use ``env_mask`` (boolean ``wp.array``) instead of ``env_ids``, and each kernel +checks the mask to skip non-selected environments: + +.. code-block:: python + + def reset_joints_by_offset(env, env_mask, ...): + wp.launch(_kernel, dim=env.num_envs, inputs=[env_mask, ...]) + + @wp.kernel + def _kernel(env_mask: wp.array(dtype=wp.bool), ...): + i = wp.tid() + if not env_mask[i]: + return + # ... modify state for selected envs only + +- RNG uses per-env ``env.rng_state_wp`` (``wp.uint32``) instead of ``torch.rand`` +- **Startup/prestartup** events use the stable convention ``(env, env_ids, **params)`` +- **Reset/interval** events use the warp convention ``(env, env_mask, **params)`` + + +Action Terms +^^^^^^^^^^^^ + +Actions follow a **two-stage execution**: ``process_actions`` (called once per env step) scales +and clips raw actions, and ``apply_actions`` (called once per sim step) writes targets to the +asset. Both stages use warp kernels with pre-allocated ``_raw_actions`` and ``_processed_actions`` +buffers. + + +Capture Safety +^^^^^^^^^^^^^^ + +When writing terms that run inside the captured step loop, keep in mind: + +- **No ``wp.to_torch``** or torch arithmetic — stay in warp throughout +- **No lazy-evaluated properties** — use sim-bound (Tier 1) data directly; if a derived + quantity is needed, compute it inline in the kernel +- **No dynamic allocation** — all buffers must be pre-allocated in ``__init__`` + + +Parity Testing +~~~~~~~~~~~~~~ + +Two levels of parity testing are used to validate warp terms: + +**1. Implementation parity (stable vs warp)** — verifies that the warp kernel produces the +same result as the stable torch implementation. This is optional for terms that have no stable +counterpart (e.g. new terms written directly in warp). + +.. code-block:: python + + import isaaclab.envs.mdp.observations as stable_obs + import isaaclab_experimental.envs.mdp.observations as warp_obs + + # Stable baseline + expected = stable_obs.joint_pos(stable_env, asset_cfg=cfg) + + # Warp (uncaptured) + out = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) + warp_obs.joint_pos(warp_env, out, asset_cfg=cfg) + actual = wp.to_torch(out) + + torch.testing.assert_close(actual, expected) + +**2. Capture parity (warp vs warp-captured)** — verifies that the term produces identical +results when replayed from a CUDA graph vs launched directly. A mismatch here indicates capture-unsafe +code (e.g. stale pointers, dynamic allocation, or lazy property access that doesn't replay). +This test should always be run, even for terms without a stable counterpart. + +.. code-block:: python + + # Warp uncaptured + out_uncaptured = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) + warp_obs.joint_pos(warp_env, out_uncaptured, asset_cfg=cfg) + + # Warp captured (graph replay) + out_captured = wp.zeros((num_envs, num_joints), dtype=wp.float32, device=device) + with wp.ScopedCapture() as cap: + warp_obs.joint_pos(warp_env, out_captured, asset_cfg=cfg) + wp.capture_launch(cap.graph) + + torch.testing.assert_close(wp.to_torch(out_captured), wp.to_torch(out_uncaptured)) + +See ``source/isaaclab_experimental/test/envs/mdp/`` for complete parity test examples. + + +Available Warp MDP Terms +~~~~~~~~~~~~~~~~~~~~~~~~ + +.. list-table:: + :header-rows: 1 + :widths: 20 80 + + * - Category + - Available Terms + * - Observations (11) + - | ``base_pos_z`` + | ``base_lin_vel`` + | ``base_ang_vel`` + | ``projected_gravity`` + | ``joint_pos`` + | ``joint_pos_rel`` + | ``joint_pos_limit_normalized`` + | ``joint_vel`` + | ``joint_vel_rel`` + | ``last_action`` + | ``generated_commands`` + * - Rewards (16) + - | ``is_alive`` + | ``is_terminated`` + | ``lin_vel_z_l2`` + | ``ang_vel_xy_l2`` + | ``flat_orientation_l2`` + | ``joint_torques_l2`` + | ``joint_vel_l1`` + | ``joint_vel_l2`` + | ``joint_acc_l2`` + | ``joint_deviation_l1`` + | ``joint_pos_limits`` + | ``action_rate_l2`` + | ``action_l2`` + | ``undesired_contacts`` + | ``track_lin_vel_xy_exp`` + | ``track_ang_vel_z_exp`` + * - Events (6) + - | ``reset_joints_by_offset`` + | ``reset_joints_by_scale`` + | ``reset_root_state_uniform`` + | ``push_by_setting_velocity`` + | ``apply_external_force_torque`` + | ``randomize_rigid_body_com`` + * - Terminations (4) + - | ``time_out`` + | ``root_height_below_minimum`` + | ``joint_pos_out_of_manual_limit`` + | ``illegal_contact`` + * - Actions (2) + - | ``JointPositionAction`` + | ``JointEffortAction`` + +Terms not listed here remain in stable only. When using an env that requires unlisted terms, +those terms must be implemented in warp first. diff --git a/docs/source/experimental-features/newton-physics-integration/warp-environments.rst b/docs/source/experimental-features/newton-physics-integration/warp-environments.rst new file mode 100644 index 00000000000..88a404b4a9f --- /dev/null +++ b/docs/source/experimental-features/newton-physics-integration/warp-environments.rst @@ -0,0 +1,204 @@ +.. _warp-environments: + +Warp Experimental Environments +============================== + +.. note:: + + The warp environment infrastructure lives in ``isaaclab_experimental`` and + ``isaaclab_tasks_experimental``. It's an experimental feature. + +The experimental extensions introduce **warp-first** environment infrastructure with CUDA graph capture +support. All environment-side computation (observations, rewards, resets, actions) runs as pure Warp +kernels, eliminating Python overhead and enabling CUDA graph capture for maximum throughput. + + +Workflows +~~~~~~~~~ + +Two environment workflows are supported: + +**Direct workflow** — ``DirectRLEnvWarp`` base class. You implement the step loop, observations, +rewards, and resets directly in your env class using Warp kernels. + +**Manager-based workflow** — ``ManagerBasedRLEnvWarp`` base class. You define MDP terms as +standalone Warp-kernel functions and compose them via configuration. + + +Available Environments +~~~~~~~~~~~~~~~~~~~~~~ + +Direct Warp Environments +^^^^^^^^^^^^^^^^^^^^^^^^ + +- ``Isaac-Cartpole-Direct-Warp-v0`` — Cartpole balance +- ``Isaac-Ant-Direct-Warp-v0`` — Ant locomotion +- ``Isaac-Humanoid-Direct-Warp-v0`` — Humanoid locomotion +- ``Isaac-Repose-Cube-Allegro-Direct-Warp-v0`` — Allegro hand cube repose + + +Manager-Based Warp Environments +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +**Classic** + +- ``Isaac-Cartpole-Warp-v0`` +- ``Isaac-Ant-Warp-v0`` +- ``Isaac-Humanoid-Warp-v0`` + +**Locomotion (Flat)** + +- ``Isaac-Velocity-Flat-Anymal-B-Warp-v0`` +- ``Isaac-Velocity-Flat-Anymal-C-Warp-v0`` +- ``Isaac-Velocity-Flat-Anymal-D-Warp-v0`` +- ``Isaac-Velocity-Flat-Cassie-Warp-v0`` +- ``Isaac-Velocity-Flat-G1-Warp-v0`` +- ``Isaac-Velocity-Flat-G1-Warp-v1`` +- ``Isaac-Velocity-Flat-H1-Warp-v0`` +- ``Isaac-Velocity-Flat-Unitree-A1-Warp-v0`` +- ``Isaac-Velocity-Flat-Unitree-Go1-Warp-v0`` +- ``Isaac-Velocity-Flat-Unitree-Go2-Warp-v0`` + +**Manipulation** + +- ``Isaac-Reach-Franka-Warp-v0`` +- ``Isaac-Reach-UR10-Warp-v0`` + + +Quick Start +~~~~~~~~~~~ + +.. code-block:: bash + + # Direct workflow + ./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py \ + --task Isaac-Cartpole-Direct-Warp-v0 --num_envs 4096 --headless + + # Manager-based workflow + ./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py \ + --task Isaac-Velocity-Flat-Anymal-C-Warp-v0 --num_envs 4096 --headless + +All RL libraries with warp-compatible wrappers are supported: RSL-RL, RL Games, SKRL, and +Stable-Baselines3. + + +Performance Comparison +~~~~~~~~~~~~~~~~~~~~~~ + +Step time comparison between the stable (torch/manager) and warp (CUDA graph captured) variants, +both running on the Newton physics backend. Measured over 300 iterations with 4096 environments. + +.. note:: + + The warp migration is an ongoing effort. Several components (e.g. scene write, actuator models) + have not yet been migrated to Warp kernels and still run through torch. Further performance + improvements are expected as these components are migrated. + +.. list-table:: + :header-rows: 1 + :widths: 30 12 15 15 12 + + * - Env + - Type + - Stable Step (us) + - Warp Step (us) + - Change + * - Cartpole-Direct + - Direct + - 5,274 + - 4,331 + - -17.88% + * - Ant-Direct + - Direct + - 6,368 + - 3,128 + - -50.88% + * - Humanoid-Direct + - Direct + - 13,937 + - 10,783 + - -22.63% + * - Allegro-Direct + - Direct + - 82,950 + - 74,570 + - -10.10% + * - Cartpole + - Manager + - 7,971 + - 3,642 + - -54.31% + * - Ant + - Manager + - 9,781 + - 4,672 + - -52.23% + * - Humanoid + - Manager + - 17,653 + - 12,505 + - -29.16% + * - Reach-Franka + - Manager + - 11,458 + - 7,813 + - -31.83% + * - Anymal-B + - Manager + - 29,188 + - 21,781 + - -25.38% + * - Anymal-C + - Manager + - 30,938 + - 22,228 + - -28.15% + * - Anymal-D + - Manager + - 32,294 + - 23,977 + - -25.75% + * - Cassie + - Manager + - 17,320 + - 10,706 + - -38.19% + * - G1 + - Manager + - 34,487 + - 27,300 + - -20.84% + * - H1 + - Manager + - 22,202 + - 15,864 + - -28.55% + * - A1 + - Manager + - 15,257 + - 9,907 + - -35.07% + * - Go1 + - Manager + - 16,515 + - 11,869 + - -28.13% + * - Go2 + - Manager + - 15,221 + - 9,966 + - -34.52% + + +Adding New Warp Environments +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +To add a new warp environment: + +1. Create a task config in ``isaaclab_tasks_experimental`` mirroring the stable config structure. +2. Import MDP terms from ``isaaclab_experimental.envs.mdp`` instead of ``isaaclab.envs.mdp``. +3. Configure Newton physics with ``use_cuda_graph=True``. +4. Register the task with a ``-Warp-`` suffix in the gym ID. + +For a detailed guide on converting each component (observations, rewards, events, actions), +see :doc:`warp-env-migration`. diff --git a/source/isaaclab/isaaclab/cli/__init__.py b/source/isaaclab/isaaclab/cli/__init__.py index 833bc6945cf..c99500215f5 100644 --- a/source/isaaclab/isaaclab/cli/__init__.py +++ b/source/isaaclab/isaaclab/cli/__init__.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause import argparse +import sys from .commands.envs import command_setup_conda, command_setup_uv from .commands.format import command_format @@ -143,9 +144,10 @@ def cli() -> None: elif args.python is not None: if args.python: - run_python_command(args.python[0], args.python[1:]) + result = run_python_command(args.python[0], args.python[1:]) else: - run_python_command("-i", []) + result = run_python_command("-i", []) + sys.exit(result.returncode) elif args.sim is not None: command_run_isaacsim(args.sim) diff --git a/source/isaaclab/isaaclab/envs/direct_rl_env.py b/source/isaaclab/isaaclab/envs/direct_rl_env.py index b362ac72bc2..39ad429358c 100644 --- a/source/isaaclab/isaaclab/envs/direct_rl_env.py +++ b/source/isaaclab/isaaclab/envs/direct_rl_env.py @@ -8,6 +8,7 @@ import inspect import logging import math +import os import warnings import weakref from abc import abstractmethod @@ -34,6 +35,9 @@ from .ui import ViewportCameraController from .utils.spaces import sample_space, spec_to_gym_space +DEBUG_TIMER_STEP = os.environ.get("DEBUG_TIMER_STEP", "0") == "1" +DEBUG_TIMERS = os.environ.get("DEBUG_TIMERS", "0") == "1" + if has_kit(): import omni.kit.app @@ -352,6 +356,7 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) # return observations return self._get_observations(), self.extras + @Timer(name="env_step", msg="Step took:", enable=DEBUG_TIMERS, time_unit="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics. diff --git a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py index d08b7e3be3a..7ab4ec1f813 100644 --- a/source/isaaclab/isaaclab/envs/manager_based_rl_env.py +++ b/source/isaaclab/isaaclab/envs/manager_based_rl_env.py @@ -7,6 +7,7 @@ from __future__ import annotations import math +import os from collections.abc import Sequence from typing import Any, ClassVar @@ -16,8 +17,12 @@ from isaaclab.managers import CommandManager, CurriculumManager, RewardManager, TerminationManager from isaaclab.ui.widgets import ManagerLiveVisualizer +from isaaclab.utils.timer import Timer from .common import VecEnvStepReturn + +DEBUG_TIMER_STEP = os.environ.get("DEBUG_TIMER_STEP", "0") == "1" +DEBUG_TIMERS = os.environ.get("DEBUG_TIMERS", "0") == "1" from .manager_based_env import ManagerBasedEnv from .manager_based_rl_env_cfg import ManagerBasedRLEnvCfg @@ -150,6 +155,7 @@ def setup_manager_visualizers(self): Operations - MDP """ + @Timer(name="env_step", msg="Step took:", enable=DEBUG_TIMERS, time_unit="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics and reset terminated environments. diff --git a/source/isaaclab_experimental/isaaclab_experimental/envs/direct_rl_env_warp.py b/source/isaaclab_experimental/isaaclab_experimental/envs/direct_rl_env_warp.py index 9bb0b7ed371..a5733051ff8 100644 --- a/source/isaaclab_experimental/isaaclab_experimental/envs/direct_rl_env_warp.py +++ b/source/isaaclab_experimental/isaaclab_experimental/envs/direct_rl_env_warp.py @@ -377,7 +377,7 @@ def reset(self, seed: int | None = None, options: dict[str, Any] | None = None) self._get_observations() return {"policy": self.torch_obs_buf.clone()}, self.extras - @Timer(name="env_step", msg="Step took:", enable=DEBUG_TIMER_STEP or DEBUG_TIMERS) + @Timer(name="env_step", msg="Step took:", enable=DEBUG_TIMERS, time_unit="us") def step(self, action: torch.Tensor) -> VecEnvStepReturn: """Execute one time-step of the environment's dynamics. @@ -409,7 +409,7 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: action = self._action_noise_model(action) # process actions, #TODO pass the torch tensor directly. - with Timer(name="pre_physics", msg="Pre-physics step took:", enable=DEBUG_TIMERS): + with Timer(name="pre_physics", msg="Pre-physics step took:", enable=DEBUG_TIMER_STEP): self._pre_physics_step( wp.from_torch(action) ) # Creates a tensor and discards it. Not graphable unless training loop reuses the same pointer. @@ -420,20 +420,22 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: is_rendering = bool(self.sim.settings.get("/isaaclab/visualizer")) or _has_rtx # perform physics stepping - with Timer(name="physics_loop", msg="Physics loop took:", enable=DEBUG_TIMERS): + with Timer(name="physics_loop", msg="Physics loop took:", enable=DEBUG_TIMER_STEP): for _ in range(self.cfg.decimation): self._sim_step_counter += 1 # set actions into buffers # simulate - with Timer(name="apply_action", msg="Action processing step took:", enable=DEBUG_TIMERS): + with Timer(name="apply_action", msg="Action processing step took:", enable=DEBUG_TIMER_STEP): self._graph_cache.capture_or_replay("action", self.step_warp_action) # write_data_to_sim runs outside the CUDA graph because _apply_actuator_model # uses torch ops (wp.to_torch + torch arithmetic) that cross CUDA streams. - with Timer(name="write_data_to_sim_loop", msg="Write data to sim (loop) took:", enable=DEBUG_TIMERS): + with Timer( + name="write_data_to_sim_loop", msg="Write data to sim (loop) took:", enable=DEBUG_TIMER_STEP + ): self.scene.write_data_to_sim() - with Timer(name="simulate", msg="Newton simulation step took:", enable=DEBUG_TIMERS): + with Timer(name="simulate", msg="Newton simulation step took:", enable=DEBUG_TIMER_STEP): self.sim.step(render=False) # render between steps only if the GUI or an RTX sensor needs it # note: we assume the render interval to be the shortest accepted rendering interval. @@ -441,21 +443,21 @@ def step(self, action: torch.Tensor) -> VecEnvStepReturn: if self._sim_step_counter % self.cfg.sim.render_interval == 0 and is_rendering: self.sim.render() # update buffers at sim dt - with Timer(name="scene_update", msg="Scene update took:", enable=DEBUG_TIMERS): + with Timer(name="scene_update", msg="Scene update took:", enable=DEBUG_TIMER_STEP): self.scene.update(dt=self.physics_dt) self.common_step_counter += 1 # total step (common for all envs) - with Timer(name="end_pre_graph", msg="End pre-graph took:", enable=DEBUG_TIMERS): + with Timer(name="end_pre_graph", msg="End pre-graph took:", enable=DEBUG_TIMER_STEP): self._graph_cache.capture_or_replay("end_pre", self._step_warp_end_pre) # write_data_to_sim runs uncaptured — it uses torch ops that cross CUDA streams. - with Timer(name="write_data_to_sim_post", msg="Write data to sim (post-reset) took:", enable=DEBUG_TIMERS): + with Timer(name="write_data_to_sim_post", msg="Write data to sim (post-reset) took:", enable=DEBUG_TIMER_STEP): self.scene.write_data_to_sim() - with Timer(name="end_post_graph", msg="End post-graph took:", enable=DEBUG_TIMERS): + with Timer(name="end_post_graph", msg="End post-graph took:", enable=DEBUG_TIMER_STEP): self._graph_cache.capture_or_replay("end_post", self._step_warp_end_post) # Visualization hook — runs after CUDA graph scope. Override in subclass # to update markers or other non-graphable visual elements. - with Timer(name="visualize", msg="Visualize took:", enable=DEBUG_TIMERS): + with Timer(name="visualize", msg="Visualize took:", enable=DEBUG_TIMER_STEP): self._post_step_visualize() # return observations, rewards, resets and extras