diff --git a/scripts/reinforcement_learning/rsl_rl/play.py b/scripts/reinforcement_learning/rsl_rl/play.py index 1828983..c141723 100644 --- a/scripts/reinforcement_learning/rsl_rl/play.py +++ b/scripts/reinforcement_learning/rsl_rl/play.py @@ -14,8 +14,6 @@ """Launch Isaac Sim Simulator first.""" import argparse -from collections import deque -import math import os import sys @@ -87,16 +85,6 @@ import time import torch -import isaaclab.utils.math as math_utils - -try: - import isaacsim.util.debug_draw._debug_draw as omni_debug_draw -except Exception: - try: - import omni.isaac.debug_draw._debug_draw as omni_debug_draw - except Exception: - omni_debug_draw = None - from rsl_rl.runners import OnPolicyRunner from isaaclab.devices import Se2Keyboard, Se2KeyboardCfg @@ -154,7 +142,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen env_cfg.events.push_robot = None env_cfg.curriculum.command_levels = None - keyboard_command_state = None if args_cli.keyboard: env_cfg.scene.num_envs = 1 env_cfg.terminations.time_out = None @@ -165,14 +152,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen omega_z_sensitivity=env_cfg.commands.base_velocity.ranges.ang_vel_z[1], ) controller = Se2Keyboard(config) - - def _keyboard_obs_term(env): - nonlocal keyboard_command_state - keyboard_command_state = torch.tensor(controller.advance(), dtype=torch.float32).unsqueeze(0).to(env.device) - return keyboard_command_state - env_cfg.observations.policy.velocity_commands = ObsTerm( - func=_keyboard_obs_term, + func=lambda env: torch.tensor(controller.advance(), dtype=torch.float32).unsqueeze(0).to(env.device), ) # specify directory for logging experiments @@ -250,6 +231,36 @@ def _keyboard_obs_term(env): filename="policy.pt", ) + # ====== Visualization init ====== + VIS_ENABLED = True + draw_interface = None + foot_ids = None + act_hist = None + + try: + import isaacsim.util.debug_draw._debug_draw as omni_debug_draw + draw_interface = omni_debug_draw.acquire_debug_draw_interface() + if draw_interface is None: + VIS_ENABLED = False + except: + VIS_ENABLED = False + + if VIS_ENABLED: + robot = env.unwrapped.scene["robot"] + # Auto-detect foot/wheel bodies + foot_candidates = [name for name in robot.body_names + if "foot" in name.lower() or "wheel" in name.lower()] + foot_candidates.sort() + + if len(foot_candidates) >= 4: + foot_names = foot_candidates[:4] + foot_ids = [robot.find_bodies(name)[0] for name in foot_names] + from collections import deque + act_hist = [deque(maxlen=100) for _ in range(4)] + print(f"[INFO] Using foot bodies: {foot_names}") + else: + VIS_ENABLED = False + dt = env.unwrapped.step_dt # reset environment obs, _ = env.reset() @@ -265,102 +276,26 @@ def _keyboard_obs_term(env): # env stepping obs, _, _, _ = env.step(actions) - - if ( - VIS_ENABLEd - and draw_interface is not None - and foot_ids is not None - and phase_offsets is not None - and cycle_time is not None - and gait_span is not None - and gait_psi is not None - and gait_delta is not None - and x_offset is not None - and stance_span is not None - and cmd_threshold is not None - and stand_ref_z_offset is not None - and cmd_hist is not None - and act_hist is not None - ): - local_foot_ids = foot_ids + # ----- Draw foot trajectories ----- + if VIS_ENABLED and draw_interface and foot_ids and act_hist: robot = env.unwrapped.scene["robot"] - root_pos = robot.data.root_pos_w[0] - root_quat = robot.data.root_quat_w[0].unsqueeze(0) - - # Initialize base-fixed stand reference once from current posture. - if stand_ref_body is None: - rel_init = robot.data.body_pos_w[0, local_foot_ids, :] - root_pos.unsqueeze(0) - stand_ref_body = math_utils.quat_apply_inverse(root_quat.expand(len(local_foot_ids), -1), rel_init) - stand_ref_body[:, 2] += stand_ref_z_offset - - elapsed_t = float(env.unwrapped.common_step_counter) * dt - phase_s = torch.remainder((2.0 * elapsed_t / max(cycle_time, 1e-6)) + phase_offsets, 2.0) - cmd_local = _mujoco_phase_traj_body( - phase_s=phase_s, - gait_span=gait_span, - gait_psi=gait_psi, - gait_delta=gait_delta, - x_offset=x_offset, - stance_span=stance_span, - ) - ref_body = stand_ref_body + cmd_local - ref_world = root_pos.unsqueeze(0) + math_utils.quat_apply(root_quat.expand(len(local_foot_ids), -1), ref_body) - - actual_world = robot.data.body_pos_w[0, local_foot_ids, :] + actual_world = robot.data.body_pos_w[0, foot_ids, :] + for i in range(4): - cmd_hist[i].append(ref_world[i].detach().cpu().tolist()) act_hist[i].append(actual_world[i].detach().cpu().tolist()) - - if args_cli.keyboard and keyboard_command_state is not None: - cmd_vec = keyboard_command_state[0, :3] - else: - cmd_vec = env.unwrapped.command_manager.get_command("base_velocity")[0, :3] - - cmd_norm = torch.linalg.norm(cmd_vec).item() - gate_on = cmd_norm > cmd_threshold - - if args_cli.keyboard: - ref_gate_on = cmd_norm > 0.1 - else: - ref_gate_on = gate_on - + draw_interface.clear_lines() - starts = [] - ends = [] - colors = [] - widths = [] - - ref_alpha = 0.95 if gate_on else 0.35 - act_alpha = 0.35 if gate_on else 0.20 - - if not phase_vis_z_printed: - print( - "[INFO] phase_foot_trajectory_exp z check: " - f"ref_z_mean={ref_world[:, 2].mean().item():.4f}, " - f"act_z_mean={actual_world[:, 2].mean().item():.4f}, " - f"stand_ref_z_offset={stand_ref_z_offset:.4f}" - ) - phase_vis_z_printed = True - + starts, ends = [], [] + for i in range(4): - act_pts = list(act_hist[i]) - for j in range(1, len(act_pts)): - starts.append(act_pts[j - 1]) - ends.append(act_pts[j]) - colors.append([0.0, 0.0, 0.0, act_alpha]) - widths.append(1.5) - - cmd_pts = list(cmd_hist[i]) - if VIS_REF_ENABLE and ref_gate_on: - for j in range(1, len(cmd_pts)): - starts.append(cmd_pts[j - 1]) - ends.append(cmd_pts[j]) - color = color_palette[i].copy() - color[3] = ref_alpha - colors.append(color) - widths.append(2.8) - + pts = list(act_hist[i]) + for j in range(1, len(pts)): + starts.append(pts[j-1][0]) + ends.append(pts[j][0]) + if starts: + colors = [[0.0, 0.0, 0.0, 0.6]] * len(starts) + widths = [2.0] * len(starts) draw_interface.draw_lines(starts, ends, colors, widths) if args_cli.video: timestep += 1