diff --git a/scripts/demos/sensors/raycaster_sensor.py b/scripts/demos/sensors/raycaster_sensor.py index 23fceaa55df..f532bc6c0a6 100644 --- a/scripts/demos/sensors/raycaster_sensor.py +++ b/scripts/demos/sensors/raycaster_sensor.py @@ -125,13 +125,13 @@ def run_simulator(sim: sim_utils.SimulationContext, scene: InteractiveScene): # print information from the sensors print("-------------------------------") print(scene["ray_caster"]) - print("Ray cast hit results: ", scene["ray_caster"].data.ray_hits_w) + print("Ray cast hit results: ", wp.to_torch(scene["ray_caster"].data.ray_hits_w)) if not triggered: if countdown > 0: countdown -= 1 continue - data = scene["ray_caster"].data.ray_hits_w.cpu().numpy() + data = wp.to_torch(scene["ray_caster"].data.ray_hits_w).cpu().numpy() np.save("cast_data.npy", data) triggered = True else: diff --git a/scripts/tutorials/04_sensors/add_sensors_on_robot.py b/scripts/tutorials/04_sensors/add_sensors_on_robot.py index 31f9a2bcefc..f5e3a19c0be 100644 --- a/scripts/tutorials/04_sensors/add_sensors_on_robot.py +++ b/scripts/tutorials/04_sensors/add_sensors_on_robot.py @@ -150,7 +150,10 @@ def run_simulator(sim: sim_utils.SimulationContext, scene: InteractiveScene): print("Received shape of depth image: ", scene["camera"].data.output["distance_to_image_plane"].shape) print("-------------------------------") print(scene["height_scanner"]) - print("Received max height value: ", torch.max(scene["height_scanner"].data.ray_hits_w[..., -1]).item()) + print( + "Received max height value: ", + torch.max(wp.to_torch(scene["height_scanner"].data.ray_hits_w)[..., -1]).item(), + ) print("-------------------------------") print(scene["contact_forces"]) print("Received max contact force of: ", torch.max(scene["contact_forces"].data.net_forces_w).item()) diff --git a/scripts/tutorials/04_sensors/run_ray_caster.py b/scripts/tutorials/04_sensors/run_ray_caster.py index 3e46ef1a08f..35b1c8a32a4 100644 --- a/scripts/tutorials/04_sensors/run_ray_caster.py +++ b/scripts/tutorials/04_sensors/run_ray_caster.py @@ -35,6 +35,7 @@ import warp as wp import isaaclab.sim as sim_utils + from isaaclab.assets import RigidObject, RigidObjectCfg from isaaclab.sensors.ray_caster import RayCaster, RayCasterCfg, patterns from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR @@ -120,7 +121,7 @@ def run_simulator(sim: sim_utils.SimulationContext, scene_entities: dict): # Update the ray-caster with Timer( f"Ray-caster update with {4} x {ray_caster.num_rays} rays with max height of" - f" {torch.max(ray_caster.data.pos_w).item():.2f}" + f" {torch.max(wp.to_torch(ray_caster.data.pos_w)).item():.2f}" ): ray_caster.update(dt=sim.get_physics_dt(), force_recompute=True) # Update counter diff --git a/source/isaaclab/isaaclab/envs/mdp/observations.py b/source/isaaclab/isaaclab/envs/mdp/observations.py index 5bae5ec2d54..f4c50e6ab79 100644 --- a/source/isaaclab/isaaclab/envs/mdp/observations.py +++ b/source/isaaclab/isaaclab/envs/mdp/observations.py @@ -304,7 +304,7 @@ def height_scan(env: ManagerBasedEnv, sensor_cfg: SceneEntityCfg, offset: float # extract the used quantities (to enable type-hinting) sensor: RayCaster = env.scene.sensors[sensor_cfg.name] # height scan: height = sensor_height - hit_point_z - offset - return sensor.data.pos_w[:, 2].unsqueeze(1) - sensor.data.ray_hits_w[..., 2] - offset + return wp.to_torch(sensor.data.pos_w)[:, 2].unsqueeze(1) - wp.to_torch(sensor.data.ray_hits_w)[..., 2] - offset def body_incoming_wrench(env: ManagerBasedEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor: diff --git a/source/isaaclab/isaaclab/envs/mdp/rewards.py b/source/isaaclab/isaaclab/envs/mdp/rewards.py index 5a53583f5e4..74bea7ee786 100644 --- a/source/isaaclab/isaaclab/envs/mdp/rewards.py +++ b/source/isaaclab/isaaclab/envs/mdp/rewards.py @@ -116,7 +116,7 @@ def base_height_l2( if sensor_cfg is not None: sensor: RayCaster = env.scene[sensor_cfg.name] # Adjust the target height using the sensor data - adjusted_target_height = target_height + torch.mean(sensor.data.ray_hits_w[..., 2], dim=1) + adjusted_target_height = target_height + torch.mean(wp.to_torch(sensor.data.ray_hits_w)[..., 2], dim=1) else: # Use the provided target height directly for flat terrain adjusted_target_height = target_height diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/kernels.py b/source/isaaclab/isaaclab/sensors/ray_caster/kernels.py new file mode 100644 index 00000000000..7b9a9537b49 --- /dev/null +++ b/source/isaaclab/isaaclab/sensors/ray_caster/kernels.py @@ -0,0 +1,182 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Warp kernels for the ray caster sensor.""" + +import warp as wp + +ALIGNMENT_WORLD = wp.constant(0) +ALIGNMENT_YAW = wp.constant(1) +ALIGNMENT_BASE = wp.constant(2) + + +@wp.func +def quat_yaw_only(q: wp.quatf) -> wp.quatf: + """Extract yaw-only quaternion by zeroing x,y components and renormalizing.""" + z = q[2] + w = q[3] + length = wp.sqrt(z * z + w * w) + if length > 0.0: + return wp.quatf(0.0, 0.0, z / length, w / length) + else: + return wp.quatf(0.0, 0.0, 0.0, 1.0) + + +@wp.kernel(enable_backward=False) +def update_ray_caster_kernel( + transforms: wp.array(dtype=wp.transformf), + env_mask: wp.array(dtype=wp.bool), + offset_pos: wp.array(dtype=wp.vec3f), + offset_quat: wp.array(dtype=wp.quatf), + drift: wp.array(dtype=wp.vec3f), + ray_cast_drift: wp.array(dtype=wp.vec3f), + ray_starts_local: wp.array2d(dtype=wp.vec3f), + ray_directions_local: wp.array2d(dtype=wp.vec3f), + alignment_mode: int, + pos_w: wp.array(dtype=wp.vec3f), + quat_w: wp.array(dtype=wp.quatf), + ray_starts_w: wp.array2d(dtype=wp.vec3f), + ray_directions_w: wp.array2d(dtype=wp.vec3f), +): + """Compute sensor world poses and transform rays into world frame. + + Combines the PhysX view transform with the sensor offset, applies drift, + and transforms local ray starts/directions according to the alignment mode. + + Launch with dim=(num_envs, num_rays). + + Args: + transforms: World transforms from PhysX view. Shape is (num_envs,). + env_mask: Boolean mask for which environments to update. Shape is (num_envs,). + offset_pos: Per-env position offset [m] from view to sensor. Shape is (num_envs,). + offset_quat: Per-env quaternion offset from view to sensor. Shape is (num_envs,). + drift: Per-env position drift [m]. Shape is (num_envs,). + ray_cast_drift: Per-env ray cast drift [m]. Shape is (num_envs,). + ray_starts_local: Per-env local ray start positions. Shape is (num_envs, num_rays). + ray_directions_local: Per-env local ray directions. Shape is (num_envs, num_rays). + alignment_mode: 0=world, 1=yaw, 2=base. + pos_w: Output sensor position in world frame. Shape is (num_envs,). + quat_w: Output sensor orientation in world frame. Shape is (num_envs,). + ray_starts_w: Output world-frame ray starts. Shape is (num_envs, num_rays). + ray_directions_w: Output world-frame ray directions. Shape is (num_envs, num_rays). + """ + env_id, ray_id = wp.tid() + if not env_mask[env_id]: + return + + t = transforms[env_id] + view_pos = wp.transform_get_translation(t) + view_quat = wp.transform_get_rotation(t) + + # combine_frame_transforms: q02 = q01 * q12, t02 = t01 + quat_rotate(q01, t12) + combined_quat = view_quat * offset_quat[env_id] + combined_pos = view_pos + wp.quat_rotate(view_quat, offset_pos[env_id]) + + combined_pos = combined_pos + drift[env_id] + + if ray_id == 0: + pos_w[env_id] = combined_pos + quat_w[env_id] = combined_quat + + local_start = ray_starts_local[env_id, ray_id] + local_dir = ray_directions_local[env_id, ray_id] + rcd = ray_cast_drift[env_id] + + if alignment_mode == ALIGNMENT_WORLD: + pos_drifted = wp.vec3f(combined_pos[0] + rcd[0], combined_pos[1] + rcd[1], combined_pos[2]) + ray_starts_w[env_id, ray_id] = local_start + pos_drifted + ray_directions_w[env_id, ray_id] = local_dir + elif alignment_mode == ALIGNMENT_YAW: + yaw_q = quat_yaw_only(combined_quat) + rot_drift = wp.quat_rotate(yaw_q, rcd) + pos_drifted = wp.vec3f(combined_pos[0] + rot_drift[0], combined_pos[1] + rot_drift[1], combined_pos[2]) + ray_starts_w[env_id, ray_id] = wp.quat_rotate(yaw_q, local_start) + pos_drifted + ray_directions_w[env_id, ray_id] = local_dir + else: + rot_drift = wp.quat_rotate(combined_quat, rcd) + pos_drifted = wp.vec3f(combined_pos[0] + rot_drift[0], combined_pos[1] + rot_drift[1], combined_pos[2]) + ray_starts_w[env_id, ray_id] = wp.quat_rotate(combined_quat, local_start) + pos_drifted + ray_directions_w[env_id, ray_id] = wp.quat_rotate(combined_quat, local_dir) + + +@wp.kernel(enable_backward=False) +def fill_vec3_inf_kernel( + env_mask: wp.array(dtype=wp.bool), + data: wp.array2d(dtype=wp.vec3f), + inf_val: wp.float32, +): + """Fill a 2D vec3f array with a given value for masked environments. + + Launch with dim=(num_envs, num_rays). + + Args: + env_mask: Boolean mask for which environments to update. Shape is (num_envs,). + data: Array to fill. Shape is (num_envs, num_rays). + inf_val: Value to fill with (typically inf). + """ + env, ray = wp.tid() + if not env_mask[env]: + return + data[env, ray] = wp.vec3f(inf_val, inf_val, inf_val) + + +@wp.kernel(enable_backward=False) +def raycast_mesh_masked_kernel( + mesh: wp.uint64, + env_mask: wp.array(dtype=wp.bool), + ray_starts: wp.array2d(dtype=wp.vec3f), + ray_directions: wp.array2d(dtype=wp.vec3f), + ray_hits: wp.array2d(dtype=wp.vec3f), + max_dist: wp.float32, +): + """Ray-cast against a single static mesh for masked environments. + + Launch with dim=(num_envs, num_rays). + + Args: + mesh: The warp mesh id to ray-cast against. + env_mask: Boolean mask for which environments to update. Shape is (num_envs,). + ray_starts: World-frame ray start positions. Shape is (num_envs, num_rays). + ray_directions: World-frame ray directions. Shape is (num_envs, num_rays). + ray_hits: Output ray hit positions [m]. Shape is (num_envs, num_rays). + Pre-filled with inf for missed hits. + max_dist: Maximum ray-cast distance [m]. + """ + env, ray = wp.tid() + if not env_mask[env]: + return + + t = float(0.0) + u = float(0.0) + v = float(0.0) + sign = float(0.0) + n = wp.vec3() + f = int(0) + + hit = wp.mesh_query_ray(mesh, ray_starts[env, ray], ray_directions[env, ray], max_dist, t, u, v, sign, n, f) + if hit: + ray_hits[env, ray] = ray_starts[env, ray] + t * ray_directions[env, ray] + + +@wp.kernel(enable_backward=False) +def apply_z_drift_kernel( + env_mask: wp.array(dtype=wp.bool), + ray_cast_drift: wp.array(dtype=wp.vec3f), + ray_hits: wp.array2d(dtype=wp.vec3f), +): + """Apply vertical (z) drift to ray hit positions for masked environments. + + Launch with dim=(num_envs, num_rays). + + Args: + env_mask: Boolean mask for which environments to update. Shape is (num_envs,). + ray_cast_drift: Per-env drift vector [m]; only z-component is used. Shape is (num_envs,). + ray_hits: Ray hit positions to modify in-place. Shape is (num_envs, num_rays). + """ + env, ray = wp.tid() + if not env_mask[env]: + return + hit = ray_hits[env, ray] + ray_hits[env, ray] = wp.vec3f(hit[0], hit[1], hit[2] + ray_cast_drift[env][2]) diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py index 06ce2183e2f..1cf54887197 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from .multi_mesh_ray_caster_cfg import MultiMeshRayCasterCfg -# import logger logger = logging.getLogger(__name__) @@ -90,33 +89,24 @@ def __init__(self, cfg: MultiMeshRayCasterCfg): Args: cfg: The configuration parameters. """ - # Initialize base class super().__init__(cfg) - # Create empty variables for storing output data self._num_meshes_per_env: dict[str, int] = {} - """Keeps track of the number of meshes per env for each ray_cast target. - Since we allow regex indexing (e.g. env_*/object_*) they can differ - """ self._raycast_targets_cfg: list[MultiMeshRayCasterCfg.RaycastTargetCfg] = [] for target in self.cfg.mesh_prim_paths: - # Legacy support for string targets. Treat them as global targets. if isinstance(target, str): self._raycast_targets_cfg.append(cfg.RaycastTargetCfg(prim_expr=target, track_mesh_transforms=False)) else: self._raycast_targets_cfg.append(target) - # Resolve regex namespace if set for cfg in self._raycast_targets_cfg: cfg.prim_expr = cfg.prim_expr.format(ENV_REGEX_NS="/World/envs/env_.*") - # overwrite the data class self._data = MultiMeshRayCasterData() def __str__(self) -> str: """Returns: A string containing information about the instance.""" - return ( f"Ray-caster @ '{self.cfg.prim_path}': \n" f"\tview type : {self._view.__class__}\n" @@ -133,9 +123,7 @@ def __str__(self) -> str: @property def data(self) -> MultiMeshRayCasterData: - # update sensors if needed self._update_outdated_buffers() - # return the data return self._data """ @@ -163,9 +151,7 @@ def _initialize_warp_meshes(self): """ multi_mesh_ids: dict[str, list[list[int]]] = {} for target_cfg in self._raycast_targets_cfg: - # target prim path to ray cast against target_prim_path = target_cfg.prim_expr - # # check if mesh already casted into warp mesh and skip if so. if target_prim_path in multi_mesh_ids: logger.warning( f"Mesh at target prim path '{target_prim_path}' already exists in the mesh cache. Duplicate entries" @@ -173,30 +159,22 @@ def _initialize_warp_meshes(self): ) continue - # find all matching prim paths to provided expression of the target target_prims = sim_utils.find_matching_prims(target_prim_path) if len(target_prims) == 0: raise RuntimeError(f"Failed to find a prim at path expression: {target_prim_path}") - # If only one prim is found, treat it as a global prim. - # Either it's a single global object (e.g. ground) or we are only using one env. is_global_prim = len(target_prims) == 1 loaded_vertices: list[np.ndarray | None] = [] wp_mesh_ids = [] for target_prim in target_prims: - # Reuse previously parsed shared mesh instance if possible. if target_cfg.is_shared and len(wp_mesh_ids) > 0: - # Verify if this mesh has already been registered in an earlier environment. - # Note, this check may fail, if the prim path is not following the env_.* pattern - # Which (worst case) leads to parsing the mesh and skipping registering it at a later stage - curr_prim_base_path = re.sub(r"env_\d+", "env_0", str(target_prim.GetPath())) # + curr_prim_base_path = re.sub(r"env_\d+", "env_0", str(target_prim.GetPath())) if curr_prim_base_path in MultiMeshRayCaster.meshes: MultiMeshRayCaster.meshes[str(target_prim.GetPath())] = MultiMeshRayCaster.meshes[ curr_prim_base_path ] - # Reuse mesh imported by another ray-cast sensor (global cache). if str(target_prim.GetPath()) in MultiMeshRayCaster.meshes: wp_mesh_ids.append(MultiMeshRayCaster.meshes[str(target_prim.GetPath())].id) loaded_vertices.append(None) @@ -219,7 +197,6 @@ def _initialize_warp_meshes(self): trimesh_meshes = [] for mesh_prim in mesh_prims: - # check if valid if mesh_prim is None or not mesh_prim.IsValid(): raise RuntimeError(f"Invalid mesh prim path: {target_prim}") @@ -240,13 +217,11 @@ def _initialize_warp_meshes(self): transform[:3, 3] = relative_pos.numpy() mesh.apply_transform(transform) - # add to list of parsed meshes trimesh_meshes.append(mesh) if len(trimesh_meshes) == 1: trimesh_mesh = trimesh_meshes[0] elif target_cfg.merge_prim_meshes: - # combine all trimesh meshes into a single mesh trimesh_mesh = trimesh.util.concatenate(trimesh_meshes) else: raise RuntimeError( @@ -254,11 +229,9 @@ def _initialize_warp_meshes(self): " enable `merge_prim_meshes` in the configuration or specify each mesh separately." ) - # check if the mesh is already registered, if so only reference the mesh registered_idx = _registered_points_idx(trimesh_mesh.vertices, loaded_vertices) if registered_idx != -1 and self.cfg.reference_meshes: logger.info("Found a duplicate mesh, only reference the mesh.") - # Found a duplicate mesh, only reference the mesh. loaded_vertices.append(None) wp_mesh_ids.append(wp_mesh_ids[registered_idx]) else: @@ -267,7 +240,6 @@ def _initialize_warp_meshes(self): MultiMeshRayCaster.meshes[str(target_prim.GetPath())] = wp_mesh wp_mesh_ids.append(wp_mesh.id) - # print info if registered_idx != -1: logger.info(f"Found duplicate mesh for mesh prims under path '{target_prim.GetPath()}'.") else: @@ -277,12 +249,9 @@ def _initialize_warp_meshes(self): ) if is_global_prim: - # reference the mesh for each environment to ray cast against multi_mesh_ids[target_prim_path] = [wp_mesh_ids] * self._num_envs self._num_meshes_per_env[target_prim_path] = len(wp_mesh_ids) else: - # split up the meshes for each environment. Little bit ugly, since - # the current order is interleaved (env1_obj1, env1_obj2, env2_obj1, env2_obj2, ...) multi_mesh_ids[target_prim_path] = [] mesh_idx = 0 n_meshes_per_env = len(wp_mesh_ids) // self._num_envs @@ -296,7 +265,6 @@ def _initialize_warp_meshes(self): self._obtain_trackable_prim_view(target_prim_path) ) - # throw an error if no meshes are found if all([target_cfg.prim_expr not in multi_mesh_ids for target_cfg in self._raycast_targets_cfg]): raise RuntimeError( f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}" @@ -306,12 +274,10 @@ def _initialize_warp_meshes(self): self._mesh_positions_w = torch.zeros(self._num_envs, total_n_meshes_per_env, 3, device=self.device) self._mesh_orientations_w = torch.zeros(self._num_envs, total_n_meshes_per_env, 4, device=self.device) - # Update the mesh positions and rotations mesh_idx = 0 for target_cfg in self._raycast_targets_cfg: n_meshes = self._num_meshes_per_env[target_cfg.prim_expr] - # update position of the target meshes pos_w, ori_w = [], [] for prim in sim_utils.find_matching_prims(target_cfg.prim_expr): translation, quat = sim_utils.resolve_prim_pose(prim) @@ -324,7 +290,6 @@ def _initialize_warp_meshes(self): self._mesh_orientations_w[:, mesh_idx : mesh_idx + n_meshes] = ori_w mesh_idx += n_meshes - # flatten the list of meshes that are included in mesh_prim_paths of the specific ray caster multi_mesh_ids_flattened = [] for env_idx in range(self._num_envs): meshes_in_env = [] @@ -337,7 +302,6 @@ def _initialize_warp_meshes(self): for target_cfg in self._raycast_targets_cfg ] - # save a warp array with mesh ids that is passed to the raycast function self._mesh_ids_wp = wp.array2d(multi_mesh_ids_flattened, dtype=wp.uint64, device=self.device) def _initialize_rays_impl(self): @@ -353,7 +317,7 @@ def _update_buffers_impl(self, env_mask: wp.array): if len(env_ids) == 0: return - self._update_ray_infos(env_ids) + self._update_ray_infos(env_mask) # Update the mesh positions and rotations mesh_idx = 0 @@ -362,7 +326,6 @@ def _update_buffers_impl(self, env_mask: wp.array): mesh_idx += self._num_meshes_per_env[target_cfg.prim_expr] continue - # update position of the target meshes pos_w, ori_w = obtain_world_pose_from_view(view, None) pos_w = pos_w.squeeze(0) if len(pos_w.shape) == 3 else pos_w ori_w = ori_w.squeeze(0) if len(ori_w.shape) == 3 else ori_w @@ -373,7 +336,7 @@ def _update_buffers_impl(self, env_mask: wp.array): ori_w = quat_mul(ori_offset.expand(ori_w.shape[0], -1), ori_w) count = view.count - if count != 1: # Mesh is not global, i.e. we have different meshes for each env + if count != 1: count = count // self._num_envs pos_w = pos_w.view(self._num_envs, count, 3) ori_w = ori_w.view(self._num_envs, count, 4) @@ -382,10 +345,11 @@ def _update_buffers_impl(self, env_mask: wp.array): self._mesh_orientations_w[:, mesh_idx : mesh_idx + count] = ori_w mesh_idx += count - self._data.ray_hits_w[env_ids], _, _, _, mesh_ids = raycast_dynamic_meshes( - self._ray_starts_w[env_ids], - self._ray_directions_w[env_ids], - mesh_ids_wp=self._mesh_ids_wp, # list with shape num_envs x num_meshes_per_env + # Use torch views of warp arrays for the torch-based raycast_dynamic_meshes + self._data._ray_hits_w_torch[env_ids], _, _, _, mesh_ids = raycast_dynamic_meshes( + self._ray_starts_w_torch[env_ids], + self._ray_directions_w_torch[env_ids], + mesh_ids_wp=self._mesh_ids_wp, max_dist=self.cfg.max_distance, mesh_positions_w=self._mesh_positions_w[env_ids], mesh_orientations_w=self._mesh_orientations_w[env_ids], diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera_data.py b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera_data.py index d2f26abdbf4..21338f0a061 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera_data.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/multi_mesh_ray_caster_camera_data.py @@ -9,11 +9,15 @@ from isaaclab.sensors.camera import CameraData -from .ray_caster_data import RayCasterData +class MultiMeshRayCasterCameraData(CameraData): + """Data container for the multi-mesh ray-cast camera sensor. -class MultiMeshRayCasterCameraData(CameraData, RayCasterData): - """Data container for the multi-mesh ray-cast sensor.""" + This class extends :class:`CameraData` with additional mesh-id information. + It does not inherit from :class:`RayCasterData` because the camera variant + manages its own torch-based pose and hit buffers independently from the + warp-native :class:`RayCasterData`. + """ image_mesh_ids: torch.Tensor = None """The mesh ids of the image pixels. diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py index 55a9d141132..cdb85ee24db 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster.py @@ -15,22 +15,27 @@ from pxr import Gf, Usd, UsdGeom, UsdPhysics +import omni.physics.tensors.impl.api as physx + import isaaclab.sim as sim_utils import isaaclab.utils.math as math_utils from isaaclab.markers import VisualizationMarkers from isaaclab.sim.views import XformPrimView from isaaclab.terrains.trimesh.utils import make_plane -from isaaclab.utils.math import quat_apply, quat_apply_yaw -from isaaclab.utils.warp import convert_to_warp_mesh, raycast_mesh +from isaaclab.utils.warp import convert_to_warp_mesh from ..sensor_base import SensorBase -from .ray_cast_utils import obtain_world_pose_from_view +from .kernels import ( + apply_z_drift_kernel, + fill_vec3_inf_kernel, + raycast_mesh_masked_kernel, + update_ray_caster_kernel, +) from .ray_caster_data import RayCasterData if TYPE_CHECKING: from .ray_caster_cfg import RayCasterCfg -# import logger logger = logging.getLogger(__name__) @@ -53,7 +58,6 @@ class RayCaster(SensorBase): cfg: RayCasterCfg """The configuration parameters.""" - # Class variables to share meshes across instances meshes: ClassVar[dict[str, wp.Mesh]] = {} """A dictionary to store warp meshes for raycasting, shared across all instances. @@ -68,9 +72,7 @@ def __init__(self, cfg: RayCasterCfg): cfg: The configuration parameters. """ RayCaster._instance_count += 1 - # Initialize base class super().__init__(cfg) - # Create empty variables for storing output data self._data = RayCasterData() def __str__(self) -> str: @@ -116,10 +118,10 @@ def reset(self, env_ids: Sequence[int] | None = None, env_mask: wp.array | None else: env_ids = slice(None) num_envs_ids = self._view.count - # resample the drift + # resample drift (uses torch views for indexing) r = torch.empty(num_envs_ids, 3, device=self.device) self.drift[env_ids] = r.uniform_(*self.cfg.drift_range) - # resample the height drift + # resample the ray cast drift range_list = [self.cfg.ray_cast_drift_range.get(key, (0.0, 0.0)) for key in ["x", "y", "z"]] ranges = torch.tensor(range_list, device=self.device) self.ray_cast_drift[env_ids] = math_utils.sample_uniform( @@ -132,7 +134,6 @@ def reset(self, env_ids: Sequence[int] | None = None, env_mask: wp.array | None def _initialize_impl(self): super()._initialize_impl() - # obtain global simulation view self._physics_sim_view = sim_utils.SimulationContext.instance().physics_manager.get_physics_sim_view() prim = sim_utils.find_first_matching_prim(self.cfg.prim_path) @@ -144,9 +145,32 @@ def _initialize_impl(self): self._view, self._offset = self._obtain_trackable_prim_view(self.cfg.prim_path) - # load the meshes by parsing the stage + # Convert offsets to warp (zero-copy from existing torch tensors) + self._offset_pos_wp = wp.from_torch(self._offset[0].contiguous(), dtype=wp.vec3f) + self._offset_quat_wp = wp.from_torch(self._offset[1].contiguous(), dtype=wp.quatf) + + # Handle deprecated attach_yaw_only at init time + if self.cfg.attach_yaw_only is not None: + msg = ( + "Raycaster attribute 'attach_yaw_only' property will be deprecated in a future release." + " Please use the parameter 'ray_alignment' instead." + ) + if self.cfg.attach_yaw_only: + self.cfg.ray_alignment = "yaw" + msg += " Setting ray_alignment to 'yaw'." + else: + self.cfg.ray_alignment = "base" + msg += " Setting ray_alignment to 'base'." + logger.warning(msg) + self.cfg.attach_yaw_only = None + + # Resolve alignment mode to integer constant for kernel dispatch + alignment_map = {"world": 0, "yaw": 1, "base": 2} + if self.cfg.ray_alignment not in alignment_map: + raise RuntimeError(f"Unsupported ray_alignment type: {self.cfg.ray_alignment}.") + self._alignment_mode = alignment_map[self.cfg.ray_alignment] + self._initialize_warp_meshes() - # initialize the ray start and directions self._initialize_rays_impl() def _initialize_warp_meshes(self): @@ -158,29 +182,20 @@ def _initialize_warp_meshes(self): # read prims to ray-cast for mesh_prim_path in self.cfg.mesh_prim_paths: - # check if mesh already casted into warp mesh if mesh_prim_path in RayCaster.meshes: continue - # check if the prim is a plane - handle PhysX plane as a special case - # if a plane exists then we need to create an infinite mesh that is a plane mesh_prim = sim_utils.get_first_matching_child_prim( mesh_prim_path, lambda prim: prim.GetTypeName() == "Plane" ) - # if we did not find a plane then we need to read the mesh if mesh_prim is None: - # obtain the mesh prim mesh_prim = sim_utils.get_first_matching_child_prim( mesh_prim_path, lambda prim: prim.GetTypeName() == "Mesh" ) - # check if valid if mesh_prim is None or not mesh_prim.IsValid(): raise RuntimeError(f"Invalid mesh prim path: {mesh_prim_path}") - # cast into UsdGeomMesh mesh_prim = UsdGeom.Mesh(mesh_prim) - # read the vertices and faces points = np.asarray(mesh_prim.GetPointsAttr().Get()) - # Get world transform using pure USD (UsdGeom.Xformable) xformable = UsdGeom.Xformable(mesh_prim) world_transform: Gf.Matrix4d = xformable.ComputeLocalToWorldTransform(Usd.TimeCode.Default()) transform_matrix = np.array(world_transform).T @@ -188,140 +203,157 @@ def _initialize_warp_meshes(self): points += transform_matrix[:3, 3] indices = np.asarray(mesh_prim.GetFaceVertexIndicesAttr().Get()) wp_mesh = convert_to_warp_mesh(points, indices, device=self.device) - # print info logger.info( f"Read mesh prim: {mesh_prim.GetPath()} with {len(points)} vertices and {len(indices)} faces." ) else: mesh = make_plane(size=(2e6, 2e6), height=0.0, center_zero=True) wp_mesh = convert_to_warp_mesh(mesh.vertices, mesh.faces, device=self.device) - # print info logger.info(f"Created infinite plane mesh prim: {mesh_prim.GetPath()}.") - # add the warp mesh to the list RayCaster.meshes[mesh_prim_path] = wp_mesh - # throw an error if no meshes are found if all([mesh_prim_path not in RayCaster.meshes for mesh_prim_path in self.cfg.mesh_prim_paths]): raise RuntimeError( f"No meshes found for ray-casting! Please check the mesh prim paths: {self.cfg.mesh_prim_paths}" ) def _initialize_rays_impl(self): - # compute ray stars and directions - self.ray_starts, self.ray_directions = self.cfg.pattern_cfg.func(self.cfg.pattern_cfg, self._device) - self.num_rays = len(self.ray_directions) - # apply offset transformation to the rays + # Compute ray starts and directions from pattern (torch, init-time only) + ray_starts_torch, ray_directions_torch = self.cfg.pattern_cfg.func(self.cfg.pattern_cfg, self._device) + self.num_rays = len(ray_directions_torch) + + # Apply sensor offset rotation/position to local ray pattern offset_pos = torch.tensor(list(self.cfg.offset.pos), device=self._device) offset_quat = torch.tensor(list(self.cfg.offset.rot), device=self._device) - self.ray_directions = quat_apply(offset_quat.repeat(len(self.ray_directions), 1), self.ray_directions) - self.ray_starts += offset_pos - # repeat the rays for each sensor - self.ray_starts = self.ray_starts.repeat(self._view.count, 1, 1) - self.ray_directions = self.ray_directions.repeat(self._view.count, 1, 1) - # prepare drift - self.drift = torch.zeros(self._view.count, 3, device=self.device) - self.ray_cast_drift = torch.zeros(self._view.count, 3, device=self.device) - # fill the data buffer - self._data.pos_w = torch.zeros(self._view.count, 3, device=self.device) - self._data.quat_w = torch.zeros(self._view.count, 4, device=self.device) - self._data.ray_hits_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) - self._ray_starts_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) - self._ray_directions_w = torch.zeros(self._view.count, self.num_rays, 3, device=self.device) - - def _update_ray_infos(self, env_ids: Sequence[int]): - """Updates the ray information buffers.""" - - pos_w, quat_w = obtain_world_pose_from_view(self._view, env_ids) - pos_w, quat_w = math_utils.combine_frame_transforms( - pos_w, quat_w, self._offset[0][env_ids], self._offset[1][env_ids] + ray_directions_torch = math_utils.quat_apply( + offset_quat.repeat(len(ray_directions_torch), 1), ray_directions_torch ) - # apply drift to ray starting position in world frame - pos_w += self.drift[env_ids] - # store the poses - self._data.pos_w[env_ids] = pos_w - self._data.quat_w[env_ids] = quat_w + ray_starts_torch += offset_pos - # check if user provided attach_yaw_only flag - if self.cfg.attach_yaw_only is not None: - msg = ( - "Raycaster attribute 'attach_yaw_only' property will be deprecated in a future release." - " Please use the parameter 'ray_alignment' instead." - ) - # set ray alignment to yaw - if self.cfg.attach_yaw_only: - self.cfg.ray_alignment = "yaw" - msg += " Setting ray_alignment to 'yaw'." - else: - self.cfg.ray_alignment = "base" - msg += " Setting ray_alignment to 'base'." - # log the warning - logger.warning(msg) - # ray cast based on the sensor poses - if self.cfg.ray_alignment == "world": - # apply horizontal drift to ray starting position in ray caster frame - pos_w[:, 0:2] += self.ray_cast_drift[env_ids, 0:2] - # no rotation is considered and directions are not rotated - ray_starts_w = self.ray_starts[env_ids] - ray_starts_w += pos_w.unsqueeze(1) - ray_directions_w = self.ray_directions[env_ids] - elif self.cfg.ray_alignment == "yaw": - # apply horizontal drift to ray starting position in ray caster frame - pos_w[:, 0:2] += quat_apply_yaw(quat_w, self.ray_cast_drift[env_ids])[:, 0:2] - # only yaw orientation is considered and directions are not rotated - ray_starts_w = quat_apply_yaw(quat_w.repeat(1, self.num_rays), self.ray_starts[env_ids]) - ray_starts_w += pos_w.unsqueeze(1) - ray_directions_w = self.ray_directions[env_ids] - elif self.cfg.ray_alignment == "base": - # apply horizontal drift to ray starting position in ray caster frame - pos_w[:, 0:2] += quat_apply(quat_w, self.ray_cast_drift[env_ids])[:, 0:2] - # full orientation is considered - ray_starts_w = quat_apply(quat_w.repeat(1, self.num_rays), self.ray_starts[env_ids]) - ray_starts_w += pos_w.unsqueeze(1) - ray_directions_w = quat_apply(quat_w.repeat(1, self.num_rays), self.ray_directions[env_ids]) - else: - raise RuntimeError(f"Unsupported ray_alignment type: {self.cfg.ray_alignment}.") + # Repeat for each environment + ray_starts_torch = ray_starts_torch.repeat(self._view.count, 1, 1) + ray_directions_torch = ray_directions_torch.repeat(self._view.count, 1, 1) + + # Create warp arrays from the init-time torch data + # The warp arrays own the memory; torch views provide backward-compat indexing + self._ray_starts_local = wp.from_torch(ray_starts_torch.contiguous(), dtype=wp.vec3f) + self._ray_directions_local = wp.from_torch(ray_directions_torch.contiguous(), dtype=wp.vec3f) + + # Torch views (same attribute names as before for subclass compatibility) + self.ray_starts = wp.to_torch(self._ray_starts_local) + self.ray_directions = wp.to_torch(self._ray_directions_local) + + # Drift buffers (warp-owned, torch views for reset indexing) + self._drift = wp.zeros(self._view.count, dtype=wp.vec3f, device=self._device) + self._ray_cast_drift = wp.zeros(self._view.count, dtype=wp.vec3f, device=self._device) + self.drift = wp.to_torch(self._drift) + self.ray_cast_drift = wp.to_torch(self._ray_cast_drift) + + # World-frame ray buffers + self._ray_starts_w = wp.zeros((self._view.count, self.num_rays), dtype=wp.vec3f, device=self._device) + self._ray_directions_w = wp.zeros((self._view.count, self.num_rays), dtype=wp.vec3f, device=self._device) - self._ray_starts_w[env_ids] = ray_starts_w - self._ray_directions_w[env_ids] = ray_directions_w + # Torch views for subclass compatibility + self._ray_starts_w_torch = wp.to_torch(self._ray_starts_w) + self._ray_directions_w_torch = wp.to_torch(self._ray_directions_w) + + # Data buffers + self._data.create_buffers(self._view.count, self.num_rays, self._device) + + def _get_view_transforms_wp(self) -> wp.array: + """Get world transforms from the physics view as a warp array. + + Returns: + Warp array of ``wp.transformf`` with shape (num_envs,). + """ + if isinstance(self._view, XformPrimView): + pos_w, quat_w = self._view.get_world_poses() + poses = torch.cat([pos_w, quat_w], dim=-1).contiguous() + return wp.from_torch(poses).view(wp.transformf) + elif isinstance(self._view, physx.ArticulationView): + return self._view.get_root_transforms().view(wp.transformf) + elif isinstance(self._view, physx.RigidBodyView): + return self._view.get_transforms().view(wp.transformf) + else: + raise NotImplementedError(f"Cannot get transforms for view type '{type(self._view)}'.") + + def _update_ray_infos(self, env_mask: wp.array): + """Updates sensor poses and ray world-frame buffers via a single warp kernel.""" + transforms = self._get_view_transforms_wp() + + wp.launch( + update_ray_caster_kernel, + dim=(self._num_envs, self.num_rays), + inputs=[ + transforms, + env_mask, + self._offset_pos_wp, + self._offset_quat_wp, + self._drift, + self._ray_cast_drift, + self._ray_starts_local, + self._ray_directions_local, + self._alignment_mode, + ], + outputs=[ + self._data._pos_w, + self._data._quat_w, + self._ray_starts_w, + self._ray_directions_w, + ], + device=self._device, + ) def _update_buffers_impl(self, env_mask: wp.array): """Fills the buffers of the sensor data.""" - env_ids = wp.to_torch(env_mask).nonzero(as_tuple=False).squeeze(-1) - if len(env_ids) == 0: - return - self._update_ray_infos(env_ids) + self._update_ray_infos(env_mask) + + # Fill ray hits with inf before raycasting + wp.launch( + fill_vec3_inf_kernel, + dim=(self._num_envs, self.num_rays), + inputs=[env_mask, self._data._ray_hits_w, float("inf")], + device=self._device, + ) - # ray cast and store the hits - # TODO: Make this work for multiple meshes? - self._data.ray_hits_w[env_ids] = raycast_mesh( - self._ray_starts_w[env_ids], - self._ray_directions_w[env_ids], - max_dist=self.cfg.max_distance, - mesh=RayCaster.meshes[self.cfg.mesh_prim_paths[0]], - )[0] + # Ray-cast against the mesh + wp.launch( + raycast_mesh_masked_kernel, + dim=(self._num_envs, self.num_rays), + inputs=[ + RayCaster.meshes[self.cfg.mesh_prim_paths[0]].id, + env_mask, + self._ray_starts_w, + self._ray_directions_w, + self._data._ray_hits_w, + float(self.cfg.max_distance), + ], + device=self._device, + ) - # apply vertical drift to ray starting position in ray caster frame - self._data.ray_hits_w[env_ids, :, 2] += self.ray_cast_drift[env_ids, 2].unsqueeze(-1) + # Apply vertical drift to ray hits + wp.launch( + apply_z_drift_kernel, + dim=(self._num_envs, self.num_rays), + inputs=[env_mask, self._ray_cast_drift, self._data._ray_hits_w], + device=self._device, + ) def _set_debug_vis_impl(self, debug_vis: bool): - # set visibility of markers - # note: parent only deals with callbacks. not their visibility if debug_vis: if not hasattr(self, "ray_visualizer"): self.ray_visualizer = VisualizationMarkers(self.cfg.visualizer_cfg) - # set their visibility to true self.ray_visualizer.set_visibility(True) else: if hasattr(self, "ray_visualizer"): self.ray_visualizer.set_visibility(False) def _debug_vis_callback(self, event): - if self._data.ray_hits_w is None: + if self._data._ray_hits_w is None: return - # remove possible inf values - viz_points = self._data.ray_hits_w.reshape(-1, 3) + ray_hits_torch = wp.to_torch(self._data._ray_hits_w) + viz_points = ray_hits_torch.reshape(-1, 3) viz_points = viz_points[~torch.any(torch.isinf(viz_points), dim=1)] - self.ray_visualizer.visualize(viz_points) """ @@ -331,7 +363,7 @@ def _debug_vis_callback(self, event): def _obtain_trackable_prim_view( self, target_prim_path: str ) -> tuple[XformPrimView | any, tuple[torch.Tensor, torch.Tensor]]: - """Obtain a prim view that can be used to track the pose of the parget prim. + """Obtain a prim view that can be used to track the pose of the target prim. The target prim path is a regex expression that matches one or more mesh prims. While we can track its pose directly using XFormPrim, this is not efficient and can be slow. Instead, we create a prim view @@ -358,13 +390,11 @@ def _obtain_trackable_prim_view( prim_view = None while prim_view is None: - # TODO: Need to handle the case where API is present but it is disabled if current_prim.HasAPI(UsdPhysics.ArticulationRootAPI): prim_view = self._physics_sim_view.create_articulation_view(current_path_expr.replace(".*", "*")) logger.info(f"Created articulation view for mesh prim at path: {target_prim_path}") break - # TODO: Need to handle the case where API is present but it is disabled if current_prim.HasAPI(UsdPhysics.RigidBodyAPI): prim_view = self._physics_sim_view.create_rigid_body_view(current_path_expr.replace(".*", "*")) logger.info(f"Created rigid body view for mesh prim at path: {target_prim_path}") @@ -383,7 +413,6 @@ def _obtain_trackable_prim_view( ) break - # switch the current prim to the parent prim current_prim = new_root_prim # obtain the relative transforms between target prim and the view prims @@ -413,9 +442,7 @@ def _obtain_trackable_prim_view( def _invalidate_initialize_callback(self, event): """Invalidates the scene elements.""" - # call parent super()._invalidate_initialize_callback(event) - # set all existing views to None to invalidate them self._view = None def __del__(self): diff --git a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py index 6103a2167d6..e1c903e19b2 100644 --- a/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py +++ b/source/isaaclab/isaaclab/sensors/ray_caster/ray_caster_data.py @@ -3,28 +3,69 @@ # # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass +from __future__ import annotations -import torch +import warp as wp -@dataclass class RayCasterData: - """Data container for the ray-cast sensor.""" + """Data container for the ray-cast sensor. - pos_w: torch.Tensor = None - """Position of the sensor origin in world frame. - - Shape is (N, 3), where N is the number of sensors. + All public properties return :class:`wp.array` backed by device memory. Use + :func:`wp.to_torch` at the call-site when a PyTorch tensor is needed. """ - quat_w: torch.Tensor = None - """Orientation of the sensor origin in quaternion (x, y, z, w) in world frame. - Shape is (N, 4), where N is the number of sensors. - """ - ray_hits_w: torch.Tensor = None - """The ray hit positions in the world frame. + def __init__(self): + self._pos_w: wp.array | None = None + self._quat_w: wp.array | None = None + self._ray_hits_w: wp.array | None = None - Shape is (N, B, 3), where N is the number of sensors, B is the number of rays - in the scan pattern per sensor. - """ + self._pos_w_torch = None + self._quat_w_torch = None + self._ray_hits_w_torch = None + + @property + def pos_w(self) -> wp.array | None: + """Position of the sensor origin in world frame [m]. + + Shape is (N,), dtype ``wp.vec3f``. In torch this resolves to (N, 3), + where N is the number of sensors. + """ + return self._pos_w + + @property + def quat_w(self) -> wp.array | None: + """Orientation of the sensor origin in quaternion (x, y, z, w) in world frame. + + Shape is (N,), dtype ``wp.quatf``. In torch this resolves to (N, 4), + where N is the number of sensors. + """ + return self._quat_w + + @property + def ray_hits_w(self) -> wp.array | None: + """The ray hit positions in the world frame [m]. + + Shape is (N, B), dtype ``wp.vec3f``. In torch this resolves to (N, B, 3), + where N is the number of sensors, B is the number of rays per sensor. + Contains ``inf`` for missed hits. + """ + return self._ray_hits_w + + def create_buffers(self, num_envs: int, num_rays: int, device: str) -> None: + """Create internal warp buffers and corresponding zero-copy torch views. + + Args: + num_envs: Number of environments / sensors. + num_rays: Number of rays per sensor. + device: Device for tensor storage. + """ + self._device = device + + self._pos_w = wp.zeros(num_envs, dtype=wp.vec3f, device=device) + self._quat_w = wp.zeros(num_envs, dtype=wp.quatf, device=device) + self._ray_hits_w = wp.zeros((num_envs, num_rays), dtype=wp.vec3f, device=device) + + self._pos_w_torch = wp.to_torch(self._pos_w) + self._quat_w_torch = wp.to_torch(self._quat_w) + self._ray_hits_w_torch = wp.to_torch(self._ray_hits_w) diff --git a/source/isaaclab_tasks/isaaclab_tasks/direct/anymal_c/anymal_c_env.py b/source/isaaclab_tasks/isaaclab_tasks/direct/anymal_c/anymal_c_env.py index ea95dfe5b98..e58e377a0d6 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/direct/anymal_c/anymal_c_env.py +++ b/source/isaaclab_tasks/isaaclab_tasks/direct/anymal_c/anymal_c_env.py @@ -88,7 +88,9 @@ def _get_observations(self) -> dict: height_data = None if isinstance(self.cfg, AnymalCRoughEnvCfg): height_data = ( - self._height_scanner.data.pos_w[:, 2].unsqueeze(1) - self._height_scanner.data.ray_hits_w[..., 2] - 0.5 + wp.to_torch(self._height_scanner.data.pos_w)[:, 2].unsqueeze(1) + - wp.to_torch(self._height_scanner.data.ray_hits_w)[..., 2] + - 0.5 ).clip(-1.0, 1.0) obs = torch.cat( [