Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 45 additions & 110 deletions scripts/reinforcement_learning/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
"""Launch Isaac Sim Simulator first."""

import argparse
from collections import deque
import math
import os
import sys

Expand Down Expand Up @@ -87,16 +85,6 @@
import time
import torch

import isaaclab.utils.math as math_utils

try:
import isaacsim.util.debug_draw._debug_draw as omni_debug_draw
except Exception:
try:
import omni.isaac.debug_draw._debug_draw as omni_debug_draw
except Exception:
omni_debug_draw = None

from rsl_rl.runners import OnPolicyRunner

from isaaclab.devices import Se2Keyboard, Se2KeyboardCfg
Expand Down Expand Up @@ -154,7 +142,6 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env_cfg.events.push_robot = None
env_cfg.curriculum.command_levels = None

keyboard_command_state = None
if args_cli.keyboard:
env_cfg.scene.num_envs = 1
env_cfg.terminations.time_out = None
Expand All @@ -165,14 +152,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
omega_z_sensitivity=env_cfg.commands.base_velocity.ranges.ang_vel_z[1],
)
controller = Se2Keyboard(config)

def _keyboard_obs_term(env):
nonlocal keyboard_command_state
keyboard_command_state = torch.tensor(controller.advance(), dtype=torch.float32).unsqueeze(0).to(env.device)
return keyboard_command_state

env_cfg.observations.policy.velocity_commands = ObsTerm(
func=_keyboard_obs_term,
func=lambda env: torch.tensor(controller.advance(), dtype=torch.float32).unsqueeze(0).to(env.device),
)

# specify directory for logging experiments
Expand Down Expand Up @@ -250,6 +231,36 @@ def _keyboard_obs_term(env):
filename="policy.pt",
)

# ====== Visualization init ======
VIS_ENABLED = True
draw_interface = None
foot_ids = None
act_hist = None

try:
import isaacsim.util.debug_draw._debug_draw as omni_debug_draw
draw_interface = omni_debug_draw.acquire_debug_draw_interface()
if draw_interface is None:
VIS_ENABLED = False
except:
VIS_ENABLED = False

if VIS_ENABLED:
robot = env.unwrapped.scene["robot"]
# Auto-detect foot/wheel bodies
foot_candidates = [name for name in robot.body_names
if "foot" in name.lower() or "wheel" in name.lower()]
foot_candidates.sort()

if len(foot_candidates) >= 4:
foot_names = foot_candidates[:4]
foot_ids = [robot.find_bodies(name)[0] for name in foot_names]
from collections import deque
act_hist = [deque(maxlen=100) for _ in range(4)]
print(f"[INFO] Using foot bodies: {foot_names}")
else:
VIS_ENABLED = False

dt = env.unwrapped.step_dt
# reset environment
obs, _ = env.reset()
Expand All @@ -265,102 +276,26 @@ def _keyboard_obs_term(env):

# env stepping
obs, _, _, _ = env.step(actions)

if (
VIS_ENABLEd
and draw_interface is not None
and foot_ids is not None
and phase_offsets is not None
and cycle_time is not None
and gait_span is not None
and gait_psi is not None
and gait_delta is not None
and x_offset is not None
and stance_span is not None
and cmd_threshold is not None
and stand_ref_z_offset is not None
and cmd_hist is not None
and act_hist is not None
):
local_foot_ids = foot_ids
# ----- Draw foot trajectories -----
if VIS_ENABLED and draw_interface and foot_ids and act_hist:
robot = env.unwrapped.scene["robot"]
root_pos = robot.data.root_pos_w[0]
root_quat = robot.data.root_quat_w[0].unsqueeze(0)

# Initialize base-fixed stand reference once from current posture.
if stand_ref_body is None:
rel_init = robot.data.body_pos_w[0, local_foot_ids, :] - root_pos.unsqueeze(0)
stand_ref_body = math_utils.quat_apply_inverse(root_quat.expand(len(local_foot_ids), -1), rel_init)
stand_ref_body[:, 2] += stand_ref_z_offset

elapsed_t = float(env.unwrapped.common_step_counter) * dt
phase_s = torch.remainder((2.0 * elapsed_t / max(cycle_time, 1e-6)) + phase_offsets, 2.0)
cmd_local = _mujoco_phase_traj_body(
phase_s=phase_s,
gait_span=gait_span,
gait_psi=gait_psi,
gait_delta=gait_delta,
x_offset=x_offset,
stance_span=stance_span,
)
ref_body = stand_ref_body + cmd_local
ref_world = root_pos.unsqueeze(0) + math_utils.quat_apply(root_quat.expand(len(local_foot_ids), -1), ref_body)

actual_world = robot.data.body_pos_w[0, local_foot_ids, :]
actual_world = robot.data.body_pos_w[0, foot_ids, :]

for i in range(4):
cmd_hist[i].append(ref_world[i].detach().cpu().tolist())
act_hist[i].append(actual_world[i].detach().cpu().tolist())

if args_cli.keyboard and keyboard_command_state is not None:
cmd_vec = keyboard_command_state[0, :3]
else:
cmd_vec = env.unwrapped.command_manager.get_command("base_velocity")[0, :3]

cmd_norm = torch.linalg.norm(cmd_vec).item()
gate_on = cmd_norm > cmd_threshold

if args_cli.keyboard:
ref_gate_on = cmd_norm > 0.1
else:
ref_gate_on = gate_on


draw_interface.clear_lines()
starts = []
ends = []
colors = []
widths = []

ref_alpha = 0.95 if gate_on else 0.35
act_alpha = 0.35 if gate_on else 0.20

if not phase_vis_z_printed:
print(
"[INFO] phase_foot_trajectory_exp z check: "
f"ref_z_mean={ref_world[:, 2].mean().item():.4f}, "
f"act_z_mean={actual_world[:, 2].mean().item():.4f}, "
f"stand_ref_z_offset={stand_ref_z_offset:.4f}"
)
phase_vis_z_printed = True

starts, ends = [], []

for i in range(4):
act_pts = list(act_hist[i])
for j in range(1, len(act_pts)):
starts.append(act_pts[j - 1])
ends.append(act_pts[j])
colors.append([0.0, 0.0, 0.0, act_alpha])
widths.append(1.5)

cmd_pts = list(cmd_hist[i])
if VIS_REF_ENABLE and ref_gate_on:
for j in range(1, len(cmd_pts)):
starts.append(cmd_pts[j - 1])
ends.append(cmd_pts[j])
color = color_palette[i].copy()
color[3] = ref_alpha
colors.append(color)
widths.append(2.8)

pts = list(act_hist[i])
for j in range(1, len(pts)):
starts.append(pts[j-1][0])
ends.append(pts[j][0])

if starts:
colors = [[0.0, 0.0, 0.0, 0.6]] * len(starts)
widths = [2.0] * len(starts)
draw_interface.draw_lines(starts, ends, colors, widths)
if args_cli.video:
timestep += 1
Expand Down