diff --git a/.gitignore b/.gitignore index 9d6232dd..96912866 100644 --- a/.gitignore +++ b/.gitignore @@ -172,3 +172,11 @@ MUJOCO_LOG.TXT _METADATA checkpoint wandb/ + +# VS Code settings +*.code-workspace + +# checkpoints +**/checkpoints/ +**/classifier_checkpoints/ +**/classifier_demos/ diff --git a/README.md b/README.md index f6c9bccd..63904914 100644 --- a/README.md +++ b/README.md @@ -1,124 +1,70 @@ -# SERL: A Software Suite for Sample-Efficient Robotic Reinforcement Learning +# FractalSerl β€” Fractal Symmetries for Sample-Efficient Robotic Learning -![](https://github.com/rail-berkeley/serl/workflows/pre-commit/badge.svg) +[![Discord](https://img.shields.io/discord/1302866684612444190?label=Join%20Us%20on%20Discord&logo=discord&color=7289da)](https://discord.com/invite/bAxjvvJzNM) +[![Notion](https://img.shields.io/badge/Notion-Workspace-000000?logo=notion&logoColor=white)](https://lipscomb-robotics.notion.site/?source=copy_link) +[![Paper](https://img.shields.io/badge/Paper-Frontiers-blue?logo=zenodo&logoColor=white)](https://www.frontiersin.org/journals/robotics-and-ai/articles/10.3389/frobt.2026.1791812/abstract) +[![Instagram](https://img.shields.io/badge/Instagram-Follow-E4405F?logo=instagram&logoColor=white)](https://www.instagram.com/lippyrobotics/) +[![YouTube](https://img.shields.io/badge/YouTube-Channel-FF0000?logo=youtube&logoColor=white)](https://www.youtube.com/@lippyrobotics) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) -[![Static Badge](https://img.shields.io/badge/Project-Page-a)](https://serl-robot.github.io/) -[![Discord](https://img.shields.io/discord/1302866684612444190?label=Join%20Us%20on%20Discord&logo=discord&color=7289da)](https://discord.gg/G4xPJEhwuC) -![](./docs/images/tasks-banner.gif) +Short description +----------------- -**Webpage: [https://serl-robot.github.io/](https://serl-robot.github.io/)** +FractalSERL implements Branched Euclidean Group Fractal Symmetries β€” a trajectory-level augmentation framework that accelerates policy learning by iteratively applying affine and Euclidean-group transformations to episodic trajectories. Treating an episodic MDP as a tree of state–action pairs, self-similar branching produces fractal symmetry expansions that populate replay buffers with diverse, consistent experiences. We demonstrate improvements on simulated and real Franka manipulation tasks, achieving robust policies in as little as 14 minutes (avg. ~20 minutes) of wall-clock training. +Contributions in this repo include: +- **SymmGrid Framework**: A preliminary research implementation of fractal symmetry for deep reinforcement learning, demonstrating how branched symmetries accelerate DRL policy learning in physical robots. +- **Data Augmentation via Super-Scaling**: Efficient robot data generation through trajectory-level augmentation that significantly speeds up policy learning while improving performance and consistency on physical hardware. +- **Fractal Symmetry Replay Buffer**: An Optimized Datastore and Replay Buffer implementation designed to support parallelized computations and image handling without excessive memory overhead, enabling faster training iterations. +- **nAUC Performance Metric**: Using normalized Area under the Curve (nAUC) as a trajectory-wide performance metric to capture combined contributions of sample efficiency and policy performance throughout training. -**Also check out our new project HIL-SERL: [https://hil-serl.github.io/](https://hil-serl.github.io/)** +
+ +
-SERL provides a set of libraries, env wrappers, and examples to train RL policies for robotic manipulation tasks. The following sections describe how to use SERL. We will illustrate the usage with examples. + -**Table of Contents** -- [SERL: A Software Suite for Sample-Efficient Robotic Reinforcement Learning](#serl-a-software-suite-for-sample-efficient-robotic-reinforcement-learning) - - [Installation](#installation) - - [Overview and Code Structure](#overview-and-code-structure) - - [Quick Start with SERL in Sim](#quick-start-with-serl-in-sim) - - [Run with Franka Arm on Real Robot](#run-with-franka-arm-on-real-robot) - - [Contribution](#contribution) - - [Citation](#citation) +Navigation +---------- -## Major updates -#### June 24, 2024 -For people who use SERL for tasks involving controlling the gripper (e.g.,pick up objects), we strong recommend adding a small penalty to the gripper action change, as it will greatly improves the training speed. -For detail, please refer to: [PR #65](https://github.com/rail-berkeley/serl/pull/65). +The `docs/` folder contains additional Markdown files with step-by-step guides. Quick links are provided below: +- [Overview of code structure](docs/overview.md) +- [Installation guide](docs/installation.md) +- [Run in simulation](docs/run_sim.md) +- [Run on the real robot](docs/run_realrobot.md) -Further, we also recommend providing interventions online during training in addition to loading the offline demos. If you have a Franka robot and SpaceMouse, this can be as easy as just touching the SpaceMouse during training. -#### April 25, 2024 -We fixed a major issue in the intervention action frame. See release [v0.1.1](https://github.com/rail-berkeley/serl/releases/tag/v0.1.1) Please update your code with the main branch. +Quick start (very short) +------------------------ -## Installation -1. **Setup Conda Environment:** - create an environment with - ```bash - conda create -n serl python=3.10 - ``` +1. Install dependencies: see `docs/installation.md`. +2. Run a demo in sim: see `docs/run_sim.md` for instructions to launch `franka_sim` +3. For real hardware, follow the instructions in `docs/run_realrobot.md` and configure the files related to `serl_robot_infra/`. -2. **Install Jax as follows:** - - For CPU (not recommended): - ```bash - pip install --upgrade "jax[cpu]" - ``` +Citation +-------- - - For GPU: - ```bash - pip install --upgrade "jax[cuda12_pip]==0.4.35" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - ``` - - - For TPU - ```bash - pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html - ``` - - See the [Jax Github page](https://github.com/google/jax) for more details on installing Jax. - -3. **Install the serl_launcher** - ```bash - cd serl_launcher - pip install -e . - pip install -r requirements.txt - ``` - -## Overview and Code Structure - -SERL provides a set of common libraries for users to train RL policies for robotic manipulation tasks. The main structure of running the RL experiments involves having an actor node and a learner node, both of which interact with the robot gym environment. Both nodes run asynchronously, with data being sent from the actor to the learner node via the network using [agentlace](https://github.com/youliangtan/agentlace). The learner will periodically synchronize the policy with the actor. This design provides flexibility for parallel training and inference. - -

- -

- -**Table for code structure** - -| Code Directory | Description | -| --- | --- | -| [serl_launcher](https://github.com/rail-berkeley/serl/blob/main/serl_launcher) | Main code for SERL | -| [serl_launcher.agents](https://github.com/rail-berkeley/serl/blob/main/serl_launcher/serl_launcher/agents/) | Agent Policies (e.g. DRQ, SAC, BC) | -| [serl_launcher.wrappers](https://github.com/rail-berkeley/serl/blob/main/serl_launcher/serl_launcher/wrappers) | Gym env wrappers | -| [serl_launcher.data](https://github.com/rail-berkeley/serl/blob/main/serl_launcher/serl_launcher/data) | Replay buffer and data store | -| [serl_launcher.vision](https://github.com/rail-berkeley/serl/blob/main/serl_launcher/serl_launcher/vision) | Vision related models and utils | -| [franka_sim](./franka_sim) | Franka mujoco simulation gym environment | -| [serl_robot_infra](./serl_robot_infra/) | Robot infra for running with real robots | -| [serl_robot_infra.robot_servers](https://github.com/rail-berkeley/serl/blob/main/serl_robot_infra/robot_servers/) | Flask server for sending commands to robot via ROS | -| [serl_robot_infra.franka_env](https://github.com/rail-berkeley/serl/blob/main/serl_robot_infra/franka_env/) | Gym env for real franka robot | - -## Quick Start with SERL in Sim - -We provide a simulated environment for trying out SERL with a franka robot. - -Check out the [Quick Start with SERL in Sim](/docs/sim_quick_start.md) - - [Training from state observation example](/docs/sim_quick_start.md#1-training-from-state-observation-example) - - [Training from image observation example](/docs/sim_quick_start.md#2-training-from-image-observation-example) - - [Training from image observation with 20 demo trajectories example](/docs/sim_quick_start.md#3-training-from-image-observation-with-20-demo-trajectories-example) - -## Run with Franka Arm on Real Robot - -We provide a step-by-step guide to run RL policies with SERL on the real Franka robot. - -Check out the [Run with Franka Arm on Real Robot](/docs/real_franka.md) - - [Peg Insertion πŸ“](/docs/real_franka.md#1-peg-insertion-πŸ“) - - [PCB Component Insertion πŸ–₯️](/docs/real_franka.md#2-pcb-component-insertion-πŸ–₯️) - - [Cable Routing πŸ”Œ](/docs/real_franka.md#3-cable-routing-πŸ”Œ) - - [Object Relocation πŸ—‘οΈ](/docs/real_franka.md#4-object-relocation-πŸ—‘οΈ) - -## Contribution - -We welcome contributions to this repository! Fork and submit a PR if you have any improvements to the codebase. Before submitting a PR, please run `pre-commit run --all-files` to ensure that the codebase is formatted correctly. - -## Citation - -If you use this code for your research, please cite our paper: +If you use FractalSERL in your research, please cite our paper: ```bibtex +@misc{vanderstelt2026SymmGrid, + title={Towards Accelerating Deep Reinforcement Learning via Branched Symmetries}, + author={Ryan Vanderstelt, Cleiver Ruiz Martinez, Caeden Rosen, Blake Hull, and Juan Rojas}, + year={2026}, + eprint={____}, + archivePrefix={arXiv}, + primaryClass={cs.RO} +} + @misc{luo2024serl, title={SERL: A Software Suite for Sample-Efficient Robotic Reinforcement Learning}, author={Jianlan Luo and Zheyuan Hu and Charles Xu and You Liang Tan and Jacob Berg and Archit Sharma and Stefan Schaal and Chelsea Finn and Abhishek Gupta and Sergey Levine}, @@ -127,4 +73,4 @@ If you use this code for your research, please cite our paper: archivePrefix={arXiv}, primaryClass={cs.RO} } -``` +``` \ No newline at end of file diff --git a/demos/demos/__init__.py b/demos/demos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/demos/demos/demoHandling.py b/demos/demos/demoHandling.py new file mode 100644 index 00000000..04a5e974 --- /dev/null +++ b/demos/demos/demoHandling.py @@ -0,0 +1,157 @@ +import os +from pathlib import Path +import numpy as np +from agentlace.data.data_store import QueuedDataStore + +class DemoHandling: + """ + Koads an .npz file containing demonstration data into a data object. + This class is designed to work with Gymnasium-style demonstration data + and is intended to be used with a QueuedDataStore or similar data store. + + The .npz file should contain the following arrays: + - 'obs' : shape (N, T+1, *obs_shape*), list of observations + - 'acs' : shape (N, T, *act_shape*), list of actions + - 'rewards' : shape (N, T), list of rewards + - 'terminateds' : shape (N, T), list of terminated flags + - 'truncateds' : shape (N, T), list of truncated flags + - 'info' : shape (N, T), list of info dicts + - 'dones' : shape (N, T), list of done flags (if available) + + Parameters + ---------- + demo_dir : str + Directory where demo .npz files live by default. + file_name : str + Name of the demo file to load. If not provided, a default will be used. + """ + def __init__( + self, + demo_dir: str = '/data/data/serl/demos', + file_name: str = 'data_franka_reach_random_20.npz' + ): + + self.debug = False # Set to True for debugging purposes + self.demo_dir = demo_dir + self.transition_ctr = 0 # Global counter for transitions across all episodes + + # Load the demo data from the .npz file + + # Check if the demo directory exists + if not os.path.exists(self.demo_dir): + raise FileNotFoundError(f"Demo directory '{self.demo_dir}' does not exist.") + + # Construct the full path to the demo file + self.demo_npz_path = os.path.join(self.demo_dir, file_name) + if not os.path.isfile(self.demo_npz_path): + raise FileNotFoundError(f"Demo file '{self.demo_npz_path}' does not exist.") + + # Load the .npz file + self.data = np.load(self.demo_npz_path, allow_pickle=True) + + def get_num_transitions(self): + """ + Returns the total number of transitions counted in the demo data. + """ + return int(self.data["transition_ctr"]) if "transition_ctr" in self.data else 0 + + def get_num_demos(self): + """ + Returns the total number of demonstrations in the demo data. + """ + return int(self.data["num_demos"]) if "num_demos" in self.data else 0 + + def insert_data_to_buffer(self,data_store: QueuedDataStore): + """ + Load a raw Gymnasium-style .npz of expert episodes into data_store. + The .npz file must contain arrays named 'obs', 'acs', 'rewards', + 'terminateds', 'truncateds', 'info', and optionally 'dones'. + Each episode is processed, and transitions are inserted into the data_store. + Inserted transitions in data store will remain in the data_store as pointers. + + ***Note*** + Need to insert obs and acs in the same way as async_sac_state via jax + + Parameters + ---------- + data_store : QueuedDataStore + + Returns + ------- + None + """ + + obs_buffer = self.data['obs'] # shape (N, T+1, ...) + act_buffer = self.data['acs'] # shape (N, T, ...) + rew_buffer = self.data['rewards'] # shape (N, T) + term_buffer = self.data['terminateds'] # shape (N, T) + trunc_buffer = self.data['truncateds'] # shape (N, T) + info_buffer = self.data['info'] # shape (N, T) + done_buffer = self.data['dones'] # shape (N, T) #.get('dones', term_buffer | trunc_buffer) + + num_demos = self.get_num_demos() + if num_demos == 0: + raise ValueError("No demonstrations found in the provided .npz file.") + + num_transitions = self.get_num_transitions() + if num_transitions == 0: + raise ValueError("No transitions found in the provided .npz file.") + + + # Extract the number of episodes and transitions + for ep in range(num_demos): + ep_obs = obs_buffer[ep] + ep_acts = act_buffer[ep] + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] + + T = len(ep_acts) + for t in range(T): + obs_t = np.asarray(ep_obs[t], dtype=np.float32) + next_obs_t = np.asarray(ep_obs[t+1], dtype=np.float32) + a_t = np.asarray(ep_acts[t], dtype=np.float32) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + #info_t = ep_info[t] + # masks will be created right before insert below + + if self.debug: + np.set_printoptions(precision=3, suppress=True) + + print(f"Demo {ep:2}, Step {t:3} \n " + f"Obs: [{obs_t[0]:.2f} {obs_t[1]:.2f} {obs_t[2]:.2f}] \n " + f"Action: [{a_t[0]:.2f} {a_t[1]:.2f} {a_t[2]:.2f}] \n " + f"Reward: {r_t:.2f} \n " + f"Done: {done_t}") + + # Insert using SERLs data_store/ReplayBuffer insert mechanism directly. + data_store.insert( + dict( + observations =obs_t, + actions =a_t, + next_observations=next_obs_t, + rewards =r_t, + masks =1.0 - done_t, + dones =done_t + ) + ) + + print(f"Loaded a total of {num_transitions} from {num_demos} episodes from '{self.demo_npz_path}' ") + + +# if __name__ == "__main__": +# # Instantiate a DemoHandling object +# handler = DemoHandling(demo_dir='/data/data/serl/demos', +# file_name='data_franka_reach_random_20.npz') + +# # Idenitfy the total number of transitions in the datastore +# print(f'We have {handler.data["transition_ctr"]} transitions in the datastore.') + +# # Simulate SERL's datastore creation w/ capacity 2000 +# ds = QueuedDataStore(2000) + +# # Insert the demo data into the datastore +# handler.insert_data_to_buffer(ds) diff --git a/demos/demos/franka_pick_n_place_drq_demo_script.py b/demos/demos/franka_pick_n_place_drq_demo_script.py new file mode 100644 index 00000000..3e1a8458 --- /dev/null +++ b/demos/demos/franka_pick_n_place_drq_demo_script.py @@ -0,0 +1,511 @@ +#!/usr/bin/env python3 +import os +import time +from datetime import datetime + +import numpy as np + +# Logging +from absl import app, flags, logging +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +# DRL +import gym +import mujoco +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper + +# Needed to create the franka environment +import franka_sim + +# Teleoperation imports +import sys, select, termios, tty + +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- +flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") + +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Telop Config Variables +#------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + 'g':(0,0,0,0.1), # open gripper + 'h':(0,0,0,-0.1), # close gripper + + } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def ensure_dir_exists(): + """ + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaReachSparseCube-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + + + Returns + ------- + out_path : str + The path to the output directory. + """ + # Customize the path + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. + + # Create output filename with configuration details + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") + + return dataset_dir + +def getKey(settings): + """ + Waits briefly for a keypress and returns the pressed key. + + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. + + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. + """ + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) + + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) + + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' + + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(env,speed=0.075): + """ + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) + + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) + + try: + # Capture the pressed key + key = getKey(settings) + + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed + + elif key == 'k': + # 'k' means stop β†’ zero vector + action = np.zeros(4, dtype=float) + + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt + + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + + return action + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +############################################################################## +def main(unused_argv): + logging.info(f'Creating gym environment...') + + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers + env = gym.make(FLAGS.env, render_mode=_render_mode) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env == "PandaReachSparseCube-v0" or FLAGS.env == "PandaPickCubeVision-v0": + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + logging.info(f'Done creating {FLAGS.env} environment.') + + # Set camera to front view if viewer is available + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None + if FLAGS.enable_envlogger: + + dataset_dir = ensure_dir_exists() + + # Will save as many episodes as possible into files of 200MB each by default. + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" + ) + logging.info('Recording %r demos...', FLAGS.num_demos) + + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + try: + for i in range(FLAGS.num_demos): + + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + + # Start a new episode + env.reset() + terminated = False + truncated = False + + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/demos/demos/franka_pick_place_demo_script.py b/demos/demos/franka_pick_place_demo_script.py new file mode 100755 index 00000000..e5abbc11 --- /dev/null +++ b/demos/demos/franka_pick_place_demo_script.py @@ -0,0 +1,407 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with stable-baselines3. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences +""" +import os +import numpy as np +import gymnasium as gym +from gymnasium.wrappers import TimeLimit +from time import sleep +from panda_mujoco_gym.envs import FrankaPickAndPlaceEnv + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + + +def activate_weld(env, constraint_name="grasp_weld"): + """ + Activate a weld constraint during pick portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to activate + :return: True if the weld was successfully activated, False if the constraint was not found + """ + + try: + # Activate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 1 + print("Activated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def deactivate_weld(env, constraint_name="grasp_weld"): + """ + Deactivate a weld constraint during place portion of a demo + :param env: The environment containing the model + :param constraint_name: The name of the weld constraint to deactivate + :return: True if the weld was successfully deactivated, False if the constraint was not + found + """ + + try: + # Deactivate the weld constraint + env.unwrapped.model.eq(constraint_name).active = 0 + print("Deactivated weld") + return True + + except KeyError: + print(f"Warning: Constraint '{constraint_name}' not found") + return False + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the pick-and-place task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ file for use with stable-baselines3. + """ + # Initialize Fetch pick-and-place environment + env = FrankaPickAndPlaceEnv(reward_type="sparse", render_mode="rgb_array") + env = TimeLimit(env, max_episode_steps=50) + + # Adjust physical settings + # env.model.opt.timestep = 0.001 # Smaller timestep for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + attempted_demos = 1 # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + num_demos = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() + print("Reset!") + + # Generate demonstration episodes + while len(actions) < attempted_demos: + obs,_ = env.reset() # Reset environment for new episode + print(f"We will run a total of: {attempted_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = pick_and_place_demo(env, obs) + + # Print success message + if res: + num_demos += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {num_demos}/{attempted_demos}") + + ## Write data to demos folder + # 1. Get the absolute path of this script + #script_path = os.path.abspath(__file__) + + # 2. Extract its directory + #script_dir = os.path.dirname(script_path) + script_dir = '/home/student/data/franka_baselines/demos/pick_n_place' # Assumes data folder in user directory. + + # 3. Create output filename with configuration details + fileName = "data_" + robot + fileName += "_" + initStateSpace + fileName += "_" + str(attempted_demos) + fileName += ".npz" + + # 3. Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict + np.savez_compressed(out_path, + acs = actions, + obs = observations, + rewards = rewards, + info = infos, + terminateds = terminateds, + truncateds = truncateds, + dones = dones) + + print(f"Data saved to {fileName}.") + +def pick_and_place_demo(env, lastObs): + """ + Executes a scripted pick-and-place sequence using a hierarchical approach. + + Implements 4-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + 2. Grasp: Move to object and close gripper + 3. Transport: Move grasped object to goal position + 4. Maintain: Hold position until episode ends + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Args: + env: Gymnasium environment instance + lastObs: Last observation containing goal and object state information + - desired_goal: Target position for object placement + - observations: + ee_position[0:3], + ee_velocity[3:6], + fingers_width[6], + object_position[7:10], + object_rotation[10:13], + object_velp[13:16], + object_velr[16:19], + """ + + ## Init goal, current_pos, and object position from last observation + goal = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + object_pos = np.zeros(3, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + fgr_pos = np.zeros(1, dtype=np.float32) + + # Initialize episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Proportional control gain for action scaling -- empirically tuned + Kp = 8.0 + + # pre_pick_offset + pre_pick_offset = np.array([0,0,0.03], dtype=float) # Offset to approach object safely (3cm) + + # Error thresholds + error_threshold = 0.011 # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + # Extract desired position from desired_goal dict + goal = lastObs["desired_goal"][0:3] + + # Current robot end-effector position from observation dict + current_pos = lastObs["observation"][0:3] + + # Current object position from observation dict: + object_pos = lastObs["observation"][7:10] + + # Relative position between end-effector and object + object_rel_pos = object_pos - current_pos + + ## Phase 1: Approach Object (Above) + # Create target position 3cm above the object. Use copy() method. + error = object_rel_pos.copy() + error+=pre_pick_offset # Move 3cm above object for safe approach. Fingers should still end up surrounding object. + + timeStep = 0 # Track total timesteps in episode + episodeObs.append(lastObs) + + # Phase 1: Move gripper to position above object + # Terminate when distance to above-object position < 5mm + print(f"----------------------------------------------- Phase 1: Approach Object -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and timeStep <= env._max_episode_steps: + env.render() # Visual feedback + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Proportional control with gain of 6 + # action = Kp * error + action[:3] = error * Kp + + # Open gripper for approach + action[ len(action)-1 ] = 0.05 + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeAcs.append(action) + episodeInfo.append(info) + episodeRews.append(reward) + episodeObs.append(new_obs) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = (object_pos+pre_pick_offset) - current_pos # Error with regard to offset position + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + # Phase 2: Descend Grasp Object + # Move gripper directly to object and close gripper + # Terminate when relative distance to object < 5mm + print(f"----------------------------------------------- Phase 2: Grip -----------------------------------------------") + error = object_pos - current_pos # remove offset + while (np.linalg.norm(error) >= error_threshold or fgr_pos>=0.39) and timeStep <= env._max_episode_steps: # Cube of width 4cm, each finger open to 2cm + env.render() + + # Initialize action vector [x, y, z, gripper] + action = np.array([0., 0., 0., 0.]) + + # Direct proportional control to object position + action[:3] = error * Kp + + # Close gripper to grasp object + action[len(action)-1] = -finger_delta_fast * 2 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + error = object_pos - current_pos #- np.array([0.,0.,0.01]) # Grab lower + + # Print debug information + print( + f"Time Step: {timeStep}, Error: {np.linalg.norm(error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=3)}, " + f"Error: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + #sleep(0.5) # Optional: Slow down for better visualization + + # Phase 3: Transport to Goal + # Move grasped object to desired goal position + # Terminate when distance between object and goal < 1cm + print(f"----------------------------------------------- Phase 3: Transport to Goal -----------------------------------------------") + + # Weld activation + if weld_flag: + activate_weld(env, constraint_name="grasp_weld") + + # Set error between goal and hand assuming the object is grasped + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + while np.linalg.norm(gh_error) >= 0.01 and timeStep <= env._max_episode_steps: + env.render() + + action = np.array([0., 0., 0., 0.]) + + # Proportional control toward goal position + action[:3] = gh_error[:3] * Kp + + # Maintain grip on object + #action[len(action)-1] = 0 + + # Execute action and collect data + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + timeStep += 1 + + # Store episode data + episodeObs.append(new_obs) + episodeRews.append(reward) + episodeAcs.append(action) + episodeInfo.append(info) + episodeTerminated.append(terminated) + episodeTruncated.append(truncated) + episodeDones.append(done) + + # Update state information + fgr_pos = new_obs["observation"][6] + current_pos = new_obs["observation"][0:3] + object_pos = new_obs['observation'][7:10] + gh_error = goal - current_pos # Error between goal and hand position + ho_error = object_pos - current_pos # Error between object and hand position + + # Print debug information + print( + f"Time Step: {timeStep}, Error Norm: {np.linalg.norm(gh_error):.4f}, " + f"Eff_pos: {np.array2string(current_pos, precision=3)}, " + f"goal_pos: {np.array2string(goal, precision=3)}, " + f"fgr_pos: {np.array2string(fgr_pos, precision=2)}, " + f"Error: {np.array2string(gh_error, precision=3)}, " + f"Action: {np.array2string(action, precision=3)}" + ) + + sleep(0.5) # Optional: Slow down for better visualization + + ## Check for success and store episode data + gh_norm = np.linalg.norm(gh_error) + ho_nomr = np.linalg.norm(ho_error) + if gh_norm < error_threshold and ho_nomr < error_threshold: + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + actions.append(episodeAcs) + observations.append(episodeObs) + infos.append(episodeInfo) + rewards.append(episodeRews) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episodeTerminated) + truncateds.append(episodeTruncated) + dones.append(episodeDones) + + # Deactivate weld constraint after successful pick + if weld_flag: + deactivate_weld(env, constraint_name="grasp_weld") + + # Close mujoco viewer + env.close() + + # Break out of the loop to start a new episode + return True + + # If we reach here, the episode was not successful + if weld_flag: + print("Failed to transport object to goal position. Deactivating weld.") + deactivate_weld(env, constraint_name="grasp_weld") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/demos/franka_reach_demo_script.py b/demos/demos/franka_reach_demo_script.py new file mode 100755 index 00000000..f97848a4 --- /dev/null +++ b/demos/demos/franka_reach_demo_script.py @@ -0,0 +1,453 @@ +""" +Scripted Controller for Franka FR3 Robot - Demonstration Data Generation + +This script implements a scripted controller for generating expert demonstration data +for Deep Reinforcement Learning (DRL) algorithms using the Franka FR3 robot in a +pick-and-place task. The generated data serves as bootstrapping demonstrations for +training RL agents with your desired algorithm. + +Note that different error thresholds lead to different rewards and in turn done values. +Adjust carefully. Depending on the the controller and clipping scaling (ACTION_MAX) settings, you may get very different behaviors. +The current program clips actions and leads to small increments that allows to more precise movements and close the error. + +The controller uses a 4-phase hierarchical approach: +1. Approach Object (move gripper above object) +2. Grasp Object (move to object and close gripper) +3. Transport to Goal (move grasped object to target) +4. Maintain Position (hold final position) + +Output: Compressed NPZ file containing action, observation, and info sequences + +TODO: Convert to a class-based structure for better modularity and reusability. +""" +import os +import numpy as np +import gym +from time import sleep, perf_counter +from datetime import datetime + +import franka_sim +import franka_sim.envs.panda_reach_gym_env as panda_reach_env + +# Global variables to store episode data across all iterations +observations = [] # List storing observation sequences for each episode +actions = [] # List storing action sequences for each episode +rewards = [] # List storing reward sequences for each episode +infos = [] # List storing info dictionaries for each episode +terminateds = [] # List storing terminated flags for each episode +truncateds = [] # List storing truncated flags for each episode +dones = [] # List storing done flags (terminated or truncated) for each episode +transition_ctr = 0 # Global counter for transitions across all episodes + +#------------------------------------------------------------------------------------------- +## Key Config Variables +#------------------------------------------------------------------------------------------- +# Proportional and derivative control gain for action scaling -- empirically tuned +Kp = 10.0 # Values between 20 and 24 seem to be somewhat stable for Kv = 24 +Kv = 10.0 + +ACTION_MAX = 10 # Maximum action value for clipping actions +ERROR_THRESHOLD = 0.008 # Note!! When this number is changed, the way rewards are computed in the PandaReachCubeEnv.step() L220 must also be changed such that done=True only at the end of a successfull run. + +# Number of demonstration episodes to generate +NUM_DEMOS = 20 + +# Robot configuration +robot = 'franka' # Robot type used in the environment, can be 'franka' or 'fetch' +task = 'reach' # Task type used in the environment, can be 'reach' or 'pick-and-place' + +# Debug mode for rendering and visualization +DEBUG = False + +if DEBUG: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' +else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + +# Indices for franka_sim reach environment observations +if robot == 'franka' and task == 'reach': + + opi = np.array([0, 3]) # Indices for object position in observation + gpi = np.array([3]) # Indices for gripper position in observation + rpi = np.array([4, 7]) # Indices for robot position in observation + rvi = np.array([7, 10]) # Indices for robot velocity in observation + +# Weld constraint flag +weld_flag = True # Flag to activate weld constraint during pick-and-place + +#------------------------------------------------------------------------------------------- +# Franka sim environments do not have weld constraints like the franka_mujoco environments. +# def activate_weld(env, constraint_name="grasp_weld"): +# """ +# Activate a weld constraint during pick portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to activate +# :return: True if the weld was successfully activated, False if the constraint was not found +# """ + +# try: +# # Activate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 1 +# print("Activated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +# def deactivate_weld(env, constraint_name="grasp_weld"): +# """ +# Deactivate a weld constraint during place portion of a demo +# :param env: The environment containing the model +# :param constraint_name: The name of the weld constraint to deactivate +# :return: True if the weld was successfully deactivated, False if the constraint was not +# found +# """ + +# try: +# # Deactivate the weld constraint +# env.unwrapped.model.eq(constraint_name).active = 0 +# print("Deactivated weld") +# return True + +# except KeyError: +# print(f"Warning: Constraint '{constraint_name}' not found") +# return False + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 3.0 # Camera distance + viewer.cam.azimuth = 135 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +def store_transition_data(episode_dict, new_obs, rewards, action, info, terminated, truncated, done): + """ + Store transition data in the episode dictionary and update global counter. + """ + global transition_ctr + transition_ctr += 1 + + episode_dict["observations"].append(new_obs) + episode_dict["rewards"].append(rewards) + episode_dict["actions"].append(action) + episode_dict["infos"].append(info) + episode_dict["terminateds"].append(terminated) + episode_dict["truncateds"].append(truncated) + episode_dict["dones"].append(done) + +def store_episode_data(episode_data): + """ + Store complete episode data in global lists only if we succeeded (avoid bad demos). + """ + actions.append(episode_data["actions"]) + observations.append(episode_data["observations"]) + infos.append(episode_data["infos"]) + rewards.append(episode_data["rewards"]) + + # Optionally, also store the done/terminated/truncated flags globally if needed: + terminateds.append(episode_data["terminateds"]) + truncateds.append(episode_data["truncateds"]) + dones.append(episode_data["dones"]) + +def update_state_info(episode_data, time_step, dt, error, reward): + """ + Update and return the current state information. Always get the latest entry with [-1] + + Args: + new_obs (dict): New observation dictionary containing the current state. + time_step (int): Current time step in the episode. + dt (float): Current time step in the episode. + error (np.ndarray): Current error vector between object and end-effector positions. + reward (float): Current reward value for the action taken. + + Returns: + object_pos (np.ndarray): Current position of the object in the environment. + gripper_pos (float): Current position of the gripper. + current_pos (np.ndarray): Current position of the end-effector. + current_vel (np.ndarray): Current velocity of the end-effector. + """ + object_pos = episode_data["observations"][-1][ opi[0]:opi[1] ] # Block position + gripper_pos = episode_data["observations"][-1][ gpi[0] ] # Gripper position + current_pos = episode_data["observations"][-1][ rpi[0]:rpi[1] ] # Panda/tcp position + current_vel = episode_data["observations"][-1][ rvi[0]:rvi[1] ] # Panda/t + + + # Print debug information + print( + f"Step: {time_step}, ErrNorm: {np.linalg.norm(error):.4f}, " + f"bot_pos: {np.array2string(current_pos, precision=3)}, " + f"obj_pos: {np.array2string(object_pos, precision=3)}, " + f"fgr_pos: {np.array2string(gripper_pos, precision=2)}, " + f"err: {np.array2string(error, precision=3)}, " + f"Action: {np.array2string(episode_data['actions'][-1], precision=3)}, " + f"dt: {dt:.4f}, reward: {reward:.3f}" + ) + + + return object_pos, gripper_pos, current_pos, current_vel + +def compute_error(object_pos, current_pos, prev_error, dt): + """Compute the error and its derivative between the object position and the current end-effector position. + Args: + object_pos (np.ndarray): The position of the object in the environment. + current_pos (np.ndarray): The current position of the end-effector. + prev_error (np.ndarray): The previous error value for derivative calculation. + dt (float): Time step for derivative calculation. + + Returns: + error (np.ndarray): The current error vector between the object and end-effector positions. + derror (np.ndarray): The derivative of the error vector. + """ + error = object_pos - current_pos # Calculate the error vector + derror = (error - prev_error) / dt # Calculate the derivative of the error vector + + prev_error = error.copy() # Update previous error for next iteration + return error, derror + +def demo(env, lastObs): + """ + Executes a scripted reach sequence using a hierarchical approach. + + Implements 1-phase control strategy: + 1. Approach: Move gripper above object (3cm offset) + + Store observations, actions, and info in global lists for later replay buffer inclusion. + + Gripper: + - The gripper in Mujoco ranges from a value of 0 to 0.4, where 0 is fully open and 0.4 is fully closed. + + Args: + env: Gymnasium environment instance + lastObs: Flattened observations set as object_pos, gripper_pos, panda/tcp_pos, panda/tcp_vel + - observations: + object_pos[0:3], + gripper_pos[3] + panda/tcp_pos[4:7], + panda/tcp_vel[7:10] + + Returns: + """ + + ## Init goal, current_pos, and object position from last observation + object_pos = np.zeros(3, dtype=np.float32) + current_pos = np.zeros(3, dtype=np.float32) + gripper_pos = np.zeros(1, dtype=np.float32) + object_rel_pos = np.zeros(3, dtype=np.float32) + + + # Initialize (single) episode data collection + episodeObs = [] # Observations for this episode + episodeAcs = [] # Actions for this episode + episodeRews = [] # Rewards for this episode + episodeInfo = [] # Info for this episode + episodeTerminated = [] # Terminated flags for this episode + episodeTruncated = [] # Truncated flags for this episode + episodeDones = [] # Done flags (terminated or truncated) for this episode + + # Dictionary to store episode data + episode_data = { + "observations": episodeObs, + "actions": episodeAcs, + "rewards": episodeRews, + "infos": episodeInfo, + "terminateds": episodeTerminated, + "truncateds": episodeTruncated, + "dones": episodeDones + } + + # close gripper + fgr_pos = 0 + + # Error thresholds + error_threshold = ERROR_THRESHOLD # Threshold for stopping condition (Xmm) + + finger_delta_fast = 0.05 # Action delta for fingers 5cm per step (will get clipped by controller)... more of a scalar. + finger_delta_slow = 0.005 # Franka has a range from 0 to 4cm per finger + + ## Extract data + object_pos = lastObs[opi[0]:opi[1]] # block pos + current_pos = lastObs[rpi[0]:rpi[1]] # panda/tcp_pos + + # Relative position between end-effector and object + dt = env.unwrapped.model.opt.timestep # Mujoco time step + prev_error = np.zeros_like(object_pos) + error, derror = compute_error(object_pos, current_pos, prev_error, dt) + + time_step = 0 # Track total time_steps in episode + episodeObs.append(lastObs) # Store initial observation + + # Initialize previous time for dt calculation + prev_time = perf_counter() # Start time for dt calculation + + # Phase 1: Reach + # Terminate when distance to above-object position < error_threshold + print(f"----------------------------------------------- Phase 1: Reach -----------------------------------------------") + while np.linalg.norm(error) >= error_threshold and time_step <= env.spec.max_episode_steps: + env.render() # Visual feedback + + # Record current time and compute dt + curr_time = perf_counter() + dt = curr_time - prev_time + prev_time = curr_time + + # Initialize action vector [x, y, z] + action = np.array([0., 0., 0.]) + + # Proportional control with gain of 6 + action[:3] = error * Kp + derror * Kv + prev_error = error.copy() # Update previous error for next iteration + + # Clip action to prevent excessive movements + action = np.clip(action/ACTION_MAX, -0.1, 0.1) # + + # Keep gripper closed -- no need. only 3 dimensions of control + #action[ len(action)-1 ] = -finger_delta_fast # Maintain gripper closed. + + # Unpack new Gymnasium step API + new_obs, reward, terminated, truncated, info = env.step(action) + done = terminated or truncated + + # Store episode data + store_transition_data(episode_data, new_obs, reward, action, info, terminated, truncated, done) + + # Update and print state information + object_pos,gripper_pos,cur_pos,cur_vel = update_state_info(episode_data, time_step, dt, error,reward) + + # Update error for next iteration + error, derror = compute_error(object_pos, cur_pos, prev_error, dt) + + # Update time step + time_step += 1 + + # Sleep + #if DEBUG: + sleep(0.25) # Activated when DEBUG is True for better visualization. + + # Store complete episode data in global lists only if we succeeded (avoid bad demos) + store_episode_data(episode_data) + + # Deactivate weld constraint after successful pick -- franka_sim env does not have weld like franka_mujoco env. + # if weld_flag: + # deactivate_weld(env, constraint_name="grasp_weld") + + # Break out of the loop to start a new episode + return True + + # # If we reach here, the episode was not successful + # if weld_flag: + # print("Failed to transport object to goal position. Deactivating weld.") + # deactivate_weld(env, constraint_name="grasp_weld") + +def main(): + """ + Orchestrates the data generation process by running multiple episodes + of the task. + + Creates environment, runs scripted episodes, and saves demonstration data + to compressed NPZ. + + Arguments that can be configured with flags: + - env + - render + - demo_ctr + + """ + # Initialize the Panda environment. + env = gym.make("PandaReachCube-v0", render_mode=_render_mode) + env = gym.wrappers.FlattenObservation(env) + + # Adjust physical settings + # env.model.opt.time_step = 0.001 # Smaller time_step for more accurate physics. Default is 0.002. + # env.model.opt.iterations = 100 # More solver iterations for better contact resolution. Default is 50. + + # Configuration parameters + initStateSpace = "random" # Initial state space configuration + + # Demos configs + num_demos = NUM_DEMOS # Number of demonstration episodes to generate--ADJUST THIS VALUE FOR MORE OR LESS DEMOS** + + demo_ctr = 0 # Counter for successful demonstration episodes + + # Reset environment to initial state - render for the first time. + obs, _ = env.reset() # For reach environment expect 10 observations: r_pos, r_vel, finger, object_pos. + + # Adjust camera view for better visualization + viewer = set_front_cam_view(env) + + print("Reset!") + + # Generate demonstration episodes + while len(actions) < num_demos: + obs,_ = env.reset() # Reset environment for new episode + + print(f"We will run a total of: {num_demos} demos!!") + print("Demo: #", len(actions)+1) + + # Execute pick-and-place task + res = demo(env, obs) + + # Print success message + if res: + demo_ctr += 1 + print("Episode completed successfully!") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + + # Close the environment after all episodes are done + env.close() + + ## Write data to demos folder. Assumes mounted /data folder and internal data folder. + script_dir = '/data/data/serl/demos' + + # Create output filename with configuration details + fileName = "data_" + robot + "_" + task + fileName += "_" + initStateSpace + fileName += "_" + str(num_demos) + + # Add timestamp to filename for uniqueness + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + fileName += "_" + timestamp + fileName += ".npz" + + # Build a filename in that same directory + out_path = os.path.join(script_dir, fileName) + + # Ensure the directory exists + os.makedirs(script_dir, exist_ok=True) + + # Save collected data to compressed numpy NPZ file + # Set acs,obs,info as keys in dict and values as np.arrays of type objects. This allows you to handle different lengths. + np.savez_compressed( + out_path, + acs=np.array(actions, dtype=object), + obs=np.array(observations, dtype=object), + rewards=np.array(rewards, dtype=object), + info=np.array(infos, dtype=object), + terminateds=np.array(terminateds, dtype=object), + truncateds=np.array(truncateds, dtype=object), + dones=np.array(dones, dtype=object), + transition_ctr=transition_ctr, + num_demos=num_demos + ) + + print(f"Data saved to {fileName}.") + print(f"Total successful demos: {demo_ctr}/{num_demos}") + +if __name__ == "__main__": + main() diff --git a/demos/demos/franka_reach_drq_demo_script.py b/demos/demos/franka_reach_drq_demo_script.py new file mode 100644 index 00000000..b793cda3 --- /dev/null +++ b/demos/demos/franka_reach_drq_demo_script.py @@ -0,0 +1,471 @@ +#!/usr/bin/env python3 +import os +import time +from datetime import datetime + +import numpy as np + +# Logging +from absl import app, flags, logging +from oxe_envlogger.envlogger import AutoOXEEnvLogger + +# DRL +import gym +import mujoco +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper + +# Needed to create the franka environment +import franka_sim + +# Teleoperation imports +import sys, select, termios, tty + +# RLDS/TFDS +import json, glob, inspect +import tensorflow as tf +import tensorflow_datasets as tfds +from tensorflow_datasets import folder_dataset +#------------------------------------------------------------------------------------------- +# Flags +#------------------------------------------------------------------------------------------- +flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") +flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_integer("max_traj_length", 200, "Maximum length of trajectory.") +flags.DEFINE_boolean("debug", True, "Debug mode.") # debug mode will disable wandb logging +#flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +flags.DEFINE_string("output_dir", "/data/data/serl/demos/franka_reach_drq_demo_script", + "Directory to save the output data. This is where the RLDS logs will be saved.") +flags.DEFINE_integer("num_demos", 2, "Number of episodes to log.") +flags.DEFINE_boolean("enable_envlogger", True, "Enable envlogger.") +flags.DEFINE_string("teleop_mode", "keyboard", "Teleoperation mode: 'keyboard' or 'spacemouse'.") + +FLAGS = flags.FLAGS + +#------------------------------------------------------------------------------------------- +## Telop Config Variables +#------------------------------------------------------------------------------------------- +ACTION_MAX = 10 # Maximum action value for clipping actions + +# Bind, xyz, gripper vals to keys +moveBindings = { + 'i':(1,0,0,0), + ',':(-1,0,0,0), + 'j':(0,1,0,0), + 'l':(0,-1,0,0), + 'u':(0,0,0,1), + 'o':(0,0,0,-1), + 'm':(0,0,1,0), + '.':(0,0,-1,0), + } + +# Extend bindings to include camera controls +camBindings = { + 'a': ("azimuth", -5), # rotate left + 'd': ("azimuth", 5), # rotate right + 'w': ("elevation", 2), # tilt up + 's': ("elevation", -2), # tilt down + 'q': ("distance", -0.1),# zoom in + 'e': ("distance", 0.1), # zoom out +} + +def update_camera(viewer,key): + """ + Update the camera view based on keyboard input. Assumes higher level function has checked for the existance of key in camBindings. + Controls: + 'a' : rotate left + 'd' : rotate right + 'w' : tilt up + 's' : tilt down + 'q' : zoom in + 'e' : zoom out + """ + if hasattr(viewer, 'cam'): + # Get current camera parameters + attr, delta = camBindings[key] + val = getattr(viewer.cam, attr) + + setattr(viewer.cam, attr, val + delta) + +def close_logger_and_env(env): + """ + Best-effort shutdown: + 1) close embedded envloggers/writers if present + 2) close dm_env (if you have a DeepMind-style env) + 3) close Gym/Gymnasium env + 4) close the Mujoco viewer + Also walks through common wrapper attributes (env, _env, unwrapped). + """ + import logging + + visited = set() + + def _safe_close(obj, label=""): + if obj is None or id(obj) in visited: + return + visited.add(id(obj)) + + # 1) Close common logger/writer attributes first (flush TFRecord) + for name in ("_envlogger", "envlogger", "logger", "_logger", "writer"): + try: + logger_obj = getattr(obj, name, None) + if logger_obj is not None and hasattr(logger_obj, "close"): + logger_obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.{name}.close() raised {e!r}") + + # 2) Close dm_env if present + try: + dm = getattr(obj, "dm_env", None) + if dm is not None and hasattr(dm, "close"): + dm.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.dm_env.close() raised {e!r}") + + # 3) Close Gym/Gymnasium env + try: + if hasattr(obj, "close"): + obj.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label}.close() raised {e!r}") + + # 4) Close Mujoco viewer if accessible + try: + # Some stacks keep viewer at env._viewer.viewer; others just env._viewer + viewer = None + vwrap = getattr(obj, "_viewer", None) + if vwrap is not None: + viewer = getattr(vwrap, "viewer", vwrap) + if viewer is not None and hasattr(viewer, "close"): + viewer.close() + except Exception as e: + logging.warning(f"close_logger_and_env: {label} viewer close raised {e!r}") + + # Recurse into common wrapper links + for child_name in ("env", "_env", "unwrapped", "environment", "base_env"): + child = getattr(obj, child_name, None) + if child is not None and child is not obj: + _safe_close(child, f"{label}.{child_name}" if label else child_name) + + _safe_close(env, "env") + +def finalize_tfds_metadata_beamless(builder_dir: str): + """ + Beam-free finalize: count TFRecord examples per shard and write + numShards/shardLengths into dataset_info.json so TFDS will load. + """ + import os, json, glob + import tensorflow as tf + + info_path = os.path.join(builder_dir, "dataset_info.json") + if not os.path.exists(info_path): + raise FileNotFoundError(f"Missing dataset_info.json in {builder_dir}") + + with open(info_path) as f: + info = json.load(f) + + ds_name = info["name"] # e.g. "PandaReachSparseCube-v0" + file_fmt = info.get("fileFormat", "tfrecord") + tmpl_str = info["splits"][0]["filepathTemplate"] + + # Prefer strict pattern "-.-" + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"{ds_name}-*.{file_fmt}-*"))) + if not shard_paths: + # Fallback to any tfrecord-like file + shard_paths = sorted(glob.glob(os.path.join(builder_dir, f"*.{file_fmt}*"))) + if not shard_paths: + raise FileNotFoundError( + f"No {file_fmt} shards found in {builder_dir}. " + f"Expected like '{ds_name}-train.{file_fmt}-00000' per template '{tmpl_str}'." + ) + + # Count episodes (1 Example = 1 episode with envlogger/RLDS) + shard_lengths = [sum(1 for _ in tf.data.TFRecordDataset(p)) for p in shard_paths] + + # Write lengths for each split using this template + for s in info["splits"]: + if s.get("filepathTemplate") == tmpl_str: + s["numShards"] = len(shard_paths) + s["shardLengths"] = shard_lengths + # Re-write dataset_info.json with updated shard info + with open(info_path, "w") as f: + json.dump(info, f, indent=2) + + # Sanity log + import tensorflow_datasets as tfds + b = tfds.builder_from_directory(builder_dir) + print("[finalize] splits:", {k: v.num_examples for k, v in b.info.splits.items()}) + +def ensure_dir_exists(): + """ + For oxe_envlogger + RLDS compatibility, data must be written in the following format: + /data/data/serl/demos/franka_reach_drq_demo_script/ + └── session_20250821_222412/ + └── PandaReachSparseCube-v0/ + └── 0.1.0/ + dataset_info.json + features.json + PandaReachSparseCube-v0-train.tfrecord-00000-of-00001 + + We can have a base path with customized sessions inside. + Inside each session we have: env-version-files + + + Returns + ------- + out_path : str + The path to the output directory. + """ + # Customize the path + root = FLAGS.output_dir + session = datetime.now().strftime("session_%Y%m%d_%H%M%S") + session_root = os.path.join(root, f"{FLAGS.num_demos}_demos_{session}") + #session_root = os.path.join(root, f"{session}_num_demos_{FLAGS.num_demos}") + + # Dataset details + dataset_name = FLAGS.env + version = "0.1.0" # PArt of RLDS format. needed. + + # Create output filename with configuration details + dataset_dir = os.path.join(session_root, dataset_name, version) + os.makedirs(dataset_dir, exist_ok=True) + logging.info(f"TFDS builder dir: {dataset_dir}") + + return dataset_dir + +def getKey(settings): + """ + Waits briefly for a keypress and returns the pressed key. + + Parameters + ---------- + settings : list + Original terminal settings so they can be restored after reading. + + Returns + ------- + key : str + The key pressed by the user, or '' (empty string) if no key was pressed. + """ + # Put terminal into raw mode so keypress is captured instantly + tty.setraw(sys.stdin.fileno()) + + # Wait for human to input action, we will not provide a timeout so it is blocking. + rlist, _, _ = select.select([sys.stdin], [], []) # select.select(rlist, wlist, xlist[, timeout]) + + if rlist: + # Read exactly one character if a key is pressed + key = sys.stdin.read(1) + else: + key = '' + + # Restore terminal to original settings + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + return key + +def get_kb_demo_action(env,speed=0.075): + """ + Reads keyboard input and maps it to a 3D action vector for robot control or camera action. + TODO: currently can only read one key at a time. Needs to be extended to read multiple keys to handle both. + Otherwise, none actions are still considered steps in the loop. + + The function uses non-blocking keyboard input to allow interactive + teleoperation. Keys are mapped to directions in Cartesian space: + - 'i' : +x (forward) + - ',' : -x (backward) + - 'j' : +y (left) + - 'l' : -y (right) + - 'm' : +z (up) + - '.' : -z (down) + - 'k' : stop (zero vector) + + Parameters + ---------- + speed : float, optional + Step size for each key press (default 0.2). + + Returns + ------- + np.ndarray + Action vector of shape (3,), where each entry corresponds to + [x, y, z] translation command. Example: [0.2, 0.0, 0.0]. + """ + # Save current terminal settings so we can restore later + settings = termios.tcgetattr(sys.stdin) + + # Initialize action as a zero vector (no movement) + action = np.zeros(4, dtype=float) + + try: + # Capture the pressed key + key = getKey(settings) + + # Check keys for camera first, if so, update camera and get another key for action + if key in camBindings: + if hasattr(env.unwrapped, "_viewer"): + update_camera(env.unwrapped._viewer.viewer,key) + key = getKey(settings) # get another key for action + + elif key in moveBindings: + # Lookup (x, y, z) direction and scale by speed + dx, dy, dz, g= moveBindings[key] + action = np.array([dx, dy, dz, g], dtype=float) * speed + + elif key == 'k': + # 'k' means stop β†’ zero vector + action = np.zeros(4, dtype=float) + + elif key == '\x03': # CTRL-C + raise KeyboardInterrupt + + finally: + # Restore terminal even if something goes wrong + termios.tcsetattr(sys.stdin, termios.TCSADRAIN, settings) + + # Clip action values to prevent excessive commands + action = np.clip(action, -ACTION_MAX, ACTION_MAX) + + return action + +def set_front_cam_view(env): + """ + Set the camera view to a front-facing perspective for better visualization. + + Args: + env: The environment instance containing the viewer. + + Returns: + viewer: The viewer with updated camera settings. + """ + viewer = env.unwrapped._viewer.viewer # Access the viewer from the environment + + if hasattr(viewer, 'cam'): + viewer.cam.lookat[:] = [0, 0, 0.1] # Center of robot (adjust as needed) + viewer.cam.distance = 2.0 # Camera distance + viewer.cam.azimuth = 155 # 0 = right, 90 = front, 180 = left + viewer.cam.elevation = -30 # Negative = above, positive = below + + # Hide menu + viewer._hide_overlay = True + + return viewer + +############################################################################## +def main(unused_argv): + logging.info(f'Creating gym environment...') + + # Render mode configuration based on debug flag + if FLAGS.debug: + _render_mode = 'human' # Render mode for the environment, can be 'human' or 'rgb_array' + else: + _render_mode = 'rgb_array' # Use 'rgb_array' for automated testing without GUI + + # Create the environment with the specified render mode and wrappers + env = gym.make(FLAGS.env, render_mode=_render_mode) + + if FLAGS.env == "PandaPickCube-v0": + env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env == "PandaReachSparseCube-v0": + env = SERLObsWrapper( + env, + target_hw=(128, 128), + img_dtype=np.uint8, # or np.float32 + normalize=False, # True if using float32 in [0,1] + ) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + + logging.info(f'Done creating {FLAGS.env} environment.') + + # Set camera to front view if viewer is available + if hasattr(env.unwrapped, '_viewer'): + viewer = set_front_cam_view(env) + if viewer: + logging.info('Camera view set to front-facing perspective.') + else: + logging.warning('Failed to set camera view. Viewer not available.') + + # Wrap with oxe_envlogger to record demos + dataset_dir = None + session_root = None + if FLAGS.enable_envlogger: + + dataset_dir = ensure_dir_exists() + + # Will save as many episodes as possible into files of 200MB each by default. + env = AutoOXEEnvLogger( + env=env, + dataset_name=FLAGS.env, + directory=dataset_dir, + #split_name="train", # "train", "test", or "validation" + ) + logging.info('Recording %r demos...', FLAGS.num_demos) + + #--- LOOP DEMOS/EPISODES --- + # Loop through the number of demos specified by the user to record demonstrations + try: + for i in range(FLAGS.num_demos): + + # Log custom metadata during new episode: language embeddings randomly. + if FLAGS.enable_envlogger: + # The "language_embedding" is a standard field used in robotics datasets (like the OXE format that envlogger creates) to store a numerical representation of a natural language instruction for an episode. + # How to Reconcile it with 5 Random Numbers: The five random numbers are just placeholder data. This script is a demonstration and doesn't involve a real language model. + env.set_episode_metadata({ + "language_embedding": np.random.random((5,)).astype(np.float32) + }) + env.set_step_metadata({"timestamp": time.time()}) + + logging.info('episode %r', i) + + # Start a new episode + env.reset() + terminated = False + truncated = False + + step = 0 + + # Termination occurs when the hand reaches the target or the maximum trajectory length is reached + while not (terminated or truncated): + + # Get action from the demo function + action = get_kb_demo_action(env) + + # example to log custom step metadata + if FLAGS.enable_envlogger: + env.set_step_metadata({"timestamp": np.float32(time.time())}) + + return_step = env.step(action) + + # NOTE: to handle gym.Env.step() return value change in gym 0.26 + if len(return_step) == 5: + obs, reward, terminated, truncated, info = return_step + else: + obs, reward, terminated, info = return_step + truncated = False + + print(f" step: {step}", f"reward: {reward:.3f}\n") + step += 1 + + logging.info('Done recording %r demos.', FLAGS.num_demos) + + finally: + # Finalize TFDS metadata so SERL/TFDS can load the split + if FLAGS.enable_envlogger and dataset_dir is not None: + + # Close the environment to flush/write data. AutoOXEEnvLogger implements dm_env.close() + # which flushes data to disk. This is important to ensure all data is written before finalizing metadata. + # If you skip this step, some data may not be written and the dataset may be incomplete. + logging.info("Closing environment to flush data to disk...") + + # closes logger(s) + dm_env + gym + viewer + # env.unwrapped.unwrapped.unwrapped.close() # close mujoco viewer + # env.env.env.env.close() + close_logger_and_env(env) + + # Scan files and write metadata + finalize_tfds_metadata_beamless(dataset_dir) + + # Note: async_drq_sim will read from the dataset_dir you printed above. + +if __name__ == '__main__': + app.run(main) \ No newline at end of file diff --git a/demos/demos/load_demo_test.py b/demos/demos/load_demo_test.py new file mode 100644 index 00000000..8fd33aad --- /dev/null +++ b/demos/demos/load_demo_test.py @@ -0,0 +1,127 @@ +# Updated and advanced train.py that includes logging, vectorized environments, and periodic recorded evaluations +import os +from pathlib import Path +import numpy as np + +def load_demos_to_her_buffer_gymnasium(data_store, demo_npz_path: str, combine_done: bool = True): + """ + Load a raw Gymnasium-style .npz of expert episodes into model.replay_buffer. + + demo_npz_path must contain at least these arrays: + - 'episodeObs' : shape (T+1, *obs_shape*), list of observations + - 'episodeAcs' : shape (T, *act_shape*), list of actions + - 'episodeRews' : shape (T,), list of rewards + - 'episodeTerminated' : shape (T,), list of terminated flags + - 'episodeTruncated' : shape (T,), list of truncated flags + - 'episodeInfo' : shape (T,), list of info dicts + + Parameters + ---------- + data_store : DataStore + The data store to which the demo transitions will be added. + demo_npz_path : str + Path to the .npz file you saved from your demo collector. + combine_done : bool, default=True + If True, `done = terminated or truncated`. If False, `done = terminated` only. + """ + + # Load all demo data. Structure: var_name[num_demo][time_step][key if dict] = value + data = np.load(demo_npz_path, allow_pickle=True) + + obs_buffer = data['obs'] # length T+1 + act_buffer = data['acs'] # length T + rew_buffer = data['rewards'] # length T + term_buffer = data['terminateds'] # length T + trunc_buffer = data['truncateds'] # length T + info_buffer = data['info'] # length T + done_buffer = data['dones'] # length T, if available + + # Extract number of demonstrations + num_demos = obs_buffer.shape[0] + + # Extract rollout data for a single episode + for ep in range(num_demos): + ep_obs = obs_buffer[ep] # this is a length‐(T+1) array of dicts + ep_acts = act_buffer[ep] # length‐T array of actions + ep_rews = rew_buffer[ep] + ep_terms = term_buffer[ep] + ep_trunc = trunc_buffer[ep] + ep_done = done_buffer[ep] + ep_info = info_buffer[ep] # length‐T array of dicts + + # Length of episode: + T = len(ep_acts) + + # Extract single transitions from the episode data + for t in range(T): + # raw single‐step data: + obs_t = ep_obs[t] # dict[str, np.ndarray] (obs_dim,) + next_obs_t = ep_obs[t+1] + a_t = ep_acts[t] # np.ndarray (action_dim,) + r_t = float(ep_rews[t]) + done_t = bool(ep_done[t] or ep_terms[t] or ep_trunc[t]) + + # Rehydrate info dict and inject the timeout flag + raw_info = ep_info[t] # dict[str,Any] + if isinstance(raw_info, str): + import ast + info_t = ast.literal_eval(raw_info) + else: + info_t = raw_info.copy() + # Append truncated information to info_t + info_t["TimeLimit.truncated"] = bool(ep_trunc[t]) + + # Enter transition into data_store or QueuedDataStore + data_store.insert( + dict( + observations=obs_t, + actions=a_t, + next_observations=next_obs_t, + rewards=r_t, + masks=1.0 - done_t, + dones=done_t + ) + ) + + print(f"Can load {num_demos} transitions successfullly from {demo_npz_path}." + f"(combine_done={combine_done}).") + +def get_demo_path(relative_path: str) -> str: + """ + Given a path relative to this script file, return + the absolute, normalized path as a string. + + Example: + # If your demos live at ../../../demos/data.npz + demo_file = get_demo_path("../../../demos/data_franka_random_10.npz") + """ + # 1) Resolve this script’s directory + script_dir = Path(__file__).resolve().parent + + # 2) Join with the user-supplied relative path and normalize + full_path = (script_dir / relative_path).resolve() + + return str(full_path) + + +def main(): + """ + Load demos + """ + + # Get abs path to demo file + script_dir = '/data/data/serl/demos' + default_file = 'data_franka_reach_random_20.npz' + + prompt = f"Please input the name of the file to load [{default_file}]: " + + file_name = input(prompt) or default_file + demo_file = os.path.join(script_dir, file_name) + + # Load the demo file into data_store as in async_sac_state.py + from agentlace.data.data_store import QueuedDataStore + data_store = QueuedDataStore(2000) + load_demos_to_her_buffer_gymnasium(data_store, demo_file, combine_done=True) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/demos/oxe_envlogger b/demos/oxe_envlogger new file mode 160000 index 00000000..de3c48bc --- /dev/null +++ b/demos/oxe_envlogger @@ -0,0 +1 @@ +Subproject commit de3c48bcf094ebba350ce0ba183efca4478a501a diff --git a/demos/requirements.txt b/demos/requirements.txt new file mode 100644 index 00000000..febfe029 --- /dev/null +++ b/demos/requirements.txt @@ -0,0 +1,4 @@ +tensorflow-metadata==1.17.2 +apache-beam==2.67.0 +protobuf>=4.21.6,<4.22 +git+https://github.com/rail-berkeley/oxe_envlogger.git@main?? diff --git a/demos/setup.py b/demos/setup.py new file mode 100644 index 00000000..d6259a7e --- /dev/null +++ b/demos/setup.py @@ -0,0 +1,7 @@ +from setuptools import setup, find_packages + +setup( + name="demos", + version="0.1", + packages=find_packages(), +) diff --git a/docs/images/cable-routing-realrobot.png b/docs/images/cable-routing-realrobot.png new file mode 100644 index 00000000..cdb8f389 Binary files /dev/null and b/docs/images/cable-routing-realrobot.png differ diff --git a/docs/images/cable-success-rate-with-robot.png b/docs/images/cable-success-rate-with-robot.png new file mode 100644 index 00000000..b7780ca2 Binary files /dev/null and b/docs/images/cable-success-rate-with-robot.png differ diff --git a/docs/images/fetch_rewards_function_grid_size.png b/docs/images/fetch_rewards_function_grid_size.png new file mode 100644 index 00000000..4a37f7bd Binary files /dev/null and b/docs/images/fetch_rewards_function_grid_size.png differ diff --git a/docs/images/fractal_grid.png b/docs/images/fractal_grid.png new file mode 100644 index 00000000..a8a75da6 Binary files /dev/null and b/docs/images/fractal_grid.png differ diff --git a/docs/images/peg-insert-realrobot.png b/docs/images/peg-insert-realrobot.png new file mode 100644 index 00000000..b1c54400 Binary files /dev/null and b/docs/images/peg-insert-realrobot.png differ diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 00000000..079c17d6 --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,105 @@ +# Installation Guide + +This document provides step-by-step instructions to set up FractalSERL on your system. The installation includes setting up the Conda environment, installing JAX with GPU support, and installing the core packages. + +## Prerequisites + +### Hardware Requirements + +For optimal performance and to reproduce the experiments in our paper, we recommend the following system configuration: + +- **Processor:** AMD Ryzen Threadripper 1950x or equivalent (16+ cores recommended) +- **RAM:** 128 GB +- **GPU:** NVIDIA RTX 4070 (12 GB VRAM) or better +> **Note:** Baseline experiments were conducted with an RTX 4090 (82.6 FP32 TFLOPS, 1008 GB/s bandwidth). The RTX 4070 has approximately 2.84Γ— lower FP32 compute (29.1 TFLOPS) and 2Γ— lower bandwidth (504 GB/s), making our results on RTX 4070 all the more significant. + +### Software Requirements + +- **Operating System:** Ubuntu 20.04 LTS +- Python 3.10 +- Conda (Miniconda or Anaconda) +- CUDA Toolkit 12.0+ (for GPU support) + +> ⚠️ **Note on End-of-Life Support:** Ubuntu 20.04 LTS reaches end-of-standard support in April 2025. If using Ubuntu 20.04 for real-robot applications requiring ROS1, note that ROS1 is also in end-of-life status. For new deployments, consider upgrading to Ubuntu 22.04 LTS with ROS2. + +## Installation Steps + +### 1. Setup Conda Environment + +Create a new Conda environment with Python 3.10: + +```bash +conda create -n serl python=3.10 +conda activate serl +``` + +### 2. Install JAX + +Choose the installation method based on your hardware: + +#### For GPU (Recommended for RTX 4070 / 4090): + +```bash +pip install --upgrade "jax[cuda12]==0.6.2" +``` + +#### For CPU (Not Recommended): + +```bash +pip install --upgrade "jax[cpu]" +``` + +#### For TPU: + +```bash +pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +``` + +For more details on JAX installation, see the [JAX GitHub page](https://github.com/google/jax). + +### 3. Install serl_launcher + +Navigate to the `serl_launcher` directory and install: + +```bash +cd serl_launcher +pip install -e . +pip install -r requirements.txt +``` + +### 4. Install franka_sim + +Navigate to the `franka_sim` directory and install: + +```bash +cd ../franka_sim +pip install -e . +pip install -r requirements.txt +``` + +### 5. Install serl_robot_infra + +Navigate to the `serl_robot_infra` directory and install: + +```bash +cd ../serl_robot_infra +pip install -e . +``` + +### 6. Install demos + +Navigate to the `demos` directory and install: + +```bash +cd ../demos +pip install -e . +``` + +Navigation +---------- +- [Home](../README.md) +- [Overview](overview.md) +- [Run in simulation](run_sim.md) +- [Run on the real robot](run_realrobot.md) +- [Training options](sim_training.md) +- [Collecting demonstrations](sim_demonstrations.md) diff --git a/docs/overview.md b/docs/overview.md new file mode 100644 index 00000000..e938ce87 --- /dev/null +++ b/docs/overview.md @@ -0,0 +1,41 @@ +# Overview β€” FractalSERL + +This document summarizes the code structure, system design, folder layout, and next steps for FractalSERL β€” our contribution and extension to the original SERL project. + +Module summary +------------------------ + +| Module | Purpose | Key files / subfolders | +|---|---|---| +| `demos/` | Demo scripts and logging | `demos/`, `oxe_envlogger/` | +| `examples/` | Training pipelines and launch scripts | `async_*/`, `bc_policy.py`, `run_*.sh` | +| `franka_sim/` | Simulation environments and controllers | `envs/`, `controllers/`, `mujoco_gym_env.py` | +| `serl_launcher/` | Core RL algorithms, agents, and utilities | `agents/`, `networks/`, `data/`, `wrappers/`, `utils/` | +| `serl_robot_infra/` | Real-robot interfaces and servers | `franka_env/`, `robot_servers/`, `camera/`, `spacemouse/` | + +Each top-level package includes a `setup.py` so it can be installed in editable mode (`pip install -e .`) during development. A `requirements.txt` file lists runtime dependencies used by demo and example scripts. + + +System design +----------------------- + +![Runtime architecture: actor/learner architecture](images/software_design.png) + +Actor node collects data from Gym-compatible environments (sim or real) and push transitions to a datastore/replay buffer; learner node consumes that data to update policies and periodically push updated weights back to actor node. Communication is asynchronous, using `agentlace` in experiments, so collection and learning scale independently. + +Key points: +- **Actor node:** environment stepping, action sampling, transition senders. +- **Learner node:** gradient updates, replay-buffer consumer, policy synchronization. +- **Environment wrappers:** `serl_launcher/wrappers/` provide a consistent Gym API across sim and real. +- **Hardware servers:** ROS code and hardware commands are located in `serl_robot_infra/robot_servers/`. + +Navigation +----------------------- +- [Home](../README.md) +- [Installation guide](installation.md) +- [Run in simulation](run_sim.md) +- [Run on the real robot](run_realrobot.md) +- [Training options](sim_training.md) +- [Collecting demonstrations](sim_demonstrations.md) + + diff --git a/docs/real_franka.md b/docs/real_franka.md old mode 100644 new mode 100755 index 9642b607..89662707 --- a/docs/real_franka.md +++ b/docs/real_franka.md @@ -1,13 +1,12 @@ # Run with Franka Arm on Real Robot -We demonstrate how to use SERL with real robot manipulators with 4 different tasks. Namely: Peg Insertion, PCB Component Insertion, Cable Routing, and Object Relocation. We provide detailed instruction on how to reproduce the Peg Insertion task as a setup test for the entire SERL package. When running with a real robot, a separate gym env is needed. For our examples, we isolated the gym env as a client to a robot server. The robot server is a Flask server that sends commands to the robot via ROS. The gym env communicates with the robot server via post requests. ![](./images/robot_infra_interfaces.png) -### Installation for `serl_robot_infra` +### Installation for `serl_franka_controllers` Follow the [README](../serl_robot_infra/README.md) in `serl_robot_infra` for installation and basic robot operation instructions. This contains the instruction for installing the impendence-based [serl_franka_controllers](https://github.com/rail-berkeley/serl_franka_controllers). @@ -17,7 +16,6 @@ After the installation, you should be able to run the robot server, interact wit ## 1. Peg Insertion πŸ“ -![](./images/peg.png) > Example is located in [examples/async_peg_insert_drq/](../examples/async_peg_insert_drq/) @@ -74,39 +72,8 @@ env = RecordEpisodeStatistics(env) # record episode statistics ``` -### 2. PCB Component Insertion πŸ–₯️ +### 2. Cable Routing πŸ”Œ -![](./images/pcb.png) - -> Example is located in [examples/async_pcb_insert_drq/](../examples/async_pcb_insert_drq/) - -> Env and default config are located in `serl_robot_infra/franka_env/envs/pcb_env/` - -Similar to peg insertion task, we define the reward in this task is given by checking whether the end-effector pose matches a fixed target pose. Update the `TARGET_POSE` in [peg_env/config.py](../serl_robot_infra/franka_env/envs/peg_env/config.py) with the measured end-effector pose. - -Here we record demo trajectories with the robot, then run the learner and actor nodes. -```bash -# record demo trajectories -python record_demo.py - -# run learner and actor nodes -bash run_learner.sh -bash run_actor.sh -``` - -A baseline of using BC as policy is also provided. To train BC, simply run the following command: -```bash -python3 examples/bc_policy.py ....TODO_ADD_ARGS..... -``` - -To run the BC policy, simply run the following command: -```bash -bash run_bc.sh -``` - -### 3. Cable Routing πŸ”Œ - -![](./images/cable.png) > Example is located in [examples/async_cable_routing_drq/](../examples/async_cable_routing_drq/) @@ -128,11 +95,9 @@ The reward classifier is used as a gym wrapper `franka_env.envs.wrapper.BinaryRe The reward classifier is then used in the BC policy and DRQ policy for the actor node, the path is provided as `--reward_classifier_ckpt_path` argument in `run_bc.sh` and `run_actor.sh` -### 4. Object Relocation πŸ—‘οΈ +### 3. Object Relocation πŸ—‘οΈ -![](./images/forward.png) -![](./images/backward.png) > Example is located in [examples/async_bin_relocation_fwbw_drq/](../examples/async_bin_relocation_fwbw_drq/) @@ -170,3 +135,14 @@ bash run_actor.sh bash run_fw_learner.sh bash run_bw_learner.sh ``` + +Navigation +---------- +- [Home](../README.md) +- [Overview](overview.md) +- [Installation guide](installation.md) +- [Run in simulation](run_sim.md) +- [Run on the real robot](run_realrobot.md) +- [Quick start](sim_quick_start.md) +- [Training options](sim_training.md) +- [Collecting demonstrations](sim_demonstrations.md) diff --git a/docs/run_realrobot.md b/docs/run_realrobot.md new file mode 100644 index 00000000..6b20e396 --- /dev/null +++ b/docs/run_realrobot.md @@ -0,0 +1,114 @@ +# Run with Franka Arm on Real Robot + +This guide covers the real-robot setup for FractalSERL. The robot stack is split into a Flask server that sends commands to the robot via ROS and a Gym environment client that communicates with that server over HTTP. + +![](./images/robot_infra_interfaces.png) + +## Installation for `serl_franka_controllers` + +Follow the [SERL Robot Infra README](../serl_robot_infra/README.md) for installation and basic robot operation instructions. That page includes the impedance-based [serl_franka_controllers](https://github.com/rail-berkeley/serl_franka_controllers) setup. + +After installation, you should be able to start the robot server and interact with the hardware Gym environment. + +> NOTE: The example code below is a template. It assumes you have your own robot setup, camera calibration, data, and checkpoints. + +## 1. Peg Insertion + + +![peg-insert](../docs/images/peg-insert-realrobot.png) + +**Example:** [examples/async_peg_insert_drq/](../examples/async_peg_insert_drq/) + +**Environment and Default Configuration:** [serl_robot_infra/franka_env/envs/peg_env/](../serl_robot_infra/franka_env/envs/peg_env/) + +**Wrapper:** `franka_env.envs.wrappers.SpacemouseIntervention` + +Peg insertion is recommended as the first task for validating the real-robot stack. It is the simplest setup for checking the server, cameras, reward target, and training loop. + +### Procedure +1. Prepare the peg, board, and workspace. Fix the board in place and mount the peg in the gripper. +2. Mount the wrist cameras and update the camera serial numbers in `peg_env/config.py`. +3. Adjust the wrist-camera mass in Desk so the controller matches the payload. +4. Unlock the robot, enable FCI, and start the server: + + ```bash + python serl_robot_infra/robot_servers/franka_server.py \ + --gripper_type= \ + --robot_ip= \ + --gripper_ip=<[Optional] Robotiq_gripper_IP> + ``` + +5. Use the pose and gripper endpoints to measure the target pose, then update `TARGET_POSE` in `peg_env/config.py`. +6. Set `RANDOM_RESET=False` while debugging the base task. +7. Record demonstrations with the spacemouse: + + ```bash + cd examples/async_peg_insert_drq + python record_demo.py + ``` + +8. Update `demo_path` and `checkpoint_path` in the actor and learner scripts, then launch training. +9. Evaluate checkpoints with `--eval_checkpoint_step` and `--eval_n_trajs` in `run_actor.sh`. + +### Wrapper stack + +```python +env = gym.make('FrankaPegInsert-Vision-v0') +env = GripperCloseEnv(env) +env = SpacemouseIntervention(env) +env = RelativeFrame(env) +env = Quat2EulerWrapper(env) +env = SERLObsWrapper(env) +env = ChunkingWrapper(env) +env = RecordEpisodeStatistics(env) +``` + +## 2. Cable Routing + +![cable-routing](../docs/images/cable-routing-realrobot.png) + + +**Example:** [examples/async_cable_routing_drq/](../examples/async_cable_routing_drq/) + +**Env and default config:** [serl_robot_infra/franka_env/envs/cable_env/](../serl_robot_infra/franka_env/envs/cable_env/) + +Cable routing uses an image-based reward classifier instead of a fixed target pose. Train the classifier on successful and failed trajectories, then pass its checkpoint to the actor and learner scripts. + +```bash +python train_reward_classifier.py \ + --classifier_ckpt_path CHECKPOINT_OUTPUT_DIR \ + --positive_demo_paths PATH_TO_POSITIVE_DEMO1.pkl \ + --positive_demo_paths PATH_TO_POSITIVE_DEMO2.pkl \ + --negative_demo_paths PATH_TO_NEGATIVE_DEMO1.pkl +``` + +The classifier is used with `franka_env.envs.wrapper.BinaryRewardClassifier` so the policy can train from an observation-based reward. + +## 3. Object Relocation TODO + + + +## Navigation + +- [Home](../README.md) +- [Overview](overview.md) +- [Installation guide](installation.md) +- [Run in simulation](run_sim.md) +- [Quick start](sim_quick_start.md) +- [Training options](sim_training.md) +- [Collecting demonstrations](sim_demonstrations.md) diff --git a/docs/run_sim.md b/docs/run_sim.md new file mode 100644 index 00000000..0a15af1a --- /dev/null +++ b/docs/run_sim.md @@ -0,0 +1,46 @@ +# Running FractalSERL in Simulation + +This guide covers how to train RL policies in MuJoCo simulation using FractalSERL. The simulation environment includes a Franka Panda robot arm and manipulation tasks (e.g., reaching, cube lift). We support both state-based and image-based training, with optional behavior-cloning initialization via demonstrations. + +## Prerequisites + +Before starting, ensure you have: + +1. **Installed FractalSERL** β€” follow the [installation guide](installation.md) to set up all packages. +2. **Verified `franka_sim`** β€” test the simulation environment: + ```bash + python franka_sim/franka_sim/test/test_gym_env_human.py + ``` +3. **(Optional) `tmux` installed** β€” for convenient parallel actor/learner launch: + ```bash + sudo apt install tmux + ``` + + +## Environment overview + +The default simulation task is **Reach**, where: + +- **State space:** end-effector position/orientation, velocities, gripper state, target block position. +- **Action space:** 3D delta movements (Ξ”x, Ξ”y, Ξ”z). +- **Reward:** Dense, based on distance to target: $r = \text{clip}(e^{-20d}, 0, 1)$ where $d$ is Euclidean distance. +- **Images:** When enabled, two RGB wrist-mounted camera views replace explicit block position. +- **Episode length:** 100 steps. + +This task enables testing of **branched symmetries and fractal variants** as described in the paper (pure translations applied to positional components only). + +## Training guide + +Choose your training approach based on your needs: + +- **[Training options](sim_training.md)** β€” State-based SAC, Image-based DRQ, or DRQ with behavior cloning initialization. +- **[Collecting demonstrations](sim_demonstrations.md)** β€” Record human teleoperated demos for initialization or analysis. + +Navigation +---------- +- [Home](../README.md) +- [Overview](overview.md) +- [Installation guide](installation.md) +- [Training options](sim_training.md) +- [Collecting demonstrations](sim_demonstrations.md) +- [Run on the real robot](run_realrobot.md) diff --git a/docs/sim_demonstrations.md b/docs/sim_demonstrations.md new file mode 100644 index 00000000..905786c8 --- /dev/null +++ b/docs/sim_demonstrations.md @@ -0,0 +1,95 @@ +# Collecting Demonstrations + +The `demos/` folder contains utilities to **collect and record human teleoperated demonstrations** in simulation. + +--- + +## Recording demos via keyboard teleoperation + +Use a keyboard controller to teleoperate the Franka arm in simulation. Each trajectory is automatically recorded. + +### Reach task: + +```bash +python demos/demos/franka_reach_demo_script.py +``` + +### Pick-and-place task: + +```bash +python demos/demos/franka_pick_n_place_demo_script.py +``` +--- + +## Keyboard controls + +Typical keyboard controls for demo collection (check your script for exact bindings): + +| Key | Action | +|-----|--------| +| Arrow keys or WASD | Move end-effector (X, Y, Z) | +| Q/E | Rotate gripper | +| Space | Toggle gripper open/close | +| R | Reset environment | +| Esc | Exit | + +--- + +## Loading and inspecting demos + +The `demoHandling.py` utilities let you inspect and manipulate demo trajectories: + +```python +from demos.demoHandling import load_demos, save_demos + +# Load recorded trajectories +trajectories = load_demos('/path/to/demo/file.pkl') + +# Inspect structure +print(f"Number of trajectories: {len(trajectories)}") +for i, traj in enumerate(trajectories): + print(f"Trajectory {i}: {len(traj)} steps") + # Inspect fields: traj['observations'], traj['actions'], traj['rewards'] + +# Save processed trajectories +save_demos(trajectories, '/path/to/new/demo/file.pkl') +``` + +--- + +## Using demos for training + +Once you have collected demonstrations, use them to initialize behavior cloning: + + + +```bash +cd examples/async_drq_sim +bash run_learner.sh --demo_path /path/to/your/demos.pkl +``` + + + +Navigation +---------- +- [Home](../README.md) +- [Overview](overview.md) +- [Installation guide](installation.md) +- [Run in simulation](run_sim.md) +- [Training options](sim_training.md) +- [Run on the real robot](run_realrobot.md) diff --git a/docs/sim_quick_start.md b/docs/sim_quick_start.md deleted file mode 100644 index 183bf61d..00000000 --- a/docs/sim_quick_start.md +++ /dev/null @@ -1,171 +0,0 @@ -# Quick Start with SERL in Sim - -This is a minimal mujoco simulation environment for training with SERL. The environment consists of a panda robot arm and a cube. The goal is to lift the cube to a target position. The environment is implemented using `franka_sim` and `gym` interface. - -![](./images/franka_sim.png) - -## Installation - -**Install Franka Sim library** -```bash - cd franka_sim - pip install -e . - pip install -r requirements.txt -``` - -Try if `franka_sim` is running via `python franka_sim/franka_sim/test/test_gym_env_human.py`. - -Before beginning, please make sure that the simulation environment with `franka_sim` is working. - -*Note: to set `MUJOCO_GL` as egl if you are doing off-screen rendering. -You can do so by ```export MUJOCO_GL=egl``` and remember to set the rendering argument to False in the script. -If receives `Cannot initialize a EGL device display due to GLIBCXX not found` error, try run `conda install -c conda-forge libstdcxx-ng` ([ref](https://stackoverflow.com/a/74132234))* - - -Optionally install `tmux`: `sudo apt install tmux` - -## 1. Training from state observation example - -**✨ One-liner launcher (requires `tmux`) ✨** -```bash -bash examples/async_sac_state_sim/tmux_launch.sh -``` - -To kill the tmux session, run `tmux kill-session -t serl_session`. - -### Without using one-liner tmux launcher - -You can opt for running the commands individually in 2 different terminals. - -```bash -cd examples/async_sac_state_sim -``` - -Run learner node: -```bash -bash run_learner.sh -``` - -Run actor node with rendering window: -```bash -# add --ip x.x.x.x if running on a different machine -bash run_actor.sh -``` - -You can optionally launch the learner and actor on separate machines. For example, if the learner node is running on a PC with `ip=x.x.x.x`, you can launch the actor node on a different machine with internet access to `ip=x.x.x.x` and add `--ip x.x.x.` to the commands in `run_actor.sh`. - -Remove `--debug` flag in `run_learner.sh` to upload training stats to `wandb`. - -## 2. Training from image observation example - -**✨ One-liner launcher (requires `tmux`) ✨** - -```bash -bash examples/async_drq_sim/tmux_launch.sh -``` - -### Without using one-liner tmux launcher - -You can opt for running the commands individually in 2 different terminals. - -```bash -cd examples/async_drq_sim - -# to use pre-trained ResNet weights, please download -wget https://github.com/rail-berkeley/serl/releases/download/resnet10/resnet10_params.pkl -``` - -Run learner node: -```bash -bash run_learner.sh -``` - -Run actor node with rendering window: -```bash -# add --ip x.x.x.x if running on a different machine -bash run_actor.sh -``` - -## 3. Training from image observation with 20 demo trajectories example - -**✨ One-liner launcher (requires `tmux`) ✨** -```bash -bash examples/async_drq_sim/tmux_rlpd_launch.sh -``` - -### Without using one-liner tmux launcher - -You can opt for running the commands individually in 2 different terminals. - -```bash -cd examples/async_drq_sim - -# to use pre-trained ResNet weights, please download -# note manual download is only for now, once repo is public, auto download will work -wget https://github.com/rail-berkeley/serl/releases/download/resnet10/resnet10_params.pkl - -# download 20 demo trajectories -wget \ -https://github.com/rail-berkeley/serl/releases/download/franka_sim_lift_cube_demos/franka_lift_cube_image_20_trajs.pkl -``` - -Run learner node, while provide the path to the demo trajectories in the `--demo_path` argument. -```bash -bash run_learner.sh --demo_path franka_lift_cube_image_20_trajs.pkl -``` - -Run actor node with rendering window: -```bash -# add --ip x.x.x.x if running on a different machine -bash run_actor.sh -``` - -## Use RLDS logger to save and load trajectories - -This provides a way to save and load trajectories for SERL training. [Tensorflow RLDS dataset](https://github.com/google-research/rlds) format is used to save and load trajectories. This standard is compliant with the [RTX datasets](https://robotics-transformer-x.github.io/), which can potentially can be used for other robot learning tasks. - -### Installation - -This requires additional installation of `oxe_envlogger`: -```bash -git clone git@github.com:rail-berkeley/oxe_envlogger.git -cd oxe_envlogger -pip install -e . -``` - -### Usage - -**Save the trajectories** - -With the example above, we can save the data from the replay buffer by providing the `rlds_logger_path` argument. This will save the data to the specified path. - -```bash -./run_learner.sh --log_rlds_path /path/to/save -``` - -This will save the data to the specified path in the following format: - -```bash - - /path/to/save - - dataset_info.json - - features.json - - serl_rlds_dataset-train.tfrecord-00000 - - serl_rlds_dataset-train.tfrecord-00001 - .... -``` - -**Load the trajectories** - -With the example above, we can load the data from the replay buffer by providing the `preload_rlds_path` argument. This will load the data from the specified path. - -```bash -./run_learner.sh --preload_rlds_path /path/to/load -``` - -This is similar to the `examples/async_rlpd_drq_sim/run_learner.sh` script, which uses `--demo_path` argument which load .pkl offline demo trajectories. - - -### Troubleshooting - -1. If you receive a Out of Memory error, try reducing the batch size in the `run_learner.sh` script. by adding the `--batch_size` argument. For example, `bash run_learner.sh --batch_size 64`. -2. If the provided offline RLDS data is throwing an error, this usually means the data is not compatible with current SERL format. You can provide a custom data transform with the `data_transform(data, metadata) -> data` function in the `examples/async_drq_sim/asyn_drq_sim.py` script. diff --git a/docs/sim_training.md b/docs/sim_training.md new file mode 100644 index 00000000..c1592fbb --- /dev/null +++ b/docs/sim_training.md @@ -0,0 +1,119 @@ +# Training Options + + +Choose your training approach based on your requirements: + +- **Option 1: State-based SAC** β€” Simplest and fastest. No images needed. +- **Option 2: Image-based DRQ** β€” Vision-based policy learning with pre-trained ResNet encoder. +- **Option 3: Image-based DRQ + Behavior Cloning** β€” Pre-train on human demonstrations, then fine-tune with RL. + +--- + +## Option 1: State-based SAC (simplest, fastest) + +Train using state observations (no images, less computation). + +### Quick start (with `tmux`): +```bash +bash examples/async_sac_state_sim/tmux_launch.sh +``` + +### Manual setup (two terminals): + +**Terminal 1 β€” Learner:** +```bash +cd examples/async_sac_state_sim +bash run_learner.sh +``` + +**Terminal 2 β€” Actor:** +```bash +cd examples/async_sac_state_sim +bash run_actor.sh +``` + + + +--- + +## Option 2: Image-based DRQ (vision-based policy) + +Train using visual observations from camera(s). DRQ (Data-Regularized Q-learning) is designed for image-based control. + +### Prerequisites: + +Download the pre-trained ResNet-10 encoder (required for feature extraction): +```bash +cd examples/async_drq_sim +wget https://github.com/rail-berkeley/serl/releases/download/resnet10/resnet10_params.pkl +``` + +### Quick start (with `tmux`): +```bash +bash examples/async_drq_sim/tmux_launch.sh +``` + +### Manual setup (two terminals): + +**Terminal 1 β€” Learner:** +```bash +cd examples/async_drq_sim +bash run_learner.sh +``` + +**Terminal 2 β€” Actor:** +```bash +cd examples/async_drq_sim +bash run_actor.sh +``` +--- + +## Option 3: Image-based DRQ + Behavior Cloning (with demonstrations) + +Pre-train on ~20 human demonstrations, then fine-tune with RL. This combines the best of both worlds: +- **BC phase:** Learn from expert demonstrations +- **RL phase:** Refine policy with environment interaction + +### Prerequisites: + +Download ResNet encoder and demo trajectories: +```bash +cd examples/async_drq_sim +wget https://github.com/rail-berkeley/serl/releases/download/resnet10/resnet10_params.pkl +wget https://github.com/rail-berkeley/serl/releases/download/franka_sim_lift_cube_demos/franka_lift_cube_image_20_trajs.pkl +``` + +### Quick start (with `tmux`): +```bash +bash examples/async_drq_sim/tmux_rlpd_launch.sh +``` + +### Manual setup (two terminals): + +**Terminal 1 β€” Learner (with demos):** +```bash +cd examples/async_drq_sim +bash run_learner.sh --demo_path franka_lift_cube_image_20_trajs.pkl +``` + +**Terminal 2 β€” Actor:** +```bash +cd examples/async_drq_sim +bash run_actor.sh +``` + +### Custom demonstrations: + +Don't have pre-recorded demos? Create your own: +- See [Collecting demonstrations](sim_demonstrations.md) for keyboard teleoperation +- Save demos and pass with `--demo_path ` to the learner + +Navigation +---------- +- [Home](../README.md) +- [Overview](overview.md) +- [Installation guide](installation.md) +- [Run in simulation](run_sim.md) +- [Collecting demonstrations](sim_demonstrations.md) +- [Run on the real robot](run_realrobot.md) + diff --git a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py index 78a8eee4..01e9105f 100644 --- a/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py +++ b/examples/async_bin_relocation_fwbw_drq/async_drq_randomized.py @@ -24,6 +24,7 @@ from agentlace.data.data_store import QueuedDataStore from serl_launcher.utils.launcher import ( + make_replay_buffer, make_drq_agent, make_trainer_config, make_wandb_logger, @@ -54,6 +55,16 @@ flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") +# replay buffer flags (fractal symmetry support) +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_integer("branching_factor", None, "Factor by which branch count is changed") +flags.DEFINE_integer("max_depth", None, "Maximum number of splits that may occur in one episode") +flags.DEFINE_string("branch_method", "constant", "Method for how many branches to generate") +flags.DEFINE_string("split_method", "never", "Method for when to change number of branches") +flags.DEFINE_float("alpha", 0.2, "Rate of change of max_traj_length") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("starting_branch_count", 27, "Initial number of branches") + flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") flags.DEFINE_integer("steps_per_update", 30, "Number of steps per update the server.") @@ -137,6 +148,7 @@ def actor( success_count = {"fw": 0, "bw": 0} overall_success_count = 0 cycle_time = {"fw": [], "bw": []} + env.reset(joint_reset=True) for _ in range(FLAGS.eval_n_trajs): for task_id, task_name in id_to_task.items(): @@ -202,7 +214,7 @@ def update_params_bw(params): clients["bw"].recv_network_callback(update_params_bw) env.set_task_id(0) - obs, _ = env.reset() + obs, _ = env.reset(joint_reset=True) done = False # training loop @@ -233,7 +245,7 @@ def update_params_bw(params): actions = agents[task_name].sample_actions( observations=jax.device_put(obs), seed=key, - deterministic=False, + argmax=False, ) actions = np.asarray(jax.device_get(actions)) @@ -299,7 +311,7 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer): """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project="202605_obj_relocation", description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) @@ -425,7 +437,9 @@ def main(_): # create env and load dataset env = gym.make( - FLAGS.env, fake_env=FLAGS.learner, save_video=FLAGS.eval_checkpoint_step + FLAGS.env, + fake_env=FLAGS.learner, + save_video=FLAGS.eval_checkpoint_step ) if FLAGS.actor: env = SpacemouseIntervention(env) @@ -485,7 +499,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) agents[v] = agent else: @@ -500,21 +514,53 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) + ## Set indices to be transformed by fractal class for serl_robot_infra/franka_env/envs/franka_env + # observation_space[state] is sorted into an OrderedDict by SERLObsWrapper: + # gripper_pose:0 + # tcp_force.x: 1 + # tcp_force.y: 2 + # tcp_force.z: 3 + # tcp_pose.x: 4 <-- rel_frame.x points to base.+y + # tcp_pose.y: 5 <-- rel_frame.y points to base.+x + # tcp_pose.z: 6 <-- rel_frame.z points to base.-z + x_obs_idx = np.array([4]) + y_obs_idx = np.array([5]) + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) - replay_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, image_keys=image_keys, ) - demo_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, - capacity=5000, + demo_buffer = make_replay_buffer( + env, + capacity=FLAGS.replay_buffer_capacity, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, image_keys=image_keys, ) import pickle as pkl @@ -548,4 +594,4 @@ def main(_): if __name__ == "__main__": - app.run(main) + app.run(main) \ No newline at end of file diff --git a/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained/checkpoint_1000 b/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained/checkpoint_1000 new file mode 100644 index 00000000..f39cfb83 Binary files /dev/null and b/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained/checkpoint_1000 differ diff --git a/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained/checkpoint_1000 b/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained/checkpoint_1000 new file mode 100644 index 00000000..b6e3b592 Binary files /dev/null and b/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained/checkpoint_1000 differ diff --git a/examples/async_bin_relocation_fwbw_drq/record_transitions.py b/examples/async_bin_relocation_fwbw_drq/classifier/record_transitions.py similarity index 89% rename from examples/async_bin_relocation_fwbw_drq/record_transitions.py rename to examples/async_bin_relocation_fwbw_drq/classifier/record_transitions.py index a87582d5..b6af15ff 100644 --- a/examples/async_bin_relocation_fwbw_drq/record_transitions.py +++ b/examples/async_bin_relocation_fwbw_drq/classifier/record_transitions.py @@ -35,7 +35,7 @@ arg_parser.add_argument( "--transitions_needed", type=int, - default=400, + default=500, help="number of transitions to collect", ) arg_parser.add_argument( @@ -85,7 +85,7 @@ def on_release(key): def check_all_done(): return ( len(fw_failed_transitions) >= args.transitions_needed - or len(bw_failed_transitions) >= args.transitions_needed + and len(bw_failed_transitions) >= args.transitions_needed ) else: @@ -93,7 +93,7 @@ def check_all_done(): def check_all_done(): return ( len(fw_success_transitions) >= args.transitions_needed - or len(bw_success_transitions) >= args.transitions_needed + and len(bw_success_transitions) >= args.transitions_needed ) # Loop until we have enough transitions @@ -152,13 +152,13 @@ def check_all_done(): # save success transitions if not args.record_failed_only: file_name = ( - f"fw_bin_relocate_{args.transitions_needed}_front_cam_goal_{uuid}.pkl" + f"./fw_classifier_demos/fw_bin_relocate_{args.transitions_needed}_front_cam_goal_{uuid}.pkl" ) with open(file_name, "wb") as f: pkl.dump(fw_success_transitions, f) print(f"saved {len(fw_success_transitions)} transitions to {file_name}") file_name = ( - f"bw_bin_relocate_{args.transitions_needed}_front_cam_goal_{uuid}.pkl" + f"./bw_classifier_demos/bw_bin_relocate_{args.transitions_needed}_front_cam_goal_{uuid}.pkl" ) with open(file_name, "wb") as f: pkl.dump(bw_success_transitions, f) @@ -166,13 +166,13 @@ def check_all_done(): # save failed transitions file_name = ( - f"fw_bin_relocate_{len(fw_failed_transitions)}_front_cam_failed_{uuid}.pkl" + f"./fw_classifier_demos/fw_bin_relocate_{len(fw_failed_transitions)}_front_cam_failed_{uuid}.pkl" ) with open(file_name, "wb") as f: pkl.dump(fw_failed_transitions, f) print(f"saved {len(fw_failed_transitions)} transitions to {file_name}") file_name = ( - f"bw_bin_relocate_{len(bw_failed_transitions)}_front_cam_failed_{uuid}.pkl" + f"./bw_classifier_demos/bw_bin_relocate_{len(bw_failed_transitions)}_front_cam_failed_{uuid}.pkl" ) with open(file_name, "wb") as f: pkl.dump(bw_failed_transitions, f) diff --git a/examples/async_bin_relocation_fwbw_drq/test_classifier.py b/examples/async_bin_relocation_fwbw_drq/classifier/test_classifier.py similarity index 89% rename from examples/async_bin_relocation_fwbw_drq/test_classifier.py rename to examples/async_bin_relocation_fwbw_drq/classifier/test_classifier.py index 78ac680d..6c7dc3df 100644 --- a/examples/async_bin_relocation_fwbw_drq/test_classifier.py +++ b/examples/async_bin_relocation_fwbw_drq/classifier/test_classifier.py @@ -38,14 +38,14 @@ key=key, sample=env.front_observation_space.sample(), image_keys=image_keys, - checkpoint_path="/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/fw_classifier_ckpt", + checkpoint_path="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained", ) rng, key = jax.random.split(rng) bw_classifier_func = load_classifier_func( key=key, sample=env.front_observation_space.sample(), image_keys=image_keys, - checkpoint_path="/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bw_classifier_ckpt", + checkpoint_path="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained", ) env = FWBWFrontCameraBinaryRewardClassifierWrapper( env, fw_classifier_func, bw_classifier_func diff --git a/examples/async_bin_relocation_fwbw_drq/classifier/trainCommand.txt b/examples/async_bin_relocation_fwbw_drq/classifier/trainCommand.txt new file mode 100644 index 00000000..7016a583 --- /dev/null +++ b/examples/async_bin_relocation_fwbw_drq/classifier/trainCommand.txt @@ -0,0 +1,36 @@ +python train_reward_classifier.py --num_epochs 1000 --batch_size 512 \ + --classifier_ckpt_path ./fw_classifier_trained \ + --positive_demo_paths ./fw_classifier_demos/P1.pkl \ + --positive_demo_paths ./fw_classifier_demos/P2.pkl \ + --positive_demo_paths ./fw_classifier_demos/P3.pkl \ + --positive_demo_paths ./fw_classifier_demos/P4.pkl \ + --negative_demo_paths ./fw_classifier_demos/N1.pkl \ + --negative_demo_paths ./fw_classifier_demos/N2.pkl \ + --negative_demo_paths ./fw_classifier_demos/N3.pkl \ + --negative_demo_paths ./fw_classifier_demos/N4.pkl \ + --negative_demo_paths ./fw_classifier_demos/N5.pkl \ + --negative_demo_paths ./fw_classifier_demos/N6.pkl \ + --negative_demo_paths ./fw_classifier_demos/N7.pkl \ + --negative_demo_paths ./fw_classifier_demos/N9.pkl \ + --negative_demo_paths ./fw_classifier_demos/N10.pkl \ + --negative_demo_paths ./fw_classifier_demos/N11.pkl \ + --negative_demo_paths ./fw_classifier_demos/N12.pkl \ + + +python train_reward_classifier.py --num_epochs 1000 --batch_size 512 \ + --classifier_ckpt_path ./bw_classifier_trained \ + --positive_demo_paths ./bw_classifier_demos/P1.pkl \ + --positive_demo_paths ./bw_classifier_demos/P2.pkl \ + --positive_demo_paths ./bw_classifier_demos/P3.pkl \ + --positive_demo_paths ./bw_classifier_demos/P4.pkl \ + --negative_demo_paths ./bw_classifier_demos/N1.pkl \ + --negative_demo_paths ./bw_classifier_demos/N2.pkl \ + --negative_demo_paths ./bw_classifier_demos/N3.pkl \ + --negative_demo_paths ./bw_classifier_demos/N4.pkl \ + --negative_demo_paths ./bw_classifier_demos/N5.pkl \ + --negative_demo_paths ./bw_classifier_demos/N6.pkl \ + --negative_demo_paths ./bw_classifier_demos/N7.pkl \ + --negative_demo_paths ./bw_classifier_demos/N9.pkl \ + --negative_demo_paths ./bw_classifier_demos/N10.pkl \ + --negative_demo_paths ./bw_classifier_demos/N11.pkl \ + --negative_demo_paths ./bw_classifier_demos/N12.pkl diff --git a/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py b/examples/async_bin_relocation_fwbw_drq/classifier/train_reward_classifier.py similarity index 56% rename from examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py rename to examples/async_bin_relocation_fwbw_drq/classifier/train_reward_classifier.py index 0a16fe12..aab9ad31 100644 --- a/examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py +++ b/examples/async_bin_relocation_fwbw_drq/classifier/train_reward_classifier.py @@ -1,3 +1,48 @@ +"""Beginner tutorial for training a vision-based reward classifier. + +This script teaches a classifier to answer a binary question: +"Does this camera observation look like success?" + +The training data comes from demonstration trajectories saved on disk: + +1. Positive demonstrations: trajectories that end in success. +2. Negative demonstrations: trajectories that show failure or non-success. + +At a high level, the script does the following: + +1. Build the robot environment only to recover the observation/action spaces. +2. Load positive and negative trajectories into replay buffers. +3. Repeatedly draw a half-batch from each buffer. +4. Convert those samples into images plus binary labels. +5. Train a classifier with binary cross-entropy. +6. Save the trained model as a checkpoint for later evaluation or reward shaping. + +Reading tips for new students: + +- `observations` are the states at time step t. +- `next_observations` are the states at time step t + 1. +- The classifier never predicts an action. It only predicts whether an + observation should be labeled positive or negative. +- JAX/Flax code often separates "define the computation" from "run the + computation". The `train_step` function below is the compiled update rule. + +Typical usage: + +```bash +python examples/async_bin_relocation_fwbw_drq/train_reward_classifier.py \ + --positive_demo_paths=/path/to/success.pkl \ + --negative_demo_paths=/path/to/failure.pkl \ + --classifier_ckpt_path=/tmp/reward_classifier_ckpt +``` +""" + +import os +# Set JAX's XLA backend not to grab most GPU memory up front when the program starts. +# Otherwise the XLA will preallocate a large chunk of GPU memmory and starve other processes. +os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +# Set JAX's XLA backend to cap GPU memory use to up to 20%. +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2" + import pickle as pkl import jax from jax import numpy as jnp @@ -7,7 +52,6 @@ import optax from tqdm import tqdm import gym -import os from absl import app, flags from serl_launcher.wrappers.chunking import ChunkingWrapper @@ -26,10 +70,6 @@ from franka_env.envs.wrappers import Quat2EulerWrapper from franka_env.envs.relative_env import RelativeFrame -# Set above env export to prevent OOM errors from memory preallocation -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2" - FLAGS = flags.FLAGS flags.DEFINE_multi_string("positive_demo_paths", None, "paths to positive demos") flags.DEFINE_multi_string("negative_demo_paths", None, "paths to negative demos") @@ -39,6 +79,12 @@ def main(_): + """Construct the wrapped environment and launch training. + + The environment is used here as a convenient source of the observation and + action spaces expected by the classifier and replay buffers. Training itself + happens purely from the saved demonstrations, not from fresh online rollouts. + """ env = gym.make("FrankaBinRelocation-Vision-v0", save_video=False) env = RelativeFrame(env) env = Quat2EulerWrapper(env) @@ -46,42 +92,60 @@ def main(_): env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) env = FrontCameraWrapper(env) - # we will only use the front camera view for training the reward classifier + # This example keeps only the front camera stream for classification. train_reward_classifier(env.front_observation_space, env.action_space) def train_reward_classifier(observation_space, action_space): - """ - User can provide custom observation space to be used as the - input to the classifier. This function is used to train a reward - classifier using the provided positive and negative demonstrations. + """Train a binary reward classifier from offline demonstrations. + + Tutorial overview: + + 1. Identify which observation entries are image-like inputs. + 2. Load success and failure trajectories into separate replay buffers. + 3. Sample equally from both buffers to keep the dataset balanced. + 4. Apply random crop augmentation to improve visual robustness. + 5. Optimize a classifier that outputs one logit per example. + 6. Save the final classifier state for future use. + + Args: + observation_space: Observation structure expected by the classifier. + action_space: Action structure needed by the replay buffer API. NOTE: this function is duplicated and used in both - async_bin_relocation_fwbw_drq and async_cable_route_drq examples + `async_bin_relocation_fwbw_drq` and `async_cable_route_drq` examples. """ devices = jax.local_devices() sharding = jax.sharding.PositionalSharding(devices) + # In this codebase, low-dimensional proprioceptive features usually contain + # "state" in the key name, so the remaining keys are camera observations. image_keys = [k for k in observation_space.keys() if "state" not in k] + # Positive buffer: successful demonstrations. pos_buffer = MemoryEfficientReplayBufferDataStore( observation_space, action_space, - capacity=10000, + capacity=20000, image_keys=image_keys, ) pos_buffer = populate_data_store(pos_buffer, FLAGS.positive_demo_paths) + # Negative buffer: failed or non-successful demonstrations. neg_buffer = MemoryEfficientReplayBufferDataStore( observation_space, action_space, - capacity=10000, + capacity=20000, image_keys=image_keys, ) neg_buffer = populate_data_store(neg_buffer, FLAGS.negative_demo_paths) print(f"failed buffer size: {len(neg_buffer)}") print(f"success buffer size: {len(pos_buffer)}") + + # Each iterator returns mini-batches directly on the JAX device layout. + # We use half of the final batch from positives and half from negatives so + # that the binary labels remain balanced. pos_iterator = pos_buffer.get_iterator( sample_args={ "batch_size": FLAGS.batch_size // 2, @@ -103,10 +167,13 @@ def train_reward_classifier(observation_space, action_space): neg_sample = next(neg_iterator) sample = concat_batches(pos_sample, neg_sample, axis=0) + # The classifier is initialized from a real example batch so Flax can infer + # the expected input shapes for every camera stream. rng, key = jax.random.split(rng) classifier = create_classifier(key, sample["next_observations"], image_keys) def data_augmentation_fn(rng, observations): + """Apply the same style of random crop augmentation to each image key.""" for pixel_key in image_keys: observations = observations.copy( add_or_replace={ @@ -120,6 +187,7 @@ def data_augmentation_fn(rng, observations): # Define the training step @jax.jit def train_step(state, batch, key): + """Run one compiled gradient step and report loss/accuracy.""" def loss_fn(params): logits = state.apply_fn( {"params": params}, batch["data"], rngs={"dropout": key}, train=True @@ -140,10 +208,18 @@ def loss_fn(params): # Sample equal number of positive and negative examples pos_sample = next(pos_iterator) neg_sample = next(neg_iterator) - # Merge and create labels + # Merge and create labels. + # + # POTENTIAL BUG: Positive examples use `next_observations` while + # negative examples use `observations`. If the intended supervision is + # "classify final success-like states versus final failure-like states", + # this mismatch may let the model learn temporal offset cues instead of + # the reward concept itself. sample = concat_batches( pos_sample["next_observations"], neg_sample["observations"], axis=0 ) + # Random crops are a standard vision trick: the label stays the same, + # but the model sees slightly different versions of the same image. rng, key = jax.random.split(rng) sample = data_augmentation_fn(key, sample) labels = jnp.concatenate( @@ -155,6 +231,7 @@ def loss_fn(params): ) batch = {"data": sample, "labels": labels} + # One optimizer step over the balanced batch. rng, key = jax.random.split(rng) classifier, train_loss, train_accuracy = train_step(classifier, batch, key) @@ -162,7 +239,8 @@ def loss_fn(params): f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}" ) - # this is used to save the without the orbax checkpointing + # Save a plain Flax checkpoint so downstream scripts can load the model + # without depending on Orbax checkpointing behavior. flax.config.update("flax_use_orbax_checkpointing", False) checkpoints.save_checkpoint( FLAGS.classifier_ckpt_path, diff --git a/examples/async_bin_relocation_fwbw_drq/record_demo.py b/examples/async_bin_relocation_fwbw_drq/record_demo.py index 57508984..6d767836 100644 --- a/examples/async_bin_relocation_fwbw_drq/record_demo.py +++ b/examples/async_bin_relocation_fwbw_drq/record_demo.py @@ -21,8 +21,6 @@ from serl_launcher.wrappers.chunking import ChunkingWrapper if __name__ == "__main__": - from pynput import keyboard - env = gym.make("FrankaBinRelocation-Vision-v0", save_video=False) env = SpacemouseIntervention(env) env = RelativeFrame(env) @@ -45,30 +43,34 @@ key=key, sample=env.front_observation_space.sample(), image_keys=image_keys, - checkpoint_path="/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/fw_classifier_ckpt", + checkpoint_path="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained", ) rng, key = jax.random.split(rng) bw_classifier_func = load_classifier_func( key=key, sample=env.front_observation_space.sample(), image_keys=image_keys, - checkpoint_path="/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bw_classifier_ckpt", + checkpoint_path="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained", ) env = FWBWFrontCameraBinaryRewardClassifierWrapper( env, fw_classifier_func, bw_classifier_func ) - transitions_needed = 2000 + successes_needed = 30 fw_transitions = [] bw_transitions = [] + fw_success_count = 0 + bw_success_count = 0 + episode_buffer = [] + episode_task_id = env.task_id - fw_pbar = tqdm(total=transitions_needed, desc="fw") - bw_pbar = tqdm(total=transitions_needed, desc="bw") + fw_pbar = tqdm(total=successes_needed, desc="fw successes") + bw_pbar = tqdm(total=successes_needed, desc="bw successes") uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - fw_file_name = f"fw_bin_demo_{uuid}.pkl" - bw_file_name = f"bw_bin_demo_{uuid}.pkl" - file_dir = os.path.dirname(os.path.realpath(__file__)) # same dir as this script + fw_file_name = f"./demos/fw_demos/fw_bin_demo_{successes_needed}_episodes_{uuid}.pkl" + bw_file_name = f"./demos/bw_demos/bw_bin_demo_{successes_needed}_episodes_{uuid}.pkl" + file_dir = os.path.dirname(os.path.realpath(__file__)) fw_file_path = os.path.join(file_dir, fw_file_name) bw_file_path = os.path.join(file_dir, bw_file_name) @@ -81,10 +83,7 @@ if not os.access(file_dir, os.W_OK): raise PermissionError(f"No permission to write to {file_dir}") - while ( - len(fw_transitions) < transitions_needed - or len(bw_transitions) < transitions_needed - ): + while fw_success_count < successes_needed or bw_success_count < successes_needed: actions = np.zeros((7,)) next_obs, rew, done, truncated, info = env.step(action=actions) if "intervene_action" in info: @@ -100,31 +99,56 @@ dones=done, ) ) - - if env.task_id == 0 and len(fw_transitions) < transitions_needed: - fw_transitions.append(transition) - fw_pbar.update(1) - elif env.task_id == 1 and len(bw_transitions) < transitions_needed: - bw_transitions.append(transition) - bw_pbar.update(1) - + episode_buffer.append(transition) obs = next_obs if done: - print(rew) + success = bool(rew) + if success and episode_task_id == 0 and fw_success_count < successes_needed: + fw_transitions.extend(episode_buffer) + fw_success_count += 1 + fw_pbar.update(1) + print( + f"fw success! kept {len(episode_buffer)} transitions " + f"({fw_success_count}/{successes_needed} fw, " + f"{bw_success_count}/{successes_needed} bw)" + ) + elif success and episode_task_id == 1 and bw_success_count < successes_needed: + bw_transitions.extend(episode_buffer) + bw_success_count += 1 + bw_pbar.update(1) + print( + f"bw success! kept {len(episode_buffer)} transitions " + f"({fw_success_count}/{successes_needed} fw, " + f"{bw_success_count}/{successes_needed} bw)" + ) + else: + print( + f"discarded episode (task_id={episode_task_id}, rew={rew}, " + f"len={len(episode_buffer)})" + ) + + episode_buffer = [] next_task_id = env.task_graph(env.get_front_cam_obs()) print(f"transition from {env.task_id} to next task: {next_task_id}") env.set_task_id(next_task_id) + episode_task_id = next_task_id obs, _ = env.reset() with open(fw_file_path, "wb") as f: pkl.dump(fw_transitions, f) - print(f"saved {len(fw_transitions)} transitions to {fw_file_path}") + print( + f"saved {fw_success_count} fw episodes " + f"({len(fw_transitions)} transitions) to {fw_file_path}" + ) with open(bw_file_path, "wb") as f: pkl.dump(bw_transitions, f) - print(f"saved {len(bw_transitions)} transitions to {bw_file_path}") + print( + f"saved {bw_success_count} bw episodes " + f"({len(bw_transitions)} transitions) to {bw_file_path}" + ) env.close() fw_pbar.close() - bw_pbar.close() + bw_pbar.close() \ No newline at end of file diff --git a/examples/async_bin_relocation_fwbw_drq/run_actor.sh b/examples/async_bin_relocation_fwbw_drq/run_actor.sh index 5f25bc16..c3c8f67d 100644 --- a/examples/async_bin_relocation_fwbw_drq/run_actor.sh +++ b/examples/async_bin_relocation_fwbw_drq/run_actor.sh @@ -1,17 +1,23 @@ -export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +export XLA_PYTHON_CLIENT_PREALLOCATE=true && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.15 && \ python async_drq_randomized.py "$@" \ + --replay_buffer_type sample_efficient_replay_buffer \ + --seed 1 \ + --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096 \ + --branch_method "constant" \ --actor \ --render \ --env FrankaBinRelocation-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096 \ - --seed 0 \ - --random_steps 200 \ + --max_steps 30_000 \ + --random_steps 0 \ --encoder_type resnet-pretrained \ - --demo_path fw_bin_2000_demo_2024-01-23_18-49-56.pkl \ - --fw_ckpt_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bin_fw_096 \ - --bw_ckpt_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bin_bw_096 \ - --fw_reward_classifier_ckpt_path "/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/fw_classifier_ckpt" \ - --bw_reward_classifier_ckpt_path "/home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bw_classifier_ckpt" \ - --eval_checkpoint_step 31000 \ - --eval_checkpoint_step 100 + --split_method "time" \ + --branching_factor 3 \ + --max_depth 3 \ + --alpha 0.2 \ + --max_traj_length 100 \ + --workspace_width 0.3 \ + --fw_ckpt_path /home/student/code/serl/examples/async_bin_relocation_fwbw_drq/checkpoints/fw \ + --bw_ckpt_path /home/student/code/serl/examples/async_bin_relocation_fwbw_drq/checkpoints/bw \ + --fw_reward_classifier_ckpt_path "/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/fw_classifier_trained" \ + --bw_reward_classifier_ckpt_path "/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier/bw_classifier_trained" \ \ No newline at end of file diff --git a/examples/async_bin_relocation_fwbw_drq/run_bw_learner.sh b/examples/async_bin_relocation_fwbw_drq/run_bw_learner.sh index ade37842..bcedffe7 100644 --- a/examples/async_bin_relocation_fwbw_drq/run_bw_learner.sh +++ b/examples/async_bin_relocation_fwbw_drq/run_bw_learner.sh @@ -1,17 +1,31 @@ -export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +# --replay_buffer_capacity 200_000 \ # automatically handled by replay buffer logic +export XLA_PYTHON_CLIENT_PREALLOCATE=true && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/bw-$TIMESTAMP" && \ + python async_drq_randomized.py "$@" \ + --seed 1 \ + --replay_buffer_type memory_efficient_replay_buffer \ + --demo_path ./demos/bw_demos/baseline_01.pkl \ + --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096_bw \ --learner \ --env FrankaBinRelocation-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096_bw \ - --seed 0 \ + --max_steps 30_000 \ --random_steps 200 \ --training_starts 200 \ --critic_actor_ratio 4 \ --batch_size 256 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ + --starting_branch_count 27 \ + --branch_method "constant" \ + --split_method "never" \ + --alpha 0.2 \ + --max_depth 3 \ + --branching_factor 3 \ + --workspace_width 0.3 \ --fwbw bw \ - --demo_path ./demos/bw_bin_2000_demo_2024-01-23_18-49-56.pkl \ - --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bin_bw_096 + --checkpoint_period 500 \ + --checkpoint_path $CHECKPOINT_DIR diff --git a/examples/async_bin_relocation_fwbw_drq/run_eval.sh b/examples/async_bin_relocation_fwbw_drq/run_eval.sh new file mode 100644 index 00000000..23a05d78 --- /dev/null +++ b/examples/async_bin_relocation_fwbw_drq/run_eval.sh @@ -0,0 +1,22 @@ +# All export statements end with && \ to chain them together +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/checkpoints" +export CLASSIFIER_DIR="/home/student/code/serl/examples/async_bin_relocation_fwbw_drq/classifier" +export STEP=15000 + +python async_drq_randomized.py \ + --actor \ + --render \ + --env FrankaBinRelocation-Vision-v0 \ + --bw_reward_classifier_ckpt_path "$CLASSIFIER_DIR/bw_classifier_trained/" \ + --fw_reward_classifier_ckpt_path "$CLASSIFIER_DIR/fw_classifier_trained/" \ + --eval_checkpoint_step $STEP \ + --eval_n_trajs 50 \ + --bw_ckpt_path "$CHECKPOINT_EVAL/baseline_01_bw" \ + --fw_ckpt_path "$CHECKPOINT_EVAL/baseline_01_fw" \ + "$@" diff --git a/examples/async_bin_relocation_fwbw_drq/run_fw_learner.sh b/examples/async_bin_relocation_fwbw_drq/run_fw_learner.sh index 67d16533..71c6ecd3 100644 --- a/examples/async_bin_relocation_fwbw_drq/run_fw_learner.sh +++ b/examples/async_bin_relocation_fwbw_drq/run_fw_learner.sh @@ -1,17 +1,31 @@ -export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +# --replay_buffer_capacity 200_000 \ # automatically handled by replay buffer logic +export XLA_PYTHON_CLIENT_PREALLOCATE=true && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/fw-$TIMESTAMP" && \ + python async_drq_randomized.py "$@" \ + --seed 1 \ + --replay_buffer_type memory_efficient_replay_buffer \ + --demo_path ./demos/fw_demos/baseline_01.pkl \ + --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096_fw \ --learner \ --env FrankaBinRelocation-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd20demos_bin_fwbw_resnet_096_fw \ - --seed 0 \ + --max_steps 30_000 \ --random_steps 200 \ --training_starts 200 \ --critic_actor_ratio 4 \ --batch_size 256 \ --eval_period 2000 \ --encoder_type resnet-pretrained \ + --starting_branch_count 27 \ + --branch_method "constant" \ + --split_method "never" \ + --alpha 0.2 \ + --max_depth 3 \ + --branching_factor 3 \ + --workspace_width 0.3 \ --fwbw fw \ - --demo_path ./demos/fw_bin_2000_demo_2024-01-23_18-49-56.pkl \ - --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_bin_relocation_fwbw_drq/bin_fw_096 + --checkpoint_period 500 \ + --checkpoint_path $CHECKPOINT_DIR diff --git a/examples/async_cable_route_drq/async_drq_randomized.py b/examples/async_cable_route_drq/async_drq_randomized.py index 90770fc4..4d3c73a8 100644 --- a/examples/async_cable_route_drq/async_drq_randomized.py +++ b/examples/async_cable_route_drq/async_drq_randomized.py @@ -22,8 +22,8 @@ from agentlace.trainer import TrainerServer, TrainerClient from agentlace.data.data_store import QueuedDataStore -from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore from serl_launcher.utils.launcher import ( + make_replay_buffer, make_drq_agent, make_trainer_config, make_wandb_logger, @@ -52,7 +52,6 @@ flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") -flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") @@ -75,6 +74,18 @@ "reward_classifier_ckpt_path", None, "Path to reward classifier ckpt." ) +# replay buffer flags +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") +flags.DEFINE_integer("branching_factor", None, "Factor by which branch count is changed") +flags.DEFINE_integer("max_depth", None, "Maximum number of splits that may occur in one episode") +flags.DEFINE_string("branch_method", "constant", "Method for how many branches to generate") +flags.DEFINE_string("split_method", "never", "Method for when to change number of branches") +flags.DEFINE_float("alpha", 0.2, "Rate of change of max_traj_length") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("starting_branch_count", 27, "Initial number of branches") + + flags.DEFINE_integer( "eval_checkpoint_step", 0, "evaluate the policy from ckpt at this step" ) @@ -110,12 +121,18 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): step=FLAGS.eval_checkpoint_step, ) agent = agent.replace(state=ckpt) + env.reset(joint_reset=True) for episode in range(FLAGS.eval_n_trajs): obs, _ = env.reset() done = False start_time = time.time() while not done: + # Use deterministic action selection for evaluation. + # `argmax=True` tells the agent to pick the highest-value + # / highest-probability action instead of sampling stochastically. + # This produces consistent, reproducible evaluation metrics + # (success rate and completion time) by removing exploration noise. actions = agent.sample_actions( observations=jax.device_put(obs), argmax=True, @@ -147,14 +164,19 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): wait_for_server=True, ) - # Function to update the agent with new params + # Learner pushes new network parameters; keep this actor's policy in sync. def update_params(params): + """Replace only model parameters while preserving the rest of agent state.""" nonlocal agent agent = agent.replace(state=agent.state.replace(params=params)) + # registers update_params as the handler for incoming network parameter + # messages from the learner. When an update arrives, update_params(params) + # runs and swaps in new model weights client.recv_network_callback(update_params) - obs, _ = env.reset() + # Initialize first rollout episode. + obs, _ = env.reset(joint_reset=True) done = False # training loop @@ -162,11 +184,17 @@ def update_params(params): running_return = 0.0 for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): - timer.tick("total") + # Start manual 'total' stopwatch + timer.tick("total") + # Start automatic 'sample_actions' logging: with timer.context("sample_actions"): + + # Collect random steps to bootstrap the replay buffer. Initial policy poor. if step < FLAGS.random_steps: actions = env.action_space.sample() + + # Sample actions from agent. else: sampling_rng, key = jax.random.split(sampling_rng) actions = agent.sample_actions( @@ -174,11 +202,13 @@ def update_params(params): seed=key, deterministic=False, ) + # Move actions from JAX device (GPU/TPU) to host NumPy for Gym env.step. actions = np.asarray(jax.device_get(actions)) # Step environment with timer.context("step_env"): + # Collect tuple info and packet in transition dict next_obs, reward, done, truncated, info = env.step(actions) # override the action with the intervention action @@ -196,6 +226,8 @@ def update_params(params): masks=1.0 - done, dones=done, ) + + # Push transition to learner replay buffer data_store.insert(transition) obs = next_obs @@ -226,7 +258,7 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer): """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project="CableRoute-april_2026", description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) @@ -235,15 +267,21 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer): update_steps = 0 def stats_callback(type: str, payload: dict) -> dict: - """Callback for when server receives stats request.""" + """Callback for when server receives stats request. + Will log what actor sends via client.request("send-stats", payload). + i.e. from RecordEpisodeStatistics, usually info["episode"] = {"r": return, "l": length, "t": elapsed_time} + i.e. wrapper flags like left/right from SpacemouseIntervention + i.e. Timing stats inside timer: averages for keys like "sample_actions", "step_env", "total" from your Timer + Training info and log_period steps are typically captured""" + assert type == "send-stats", f"Invalid request type: {type}" if wandb_logger is not None: wandb_logger.log(payload, step=update_steps) return {} # not expecting a response - # Create server + # Create the Training server server = TrainerServer(make_trainer_config(), request_callback=stats_callback) - server.register_data_store("actor_env", replay_buffer) + server.register_data_store("actor_env", replay_buffer) # Learner registers where incoming actor data should go: server.start(threaded=True) # Loop to wait until replay_buffer is filled @@ -254,24 +292,39 @@ def stats_callback(type: str, payload: dict) -> dict: position=0, leave=True, ) + # Do not start gradient updates until enough actor data has arrived. while len(replay_buffer) < FLAGS.training_starts: pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + + # Advance only by newly added samples since last refresh. + pbar.update(len(replay_buffer) - pbar.n) time.sleep(1) pbar.update(len(replay_buffer) - pbar.n) # Update progress bar + + # Final sync so bar reaches the latest replay size before closing. + pbar.update(len(replay_buffer) - pbar.n) pbar.close() - # send the initial network to the actor + # Initial send: broadcast learner's current params so actors start rollouts with synced weights. server.publish_network(agent.state.params) print_green("sent initial network to actor") - # 50/50 sampling from RLPD, half from demo and half from online experience + # Build two equal-size iterators: one for online actor data and one for demos. + # Each learner step concatenates these batches for 50/50 mixed training. replay_iterator = replay_buffer.get_iterator( sample_args={ + # Half batch from online replay; the other half comes from demo_iterator. "batch_size": FLAGS.batch_size // 2, + # Keep (obs, next_obs) together in one packed sample for agent updates. "pack_obs_and_next_obs": True, }, + # Pre-shard onto local devices to match replicated learner state. + # Will prepare outputs as 'JAX device arrays' ahead of time vs NumPy. + # Results in less host-2-device overhead in hot training loop, fewer shape mismatches, more stable throughput. device=sharding.replicate(), ) + + # Same for demo replay buffer. demo_iterator = demo_buffer.get_iterator( sample_args={ "batch_size": FLAGS.batch_size // 2, @@ -282,35 +335,52 @@ def stats_callback(type: str, payload: dict) -> dict: # wait till the replay buffer is filled with enough data timer = Timer() + + # Main learner loop for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True, desc="learner"): + # run n-1 critic updates and 1 critic + actor update. # This makes training on GPU faster by reducing the large batch transfer time from CPU to GPU for critic_step in range(FLAGS.critic_actor_ratio - 1): + with timer.context("sample_replay_buffer"): - batch = next(replay_iterator) + batch = next(replay_iterator) demo_batch = next(demo_iterator) + batch = concat_batches(batch, demo_batch, axis=0) with timer.context("train_critics"): + # agent: updates the learner agent state (new critic params, targets, rng state, etc.). + # critis_info: contains logging metrics from that critic update (e.g., critic losses/stats). agent, critics_info = agent.update_critics( batch, ) + # Run utd "full update" steps (critic + actor + temperature). Actor weights not yet sent. with timer.context("train"): - batch = next(replay_iterator) + batch = next(replay_iterator) demo_batch = next(demo_iterator) batch = concat_batches(batch, demo_batch, axis=0) agent, update_info = agent.update_high_utd(batch, utd_ratio=1) - # publish the updated network + # publish the updated network to the actor. if step > 0 and step % (FLAGS.steps_per_update) == 0: + + # Force synchronization by waiting for pending JAX computations that produce agent. + # Key since JAX is asynchronous. Otherwise could publish params before latest updates. agent = jax.block_until_ready(agent) server.publish_network(agent.state.params) + # Log {critic/actor/temperature/optimizer} and {timer info:sample_replay_buffer/train_critics/train} info if update_steps % FLAGS.log_period == 0 and wandb_logger: + + # Comes from agent.updated_high_utd wandb_logger.log(update_info, step=update_steps) + + # Timer class will return replay_buffer/critic avg wall-clock times for sections: wandb_logger.log({"timer": timer.get_average_times()}, step=update_steps) + # Save checkpoint/model if FLAGS.checkpoint_period and update_steps % FLAGS.checkpoint_period == 0: assert FLAGS.checkpoint_path is not None checkpoints.save_checkpoint( @@ -344,17 +414,28 @@ def main(_): env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) image_keys = [key for key in env.observation_space.keys() if key != "state"] if FLAGS.actor: - # initialize the classifier and wrap the env + # Actor uses a learned binary reward classifier to turn sparse visual success + # detection into a runtime reward signal. We require a checkpoint path here + # because actor/eval runs depend on classifier inference at every env step. if FLAGS.reward_classifier_ckpt_path is None: raise ValueError("reward_classifier_ckpt_path must be specified for actor") + # Build classifier model definition, restore its parameters from checkpoint, + # and return a jitted function: obs -> success logit. + # - key: RNG for model initialization scaffolding before checkpoint restore. + # - sample: example observation to initialize classifier parameter shapes. + # - image_keys: camera streams consumed by the classifier encoder. + # - checkpoint_path: directory containing saved classifier checkpoints. reward_func = load_classifier_func( key=sampling_rng, sample=env.observation_space.sample(), image_keys=image_keys, checkpoint_path=FLAGS.reward_classifier_ckpt_path, ) + + # Inject classifier-based reward computation into env.step(...): + # wrapper thresholds classifier logits into binary success and adds it to reward. env = BinaryRewardClassifierWrapper(env, reward_func) env = RecordEpisodeStatistics(env) @@ -370,21 +451,61 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) + ## Set indices to be transformed by fractal class for the serl_robot_infra/robot_env/envs/franka_env + # Note that observation_space[state] willb e sorted and set as an ordered dict by SerlObservationWrapper + # gripper_pose:0 + # tcp_force.x: 1 + # tcp_force.y: 2 + # tcp_force.z: 3 + # tcp_pose.x: 4 <-- rel_frame.x points to base.+y + # tcp_pose.y: 5 <-- rel_frame.y points to base.+x + # tcp_pose.z: 6 <-- rel_frame.z points to base.-z + x_obs_idx = np.array([4]) + y_obs_idx = np.array([5]) + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) - replay_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, + + x_obs_idx = np.array([4]) + y_obs_idx = np.array([5]) + + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + # rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + # preload_rlds_path=FLAGS.preload_rlds_path, image_keys=image_keys, ) - demo_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, - capacity=10000, + demo_buffer = make_replay_buffer( + env, + capacity=FLAGS.replay_buffer_capacity, + # rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + # preload_rlds_path=FLAGS.preload_rlds_path, image_keys=image_keys, ) @@ -411,13 +532,14 @@ def main(_): elif FLAGS.actor: sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) data_store = QueuedDataStore(2000) # the queue size on the actor - # actor loop + + # actor loop print_green("starting actor loop") actor(agent, data_store, env, sampling_rng) else: raise NotImplementedError("Must be either a learner or an actor") - + return if __name__ == "__main__": app.run(main) diff --git a/examples/async_cable_route_drq/classifier/record_demo.py b/examples/async_cable_route_drq/classifier/record_demo.py new file mode 100644 index 00000000..dbc0900a --- /dev/null +++ b/examples/async_cable_route_drq/classifier/record_demo.py @@ -0,0 +1,170 @@ +import gym +from tqdm import tqdm +import numpy as np +import copy +import pickle as pkl +import datetime +import os + +import franka_env + +from franka_env.envs.relative_env import RelativeFrame +from franka_env.envs.wrappers import ( + GripperCloseEnv, + SpacemouseIntervention, + Quat2EulerWrapper, +) + +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper +import jax + +if __name__ == "__main__": + + ## Initializes Enviroment + env = gym.make("FrankaCableRoute-Vision-v0", save_video=False) + env = GripperCloseEnv(env) + env = SpacemouseIntervention(env) + env = RelativeFrame(env) + env = Quat2EulerWrapper(env) + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + image_keys = [k for k in env.observation_space.keys() if "state" not in k] + + ## Initialize RNG, POS/NEG Containers, and Counters + rng = jax.random.PRNGKey(0) + rng, key = jax.random.split(rng) + obs, _ = env.reset() + + pos_transitions = [] + neg_transitions = [] + transition_batch = [] + + pos_count = 0 + neg_count = 0 + pos_needed = 0 # Define a positive reward max + neg_needed = 20 + + neg_transition_count = 0 + pos_transition_count = 0 + + + ## Define Output file and safety checks + uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + pos_file_name = f"demos/positive_cable_route_{pos_needed}_demos_{uuid}.pkl" + neg_file_name = f"demos/negative_cable_route_{neg_needed}_demos_{uuid}.pkl" + + file_dir = os.path.dirname(os.path.realpath(__file__)) # same dir as this script + pos_file_path = os.path.join(file_dir, pos_file_name) + neg_file_path = os.path.join(file_dir, neg_file_name) + + if not os.path.exists(file_dir): + os.mkdir(file_dir) + if os.path.exists(pos_file_path): + raise FileExistsError(f"{pos_file_name} already exists in {file_dir}") + if os.path.exists(neg_file_path): + raise FileExistsError(f"{neg_file_name} already exists in {file_dir}") + if not os.access(file_dir, os.W_OK): + raise PermissionError(f"No permission to write to {file_dir}") + + ## Record Negative demos + print("Recording negative demos:\n") + while neg_count < neg_needed: + actions = np.zeros((6,)) + next_obs, rew, done, truncated, info = env.step(action=actions) + rew = 0 + if "intervene_action" in info: + actions = info["intervene_action"] + + transition = copy.deepcopy( + dict( + observations = obs, + actions = actions, + next_observations = next_obs, + rewards = rew, + masks = 1.0 - done, + dones = done, + ) + ) + + transition_batch.append(transition) + neg_transition_count += 1 + print(f"neg transitions: {neg_transition_count} | demos completed: {neg_count}") + obs = next_obs + + if done: + neg_transitions += transition_batch + neg_count += 1 + + print( + f"{neg_needed - neg_count} negative demos left." + ) + obs, _ = env.reset(pos_reset=False) + neg_transition_count = 0 + transition_batch.clear() + + ## Move to positive position | Asks to confirm every 50 steps + userInput = "n" + while (userInput != "y"): + prep_transition_count = 0 + while prep_transition_count < 50: + actions = np.zeros((6,)) + next_obs, rew, done, truncated, info = env.step(action=actions) + if "intervene_action" in info: + actions = info["intervene_action"] + prep_transition_count += 1 + print(f"Transition count: {prep_transition_count} / 50") + userInput = input("Is the robot in a successful pose? (y/n)") + + env.reset(pos_reset=False) + ## Record Positive demos + print("Recording positive demos:\n") + print("Please put robot in successful pose and press Enter...") + input() # pause, wait for user + while pos_count < pos_needed: + actions = np.zeros((6,)) + next_obs, rew, done, truncated, info = env.step(action=actions) + rew = 0 + if "intervene_action" in info: + actions = info["intervene_action"] + + transition = copy.deepcopy( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=rew, + masks=1.0 - done, + dones=done, + ) + ) + transition_batch.append(transition) + + obs = next_obs + pos_transition_count += 1 + print(f"pos transitions: {pos_transition_count} | demos completed: {pos_count}") + + if done: + pos_transitions += transition_batch + pos_count += 1 + pos_transition_count = 0 + + print( + f"{pos_needed - pos_count} positive demos left." + ) + obs, _ = env.reset(pos_reset=False) + transition_batch.clear() + + with open(pos_file_path, "wb") as f: + pkl.dump(pos_transitions, f) + print( + f"saved {pos_needed} demos and {len(pos_transitions)} transitions to {pos_file_path}" + ) + + with open(neg_file_path, "wb") as f: + pkl.dump(neg_transitions, f) + print( + f"saved {neg_needed} demos and {len(neg_transitions)} transitions to {neg_file_path}" + ) + + env.close() diff --git a/examples/async_cable_route_drq/test_classifier.py b/examples/async_cable_route_drq/classifier/test_classifier.py similarity index 89% rename from examples/async_cable_route_drq/test_classifier.py rename to examples/async_cable_route_drq/classifier/test_classifier.py index 430283bf..2b78a18d 100644 --- a/examples/async_cable_route_drq/test_classifier.py +++ b/examples/async_cable_route_drq/classifier/test_classifier.py @@ -21,7 +21,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "reward_classifier_ckpt_path", None, "Path to reward classifier ckpt." + "reward_classifier_ckpt_path", "./checkpoints", "Path to reward classifier ckpt." ) @@ -33,6 +33,7 @@ def main(_): env = Quat2EulerWrapper(env) env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + image_keys = [k for k in env.observation_space.keys() if "state" not in k] rng = jax.random.PRNGKey(0) @@ -55,9 +56,11 @@ def main(_): obs = next_obs - if done: + if (i % 3): print("Reward: ", rew) - obs, _ = env.reset() + # if done: + # print("Reward: ", rew) + # obs, _ = env.reset() if __name__ == "__main__": diff --git a/examples/async_cable_route_drq/train_reward_classifier.py b/examples/async_cable_route_drq/classifier/train_reward_classifier.py similarity index 88% rename from examples/async_cable_route_drq/train_reward_classifier.py rename to examples/async_cable_route_drq/classifier/train_reward_classifier.py index 63c88a4b..20ed907b 100644 --- a/examples/async_cable_route_drq/train_reward_classifier.py +++ b/examples/async_cable_route_drq/classifier/train_reward_classifier.py @@ -30,18 +30,20 @@ # Set above env export to prevent OOM errors from memory preallocation os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".2" +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".5" +os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async" + FLAGS = flags.FLAGS flags.DEFINE_multi_string("positive_demo_paths", None, "paths to positive demos") flags.DEFINE_multi_string("negative_demo_paths", None, "paths to negative demos") -flags.DEFINE_string("classifier_ckpt_path", ".", "Path to classifier checkpoint") +flags.DEFINE_string("classifier_ckpt_path", "./checkpoints", "Path to classifier checkpoint") flags.DEFINE_integer("batch_size", 256, "Batch size for training") flags.DEFINE_integer("num_epochs", 100, "Number of epochs for training") def main(_): - env = gym.make("FrankaRobotiqCableRoute-Vision-v0", fake_env=True, save_video=False) + env = gym.make("FrankaCableRoute-Vision-v0", fake_env=True, save_video=False) env = GripperCloseEnv(env) env = RelativeFrame(env) env = Quat2EulerWrapper(env) @@ -68,7 +70,7 @@ def train_reward_classifier(observation_space, action_space): pos_buffer = MemoryEfficientReplayBufferDataStore( observation_space, action_space, - capacity=10000, + capacity=20000, image_keys=image_keys, ) pos_buffer = populate_data_store(pos_buffer, FLAGS.positive_demo_paths) @@ -76,7 +78,7 @@ def train_reward_classifier(observation_space, action_space): neg_buffer = MemoryEfficientReplayBufferDataStore( observation_space, action_space, - capacity=10000, + capacity=20000, image_keys=image_keys, ) neg_buffer = populate_data_store(neg_buffer, FLAGS.negative_demo_paths) @@ -85,14 +87,14 @@ def train_reward_classifier(observation_space, action_space): print(f"success buffer size: {len(pos_buffer)}") pos_iterator = pos_buffer.get_iterator( sample_args={ - "batch_size": FLAGS.batch_size // 2, + "batch_size": FLAGS.batch_size * len(pos_buffer) // len(neg_buffer), "pack_obs_and_next_obs": False, }, device=sharding.replicate(), ) neg_iterator = neg_buffer.get_iterator( sample_args={ - "batch_size": FLAGS.batch_size // 2, + "batch_size": FLAGS.batch_size - (FLAGS.batch_size * len(pos_buffer) // len(neg_buffer)), "pack_obs_and_next_obs": False, }, device=sharding.replicate(), @@ -149,8 +151,8 @@ def loss_fn(params): sample = data_augmentation_fn(key, sample) labels = jnp.concatenate( [ - jnp.ones((FLAGS.batch_size // 2, 1)), - jnp.zeros((FLAGS.batch_size // 2, 1)), + jnp.ones((FLAGS.batch_size * len(pos_buffer) // len(neg_buffer), 1)), + jnp.zeros((FLAGS.batch_size - (FLAGS.batch_size * len(pos_buffer) // len(neg_buffer)), 1)), ], axis=0, ) diff --git a/examples/async_cable_route_drq/record_demo.py b/examples/async_cable_route_drq/record_demo.py index 55202e49..34edcc57 100644 --- a/examples/async_cable_route_drq/record_demo.py +++ b/examples/async_cable_route_drq/record_demo.py @@ -22,6 +22,8 @@ import jax if __name__ == "__main__": + + #Initializes Enviroment env = gym.make("FrankaCableRoute-Vision-v0", save_video=False) env = GripperCloseEnv(env) env = SpacemouseIntervention(env) @@ -37,20 +39,21 @@ key=key, sample=env.observation_space.sample(), image_keys=image_keys, - checkpoint_path="/home/undergrad/code/serl_dev/examples/async_cable_route_drq/classifier_ckpt/", + checkpoint_path="/home/student/code/serl/examples/async_cable_route_drq/classifier/checkpoints/", ) env = BinaryRewardClassifierWrapper(env, classifier_func) obs, _ = env.reset() + batch = [] transitions = [] success_count = 0 - success_needed = 70 + success_needed = 20 total_count = 0 pbar = tqdm(total=success_needed) uuid = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - file_name = f"./bc_demos/cable_route_{success_needed}_demos_{uuid}.pkl" + file_name = f"./demos/cable_route_{success_needed}_demos_{uuid}.pkl" file_dir = os.path.dirname(os.path.realpath(__file__)) # same dir as this script file_path = os.path.join(file_dir, file_name) @@ -77,12 +80,15 @@ dones=done, ) ) - transitions.append(transition) + batch.append(transition) obs = next_obs if done: - success_count += rew + if rew: + transitions += batch + success_count += 1 total_count += 1 + batch.clear() print( f"{rew}\tGot {success_count} successes of {total_count} trials. {success_needed} successes needed." ) diff --git a/examples/async_cable_route_drq/run_actor.sh b/examples/async_cable_route_drq/run_actor.sh index 93a4f8a1..879a0af1 100644 --- a/examples/async_cable_route_drq/run_actor.sh +++ b/examples/async_cable_route_drq/run_actor.sh @@ -1,16 +1,20 @@ +# All export statements end with && \ to chain them together export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + python async_drq_randomized.py "$@" \ --actor \ --render \ --env FrankaCableRoute-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd20demos_cable_random_resnet \ - --seed 0 \ + --seed 1 \ --random_steps 0 \ --encoder_type resnet-pretrained \ - --demo_path cable_route_20_demos_2024-01-04_12-10-54.pkl \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_cable_route_drq/10x10_30degs_20demos_rand_cable_096 \ - --reward_classifier_ckpt_path "/home/undergrad/code/serl_dev/examples/async_cable_route_drq/classifier_ckpt/" \ - --eval_checkpoint_step 20000 \ - --eval_n_trajs 20 \ + --reward_classifier_ckpt_path /home/student/code/serl/examples/async_cable_route_drq/classifier/checkpoints/ \ --max_traj_length 100 \ + # --eval_checkpoint_step 20000 \ + # --eval_n_trajs 20 \ + + diff --git a/examples/async_cable_route_drq/run_automated_eval.sh b/examples/async_cable_route_drq/run_automated_eval.sh new file mode 100644 index 00000000..c29d3229 --- /dev/null +++ b/examples/async_cable_route_drq/run_automated_eval.sh @@ -0,0 +1,11 @@ +# NOT CURRENTLY WORKING +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_cable_route_drq/checkpoints" && \ + +for STEP in 1000 8000 6000 4000 2000 1000; do + echo "Testing $STEP" + bash run_eval.sh \ + --eval_checkpoint_step $STEP \ + --checkpoint_path "$CHECKPOINT_EVAL/fractal_video_01" \ + --eval_n_trajs 50 \ + 2<&1 | grep "success rate" +done diff --git a/examples/async_cable_route_drq/run_eval.sh b/examples/async_cable_route_drq/run_eval.sh new file mode 100644 index 00000000..88f421cc --- /dev/null +++ b/examples/async_cable_route_drq/run_eval.sh @@ -0,0 +1,19 @@ +# All export statements end with && \ to chain them together +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_cable_route_drq/checkpoints" # && \ +# export STEP=8000 + +python async_drq_randomized.py \ + --actor \ + --render \ + --env FrankaCableRoute-Vision-v0 \ + --reward_classifier_ckpt_path /home/student/code/serl/examples/async_cable_route_drq/classifier/checkpoints/ \ + --eval_checkpoint_step 7000 \ + --eval_n_trajs 50 \ + --checkpoint_path "$CHECKPOINT_EVAL/fractal_video_01" \ + "$@" diff --git a/examples/async_cable_route_drq/run_learner.sh b/examples/async_cable_route_drq/run_learner.sh index a8355d9b..1fe9e6f5 100644 --- a/examples/async_cable_route_drq/run_learner.sh +++ b/examples/async_cable_route_drq/run_learner.sh @@ -1,16 +1,33 @@ +# All export statements end with && \ to chain them together export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ + python async_drq_randomized.py "$@" \ --learner \ --env FrankaCableRoute-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd20demos_cable_random_resnet_096 \ - --seed 0 \ + --exp_name="Franka-CableRoute-V2" \ + --seed 1 \ --random_steps 600 \ - --training_starts 200 \ + --training_starts 1 \ --critic_actor_ratio 4 \ --batch_size 256 \ - --eval_period 2000 \ + --max_steps 8001 \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --replay_buffer_capacity 3_600_000 \ + --starting_branch_count 27 \ + --branch_method "constant" \ + --split_method "never" \ + --alpha 0.2 \ + --max_depth 3 \ + --branching_factor 3 \ + --workspace_width 0.3 \ --encoder_type resnet-pretrained \ - --demo_path cable_route_20_demos_2024-01-04_12-10-54.pkl \ - --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_cable_route_drq/10x10_30degs_20demos_rand_cable_096 + --demo_path /home/student/code/serl/examples/async_cable_route_drq/demos/fractal_video_01.pkl \ + --checkpoint_period 500 \ + --checkpoint_path $CHECKPOINT_DIR diff --git a/examples/async_drq_sim/.vscode/launch.json b/examples/async_drq_sim/.vscode/launch.json new file mode 100644 index 00000000..4e3fbf20 --- /dev/null +++ b/examples/async_drq_sim/.vscode/launch.json @@ -0,0 +1,169 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + "--render", + "--env", "PandaReachSparseCube-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaReachCubeVision-v0_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "4", // Critic-to-actor update ratio + "--max_steps", "100_000 ", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "0", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--learner", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--render", + "--env", "PandaReachSparseCube-v0", // Environment to use + "--agent", "drq", // Agent type (drq or sac) + "--exp_name", "PandaPickCubeVision-v0_sparse_001", // Experiment name for wandb logging + "--max_traj_length", "100", // Max episode length/ Max episode length + "--seed", "42", // Random seed for reproducibility + // "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--max_steps", "50_000", // Maximum training steps + "--replay_buffer_capacity", "200_000", // Replay buffer capacity + "--random_steps", "300", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--steps_per_update", "30", // Number of env steps per update + "--actor", // REQUIRED: Indicates this is a learner instance + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + + // Fractal + // "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + // "--branch_method", "constant", + // "--split_method", "constant", + // "--starting_branch_count", "3", // Start with 27 branches + // "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + "--preload_rlds_path", "/media/bison/hdd/data/serl/demos/franka_reach_drq_demo_script/2_demos_session_20250916_100650/PandaReachSparseCube-v0/0.1.0", // Preload RLDS dataset for faster loading + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + //"--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_drq_sim/async_drq_sim.py b/examples/async_drq_sim/async_drq_sim.py index a84bc580..37af0fb5 100644 --- a/examples/async_drq_sim/async_drq_sim.py +++ b/examples/async_drq_sim/async_drq_sim.py @@ -38,9 +38,10 @@ FLAGS = flags.FLAGS -flags.DEFINE_string("env", "PandaPickCubeVision-v0", "Name of environment.") +flags.DEFINE_string("env", "PandaReachSparseCube-v0", "Name of environment.") flags.DEFINE_string("agent", "drq", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 1000, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") @@ -69,6 +70,21 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_float("alpha",None,"alpha value") +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") + flags.DEFINE_boolean( "debug", False, "Debug mode." ) # debug mode will disable wandb logging @@ -108,7 +124,7 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env == "PandaReachSparseCube-v0": eval_env = SERLObsWrapper(eval_env) eval_env = ChunkingWrapper(eval_env, obs_horizon=1, act_exec_horizon=None) eval_env = RecordEpisodeStatistics(eval_env) @@ -191,9 +207,12 @@ def learner( """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop @@ -325,11 +344,17 @@ def main(_): else: env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) - if FLAGS.env == "PandaPickCubeVision-v0": + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0", "PandaPickCubeVision-v0", "PandaReachCubeVision-v0", "PandaPickSparseCubeVision-v0", "PandaReachSparseCubeVision-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + + if FLAGS.env == "PandaReachSparseCube-v0": env = SERLObsWrapper(env) env = ChunkingWrapper(env, obs_horizon=1, act_exec_horizon=None) + else: + env = gym.wrappers.FlattenObservation(env) image_keys = [key for key in env.observation_space.keys() if key != "state"] @@ -345,7 +370,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: @@ -354,7 +379,21 @@ def main(_): env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="memory_efficient_replay_buffer", + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, image_keys=image_keys, ) @@ -370,14 +409,39 @@ def preload_data_transform(data, metadata) -> Optional[Dict[str, Any]]: # NOTE: Create your own custom data transform function here if you # are loading this via with --preload_rlds_path with tf rlds data # This default does nothing + # See: https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=X1KXM8IGecRO + # https://www.tensorflow.org/guide/data + # https://github.com/google-research/rlds/blob/main/docs/transformations.md + # Batch: rlds.transformations.batch (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=TGT3YfzFOrBm) + # Reverb: rlds.transformations.pattern_map (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_dataset_patterns.ipynb ) + # Nested data set manipulation: rlds.transformations.episode_length/.sum_dataset/.final_step/.map_nested_steps + # Concatenation: rlds.transformations.concatenate / .concat_if_terminal (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_examples.ipynb#scrollTo=pWNhxwJzOUJv) + # Stats: rlds.transformations.mean_and_std (https://colab.research.google.com/github/google-research/rlds/blob/main/rlds/examples/rlds_tutorial.ipynb#scrollTo=Z0TITfo_4oZr) + # Truncation: rlds.transformations.truncate_after_condition + # Alignment: rlds.transformations.shift_keys + # Zero Init: rlds.transformations.zeros_from_spec return data demo_buffer = make_replay_buffer( env, capacity=FLAGS.replay_buffer_capacity, - type="memory_efficient_replay_buffer", - image_keys=image_keys, + rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, + image_keys=image_keys, preload_data_transform=preload_data_transform, ) diff --git a/examples/async_drq_sim/automated_tests.sh b/examples/async_drq_sim/automated_tests.sh new file mode 100644 index 00000000..07a58f48 --- /dev/null +++ b/examples/async_drq_sim/automated_tests.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +SEEDS=$1 +# WANDB_OUTPUT_DIR=~/wandb_logs +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaReachSparseCube-v0" +MAX_STEPS=1000000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +BATCH_SIZE=128 +EXP_NAME="FIRST-TESTS-$ENV" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" +PRELOAD_RLDS="/data/data/serl/demos/franka_reach_drq_demo_script/10_demos_session_202500914_213515/PandaReachSparseCube-v0/0.1.0" +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --batch_size $BATCH_SIZE --preload_rlds_path $PRELOAD_RLDS --encoder_type resnet-pretrained" +ARGS="" + +function run_test { + + for seed in $(seq 1 1 $SEEDS) + do + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} + + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_actor.sh --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash run_learner.sh --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "logout" > /dev/null; + do + sleep 100 + done + echo "Finished!" + done +} + +# BASELINE TESTING +for replay_buffer_capacity in 1000000 +do + ARGS="--run_name baseline --replay_buffer_type memory_efficient_replay_buffer --replay_buffer_capacity $replay_buffer_capacity" + run_test +done + +# CONSTANT TESTING +for starting_branch_count in 1 27 +do + for workspace_width in 0.5 + do + for replay_buffer_capacity in 1000000 + do + ARGS="--run_name constant-$starting_branch_count^1 --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done +done + +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_drq_sim/automated_tests_helper.sh b/examples/async_drq_sim/automated_tests_helper.sh new file mode 100644 index 00000000..3c02ecb6 --- /dev/null +++ b/examples/async_drq_sim/automated_tests_helper.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export MUJOCO_GL=egl + +python async_drq_sim.py "$@" \ No newline at end of file diff --git a/examples/async_drq_sim/run_actor.sh b/examples/async_drq_sim/run_actor.sh index 52fcfc41..ece2d80e 100644 --- a/examples/async_drq_sim/run_actor.sh +++ b/examples/async_drq_sim/run_actor.sh @@ -1,10 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ -python async_drq_sim.py "$@" \ - --actor \ - --render \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --random_steps 1000 \ - --encoder_type resnet-pretrained \ - --debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +python async_drq_sim.py --actor "$@" diff --git a/examples/async_drq_sim/run_learner.sh b/examples/async_drq_sim/run_learner.sh index 39445448..48deba7e 100644 --- a/examples/async_drq_sim/run_learner.sh +++ b/examples/async_drq_sim/run_learner.sh @@ -1,11 +1,6 @@ export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ -python async_drq_sim.py "$@" \ - --learner \ - --exp_name=serl_dev_drq_sim_test_resnet \ - --seed 0 \ - --training_starts 1000 \ - --critic_actor_ratio 4 \ - --encoder_type resnet-pretrained \ - # --demo_path franka_lift_cube_image_20_trajs.pkl \ - --debug # wandb is disabled when debug +export XLA_PYTHON_CLIENT_MEM_FRACTION=.4 && \ +export MUJOCO_GL=egl && \ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ + +python async_drq_sim.py --learner "$@" diff --git a/examples/async_drq_sim/tmux_launch_tests.sh b/examples/async_drq_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..aa7a6a57 --- /dev/null +++ b/examples/async_drq_sim/tmux_launch_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SEEDS=1 +# Create a new tmux session +tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m + + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/examples/async_drq_sim/tmux_rlpd_launch.sh b/examples/async_drq_sim/tmux_rlpd_launch.sh old mode 100644 new mode 100755 diff --git a/examples/async_pcb_insert_drq/async_drq_randomized.py b/examples/async_pcb_insert_drq/async_drq_randomized.py index 8248379e..5a1770ff 100644 --- a/examples/async_pcb_insert_drq/async_drq_randomized.py +++ b/examples/async_pcb_insert_drq/async_drq_randomized.py @@ -440,7 +440,7 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) if FLAGS.learner: diff --git a/examples/async_peg_insert_drq/.vscode/launch.json b/examples/async_peg_insert_drq/.vscode/launch.json new file mode 100644 index 00000000..cc83bf63 --- /dev/null +++ b/examples/async_peg_insert_drq/.vscode/launch.json @@ -0,0 +1,119 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_randomized.py", + // Command-line arguments matching run_learner.sh + "args": [ + "--learner", // REQUIRED: Indicates this is a learner instance + "--render", + "--env", "FrankaPegInsert-Vision-v0" , // Environment to use + "--exp_name", "PegInsert-fractal_27_width_0.1_30_demos", // Experiment name for wandb logging + "--seed", "3", // Random seed for reproducibility + "--save_model", // Save model checkpoints + "--batch_size", "256", // Training batch size + "--critic_actor_ratio", "4", // Critic-to-actor update ratio + "--max_steps", "100_000 ", // Maximum training steps + "--replay_buffer_capacity", "3_600_000", // Replay buffer capacity + "--random_steps", "1_000", // Number of random steps at beginning + "--training_starts", "300", // Start training after buffer has this many samples + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + "--demo_path", "peg_insert_30_demos_2025-10-26_20-15-48.pkl", + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--starting_branch_count", "27", // Start with 27 branches + "--workspace_width", "0.1", + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_drq_randomized.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--actor", // REQUIRED: Indicates this is a learner instance + "--render", + "--env", "FrankaPegInsert-Vision-v0" , // Environment to use + "--exp_name", "PegInsert-fractal_27_width_0.1_30_demos", // Experiment name for wandb logging + "--seed", "3", // Random seed for reproducibility + "--random_steps", "0", // Number of random steps at beginning + "--training_starts", "200", // Start training after buffer has this many samples + "--encoder_type", "resnet-pretrained", // Use pixel-based observations + "--demo_path", "peg_insert_30_demos_2025-10-26_20-15-48.pkl" + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".3" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_peg_insert_drq/async_drq_randomized.py b/examples/async_peg_insert_drq/async_drq_randomized.py index 4fd76f08..cefea7ef 100644 --- a/examples/async_peg_insert_drq/async_drq_randomized.py +++ b/examples/async_peg_insert_drq/async_drq_randomized.py @@ -25,8 +25,9 @@ make_drq_agent, make_trainer_config, make_wandb_logger, + make_replay_buffer ) -from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore +# from serl_launcher.data.data_store import MemoryEfficientReplayBufferDataStore from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper from franka_env.envs.relative_env import RelativeFrame from franka_env.envs.wrappers import ( @@ -45,10 +46,10 @@ flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") +flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 4, "critic to actor update ratio.") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") -flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") flags.DEFINE_integer("random_steps", 300, "Sample random actions for this many steps.") flags.DEFINE_integer("training_starts", 300, "Training starts after this step.") @@ -68,6 +69,17 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") +# replay buffer flags +flags.DEFINE_string("replay_buffer_type", "memory_efficient_replay_buffer", "Which replay buffer to use") +flags.DEFINE_integer("replay_buffer_capacity", 200000, "Replay buffer capacity.") +flags.DEFINE_integer("branching_factor", None, "Factor by which branch count is changed") +flags.DEFINE_integer("max_depth", None, "Maximum number of splits that may occur in one episode") +flags.DEFINE_string("branch_method", "constant", "Method for how many branches to generate") +flags.DEFINE_string("split_method", "never", "Method for when to change number of branches") +flags.DEFINE_float("alpha", 0.2, "Rate of change of max_traj_length") +flags.DEFINE_float("workspace_width", 0.3, "Workspace width in meters") +flags.DEFINE_integer("starting_branch_count", 27, "Initial number of branches") + flags.DEFINE_integer( "eval_checkpoint_step", 0, "evaluate the policy from ckpt at this step" ) @@ -103,6 +115,7 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): step=FLAGS.eval_checkpoint_step, ) agent = agent.replace(state=ckpt) + env.reset(joint_reset=True) for episode in range(FLAGS.eval_n_trajs): obs, _ = env.reset() @@ -126,7 +139,7 @@ def actor(agent: DrQAgent, data_store, env, sampling_rng): success_counter += reward print(reward) - print(f"{success_counter}/{episode + 1}") + print(f"{success_counter}/{episode + 1} ({success_counter / (episode + 1):.2f})") print(f"success rate: {success_counter / FLAGS.eval_n_trajs}") print(f"average time: {np.mean(time_list)}") @@ -147,7 +160,7 @@ def update_params(params): client.recv_network_callback(update_params) - obs, _ = env.reset() + obs, _ = env.reset(joint_reset=True) done = False # training loop @@ -216,7 +229,7 @@ def learner(rng, agent: DrQAgent, replay_buffer, demo_buffer): """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, description=FLAGS.exp_name or FLAGS.env, debug=FLAGS.debug, ) @@ -324,6 +337,7 @@ def main(_): fake_env=FLAGS.learner, save_video=FLAGS.eval_checkpoint_step, ) + print(env.observation_space) env = GripperCloseEnv(env) if FLAGS.actor: env = SpacemouseIntervention(env) @@ -347,23 +361,60 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: DrQAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) + ## Set indices to be transformed by fractal class for the serl_robot_infra/robot_env/envs/franka_env + # Note that observation_space[state] willb e sorted and set as an ordered dict by SerlObservationWrapper + # gripper_pose:0 + # tcp_force.x: 1 + # tcp_force.y: 2 + # tcp_force.z: 3 + # tcp_pose.x: 4 <-- rel_frame.x points to base.+y + # tcp_pose.y: 5 <-- rel_frame.y points to base.+x + # tcp_pose.z: 6 <-- rel_frame.z points to base.-z + x_obs_idx = np.array([4]) + y_obs_idx = np.array([5]) + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) - replay_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, + replay_buffer = make_replay_buffer( + env, capacity=FLAGS.replay_buffer_capacity, + # rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + # preload_rlds_path=FLAGS.preload_rlds_path, image_keys=image_keys, ) - demo_buffer = MemoryEfficientReplayBufferDataStore( - env.observation_space, - env.action_space, - capacity=10000, + demo_buffer = make_replay_buffer( + env, + capacity=FLAGS.replay_buffer_capacity, + # rlds_logger_path=FLAGS.log_rlds_path, + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + branching_factor=FLAGS.branching_factor, + max_depth=FLAGS.max_depth, + max_traj_length=FLAGS.max_traj_length, + split_method=FLAGS.split_method, + alpha=FLAGS.alpha, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + # preload_rlds_path=FLAGS.preload_rlds_path, image_keys=image_keys, ) + print(demo_buffer._size) import pickle as pkl with open(FLAGS.demo_path, "rb") as f: @@ -391,7 +442,7 @@ def main(_): else: raise NotImplementedError("Must be either a learner or an actor") - + return if __name__ == "__main__": app.run(main) diff --git a/examples/async_peg_insert_drq/docs/TestSpaceMouse.py b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py new file mode 100644 index 00000000..0e8b86ce --- /dev/null +++ b/examples/async_peg_insert_drq/docs/TestSpaceMouse.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +""" +SpaceMouse Test Script (Fixed Version) + +This script tests if a 3Dconnexion SpaceMouse is properly connected and accessible. +It's designed to handle different versions of pyspacemouse and potential API inconsistencies. + +Usage: + python FixedTestSpaceMouse.py + +Press Ctrl+C to exit the program. +""" + +import time +import sys +import traceback +import threading + +# Try importing the pyspacemouse library +try: + import pyspacemouse + print("Successfully imported pyspacemouse library") + + # Print the library version if available + try: + version = getattr(pyspacemouse, "__version__", "unknown") + print(f"pyspacemouse version: {version}") + except: + print("Could not determine pyspacemouse version") +except ImportError as e: + print(f"Failed to import pyspacemouse: {e}") + print("Try installing it with: pip install pyspacemouse") + sys.exit(1) +except Exception as e: + print(f"Unexpected error importing pyspacemouse: {e}") + traceback.print_exc() + sys.exit(1) + +def print_device_info(device): + """Print detailed information about the SpaceMouse device.""" + print("\n===== DEVICE INFORMATION =====") + + # Check if device is a string or an object + if isinstance(device, str): + print(f"Device information: {device}") + return + + # Try to access device attributes safely + try: + attrs = [ + ("Manufacturer", "manufacturer_string", "Unknown"), + ("Product", "product_string", "Unknown"), + ("Vendor ID", "vendor_id", "Unknown"), + ("Product ID", "product_id", "Unknown"), + ("Serial Number", "serial_number", "Unknown"), + ("Release Number", "release_number", "Unknown"), + ("Interface Number", "interface_number", "Unknown") + ] + + for label, attr, default in attrs: + value = getattr(device, attr, default) + if attr in ["vendor_id", "product_id"] and value != "Unknown": + print(f"{label}: 0x{value:04x}") + else: + print(f"{label}: {value}") + except Exception as e: + print(f"Error accessing device attributes: {e}") + print(f"Raw device data: {device}") + + print("===============================\n") + +def test_device_connection(): + """Test if the SpaceMouse device is connected and accessible.""" + print("Attempting to detect SpaceMouse devices...") + + try: + # List all available devices + try: + devices = pyspacemouse.list_devices() + print(f"list_devices() returned: {devices}") + + if not devices: + print("No SpaceMouse devices found!") + return None + + print(f"Found {len(devices)} devices.") + + # Safely print device information + for i, device in enumerate(devices): + print(f"Device {i+1}:") + if isinstance(device, str): + print(f" {device}") + else: + try: + vendor_id = getattr(device, "vendor_id", "Unknown") + product_id = getattr(device, "product_id", "Unknown") + product_string = getattr(device, "product_string", "Unknown Device") + + if vendor_id != "Unknown" and product_id != "Unknown": + print(f" {product_string} (Vendor ID: 0x{vendor_id:04x}, Product ID: 0x{product_id:04x})") + else: + print(f" {product_string}") + except Exception as e: + print(f" Error printing device {i+1} info: {e}") + print(f" Raw device data: {device}") + + # Use the first device found + return devices[0] + + except AttributeError: + # Alternative approach if list_devices doesn't work as expected + print("list_devices() method not working as expected, trying open() directly...") + if pyspacemouse.open(): + print("Successfully opened a SpaceMouse device directly") + return "SpaceMouse Device" + else: + print("Failed to open any SpaceMouse device directly") + return None + + except Exception as e: + print(f"Error detecting devices: {e}") + traceback.print_exc() + return None + +def open_device(device): + """Try to open the SpaceMouse device.""" + print("Attempting to open the SpaceMouse...") + try: + # Check if device is already open + if hasattr(pyspacemouse, "is_open") and pyspacemouse.is_open(): + print("Device is already open") + return True + + # Try to open the device + if isinstance(device, str): + # If device is a string, just try to open any device + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + else: + # Try to open the specific device + try: + if pyspacemouse.open(callback=None, device=device): + print("Successfully opened the SpaceMouse device") + return True + except TypeError: + # If device parameter isn't supported, try without it + if pyspacemouse.open(callback=None): + print("Successfully opened a SpaceMouse device") + return True + + print("Failed to open the SpaceMouse device") + print("Check if another program is already using it") + return False + + except Exception as e: + print(f"Error opening device: {e}") + traceback.print_exc() + return False + +def safe_read(): + """Safely read from the device, handling potential API differences.""" + try: + state = pyspacemouse.read() + return state + except Exception as e: + print(f"Error reading from device: {e}") + return None + +def monitor_button_state(stop_event): + """Monitor and print button state changes in a separate thread.""" + last_buttons = None + + while not stop_event.is_set(): + try: + state = safe_read() + if state and hasattr(state, "buttons") and state.buttons is not None: + current_buttons = state.buttons + if last_buttons != current_buttons: + buttons_pressed = [i for i, pressed in enumerate(current_buttons) if pressed] + if buttons_pressed: + print(f"Buttons pressed: {buttons_pressed}") + last_buttons = current_buttons.copy() if hasattr(current_buttons, "copy") else current_buttons + time.sleep(0.05) # Small sleep to prevent 100% CPU usage + except Exception as e: + print(f"Error in button monitoring thread: {e}") + time.sleep(1) # Wait a bit longer on error + +def main(): + """Main function to test the SpaceMouse.""" + print("SpaceMouse Test Script (Fixed Version)") + print("-------------------------------------") + + # First, check if we can detect any devices + device = test_device_connection() + if not device: + print("\nNo SpaceMouse detected. Please ensure:") + print("1. The device is connected to your computer") + print("2. You have installed libhidapi (sudo apt-get install libhidapi-dev libhidapi-hidraw0)") + print("3. You have proper permissions (sudo usermod -a -G plugdev $USER)") + print("4. You've created proper udev rules if needed") + sys.exit(1) + + # Print detailed device information + print_device_info(device) + + # Try to open the device + if not open_device(device): + sys.exit(1) + + # Test basic functionality + print("\nTesting basic device functionality...") + test_state = safe_read() + if test_state: + print("Successfully read initial state from device:") + try: + attrs = ["x", "y", "z", "roll", "pitch", "yaw", "buttons"] + for attr in attrs: + if hasattr(test_state, attr): + value = getattr(test_state, attr) + print(f" {attr}: {value}") + else: + print(f" {attr}: Not available") + except Exception as e: + print(f"Error reading attributes: {e}") + print(f"Raw state: {test_state}") + else: + print("Could not read initial state from device!") + print("The device might be connected but not functioning correctly.") + sys.exit(1) + + # Create a thread to monitor button presses + stop_event = threading.Event() + button_thread = threading.Thread(target=monitor_button_state, args=(stop_event,)) + button_thread.daemon = True + button_thread.start() + + # Monitor device movement + print("\nMove your SpaceMouse to see the values") + print("Press Ctrl+C to exit") + print("\nReading SpaceMouse state...") + + try: + last_print_time = time.time() + while True: + # Read the current state + state = safe_read() + + if state: + current_time = time.time() + + # Check if any movement data is available + x = getattr(state, "x", 0) or 0 + y = getattr(state, "y", 0) or 0 + z = getattr(state, "z", 0) or 0 + roll = getattr(state, "roll", 0) or 0 + pitch = getattr(state, "pitch", 0) or 0 + yaw = getattr(state, "yaw", 0) or 0 + + if any(abs(val) > 0.01 for val in [x, y, z, roll, pitch, yaw]): + # Print at most 10 times per second + if current_time - last_print_time >= 0.1: + print(f"\rPosition: X:{x:6.2f} Y:{y:6.2f} Z:{z:6.2f} | " + f"Rotation: Roll:{roll:6.2f} Pitch:{pitch:6.2f} Yaw:{yaw:6.2f}", + end="", flush=True) + last_print_time = current_time + + time.sleep(0.01) # Small sleep to prevent 100% CPU usage + + except KeyboardInterrupt: + print("\n\nExiting...") + except Exception as e: + print(f"\n\nError reading from device: {e}") + traceback.print_exc() + finally: + # Clean up + stop_event.set() + button_thread.join(timeout=1.0) + try: + pyspacemouse.close() + print("\nClosed SpaceMouse connection") + except Exception as e: + print(f"\nError closing device: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/async_peg_insert_drq/docs/listGym_env.py b/examples/async_peg_insert_drq/docs/listGym_env.py new file mode 100644 index 00000000..c6179294 --- /dev/null +++ b/examples/async_peg_insert_drq/docs/listGym_env.py @@ -0,0 +1,5 @@ +import gymnasium as gym + +# Correct method to list environments +for env in gym.envs.registry.keys(): + print(env) diff --git a/examples/async_peg_insert_drq/record_demo.py b/examples/async_peg_insert_drq/record_demo.py index 5a0fd02b..7f3da22f 100644 --- a/examples/async_peg_insert_drq/record_demo.py +++ b/examples/async_peg_insert_drq/record_demo.py @@ -30,6 +30,7 @@ obs, _ = env.reset() + batch = [] transitions = [] success_count = 0 success_needed = 20 @@ -64,19 +65,23 @@ dones=done, ) ) - transitions.append(transition) - + batch.append(transition) obs = next_obs if done: - success_count += rew + if rew: + transitions += batch + success_count += 1 total_count += 1 + + batch.clear() print( f"{rew}\tGot {success_count} successes of {total_count} trials. {success_needed} successes needed." ) pbar.update(rew) obs, _ = env.reset() + with open(file_path, "wb") as f: pkl.dump(transitions, f) print(f"saved {success_needed} demos to {file_path}") diff --git a/examples/async_peg_insert_drq/run_actor.sh b/examples/async_peg_insert_drq/run_actor.sh index a251e756..89eb0a63 100644 --- a/examples/async_peg_insert_drq/run_actor.sh +++ b/examples/async_peg_insert_drq/run_actor.sh @@ -1,12 +1,30 @@ +# All export statements end with && \ to chain them together export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.1 && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_peg_insert_drq/checkpoints/checkpoints-07-14-2025-23-15-59" && \ + + +# # Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + python async_drq_randomized.py "$@" \ --actor \ --render \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet \ - --seed 0 \ + --env "FrankaPegInsert-Vision-v0" \ --random_steps 0 \ + --seed 5 \ --training_starts 200 \ + --save_model \ + --max_steps 15000 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ diff --git a/examples/async_peg_insert_drq/run_automated_eval.sh b/examples/async_peg_insert_drq/run_automated_eval.sh new file mode 100644 index 00000000..e14cdc3f --- /dev/null +++ b/examples/async_peg_insert_drq/run_automated_eval.sh @@ -0,0 +1,10 @@ +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_peg_insert_drq/checkpoints" && \ + + +for STEP in 3500 3000 2500 2000 1500 1000 500; do + echo "Testing $STEP" + bash run_eval.sh \ + --eval_checkpoint_step $STEP \ + --checkpoint_path "$CHECKPOINT_EVAL/fractal27_05" \ + 2>&1 | grep "success rate" +done diff --git a/examples/async_peg_insert_drq/run_eval.sh b/examples/async_peg_insert_drq/run_eval.sh new file mode 100644 index 00000000..4511e4b3 --- /dev/null +++ b/examples/async_peg_insert_drq/run_eval.sh @@ -0,0 +1,29 @@ +# All export statements end with && \ to chain them together +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.3 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ +export CHECKPOINT_EVAL="/home/student/code/serl/examples/async_peg_insert_drq/checkpoints" && \ + + +## Create checkpoint directory if it doesn't exist +#if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +#fi + +python async_drq_randomized.py \ + --actor \ + --render \ + --env "FrankaPegInsert-Vision-v0" \ + --eval_checkpoint_step 3500 \ + --eval_n_trajs 50000000 \ + --checkpoint_path "$CHECKPOINT_EVAL/checkpoints-04-23-2026-19-01-43" \ + "$@" diff --git a/examples/async_peg_insert_drq/run_learner.sh b/examples/async_peg_insert_drq/run_learner.sh index c2823a19..808868c3 100644 --- a/examples/async_peg_insert_drq/run_learner.sh +++ b/examples/async_peg_insert_drq/run_learner.sh @@ -1,16 +1,45 @@ +# All export statements end with && \ to chain them together export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.2 && \ +# XLA memory fraction with learner+action <0.8. Learner needs more. +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +# Use malloc_async to reduce fragmentation, overlap memory allocation with compute, lower stalls and improve worklads. Requires cuda11.2+ +export TF_GPU_ALLOCATOR=cuda_malloc_async && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export ENV_NAME="FrankaPegInsert-Vision-v0" && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="$SCRIPT_DIR/checkpoints/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist. Used to saved learn policy. +if [ ! -d "$CHECKPOINT_DIR" ]; then + echo "Creating checkpoint directory: $CHECKPOINT_DIR" + mkdir -p "$CHECKPOINT_DIR" || { + echo "Failed to create checkpoint directory!" >&2 + exit 1 + } +fi + python async_drq_randomized.py "$@" \ --learner \ - --env FrankaPegInsert-Vision-v0 \ - --exp_name=serl_dev_drq_rlpd10demos_peg_insert_random_resnet_097 \ - --seed 0 \ - --random_steps 1000 \ - --training_starts 200 \ - --critic_actor_ratio 4 \ + --env $ENV_NAME \ + --exp_name="PegInsert-march_2026" \ + --seed 5 \ + --random_steps 1_000 \ + --training_starts 1 \ + --critic_actor_ratio 8 \ --batch_size 256 \ - --eval_period 2000 \ + --max_steps 3501 \ + --replay_buffer_type "fractal_symmetry_replay_buffer" \ + --save_model \ + --replay_buffer_capacity 3_600_000 \ + --starting_branch_count 27 \ + --branch_method "constant" \ + --split_method "never" \ + --alpha 0.2 \ + --max_depth 3 \ + --branching_factor 3 \ + --workspace_width 0.3 \ --encoder_type resnet-pretrained \ - --demo_path peg_insert_20_demos_2023-12-25_16-13-25.pkl \ - --checkpoint_period 1000 \ - --checkpoint_path /home/undergrad/code/serl_dev/examples/async_peg_insert_drq/5x5_20degs_20demos_rand_peg_insert_097 + --demo_path peg_insert_20_demos_2026-04-23_17-52-59.pkl \ + --checkpoint_period 500 \ + --checkpoint_path "$CHECKPOINT_DIR" \ + --debug \ diff --git a/examples/async_sac_state_sim/.vscode/launch.json b/examples/async_sac_state_sim/.vscode/launch.json new file mode 100644 index 00000000..a1768ccb --- /dev/null +++ b/examples/async_sac_state_sim/.vscode/launch.json @@ -0,0 +1,159 @@ +{ + // VSCode debug configuration version + "version": "0.2.0", + "configurations": [ + + + + // LEARNER + { + // Configuration for the RL agent learner component + "name": "Python: Learner", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_learner.sh + "args": [ + + "--learner", // REQUIRED: Indicates this is a learner instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "50_000 ", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--starting_branch_count", "3", // Start with 27 branches + "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_learner.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + // Additional helpful debugging options + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + }, + // ACTOR + { + // Configuration for the RL agent actor component + "name": "Python: Actor", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/async_sac_state_sim.py", + // Command-line arguments matching run_actor.sh + "args": [ + "--actor", // REQUIRED: Indicates this is an actor instance + "--env", "PandaReachCube-v0", // Environment to use + "--exp_name", "PandaReachCube-v0_state_sim_3d_parallel_const_3^1_batch_256_replay_1M_utd_8", // Experiment name for wandb logging + "--seed", "0", // Random seed for reproducibility + "--random_steps", "1_000", // Number of random steps at beginning + "--max_steps", "500_000 ", // Maximum training steps + "--training_starts", "1_000", // Start training after buffer has this many samples + "--critic_actor_ratio", "8", // Critic-to-actor update ratio + "--batch_size", "256", // Training batch size + "--replay_buffer_capacity", "1_000_000", // Replay buffer capacity + + // Fractal + "--replay_buffer_type", "fractal_symmetry_replay_buffer_parallel", // Use fractal symmetry replay buffer + "--branch_method", "constant", + "--split_method", "constant", + "--starting_branch_count", "3", // Start with 27 branches + "--workspace_width", "0.5", + + // Demonstration data loading options + // "--load_demos", // Load demo dataset + // "--demo_dir", "/data/data/serl/demos", + // "--file_name", "data_franka_reach_random_5_2.npz", // Name of the demo file to load + + // Dissasociated Fractals + // "--branch_method", "disassociated", // branch method type + // "--split_method", "disassociated", // split method type + // "--disassociated_type", "octahedron", // Type of disassociated test to perform + // "--min_branch_count", "3", // Minimum branch count for disassociated testing + // "--max_branch_count", "9", // Maximum branch count for disassociated testing + // "--num_depth_sectors", "13", // Desired number of sectors to divide rollouts into for branch count splitting + "--alpha", "1" // alpha + + ], + "console": "integratedTerminal", // Use integrated terminal to see output + "justMyCode": false, // Allow stepping into libraries + // Environment variables from run_actor.sh + "env": { + "XLA_PYTHON_CLIENT_PREALLOCATE": "false", // Don't preallocate JAX/XLA memory + "XLA_PYTHON_CLIENT_MEM_FRACTION": ".5" // Limit JAX memory usage to 50% + }, + "showReturnValue": true, // Show function return values + "purpose": ["debug-in-terminal"], // Run in terminal for better output + "cwd": "${workspaceFolder}" // Set working directory explicitly + } + ], + // Compound configurations to launch multiple configurations together + "compounds": [ + { + // Launch both learner and actor at the same time + "name": "Learner + Actor", + "configurations": ["Python: Learner", "Python: Actor"], + // Note: Actor will connect to learner via TrainerClient, assuming localhost IP + // Both will run in separate debug sessions with independent controls + } + ], + + /* + DEBUGGING TIPS: + + - IMPORTANT: You must select either Learner or Actor configuration when debugging + (the NotImplementedError occurs if neither --learner nor --actor flag is specified) + + - If you get an error about 'utd_ratio', use the "Python: Learner (Fix utd_ratio)" configuration + which adds this missing parameter + + - Set breakpoints in learner() or actor() functions to step through the main training loops + + - Key places to set breakpoints: + * In actor(): near the action sampling logic (step < FLAGS.random_steps) + * In learner(): where agent.update_high_utd() is called + * Server/client communication points (client.update(), server.publish_network()) + + - For memory issues: Watch replay buffer size growth with breakpoints in data_store.insert() + + - JAX issues: Set breakpoints after jax.device_put() calls to ensure proper device placement + + - The "utd_ratio" parameter seems to be used in the learner function but isn't defined + in the FLAGS. Use the special configuration or add a --utd_ratio flag to fix. + + DEBUG WORKFLOW: + + 1. Start with "Python: Learner (Fix utd_ratio)" configuration + 2. Set breakpoints at key sections you want to monitor + 3. Run the debugger and observe variable values at each step + 4. Once learner is properly running, launch the Actor in a separate instance + 5. Watch for communication between the two + */ +} \ No newline at end of file diff --git a/examples/async_sac_state_sim/async_sac_state_sim.py b/examples/async_sac_state_sim/async_sac_state_sim.py index 90a1acf8..a5ad5657 100644 --- a/examples/async_sac_state_sim/async_sac_state_sim.py +++ b/examples/async_sac_state_sim/async_sac_state_sim.py @@ -27,16 +27,23 @@ import franka_sim +# from demos.demoHandling import DemoHandling + FLAGS = flags.FLAGS flags.DEFINE_string("env", "HalfCheetah-v4", "Name of environment.") flags.DEFINE_string("agent", "sac", "Name of agent.") flags.DEFINE_string("exp_name", None, "Name of the experiment for wandb logging.") +flags.DEFINE_string("run_name", None, "Name of run for wandb logging") flags.DEFINE_integer("max_traj_length", 100, "Maximum length of trajectory.") flags.DEFINE_integer("seed", 42, "Random seed.") flags.DEFINE_bool("save_model", False, "Whether to save model.") flags.DEFINE_integer("batch_size", 256, "Batch size.") flags.DEFINE_integer("critic_actor_ratio", 8, "critic to actor update ratio.") +flags.DEFINE_integer("port_number", 5488, "Port for server") +flags.DEFINE_integer("broadcast_port", 5489, "Port for server") +flags.DEFINE_boolean("wandb_offline", False, "Save locally to be synced with 'wandb sync ") +flags.DEFINE_string("wandb_output_dir", None, "Where to save local wandb files") flags.DEFINE_integer("max_steps", 1000000, "Maximum number of training steps.") flags.DEFINE_integer("replay_buffer_capacity", 1000000, "Replay buffer capacity.") @@ -49,7 +56,7 @@ flags.DEFINE_integer("eval_period", 2000, "Evaluation period.") flags.DEFINE_integer("eval_n_trajs", 5, "Number of trajectories for evaluation.") -# flag to indicate if this is a leaner or a actor +# flag to indicate if this is a learner or a actor flags.DEFINE_boolean("learner", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("actor", False, "Is this a learner or a trainer.") flags.DEFINE_boolean("render", False, "Render the environment.") @@ -57,14 +64,34 @@ flags.DEFINE_integer("checkpoint_period", 0, "Period to save checkpoints.") flags.DEFINE_string("checkpoint_path", None, "Path to save checkpoints.") -flags.DEFINE_boolean( - "debug", False, "Debug mode." -) # debug mode will disable wandb logging - +# flags for replay buffer +flags.DEFINE_string("replay_buffer_type", "replay_buffer", "Which replay buffer to use") +flags.DEFINE_string("branch_method", None, "Method for how many branches to generate") +flags.DEFINE_string("split_method", None, "Method for when to change number of branches generated") +flags.DEFINE_float("workspace_width", 0.5, "Workspace width in meters") +flags.DEFINE_integer("max_depth",None,"Maximum layers of depth") +flags.DEFINE_integer("starting_branch_count", None, "Initial number of branches") +flags.DEFINE_integer("branching_factor", None, "Rate of change of branches per dimension (x,y)") # For fractal_branch and fractal_contraction +flags.DEFINE_float("alpha",None,"alpha value") +flags.DEFINE_enum("disassociated_type", None, ["octahedron", "hourglass"], + "Type of disassociated fracal rollout. Octahedron: expand from min to max then contract to min," + + " Hourglass: Contract from max to min then expand to max") +flags.DEFINE_integer("min_branch_count", None, "Minimum number of branches for disassociated fractal rollout") +flags.DEFINE_integer("max_branch_count", None, "Maximum number of branches for disassociated fractal rollout") + +# Debug +flags.DEFINE_boolean("debug", False, "Debug mode.") # debug mode will disable wandb logging + +# Logging flags.DEFINE_string("log_rlds_path", None, "Path to save RLDS logs.") flags.DEFINE_string("preload_rlds_path", None, "Path to preload RLDS data.") +# Load demonstation data +flags.DEFINE_boolean("load_demos", False, "Whether to load demo dataset.") +flags.DEFINE_string("demo_dir", "/data/data/serl/demos", "Path to demo dataset.") +flags.DEFINE_string("file_name", "data_franka_reach_random_20.npz", "Name of the demo file to load.") + def print_green(x): return print("\033[92m {}\033[00m".format(x)) @@ -72,14 +99,14 @@ def print_green(x): ############################################################################## -def actor(agent: SACAgent, data_store, env, sampling_rng): +def actor(agent: SACAgent, data_store, env, sampling_rng, demos_handler=None): """ This is the actor loop, which runs when "--actor" is set to True. """ client = TrainerClient( "actor_env", FLAGS.ip, - make_trainer_config(), + make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), data_store, wait_for_server=True, ) @@ -92,8 +119,8 @@ def update_params(params): client.recv_network_callback(update_params) eval_env = gym.make(FLAGS.env) - if FLAGS.env == "PandaPickCube-v0": - eval_env = gym.wrappers.FlattenObservation(eval_env) + #if FLAGS.env == "PandaPickCube-v0": + eval_env = gym.wrappers.FlattenObservation(eval_env) ## Note!! eval_env = RecordEpisodeStatistics(eval_env) obs, _ = env.reset() @@ -102,6 +129,16 @@ def update_params(params): # training loop timer = Timer() running_return = 0.0 + + # Load demos: handler.run will insert all transition demo data into the data store. + if FLAGS.load_demos: + with timer.context("sample and step into env with loaded demos"): + + # Insert complete demonstration into the data store + print(f"Inserting {demos_handler.data['transition_ctr']} transitions into the data store.") + demos_handler.insert_data_to_buffer(data_store) + FLAGS.random_steps = 0 # Set random steps to 0 since we have demo data + # For subsequent steps, sample actions from the agent for step in tqdm.tqdm(range(FLAGS.max_steps), dynamic_ncols=True): timer.tick("total") @@ -117,30 +154,30 @@ def update_params(params): ) actions = np.asarray(jax.device_get(actions)) - # Step environment - with timer.context("step_env"): + # Step environment + with timer.context("step_env"): - next_obs, reward, done, truncated, info = env.step(actions) - next_obs = np.asarray(next_obs, dtype=np.float32) - reward = np.asarray(reward, dtype=np.float32) + next_obs, reward, done, truncated, info = env.step(actions) + next_obs = np.asarray(next_obs, dtype=np.float32) + reward = np.asarray(reward, dtype=np.float32) - running_return += reward + running_return += reward - data_store.insert( - dict( - observations=obs, - actions=actions, - next_observations=next_obs, - rewards=reward, - masks=1.0 - done, - dones=done or truncated, + data_store.insert( + dict( + observations=obs, + actions=actions, + next_observations=next_obs, + rewards=reward, + masks=1.0 - done, + dones=done or truncated, + ) ) - ) - obs = next_obs - if done or truncated: - running_return = 0.0 - obs, _ = env.reset() + obs = next_obs + if done or truncated: + running_return = 0.0 + obs, _ = env.reset() if FLAGS.render: env.render() @@ -167,21 +204,23 @@ def update_params(params): ############################################################################## - + def learner(rng, agent: SACAgent, replay_buffer, replay_iterator): """ The learner loop, which runs when "--learner" is set to True. """ # set up wandb and logging wandb_logger = make_wandb_logger( - project="serl_dev", + project=FLAGS.exp_name, + name=FLAGS.run_name, description=FLAGS.exp_name or FLAGS.env, + # wandb_output_dir=FLAGS.wandb_output_dir, debug=FLAGS.debug, + # offline=FLAGS.wandb_offline, ) # To track the step in the training loop update_steps = 0 - def stats_callback(type: str, payload: dict) -> dict: """Callback for when server receives stats request.""" assert type == "send-stats", f"Invalid request type: {type}" @@ -190,7 +229,7 @@ def stats_callback(type: str, payload: dict) -> dict: return {} # not expecting a response # Create server - server = TrainerServer(make_trainer_config(), request_callback=stats_callback) + server = TrainerServer(make_trainer_config(port_number=FLAGS.port_number, broadcast_port=FLAGS.broadcast_port), request_callback=stats_callback) server.register_data_store("actor_env", replay_buffer) server.start(threaded=True) @@ -228,7 +267,7 @@ def stats_callback(type: str, payload: dict) -> dict: batch = next(replay_iterator) with timer.context("train"): - agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.utd_ratio) + agent, update_info = agent.update_high_utd(batch, utd_ratio=FLAGS.critic_actor_ratio) agent = jax.block_until_ready(agent) # publish the updated network @@ -265,9 +304,14 @@ def main(_): env = gym.make(FLAGS.env, render_mode="human") else: env = gym.make(FLAGS.env) - - if FLAGS.env == "PandaPickCube-v0": - env = gym.wrappers.FlattenObservation(env) + + if FLAGS.env in {"PandaPickCube-v0", "PandaReachCube-v0", "PandaPickSparseCube-v0", "PandaReachSparseCube-v0"}: + x_obs_idx=np.array([0,4]) + y_obs_idx=np.array([1,5]) + else: + raise NotImplementedError(f"Unknown observation layout for {FLAGS.env}") + + env = gym.wrappers.FlattenObservation(env) rng, sampling_rng = jax.random.split(rng) agent: SACAgent = make_sac_agent( @@ -279,17 +323,56 @@ def main(_): # replicate agent across devices # need the jnp.array to avoid a bug where device_put doesn't recognize primitives agent: SACAgent = jax.device_put( - jax.tree_map(jnp.array, agent), sharding.replicate() + jax.tree.map(jnp.array, agent), sharding.replicate() ) + # Demo Data + if FLAGS.load_demos: + print_green("Setting demo parameters") + # Create a handler for the demo data + demos_handler = DemoHandling( + demo_dir=FLAGS.demo_dir, + file_name=FLAGS.file_name, + ) + + # 1. Modify actor data_store size + # Extract number of demo transitions + demo_transitions = demos_handler.get_num_transitions() + + if demo_transitions > 2000: + qds_size = demo_transitions + 1000 # Increment the queue size on the actor + else: + qds_size = 2000 # the original queue size on the actor + + # 2. Modify training starts (since we have good data) + FLAGS.training_starts = 1 + + else: + demos_handler = None + qds_size = 2000 # the original queue size on the actor + + if FLAGS.learner: sampling_rng = jax.device_put(sampling_rng, device=sharding.replicate()) replay_buffer = make_replay_buffer( env, capacity=FLAGS.replay_buffer_capacity, rlds_logger_path=FLAGS.log_rlds_path, - type="replay_buffer", + type=FLAGS.replay_buffer_type, + branch_method=FLAGS.branch_method, + split_method=FLAGS.split_method, + branching_factor=FLAGS.branching_factor, + starting_branch_count=FLAGS.starting_branch_count, + workspace_width=FLAGS.workspace_width, + max_traj_length=FLAGS.max_traj_length, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, preload_rlds_path=FLAGS.preload_rlds_path, + max_depth=FLAGS.max_depth, + alpha=FLAGS.alpha, + disassociated_type=FLAGS.disassociated_type, + min_branch_count=FLAGS.min_branch_count, + max_branch_count=FLAGS.max_branch_count, ) replay_iterator = replay_buffer.get_iterator( sample_args={ @@ -308,11 +391,20 @@ def main(_): elif FLAGS.actor: sampling_rng = jax.device_put(sampling_rng, sharding.replicate()) - data_store = QueuedDataStore(2000) # the queue size on the actor + + if FLAGS.load_demos: + print_green("loading demo data") + + # Create a data store for the actor + data_store = QueuedDataStore(qds_size) # the queue size on the actor + else: + print_green("no demo data, using empty data store") + # Create a data store for the actor + data_store = QueuedDataStore(2000) # the queue size on the actor # actor loop print_green("starting actor loop") - actor(agent, data_store, env, sampling_rng) + actor(agent, data_store, env, sampling_rng, demos_handler) else: raise NotImplementedError("Must be either a learner or an actor") diff --git a/examples/async_sac_state_sim/automated_tests.sh b/examples/async_sac_state_sim/automated_tests.sh new file mode 100644 index 00000000..185ae9bc --- /dev/null +++ b/examples/async_sac_state_sim/automated_tests.sh @@ -0,0 +1,119 @@ +#!/bin/bash + +SEEDS=$1 +WANDB_OUTPUT_DIR=~/wandb_logs +TEST="async_sac_state_sim.py" +CONDA_ENV="serl" +ENV="PandaReachCube-v0" +MAX_STEPS=10000 +TRAINING_STARTS=1000 +RANDOM_STEPS=1000 +BATCH_SIZE=256 +CRITIC_ACTOR_RATIO=8 +EXP_NAME="DISASSOCIATED_FRACTAL_TEST" +REPLAY_BUFFER_TYPE="fractal_symmetry_replay_buffer" + +BASE_ARGS="--env $ENV --exp_name $EXP_NAME --wandb_output_dir $WANDB_OUTPUT_DIR --training_starts $TRAINING_STARTS --random_steps $RANDOM_STEPS --critic_actor_ratio $CRITIC_ACTOR_RATIO --batch_size $BATCH_SIZE" +ARGS="" + +function run_test { + + for seed in $(seq 1 1 $SEEDS) + do + # OPEN_PORTS=$( comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 2 ) + # PORTS=( $OPEN_PORTS ) + # PORT_NUMBER=${PORTS[0]} + # BROADCAST_PORT=${PORTS[1]} + + # ARGS+=" --port_number $PORT_NUMBER --broadcast_port $BROADCAST_PORT" + + echo "Running constant with args: $ARGS" + tmux respawn-pane -k -t serl_session:0.1 + tmux respawn-pane -k -t serl_session:0.2 + tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --actor --max_steps 2000000000 --seed $seed $BASE_ARGS $ARGS" C-m + tmux send-keys -t serl_session:0.2 "conda activate $CONDA_ENV && bash automated_tests_helper.sh --learner --max_steps $MAX_STEPS --seed $seed $BASE_ARGS $ARGS" C-m "exit" C-m + + # Wait for learner to finish + while ! tmux capture-pane -t serl_session:0.2 -p | grep "Pane is dead" > /dev/null; + do + sleep 1 + done + echo "Finished!" + done +} + +# BASELINE TESTING + +for replay_buffer_capacity in 5000 +do + ARGS="--run_name baseline --replay_buffer_type replay_buffer --replay_buffer_capacity $replay_buffer_capacity" + run_test +done + +# CONSTANT TESTING + +for starting_branch_count in 27 +do + for workspace_width in 0.5 + do + for replay_buffer_capacity in 5000 + do + ARGS="--run_name constant-$starting_branch_count*$starting_branch_count --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'constant' --starting_branch_count $starting_branch_count" + run_test + done + done +done + + +# FRACTAL TESTING +for replay_buffer_capacity in 5000 +do + for workspace_width in 0.5 + do + for alpha in 0.1 0.5 0.9 + do + for branching_factor in 3 9 + do + for max_depth in 2 4 + do + # Fractal Expansion + ARGS="--run_name fractal_expansion-$branching_factor^$max_depth --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'fractal' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + run_test + + # Fractal Contraction + ARGS="--run_name fractal_contraction-$branching_factor^$max_depth --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'contraction' --alpha $alpha --branching_factor $branching_factor --max_depth $max_depth" + run_test + done + done + done + done +done + + +# DISASSOCIATIVE TESTING + +for replay_buffer_capacity in 5000 +do + for workspace_width in 0.5 + do + for alpha in 0.1 0.5 0.9 + do + for min_branch_count in 1 3 9 + do + for max_branch_count in 3 9 27 + do + # Disassociative (Hourglass) + ARGS="--run_name disassociative-hourglass-$min_branch_count:$max_branch_count --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'hourglass' --alpha $alpha" + run_test + + # Disassociative (Octahedron) + ARGS="--run_name disassociative-octahedron-$min_branch_count:$max_branch_count --replay_buffer_type $REPLAY_BUFFER_TYPE --replay_buffer_capacity $replay_buffer_capacity --workspace_width $workspace_width --branch_method 'disassociated' --min_branch_count $min_branch_count --max_branch_count $max_branch_count --disassociated_type 'octahedron' --alpha $alpha" + + done + done + done + done +done + + +tmux kill-window -t serl_session:$SEED diff --git a/examples/async_sac_state_sim/automated_tests_helper.sh b/examples/async_sac_state_sim/automated_tests_helper.sh new file mode 100644 index 00000000..90702a2f --- /dev/null +++ b/examples/async_sac_state_sim/automated_tests_helper.sh @@ -0,0 +1,6 @@ +#!/bin/bash + +export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ + +python async_sac_state_sim.py "$@" \ No newline at end of file diff --git a/examples/async_sac_state_sim/run_actor.sh b/examples/async_sac_state_sim/run_actor.sh index 57677916..d6397e08 100644 --- a/examples/async_sac_state_sim/run_actor.sh +++ b/examples/async_sac_state_sim/run_actor.sh @@ -1,10 +1,44 @@ +#!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ -python async_sac_state_sim.py "$@" \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + +python async_sac_state_sim.py \ --actor \ - --render \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ - --seed 0 \ + --env PandaReachCube-v0 \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 50_000 \ + --training_starts 1000 \ --random_steps 1000 \ - --debug + --critic_actor_ratio 8 \ + --batch_size 256 \ + --replay_buffer_capacity 1_000_000 \ + --save_model True \ + --branch_method constant \ + --split_method constant \ + --starting_branch_count 3 \ + --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + # --max_traj_length 100 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ + #--render diff --git a/examples/async_sac_state_sim/run_learner.sh b/examples/async_sac_state_sim/run_learner.sh index 10a203c1..7415620f 100644 --- a/examples/async_sac_state_sim/run_learner.sh +++ b/examples/async_sac_state_sim/run_learner.sh @@ -1,11 +1,43 @@ +#!/bin/bash export XLA_PYTHON_CLIENT_PREALLOCATE=false && \ -export XLA_PYTHON_CLIENT_MEM_FRACTION=.05 && \ -python async_sac_state_sim.py "$@" \ +export XLA_PYTHON_CLIENT_MEM_FRACTION=.5 && \ +export SCRIPT_DIR=$(dirname "$(realpath "$0")") && \ +export TIMESTAMP=$(date +"%m-%d-%Y-%H-%M-%S") && \ +export CHECKPOINT_DIR="/data/fsrb_testing/checkpoints-$TIMESTAMP" && \ + +# Create checkpoint directory if it doesn't exist +# if [ ! -d "$CHECKPOINT_DIR" ]; then +# echo "Creating checkpoint directory: $CHECKPOINT_DIR" +# mkdir -p "$CHECKPOINT_DIR" || { +# echo "Failed to create checkpoint directory!" >&2 +# exit 1 +# } +# fi + +python async_sac_state_sim.py "$@"\ --learner \ - --env PandaPickCube-v0 \ - --exp_name=serl_dev_sim_test \ - --seed 0 \ + --env PandaReachCube-v0 \ + --exp_name this_is_a_fake_test_experiment \ + --run_name this_is_a_custom_run_name \ + --replay_buffer_type fractal_symmetry_replay_buffer \ + --max_steps 50_000 \ --training_starts 1000 \ + --random_steps 1000 \ --critic_actor_ratio 8 \ --batch_size 256 \ - --debug # wandb is disabled when debug + --replay_buffer_capacity 1_000_000 \ + --save_model True \ + --branch_method constant \ + --split_method constant \ + --starting_branch_count 3 \ + --workspace_width 0.5 \ + --alpha 1 \ + # --debug # wandb is disabled when debug + # --load_demos \ + # --demo_dir /data/data/serl/demos \ + # --file_name data_franka_reach_random_5_2.npz \ + # --max_traj_length 100 \ + # --max_depth 4 \ + # --branching_factor 3 \ + # --checkpoint_period 10000 \ + # --checkpoint_path "$CHECKPOINT_DIR" \ \ No newline at end of file diff --git a/examples/async_sac_state_sim/tmux_launch.sh b/examples/async_sac_state_sim/tmux_launch.sh index 78ff94a8..dff31d54 100644 --- a/examples/async_sac_state_sim/tmux_launch.sh +++ b/examples/async_sac_state_sim/tmux_launch.sh @@ -1,9 +1,9 @@ #!/bin/bash -EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} +# EXAMPLE_DIR=${EXAMPLE_DIR:-"examples/async_sac_state_sim"} CONDA_ENV=${CONDA_ENV:-"serl"} -cd $EXAMPLE_DIR +# cd $EXAMPLE_DIR echo "Running from $(pwd)" # Create a new tmux session @@ -13,10 +13,10 @@ tmux new-session -d -s serl_session tmux split-window -v # Navigate to the activate the conda environment in the first pane -tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh" C-m +tmux send-keys -t serl_session:0.0 "conda activate $CONDA_ENV && bash run_actor.sh '$@'" C-m # Navigate to the activate the conda environment in the second pane -tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh" C-m +tmux send-keys -t serl_session:0.1 "conda activate $CONDA_ENV && bash run_learner.sh '$@'" C-m # Attach to the tmux session tmux attach-session -t serl_session diff --git a/examples/async_sac_state_sim/tmux_launch_tests.sh b/examples/async_sac_state_sim/tmux_launch_tests.sh new file mode 100644 index 00000000..80d47f46 --- /dev/null +++ b/examples/async_sac_state_sim/tmux_launch_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +SEEDS=5 +# Create a new tmux session +tmux new-session -d -s serl_session +tmux setw -g remain-on-exit on + +# Split the window horizontally +tmux split-window -v +tmux split-pane -h -t serl_session:0.1 + +# Navigate to the activate the conda environment in the first pane +tmux send-keys -t serl_session:0.0 "bash automated_tests.sh $SEEDS" C-m + + +# Attach to the tmux session +tmux attach-session -t serl_session + +# kill the tmux session by running the following command +# tmux kill-session -t serl_session diff --git a/franka_sim/README.md b/franka_sim/README.md index 486f0e05..05c17d43 100644 --- a/franka_sim/README.md +++ b/franka_sim/README.md @@ -8,7 +8,7 @@ It includes a state-based and a vision-based Franka lift cube task environment. - run `pip install -r requirements.txt` to install sim dependencies. # Explore the Environments -- Run `python franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. +- Run `python3 franka_sim/test/test_gym_env_human.py` to launch a display window and visualize the task. # Credits: - This simulation is initially built by [Kevin Zakka](https://kzakka.com/). @@ -20,3 +20,12 @@ It includes a state-based and a vision-based Franka lift cube task environment. export MUJOCO_GL=egl conda install -c conda-forge libstdcxx-ng ``` + +Navigation +---------- +- [Home](../README.md) +- [Overview](../docs/overview.md) +- [Installation guide](../docs/installation.md) +- [Quick start](../docs/sim_quick_start.md) +- [Run in simulation](../docs/run_sim.md) +- [Training options](../docs/sim_training.md) diff --git a/franka_sim/franka_sim/__init__.py b/franka_sim/franka_sim/__init__.py index 967e9e5d..ff6ecc2b 100644 --- a/franka_sim/franka_sim/__init__.py +++ b/franka_sim/franka_sim/__init__.py @@ -5,6 +5,7 @@ "GymRenderingSpec", ] +# Register environments from franka_sim.envs where specific classes are found in the PandaXXX.py scripts from gym.envs.registration import register register( @@ -15,6 +16,17 @@ register( id="PandaPickCubeVision-v0", entry_point="franka_sim.envs:PandaPickCubeGymEnv", - max_episode_steps=100, + max_episode_steps=200, kwargs={"image_obs": True}, ) +register( + id="PandaReachCube-v0", + entry_point="franka_sim.envs:PandaReachCubeGymEnv", + max_episode_steps=100, +) +register( + id="PandaReachSparseCube-v0", + entry_point="franka_sim.envs:PandaReachSparseCubeGymEnv", + max_episode_steps=200, +) + diff --git a/franka_sim/franka_sim/envs/__init__.py b/franka_sim/franka_sim/envs/__init__.py index 50c68828..8315e983 100644 --- a/franka_sim/franka_sim/envs/__init__.py +++ b/franka_sim/franka_sim/envs/__init__.py @@ -1,5 +1,9 @@ from franka_sim.envs.panda_pick_gym_env import PandaPickCubeGymEnv +from franka_sim.envs.panda_reach_gym_env import PandaReachCubeGymEnv +from franka_sim.envs.panda_reach_sparse_gym_env import PandaReachSparseCubeGymEnv __all__ = [ "PandaPickCubeGymEnv", + "PandaReachCubeGymEnv", + "PandaReachSparseCubeGymEnv" ] diff --git a/franka_sim/franka_sim/envs/panda_pick_gym_env.py b/franka_sim/franka_sim/envs/panda_pick_gym_env.py index f0a8d25b..c3d3324e 100644 --- a/franka_sim/franka_sim/envs/panda_pick_gym_env.py +++ b/franka_sim/franka_sim/envs/panda_pick_gym_env.py @@ -144,6 +144,9 @@ def __init__( self._viewer = MujocoRenderer( self.model, self.data, + width=960, + height=960, + camera_id=0 ) self._viewer.render(self.render_mode) @@ -226,7 +229,7 @@ def render(self): rendered_frames = [] for cam_id in self.camera_id: rendered_frames.append( - self._viewer.render(render_mode="rgb_array", camera_id=cam_id) + self._viewer.render(render_mode="rgb_array") ) return rendered_frames @@ -291,7 +294,7 @@ def _compute_reward(self) -> float: if __name__ == "__main__": env = PandaPickCubeGymEnv(render_mode="human") env.reset() - for i in range(100): + for i in range(1000): env.step(np.random.uniform(-1, 1, 4)) env.render() env.close() diff --git a/franka_sim/franka_sim/envs/panda_reach_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_gym_env.py new file mode 100644 index 00000000..b1980d7c --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_gym_env.py @@ -0,0 +1,296 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = None, + image_obs: bool = False, + demo: str = "None", + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=960, + height=960, + camera_id=0 + ) + if self.render_mode: + self._viewer.render(self.render_mode) + + self.demo = demo + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Fully open position + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + obs = self._compute_observation() + rew = self._compute_reward() + + # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + if self.demo == "franka_reach_demo": + if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + terminated = True + else: + # Check if the time limit is exceeded. + if self._time_limit is not None: + terminated = self.time_limit_exceeded() + else: + terminated = False + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 + r_close = np.exp(-20 * dist) + r_close = np.clip(r_close, 0.0, 1.0) + + return r_close + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 3)) + env.render() + env.close() diff --git a/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py new file mode 100644 index 00000000..840be6da --- /dev/null +++ b/franka_sim/franka_sim/envs/panda_reach_sparse_gym_env.py @@ -0,0 +1,304 @@ +from pathlib import Path +from typing import Any, Literal, Tuple, Dict + +import gym +import mujoco +import numpy as np +from gym import spaces + +try: + import mujoco_py +except ImportError as e: + MUJOCO_PY_IMPORT_ERROR = e +else: + MUJOCO_PY_IMPORT_ERROR = None + +from franka_sim.controllers import opspace +from franka_sim.mujoco_gym_env import GymRenderingSpec, MujocoGymEnv +from gym.envs.registration import register + +_HERE = Path(__file__).parent +_XML_PATH = _HERE / "xmls" / "arena.xml" +_PANDA_HOME = np.asarray((0, -0.785, 0, -2.35, 0, 1.57, np.pi / 4)) +_CARTESIAN_BOUNDS = np.asarray([[0.2, -0.3, 0], [0.6, 0.3, 0.5]]) +_SAMPLING_BOUNDS = np.asarray([[0.25, -0.25], [0.55, 0.25]]) + + +class PandaReachSparseCubeGymEnv(MujocoGymEnv): + metadata = {"render_modes": ["rgb_array", "human"]} + + def __init__( + self, + action_scale: np.ndarray = np.asarray([0.1, 1]), + seed: int = 0, + control_dt: float = 0.02, + physics_dt: float = 0.002, + time_limit: float = 10.0, + render_spec: GymRenderingSpec = GymRenderingSpec(), + render_mode: Literal["rgb_array", "human"] = "rgb_array", + image_obs: bool = True, + demo: str = "None", + ): + self._action_scale = action_scale + + super().__init__( + xml_path=_XML_PATH, + seed=seed, + control_dt=control_dt, + physics_dt=physics_dt, + time_limit=time_limit, + render_spec=render_spec, + ) + + self.metadata = { + "render_modes": [ + "human", + "rgb_array", + ], + "render_fps": int(np.round(1.0 / self.control_dt)), + } + + self.render_mode = render_mode + self.camera_id = (0, 1) + self.image_obs = image_obs + + # Caching. + self._panda_dof_ids = np.asarray( + [self._model.joint(f"joint{i}").id for i in range(1, 8)] + ) + self._panda_ctrl_ids = np.asarray( + [self._model.actuator(f"actuator{i}").id for i in range(1, 8)] + ) + self._gripper_ctrl_id = self._model.actuator("fingers_actuator").id + self._pinch_site_id = self._model.site("pinch").id + self._block_z = self._model.geom("block").size[2] + + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + "block_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + } + ), + } + ) + + if self.image_obs: + self.observation_space = gym.spaces.Dict( + { + "state": gym.spaces.Dict( + { + "panda/tcp_pos": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/tcp_vel": spaces.Box( + -np.inf, np.inf, shape=(3,), dtype=np.float32 + ), + "panda/gripper_pos": spaces.Box( + -np.inf, np.inf, shape=(1,), dtype=np.float32 + ), + } + ), + "images": gym.spaces.Dict( + { + "front": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + "wrist": gym.spaces.Box( + low=0, + high=255, + shape=(render_spec.height, render_spec.width, 3), + dtype=np.uint8, + ), + } + ), + } + ) + + self.action_space = gym.spaces.Box( + low=np.asarray([-1.0, -1.0, -1.0, -1.0]), + high=np.asarray([1.0, 1.0, 1.0, 1.0]), + dtype=np.float32, + ) + + # NOTE: gymnasium is used here since MujocoRenderer is not available in gym. It + # is possible to add a similar viewer feature with gym, but that can be a future TODO + from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer + + self._viewer = MujocoRenderer( + self.model, + self.data, + width=128, + height=128, + camera_id=0 + ) + self._viewer.render(self.render_mode) + + def reset( + self, seed=None, **kwargs + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Reset the environment.""" + mujoco.mj_resetData(self._model, self._data) + + # Reset arm to home position. + self._data.qpos[self._panda_dof_ids] = _PANDA_HOME + mujoco.mj_forward(self._model, self._data) + + # Reset mocap body to home position. + tcp_pos = self._data.sensor("2f85/pinch_pos").data + self._data.mocap_pos[0] = tcp_pos + + # Sample a new block position. + block_xy = np.random.uniform(*_SAMPLING_BOUNDS) + self._data.jnt("block").qpos[:3] = (*block_xy, self._block_z) + mujoco.mj_forward(self._model, self._data) + + # Cache the initial block height. + # self._z_init = self._data.sensor("block_pos").data[2] + # self._z_success = self._z_init + 0.2 + + obs = self._compute_observation() + return obs, {} + + def step( + self, action: np.ndarray + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + take a step in the environment. + Params: + action: np.ndarray + + Returns: + observation: dict[str, np.ndarray], + reward: float, + done: bool, + truncated: bool, + info: dict[str, Any] + """ + x, y, z, _ = action + + # Set the mocap position. + pos = self._data.mocap_pos[0].copy() + dpos = np.asarray([x, y, z]) * self._action_scale[0] + npos = np.clip(pos + dpos, *_CARTESIAN_BOUNDS) + self._data.mocap_pos[0] = npos + + # Set gripper grasp. + self._data.ctrl[self._gripper_ctrl_id] = 0 # Keep gripper closed. + # g = self._data.ctrl[self._gripper_ctrl_id] / 255 + # dg = grasp * self._action_scale[1] + # ng = np.clip(g + dg, 0.0, 1.0) + # self._data.ctrl[self._gripper_ctrl_id] = ng * 255 + + for _ in range(self._n_substeps): + tau = opspace( + model=self._model, + data=self._data, + site_id=self._pinch_site_id, + dof_ids=self._panda_dof_ids, + pos=self._data.mocap_pos[0], + ori=self._data.mocap_quat[0], + joint=_PANDA_HOME, + gravity_comp=True, + ) + self._data.ctrl[self._panda_ctrl_ids] = tau + mujoco.mj_step(self._model, self._data) + + # Compute observation. + obs = self._compute_observation() + + # For sparse reward we return 1 if the task is achieved, else 0. + rew = self._compute_reward() + + # # For demo reach environment, finger --- THIS SHOULD NEVER BE MERGED TO OTHER BRANCHES AFFECTING REGULAR USE OF ENVIRONMENTS. ONLY FOR DEMOS. + # # IF ACCIDENTALLY MERGED, IT WILL REDUCE PERFORMANCE OF THE AGENT. + # if self.demo == "franka_reach_demo": + # if rew >= 0.85: # Demo ERROR_THRESHOLD @ 0.3->0.675; ERROR_THRESHOLD @ 0.2->0.55 + # terminated = True + # else: + # Check if the time limit is exceeded. + + # Episode is terminated if the reward is 1.0 (i.e. the task is achieved). + terminated = (rew == 1.0) + + if self._time_limit is not None: + terminated = terminated or self.time_limit_exceeded() + + return obs, rew, terminated, False, {} + + def render(self): + rendered_frames = [] + for cam_id in self.camera_id: + rendered_frames.append( + self._viewer.render(render_mode="rgb_array") + ) + return rendered_frames + + # Helper methods. + + def _compute_observation(self) -> dict: + obs = {} + obs["state"] = {} + + tcp_pos = self._data.sensor("2f85/pinch_pos").data + obs["state"]["panda/tcp_pos"] = tcp_pos.astype(np.float32) + + tcp_vel = self._data.sensor("2f85/pinch_vel").data + obs["state"]["panda/tcp_vel"] = tcp_vel.astype(np.float32) + + gripper_pos = np.array( + self._data.ctrl[self._gripper_ctrl_id] / 255, dtype=np.float32 + ) + obs["state"]["panda/gripper_pos"] = gripper_pos + + if self.image_obs: + obs["images"] = {} + obs["images"]["front"], obs["images"]["wrist"] = self.render() + else: + block_pos = self._data.sensor("block_pos").data.astype(np.float32) + obs["state"]["block_pos"] = block_pos + + if self.render_mode == "human": + self._viewer.render(self.render_mode) + + return obs + + def _compute_reward(self) -> float: + # Get positions + block_pos = self._data.sensor("block_pos").data + tcp_pos = self._data.sensor("2f85/pinch_pos").data + + # Calculate distance + dist = np.linalg.norm(block_pos - tcp_pos) + + # Distance-based reward. Note at norm of 0.015, reward will be 0.5 + if dist < 0.015: + reward = 1.0 + else: + reward = 0.0 + + return reward + + +if __name__ == "__main__": + # Create wrapped environment + env = PandaReachSparseCubeGymEnv(render_mode="human") + env.reset() + for i in range(5000): + env.step(np.random.uniform(-1, 1, 4)) + env.render() + env.close() diff --git a/franka_sim/franka_sim/envs/xmls/arena.xml b/franka_sim/franka_sim/envs/xmls/arena.xml index e8b69cd9..5c766831 100644 --- a/franka_sim/franka_sim/envs/xmls/arena.xml +++ b/franka_sim/franka_sim/envs/xmls/arena.xml @@ -19,7 +19,7 @@ - + diff --git a/franka_sim/franka_sim/envs/xmls/franka_emika_panda/README.md b/franka_sim/franka_sim/envs/xmls/franka_emika_panda/README.md index 0eab07b7..1ae460f7 100644 --- a/franka_sim/franka_sim/envs/xmls/franka_emika_panda/README.md +++ b/franka_sim/franka_sim/envs/xmls/franka_emika_panda/README.md @@ -43,3 +43,9 @@ description](https://github.com/frankaemika/franka_ros/tree/develop/franka_descr ## License This model is released under an [Apache-2.0 License](LICENSE). + +Navigation +---------- +- [Franka Sim README](../../../../../franka_sim/README.md) +- [Home](../../../../../README.md) +- [Overview](../../../../../docs/overview.md) diff --git a/franka_sim/franka_sim/envs/xmls/robotiq_2f85/README.md b/franka_sim/franka_sim/envs/xmls/robotiq_2f85/README.md index 7a23dfb7..daf3775e 100644 --- a/franka_sim/franka_sim/envs/xmls/robotiq_2f85/README.md +++ b/franka_sim/franka_sim/envs/xmls/robotiq_2f85/README.md @@ -31,3 +31,9 @@ description](https://github.com/ros-industrial/robotiq/tree/kinetic-devel/roboti ## License This model is released under a [BSD-2-Clause License](LICENSE). + +Navigation +---------- +- [Franka Sim README](../../../../../franka_sim/README.md) +- [Home](../../../../../README.md) +- [Overview](../../../../../docs/overview.md) diff --git a/franka_sim/franka_sim/test/test_gym_env_human.py b/franka_sim/franka_sim/test/test_gym_env_human.py index 592bb24e..971a4516 100644 --- a/franka_sim/franka_sim/test/test_gym_env_human.py +++ b/franka_sim/franka_sim/test/test_gym_env_human.py @@ -6,7 +6,7 @@ from franka_sim import envs -env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) +env = envs.PandaPickCubeGymEnv(action_scale=(0.1, 1)) # or render_mode="human") action_spec = env.action_space diff --git a/franka_sim/requirements.txt b/franka_sim/requirements.txt index 1c7bae24..cbfb7970 100644 --- a/franka_sim/requirements.txt +++ b/franka_sim/requirements.txt @@ -2,6 +2,6 @@ dm_env mujoco==2.3.7 gym >= 0.26 gymnasium -dm-robotics-transformations +dm-robotics-transformations==0.9.0 imageio[ffmpeg] lxml diff --git a/serl_launcher/README.md b/serl_launcher/README.md index 83c07904..03f8e2fd 100644 --- a/serl_launcher/README.md +++ b/serl_launcher/README.md @@ -3,3 +3,11 @@ - Dependencies: `jax`, `agentlace` Code and scripts are modified from [jaxrl_m](https://github.com/dibyaghosh/jaxrl_m) or [jaxrl_m private].(https://github.com/rail-berkeley/jaxrl_minimal) + +Navigation +---------- +- [Home](../README.md) +- [Overview](../docs/overview.md) +- [Installation guide](../docs/installation.md) +- [Run in simulation](../docs/run_sim.md) +- [Run on the real robot](../docs/run_realrobot.md) diff --git a/serl_launcher/requirements.txt b/serl_launcher/requirements.txt index 86f92e95..801d773a 100644 --- a/serl_launcher/requirements.txt +++ b/serl_launcher/requirements.txt @@ -1,14 +1,16 @@ +opencv-python<4.12.0.88 gym >= 0.26 -numpy>=1.24.3 -flax>=0.8.0 +numpy>=1.24.3, < 2.0.0 +flax>=0.8.0, < 0.10.6 distrax>=0.1.2 +matplotlib == 3.10.9 ml_collections >= 0.1.0 tqdm >= 4.60.0 chex>=0.1.85 optax>=0.1.5 orbax-checkpoint>=0.5.10 absl-py >= 0.12.0 -scipy==1.11.4 +scipy>=1.11.4 wandb >= 0.12.14 tensorflow>=2.16.0 tensorflow_probability>=0.24.0 diff --git a/serl_launcher/serl_launcher/agents/continuous/drq.py b/serl_launcher/serl_launcher/agents/continuous/drq.py index 9b2b690b..39b043ff 100644 --- a/serl_launcher/serl_launcher/agents/continuous/drq.py +++ b/serl_launcher/serl_launcher/agents/continuous/drq.py @@ -21,6 +21,13 @@ class DrQAgent(SACAgent): + """SAC-style agent specialized for pixel observations with DrQ augmentation. + + DrQ (Data-regularized Q-learning) improves visual RL by applying image + augmentations during training so the critic learns features that are less + sensitive to small visual changes (for example camera jitter or lighting). + """ + @classmethod def create( cls, @@ -51,6 +58,36 @@ def create( critic_subsample_size: Optional[int] = None, image_keys: Iterable[str] = ("image",), ): + """Construct a DrQAgent from already-built actor/critic/temperature modules. + + This method is the low-level constructor used after network definitions + are prepared. It initializes parameters, optimizers, and target network + state, then stores algorithm hyperparameters in ``config``. + + Args: + rng: JAX random key used for parameter/state initialization. + observations: Example observation batch used to initialize module + shapes. + actions: Example action batch used to initialize critic/action shapes. + actor_def: Actor network module. + critic_def: Critic network module. + temperature_def: Learnable temperature module for entropy weighting. + actor_optimizer_kwargs: Optimizer settings for actor updates. + critic_optimizer_kwargs: Optimizer settings for critic updates. + temperature_optimizer_kwargs: Optimizer settings for temperature. + discount: TD discount factor. + soft_target_update_rate: Polyak averaging factor for target critic. + target_entropy: Target policy entropy. If ``None``, uses a default + based on action dimension. + entropy_per_dim: Unsupported in this class (kept for API parity). + backup_entropy: Whether to include entropy term in critic target. + critic_ensemble_size: Number of critic heads in the ensemble. + critic_subsample_size: Optional REDQ-style subset size for target min. + image_keys: Observation keys treated as pixel tensors. + + Returns: + Initialized ``DrQAgent`` with online and target parameters. + """ networks = { "actor": actor_def, "critic": critic_def, @@ -127,13 +164,47 @@ def create_drq( image_keys: Iterable[str] = ("image",), **kwargs, ): + """Build and initialize a pixel-based DrQ agent from high-level choices. + + This helper chooses an image encoder family, wraps image/proprio inputs, + constructs actor and critic networks, and then calls ``create`` to + initialize training state. + + Designed for students: + - Think of this as an "agent factory" that wires all pieces together. + - ``encoder_type`` picks how images become feature vectors. + - The critic is an ensemble (multiple Q-networks) for stability. + - DrQ augmentation is configured later in ``data_augmentation_fn`` and + used during updates. + + Args: + rng: Random key used for model initialization. + observations: Example observation used for shape inference. + actions: Example action batch used for shape inference. + encoder_type: One of ``"small"``, ``"resnet"``, or + ``"resnet-pretrained"``. + shared_encoder: Kept for API compatibility; current implementation + uses a shared encoder instance for actor and critic. + use_proprio: If ``True``, include non-image state features. + critic_network_kwargs: MLP kwargs for critic backbone. + policy_network_kwargs: MLP kwargs for policy backbone. + policy_kwargs: Extra kwargs forwarded to ``Policy``. + critic_ensemble_size: Number of critic ensemble heads. + critic_subsample_size: Optional number of heads sampled for targets. + temperature_init: Initial value for entropy temperature. + image_keys: Keys in observation dict that contain image stacks. + **kwargs: Forwarded to ``create`` (e.g., discount, optimizers). + + Returns: + A fully initialized ``DrQAgent`` ready for training. """ - Create a new pixel-based agent, with no encoders. - """ - + # Ensure the final MLP layer is active so the heads can freely shape + # outputs before final linear projections in Policy/Critic modules. policy_network_kwargs["activate_final"] = True critic_network_kwargs["activate_final"] = True + # Choose the visual encoder architecture. + # Each image key gets its own named encoder module. if encoder_type == "small": from serl_launcher.vision.small_encoders import SmallEncoder @@ -151,6 +222,7 @@ def create_drq( for image_key in image_keys } elif encoder_type == "resnet": + # Train a lightweight ResNet-10 style encoder from scratch. from serl_launcher.vision.resnet_v1 import resnetv1_configs encoders = { @@ -163,11 +235,13 @@ def create_drq( for image_key in image_keys } elif encoder_type == "resnet-pretrained": + # Use a frozen pretrained ResNet trunk, then task-specific pooling/head. from serl_launcher.vision.resnet_v1 import ( PreTrainedResNetEncoder, resnetv1_configs, ) + # Shared frozen feature extractor used by every image stream. pretrained_encoder = resnetv1_configs["resnetv1-10-frozen"]( pre_pooling=True, name="pretrained_encoder", @@ -185,6 +259,10 @@ def create_drq( else: raise NotImplementedError(f"Unknown encoder type: {encoder_type}") + # Wrap image encoders so the model can: + # - merge multiple camera streams + # - optionally concatenate proprioception + # - handle frame-stacked inputs consistently encoder_def = EncodingWrapper( encoder=encoders, use_proprio=use_proprio, @@ -192,20 +270,24 @@ def create_drq( image_keys=image_keys, ) + # Actor and critic currently share the same encoder wrapper instance. encoders = { "critic": encoder_def, "actor": encoder_def, } - # Define networks + # Build critic backbone and replicate it into an ensemble. + # Ensemble critics reduce overestimation and improve stability. critic_backbone = partial(MLP, **critic_network_kwargs) critic_backbone = ensemblize(critic_backbone, critic_ensemble_size)( name="critic_ensemble" ) + # Critic consumes encoded observations + actions. critic_def = partial( Critic, encoder=encoders["critic"], network=critic_backbone )(name="critic") + # Policy consumes encoded observations and outputs an action distribution. policy_def = Policy( encoder=encoders["actor"], network=MLP(**policy_network_kwargs), @@ -214,6 +296,7 @@ def create_drq( name="actor", ) + # Entropy temperature alpha (>= 0) is learned with a constrained module. temperature_def = GeqLagrangeMultiplier( init_value=temperature_init, constraint_shape=(), @@ -221,6 +304,7 @@ def create_drq( name="temperature", ) + # Delegate common initialization (params, optimizers, target params). agent = cls.create( rng, observations, @@ -234,7 +318,8 @@ def create_drq( **kwargs, ) - if encoder_type == "resnet-pretrained": # load pretrained weights for ResNet-10 + # For pretrained mode, load frozen ResNet-10 weights after module setup. + if encoder_type == "resnet-pretrained": from serl_launcher.utils.train_utils import load_resnet10_params agent = load_resnet10_params(agent, image_keys) @@ -242,7 +327,31 @@ def create_drq( return agent def data_augmentation_fn(self, rng, observations): + """Apply DrQ image augmentation (random crop) to each configured image key. + + The augmentation is applied to pixel observations only; non-image fields + in ``observations`` are preserved. This regularizes visual features so + Q-learning is less sensitive to small image-level perturbations. + + Args: + rng: Random key controlling crop offsets. + observations: Observation dict/FrozenDict containing pixel stacks. + + Returns: + A copy of ``observations`` where each image tensor in ``image_keys`` + is replaced by a randomly cropped version. + """ + # Iterate over every configured camera/image input. for pixel_key in self.config["image_keys"]: + # Replace this key with an augmented tensor while leaving all other + # observation entries unchanged. + # + # padding=4: + # First pad the image by 4 pixels on each side (edge padding), then + # crop back to original size at a random offset. + # num_batch_dims=2: + # Data is usually [batch, stack, H, W, C], so we treat the first + # two axes as batch-like dimensions when vectorizing crops. observations = observations.copy( add_or_replace={ pixel_key: batched_random_crop( @@ -260,16 +369,22 @@ def update_high_utd( utd_ratio: int, pmap_axis: Optional[str] = None, ) -> Tuple["DrQAgent", dict]: - """ - Fast JITted high-UTD version of `.update`. - - Splits the batch into minibatches, performs `utd_ratio` critic - (and target) updates, and then one actor/temperature update. - - Batch dimension must be divisible by `utd_ratio`. - - It also performs data augmentation on the observations and next_observations - before updating the network. + """JIT-compiled high-UTD training step with DrQ augmentation. + + Steps: + 1. Ensure packed replay samples are unpacked when needed. + 2. Apply random-crop augmentation to current and next observations. + 3. Run the parent SAC high-UTD update: + - multiple critic updates + - one actor/temperature update + + Args: + batch: Training batch sampled from replay. + utd_ratio: Number of critic updates per actor update. + pmap_axis: Optional PMAP axis name for distributed reductions. + + Returns: + Tuple of updated agent and merged logging metrics. """ new_agent = self if self.config["image_keys"][0] not in batch["next_observations"]: @@ -300,6 +415,19 @@ def update_critics( *, pmap_axis: Optional[str] = None, ) -> Tuple["DrQAgent", dict]: + """Run a critic-only update with DrQ augmentation. + + This method is useful when training with a critic:actor update ratio + greater than 1. It augments observations, updates only critic parameters + (and target critic via parent update logic), and returns critic metrics. + + Args: + batch: Training batch from replay (packed or unpacked). + pmap_axis: Optional PMAP axis name for distributed reductions. + + Returns: + Tuple of updated agent and critic-only logging info. + """ new_agent = self if self.config["image_keys"][0] not in batch["next_observations"]: batch = _unpack(batch) diff --git a/serl_launcher/serl_launcher/agents/continuous/sac.py b/serl_launcher/serl_launcher/agents/continuous/sac.py index ca933db4..dd723924 100644 --- a/serl_launcher/serl_launcher/agents/continuous/sac.py +++ b/serl_launcher/serl_launcher/agents/continuous/sac.py @@ -575,11 +575,11 @@ def scan_body(carry: Tuple[SACAgent], data: Tuple[Batch]): def make_minibatch(data: jnp.ndarray): return jnp.reshape(data, (utd_ratio, minibatch_size) + data.shape[1:]) - minibatches = jax.tree_map(make_minibatch, batch) + minibatches = jax.tree.map(make_minibatch, batch) (agent,), critic_infos = jax.lax.scan(scan_body, (self,), (minibatches,)) - critic_infos = jax.tree_map(lambda x: jnp.mean(x, axis=0), critic_infos) + critic_infos = jax.tree.map(lambda x: jnp.mean(x, axis=0), critic_infos) del critic_infos["actor"] del critic_infos["temperature"] diff --git a/serl_launcher/serl_launcher/common/common.py b/serl_launcher/serl_launcher/common/common.py index a3a53e74..1e80fa06 100644 --- a/serl_launcher/serl_launcher/common/common.py +++ b/serl_launcher/serl_launcher/common/common.py @@ -22,7 +22,7 @@ def shard_batch(batch, sharding): batch: A pytree of arrays. sharding: A jax Sharding object with shape (num_devices,). """ - return jax.tree_map( + return jax.tree.map( lambda x: jax.device_put( x, sharding.reshape(sharding.shape[0], *((1,) * (x.ndim - 1))) ), @@ -115,7 +115,7 @@ class JaxRLTrainState(struct.PyTreeNode): @staticmethod def _tx_tree_map(*args, **kwargs): - return jax.tree_map( + return jax.tree.map( *args, is_leaf=lambda x: isinstance(x, optax.GradientTransformation), **kwargs, @@ -128,7 +128,7 @@ def target_update(self, tau: float) -> "JaxRLTrainState": new_target_params = tau * params + (1 - tau) * target_params """ - new_target_params = jax.tree_map( + new_target_params = jax.tree.map( lambda p, tp: p * tau + tp * (1 - tau), self.params, self.target_params ) return self.replace(target_params=new_target_params) @@ -158,7 +158,7 @@ def apply_gradients(self, *, grads: Any) -> "JaxRLTrainState": ) # apply all the updates additively - updates_acc = jax.tree_map( + updates_acc = jax.tree.map( lambda *xs: jnp.sum(jnp.array(xs), axis=0), *updates_flat ) new_params = optax.apply_updates(self.params, updates_acc) @@ -200,7 +200,7 @@ def apply_loss_fns( rngs = jax.tree_util.tree_unflatten(treedef, rngs) # compute gradients - grads_and_aux = jax.tree_map( + grads_and_aux = jax.tree.map( lambda loss_fn, rng: jax.grad(loss_fn, has_aux=has_aux)(self.params, rng), loss_fns, rngs, @@ -214,8 +214,8 @@ def apply_loss_fns( grads_and_aux = jax.lax.pmean(grads_and_aux, axis_name=pmap_axis) if has_aux: - grads = jax.tree_map(lambda _, x: x[0], loss_fns, grads_and_aux) - aux = jax.tree_map(lambda _, x: x[1], loss_fns, grads_and_aux) + grads = jax.tree.map(lambda _, x: x[0], loss_fns, grads_and_aux) + aux = jax.tree.map(lambda _, x: x[1], loss_fns, grads_and_aux) return self.apply_gradients(grads=grads), aux else: return self.apply_gradients(grads=grads_and_aux) diff --git a/serl_launcher/serl_launcher/common/wandb.py b/serl_launcher/serl_launcher/common/wandb.py index 2ef341aa..4b27cb9e 100644 --- a/serl_launcher/serl_launcher/common/wandb.py +++ b/serl_launcher/serl_launcher/common/wandb.py @@ -6,6 +6,7 @@ import absl.flags as flags import ml_collections import wandb +import uuid def _recursive_flatten_dict(d: dict): @@ -41,12 +42,13 @@ def __init__( variant, wandb_output_dir=None, debug=False, + offline=False, ): self.config = wandb_config if self.config.unique_identifier == "": self.config.unique_identifier = datetime.datetime.now().strftime( - "%Y%m%d_%H%M%S" - ) + "%m-%d-%Y" + ) + str(uuid.uuid1()) self.config.experiment_id = ( self.experiment_id @@ -65,11 +67,15 @@ def __init__( if debug: mode = "disabled" else: - mode = "online" + if offline: + mode = "offline" + else: + mode = "online" self.run = wandb.init( config=self._variant, project=self.config.project, + name=self.config.name, entity=self.config.entity, group=self.config.group, tags=self.config.tag, diff --git a/serl_launcher/serl_launcher/data/data_store.py b/serl_launcher/serl_launcher/data/data_store.py index 20e65813..6503fdc9 100644 --- a/serl_launcher/serl_launcher/data/data_store.py +++ b/serl_launcher/serl_launcher/data/data_store.py @@ -7,6 +7,9 @@ from serl_launcher.data.memory_efficient_replay_buffer import ( MemoryEfficientReplayBuffer, ) +from serl_launcher.data.fractal_symmetry_replay_buffer import ( + FractalSymmetryReplayBuffer +) from agentlace.data.data_store import DataStoreBase @@ -112,8 +115,6 @@ def insert(self, data): RLDSStepType.TRUNCATION, }: self.step_type = RLDSStepType.RESTART - elif self.step_type == RLDSStepType.TRUNCATION: - self.step_type = RLDSStepType.RESTART elif not data["masks"]: # 0 is done, 1 is not done self.step_type = RLDSStepType.TERMINATION elif data["dones"]: @@ -143,6 +144,69 @@ def latest_data_id(self): def get_latest_data(self, from_id: int): raise NotImplementedError # TODO +class FractalSymmetryReplayBufferDataStore(FractalSymmetryReplayBuffer, DataStoreBase): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + workspace_width: int = None, + x_obs_idx = None, + y_obs_idx = None, + branch_method: str = None, + split_method: str = None, + rlds_logger: Optional[RLDSLogger] = None, + image_keys: Iterable[str] = ("image",), + **kwargs: dict, + ): + FractalSymmetryReplayBuffer.__init__(self, observation_space, action_space, capacity, workspace_width, x_obs_idx, y_obs_idx, branch_method, split_method, img_keys=image_keys, **kwargs) + DataStoreBase.__init__(self, capacity) + self._lock = Lock() + self._logger = None + + if rlds_logger: + self.step_type = RLDSStepType.TERMINATION # to init the state for restart + self._logger = rlds_logger + + # ensure thread safety + def insert(self, data): + with self._lock: + super(FractalSymmetryReplayBufferDataStore, self).insert(data) + + # TODO: Data logging currently does NOT WORK as shown if we want to log our transformed transitions + # add data to the rlds logger + if self._logger: + if self.step_type in { + RLDSStepType.TERMINATION, + RLDSStepType.TRUNCATION, + }: + self.step_type = RLDSStepType.RESTART + elif not data["masks"]: # 0 is done, 1 is not done + self.step_type = RLDSStepType.TERMINATION + elif data["dones"]: + self.step_type = RLDSStepType.TRUNCATION + else: + self.step_type = RLDSStepType.TRANSITION + + self._logger( + action=data["actions"], + obs=data["next_observations"], # TODO: check if this is correct + reward=data["rewards"], + step_type=self.step_type, + ) + + # ensure thread safety + def sample(self, *args, **kwargs): + with self._lock: + return super(FractalSymmetryReplayBufferDataStore, self).sample(*args, **kwargs) + + # NOTE: method for DataStoreBase + def latest_data_id(self): + return self._insert_index + + # NOTE: method for DataStoreBase + def get_latest_data(self, from_id: int): + raise NotImplementedError # TODO def populate_data_store( data_store: DataStoreBase, @@ -162,7 +226,6 @@ def populate_data_store( print(f"Loaded {len(data_store)} transitions.") return data_store - def populate_data_store_with_z_axis_only( data_store: DataStoreBase, demos_path: str, diff --git a/serl_launcher/serl_launcher/data/dataset.py b/serl_launcher/serl_launcher/data/dataset.py index 0760fb02..2daba136 100644 --- a/serl_launcher/serl_launcher/data/dataset.py +++ b/serl_launcher/serl_launcher/data/dataset.py @@ -1,45 +1,99 @@ -from functools import partial +from functools import partial # unused. from typing import Dict, Iterable, Optional, Tuple, Union import jax import jax.numpy as jnp import numpy as np + +# frozen_dict is an immutable & nested dictionary-like structure to manage parameters and states in NNs. Key in Flax, model params passed explicitly vs stored in mutable objects. from flax.core import frozen_dict from gym.utils import seeding +# Nested data structures where the leaves are NumPy arrays, and internal nodes are dictionaries with string keys. DataType = Union[np.ndarray, Dict[str, "DataType"]] DatasetDict = Dict[str, DataType] +# Utility functions to check lengths and subselect data from a dataset dictionary. def _check_lengths(dataset_dict: DatasetDict, dataset_len: Optional[int] = None) -> int: + """ + Check the lengths of items in a dataset dictionary. + + Upon initializing a Dataset, _check_lengths is invoked to assert that all data arrays are of equal length. This is critical because: + The Dataset assumes a uniform length across features to support indexing, sampling, and batching. + Inconsistent lengths would lead to silent errors or runtime failures during sampling, model training, or evaluation. + + If all items are of the same length, return that length. + If items are of different lengths, raise an assertion error. + If the dataset is empty, return 0. + If the dataset is not a dictionary, raise a TypeError. + Args: + dataset_dict (DatasetDict): The dataset dictionary to check. + dataset_len (Optional[int]): The length to compare against, if provided. + Returns: + int: The length of the dataset if all items are of the same length. + Raises: + TypeError: If the dataset is not a dictionary or contains unsupported types. + """ + for v in dataset_dict.values(): if isinstance(v, dict): dataset_len = dataset_len or _check_lengths(v, dataset_len) + elif isinstance(v, np.ndarray): item_len = len(v) dataset_len = dataset_len or item_len assert dataset_len == item_len, "Inconsistent item lengths in the dataset." + else: raise TypeError("Unsupported type.") return dataset_len def _subselect(dataset_dict: DatasetDict, index: np.ndarray) -> DatasetDict: + """ + Subselect enables flexible, consistent indexing into complex datasets to either split or filter data. + It is especially important when working with nested dictionary-based datasets β€” a common structure in modern machine learning. + Our dataset will be deeply structure and we want indexing to be applied consistently at every depth. + Used by Dataset.split() and Dataset.filter() methods to extract subsets of data based on indices. + + Args: + dataset_dict (DatasetDict): The dataset dictionary to subselect from. + index (np.ndarray): The indices to select from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items selected based on the index. + Raises: + TypeError: If the dataset contains unsupported types. + """ + + new_dataset_dict = {} for k, v in dataset_dict.items(): if isinstance(v, dict): new_v = _subselect(v, index) + elif isinstance(v, np.ndarray): new_v = v[index] + else: raise TypeError("Unsupported type.") + new_dataset_dict[k] = new_v return new_dataset_dict -def _sample( - dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray -) -> DatasetDict: +def _sample(dataset_dict: Union[np.ndarray, DatasetDict], indx: np.ndarray) -> DatasetDict: + """ + This function is used to extract a subset of data from the dataset dictionary, which can be either a NumPy array or a nested dictionary structure. + Args: + dataset_dict (Union[np.ndarray, DatasetDict]): The dataset dictionary or array to sample from. + indx (np.ndarray): The indices to sample from the dataset. + Returns: + DatasetDict: A new dataset dictionary with items sampled based on the indices. + Raises: + TypeError: If the dataset is not a NumPy array or a dictionary. + """ + if isinstance(dataset_dict, np.ndarray): return dataset_dict[indx] elif isinstance(dataset_dict, dict): @@ -52,7 +106,11 @@ def _sample( class Dataset(object): - def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): + + def __init__(self, + dataset_dict: DatasetDict, + seed: Optional[int] = None ): + self.dataset_dict = dataset_dict self.dataset_len = _check_lengths(dataset_dict) @@ -63,6 +121,7 @@ def __init__(self, dataset_dict: DatasetDict, seed: Optional[int] = None): if seed is not None: self.seed(seed) + # @property decorator is used here to expose np_random as a read-only attribute @property def np_random(self) -> np.random.RandomState: if self._np_random is None: @@ -70,20 +129,43 @@ def np_random(self) -> np.random.RandomState: return self._np_random def seed(self, seed: Optional[int] = None) -> list: + """ + Set the random seed for reproducibility. Ensures a valid RNG is encapsulated behind an attribute-like interface. + Users do not need to call self.seed() or self.get_np_random() explicitly. Just self.np_random above. + + Args: + seed (Optional[int]): The seed to set. If None, a random seed will be generated. + Returns: + list: A list containing the seed used. + """ self._np_random, self._seed = seeding.np_random(seed) return [self._seed] def __len__(self) -> int: return self.dataset_len - def sample( - self, - batch_size: int, - keys: Optional[Iterable[str]] = None, - indx: Optional[np.ndarray] = None, - ) -> frozen_dict.FrozenDict: + def sample(self, + batch_size: int, + keys: Optional[Iterable[str]] = None, + indx: Optional[np.ndarray] = None, + ) -> frozen_dict.FrozenDict: + """ + Sample a random batch of data from the dataset. + + This method allows for flexible sampling of data, either by specifying keys or using random indices. + This is useful for training models, where you might want to sample a batch of data points from a larger dataset. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + indx (Optional[np.ndarray]): Specific indices to sample from. If None, random indices will be generated. + Returns: + frozen_dict.FrozenDict: A frozen dictionary containing the sampled data. + """ + if indx is None: if hasattr(self.np_random, "integers"): + + # Generate batch_size num of rand ints, each sampled uniformly at random from the range [0, len(self) - 1]. indx = self.np_random.integers(len(self), size=batch_size) else: indx = self.np_random.randint(len(self), size=batch_size) @@ -101,7 +183,17 @@ def sample( return frozen_dict.freeze(batch) - def sample_jax(self, batch_size: int, keys: Optional[Iterable[str]] = None): + def sample_jax(self, + batch_size: int, + keys: Optional[Iterable[str]] = None): + """ + Sample a batch of data from the dataset using JAX. This method is optimized for performance and can be used in JAX-based training loops. + Args: + batch_size (int): The number of samples to return. + keys (Optional[Iterable[str]]): The keys to sample from the dataset. If None, all keys will be sampled. + Returns: + Tuple[int, frozen_dict.FrozenDict]: A tuple containing the maximum index sampled and a frozen dictionary with the sampled data. + """ if not hasattr(self, "rng"): self.rng = jax.random.PRNGKey(self._seed or 42) @@ -118,7 +210,7 @@ def _sample_jax(rng, src, max_indx: int): return ( rng, indx.max(), - jax.tree_map(lambda d: jnp.take(d, indx, axis=0), src), + jax.tree.map(lambda d: jnp.take(d, indx, axis=0), src), ) self._sample_jax = _sample_jax @@ -129,12 +221,25 @@ def _sample_jax(rng, src, max_indx: int): return indx_max, sample def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: + """ + Split the dataset into two parts based on the given ratio. The first part will contain a fraction of the dataset specified by the ratio, and the second part will contain the rest. + This method is useful for creating training and testing datasets, where you want to split the data into two parts for model evaluation. + Args: + ratio (float): The fraction of the dataset to include in the first part. Must be between 0 and 1. + Returns: + Tuple[Dataset, Dataset]: A tuple containing two Dataset objects, the first part and the second part of the split dataset. + Raises: + AssertionError: If the ratio is not between 0 and 1. + """ + assert 0 < ratio and ratio < 1 - train_index = np.index_exp[: int(self.dataset_len * ratio)] - test_index = np.index_exp[int(self.dataset_len * ratio) :] + train_index = np.index_exp[: int(self.dataset_len * ratio)] # First part of the dataset. + test_index = np.index_exp[int(self.dataset_len * ratio) :] # Second part of the dataset. + # Shuffle the indices to ensure random sampling. index = np.arange(len(self), dtype=np.int32) self.np_random.shuffle(index) + train_index = index[: int(self.dataset_len * ratio)] test_index = index[int(self.dataset_len * ratio) :] @@ -143,53 +248,102 @@ def split(self, ratio: float) -> Tuple["Dataset", "Dataset"]: return Dataset(train_dataset_dict), Dataset(test_dataset_dict) def _trajectory_boundaries_and_returns(self) -> Tuple[list, list, list]: + """ + This method computes the boundaries of episodes in the dataset and calculates the returns for each episode. + It identifies the start and end indices of each episode based on the 'dones' array in the dataset. + The returns for each episode are calculated by summing the rewards within each episode. + This is useful for reinforcement learning tasks where episodes are defined by sequences of states, actions, and rewards. + Returns: + Tuple[list, list, list]: A tuple containing three lists: + - episode_starts: The starting indices of each episode. + - episode_ends: The ending indices of each episode. + - episode_returns: The total returns for each episode. + """ + + # Initialize lists (note plural) to store episode boundaries and returns. episode_starts = [0] episode_ends = [] + # Initialize variables to track the current episode return and a list to store returns. episode_return = 0 episode_returns = [] + # Iterate through the dataset to find episode boundaries and calculate returns. + # The dataset_dict is expected to have 'rewards' and 'dones' keys. for i in range(len(self)): episode_return += self.dataset_dict["rewards"][i] + # If the current index indicates the end of an episode, store the return and update boundaries. if self.dataset_dict["dones"][i]: episode_returns.append(episode_return) - episode_ends.append(i + 1) + episode_ends.append(i + 1) # Store the end index of the episode including the current index. + + # If this is not the last episode, set the start of the next episode. if i + 1 < len(self): episode_starts.append(i + 1) episode_return = 0.0 return episode_starts, episode_ends, episode_returns - def filter( - self, take_top: Optional[float] = None, threshold: Optional[float] = None + def filter(self, + take_top: Optional[float] = None, + threshold: Optional[float] = None ): + """ + Filter the dataset based on episode returns. This method allows you to keep only the episodes that meet a certain return threshold or are among the top returns. + This is useful for focusing on high-performing episodes in reinforcement learning tasks. + Args: + take_top (Optional[float]): If specified, keep only the top N percent of episodes based on their returns. + threshold (Optional[float]): If specified, keep only the episodes with returns greater than or equal to this value. + Raises: + AssertionError: If both take_top and threshold are specified, or if neither is specified. + """ assert (take_top is None and threshold is not None) or ( take_top is not None and threshold is None ) + # Create a tupe of lists of episode boundaries and returns. ( episode_starts, episode_ends, episode_returns, ) = self._trajectory_boundaries_and_returns() + # If no threshold is specified, calculate it based on the top N percent of returns. if take_top is not None: + # np.percentile gives the value below which XX% of the data lies threshold = np.percentile(episode_returns, 100 - take_top) + # create a boolean index array to filter episodes based on the threshold. bool_indx = np.full((len(self),), False, dtype=bool) for i in range(len(episode_returns)): if episode_returns[i] >= threshold: bool_indx[episode_starts[i] : episode_ends[i]] = True + # Return a new dataset dictionary containing only the episodes that meet the threshold. self.dataset_dict = _subselect(self.dataset_dict, bool_indx) + # Update the dataset length after filtering. self.dataset_len = _check_lengths(self.dataset_dict) def normalize_returns(self, scaling: float = 1000): + """ + Normalize the returns in the dataset to a specified scaling factor. This is useful for stabilizing training in reinforcement learning tasks. + Normally done per batch of episodes before training a model to update the policy. + + Args: + scaling (float): The scaling factor to normalize the returns. Default is 1000. + Raises: + AssertionError: If the dataset does not contain 'rewards' or 'dones' keys. + """ + + # Extract episode returns (_, _, episode_returns) = self._trajectory_boundaries_and_returns() - self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min( - episode_returns - ) + + # Normalize rewards in the dataset from 0-1 by dividing by the max range of returns. + self.dataset_dict["rewards"] /= np.max(episode_returns) - np.min(episode_returns) + + # Scale rewards. Note that large scaling factors can lead to numerical instability, small scaleing factors can lead to poor learning. + # Scaling allows you to control the learning dynamics. self.dataset_dict["rewards"] *= scaling diff --git a/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py new file mode 100644 index 00000000..449ebebe --- /dev/null +++ b/serl_launcher/serl_launcher/data/fractal_symmetry_replay_buffer.py @@ -0,0 +1,440 @@ +import copy +from typing import Iterable, Optional + +import gym +import numpy as np +from serl_launcher.data.dataset import DatasetDict, _sample +from serl_launcher.data.replay_buffer import ReplayBuffer +from flax.core import frozen_dict + +class FractalSymmetryReplayBuffer(ReplayBuffer): + def __init__( + self, + observation_space: gym.Space, + action_space: gym.Space, + capacity: int, + workspace_width: int, + x_obs_idx : np.ndarray, + y_obs_idx : np.ndarray, + branch_method: str, + split_method: str, + img_keys: list, + kwargs: dict, + ): + + # Initialize values + self.debug_time = True + self.current_branch_count = 1 + self.update_max_traj_length = False + self.workspace_width = workspace_width + self.img_keys = img_keys + self._img_insert_index_ = 0 + + # Set the idx value (changes depending on environment/wrapper) of the x and y observations and next_observations + self.x_obs_idx = x_obs_idx + self.y_obs_idx = y_obs_idx + + # Set initial fractal config values + self.timestep = 0 + self.current_depth = 0 + + self.split_method = split_method + self.branch_method = branch_method + + self._handle_methods_(kwargs) + + # Warn about unused kwargs + for k in kwargs.keys(): + print(f"\033[33mWARNING \033[0m argument \"{k}\" not used") + + # Account for images + self._num_stack = None + observation_space = copy.deepcopy(observation_space) + next_observation_space = None + if self.img_keys: + self.img_buffer = {} + next_observation_space_dict = copy.deepcopy(observation_space.spaces) + for k in img_keys: + img_obs_space = observation_space.spaces[k] + if self._num_stack is None: + self._num_stack = img_obs_space.shape[0] + self.img_buffer_size = ((self.expected_branches + capacity - 1) // self.expected_branches) * (self._num_stack + 1) + buffer_shape = list(img_obs_space.shape[1:]) + buffer_shape.insert(0, self.img_buffer_size) + self.img_buffer[k] = np.empty(buffer_shape, img_obs_space.dtype) + + observation_space.spaces[k] = gym.spaces.Box(low=float('-inf'), high=float('inf'), shape=(), dtype=np.int32) + next_observation_space_dict.pop(k) + next_observation_space = gym.spaces.Dict(next_observation_space_dict) + + + # Init replay buffer class + super().__init__( + observation_space=observation_space, + next_observation_space=next_observation_space, + action_space=action_space, + capacity=capacity * self.expected_branches, + ) + + self.generate_transform_deltas() + + def _handle_method_arg_(self, value, method_type, method, kwargs): + if hasattr(self, value): + return + assert value in kwargs.keys(), f"\033[31mERROR: \033[0m{value} must be defined for {method_type} \"{method}\"" + setattr(self, value, kwargs[value]) + del kwargs[value] + + def _handle_methods_(self, kwargs): + + # Initialize branch_method + match self.branch_method: + case "fractal": + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) + + self.branch = self.fractal_branch + if not self.split_method: + self.split_method = "time" + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 + + case "contraction": + self._handle_method_arg_("max_depth", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("branching_factor", "branch_method", self.branch_method, kwargs) + + self.branch = self.fractal_contraction + if not self.split_method: + self.split_method = "time" + self.expected_branches = (self.branching_factor ** self.max_depth) ** 2 + + case "linear": + raise NotImplementedError("linear branch method is not yet implemented") + # self.branch = self.linear_branch + + + case "disassociated": + self._handle_method_arg_("min_branch_count", "branch_method", self.branch_method, kwargs) + self._handle_method_arg_("max_branch_count", "branch_method", self.branch_method, kwargs) + + if self.min_branch_count > self.max_branch_count: + raise ValueError(f"min_branch_count ({self.min_branch_count}) is larger than max_branch_count ({self.max_branch_count})") + + match kwargs["disassociated_type"]: + case "hourglass": + self.starting_branch_count = self.max_branch_count + case "octahedron": + self.starting_branch_count = self.min_branch_count + case _: + raise ValueError(f"incorrect value passed to disassociated_type") + + self.disassociated_type = kwargs["disassociated_type"] + del kwargs["disassociated_type"] + self.branch = self.disassociated_branch + if not self.split_method: + self.split_method = "time" + self.expected_branches = self.max_branch_count ** 2 + + case "constant": + self._handle_method_arg_("starting_branch_count", "branch_method", self.branch_method, kwargs) + + self.branch = self.constant_branch + if not self.split_method: + self.split_method = "never" + self.expected_branches = self.starting_branch_count ** 2 + + case _: + raise ValueError("incorrect value passed to branch_method") + + match self.split_method: + case "time": + self._handle_method_arg_("max_traj_length", "split_method", self.split_method, kwargs) + self._handle_method_arg_("alpha", "split_method", self.split_method, kwargs) + + self.update_max_traj_length = True + self.split = self.time_split + + case "constant": + self.split = self.constant_split + + case "never": + self.split = self.never_split + + case _: + raise ValueError("incorrect value passed to split_method") + + if hasattr(self, "starting_branch_count"): + self.current_branch_count = self.starting_branch_count + + def generate_transform_deltas(self): + + obs_state = self.dataset_dict["observations"] + if self.img_keys: + obs_state = self.dataset_dict["observations"]["state"] + + obs_size = obs_state.shape[-1] + total_branches = self.current_branch_count ** 2 + + self.transform_deltas = np.zeros(shape=(total_branches, obs_size), dtype=np.float32) + + idx = np.arange(total_branches) + x_deltas, y_deltas = np.divmod(idx, self.current_branch_count) + + x_deltas = (2 * x_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + y_deltas = (2 * y_deltas + 1) * self.workspace_width / (2 * self.current_branch_count) + x_deltas = np.repeat(x_deltas, self.x_obs_idx.size) + y_deltas = np.repeat(y_deltas, self.y_obs_idx.size) + x_deltas = np.reshape(x_deltas, (total_branches, self.x_obs_idx.size)) + y_deltas = np.reshape(y_deltas, (total_branches, self.y_obs_idx.size)) + + self.transform_deltas[..., self.x_obs_idx] = x_deltas + self.transform_deltas[..., self.y_obs_idx] = y_deltas + + if self._num_stack: + self.transform_deltas = np.expand_dims(self.transform_deltas, axis=1) + self.transform_deltas = np.repeat(self.transform_deltas, self._num_stack, axis=1) + + def fractal_branch(self): + ''' + Computes the number of branches for the current depth using an exponential growth rule. + + This method implements a "fractal branching" strategy, where the number of branches + increases exponentially with depth. At each depth `d`, the number of branches is calculated as: + + num_branches = branching_factor ** current_depth + + where: + - branching_factor: The base number of branches at each split. + - current_depth: The current depth in the fractal tree (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + # return a new number of branches = branching_factor ^ depth + return self.branching_factor ** self.current_depth + + def fractal_contraction(self): + ''' + Computes the number of branches for the current depth using a contraction rule. + + This method implements a "fractal contraction" branching strategy, where the number + of branches decreases exponentially with depth. At each depth `d`, the number of branches + is calculated as: + + num_branches = start_num / (branching_factor ** (d - 1)) + + where: + - start_num: The initial number of branches at depth 1. + - branching_factor: The factor by which the number of branches contracts at each depth. + - d: The current depth (self.current_depth). + + Returns: + int: The computed number of branches for the current depth. + ''' + + return self.branching_factor ** (self.max_depth - self.current_depth + 1) + + def constant_branch(self): + ''' + Used to create pure translations with no further branching. + self.current_branch_count used to set the total number of transformations. + ''' + # return current number of branches + return self.current_branch_count + + def disassociated_branch(self): + ''' + Used to create branches for disassociated fractal methods. + self.min_branch_count specifies the mininum branch count desired during the fractal rollout + self.max_branch_count specifies the maximum branch count desired during the fractal rollout + self.disassociated_type specifies whether to expand and then contract or to contract and then expand + self.steps_per_depth specifies the number of timesteps to take before splitting + (calculated indirectly via self.max_traj_length / self.num_depth_sectors) + self.num_depth_sectors specifies the number of sectors the rollout should be divided into for even splitting + ''' + if self.disassociated_type == "hourglass": + return int((self.max_branch_count - self.min_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.min_branch_count) + elif self.disassociated_type == "octahedron": + return int((self.min_branch_count - self.max_branch_count)/(self.max_depth/2) * np.abs(self.current_depth - (self.max_depth/2)) + self.max_branch_count) + + def linear_branch(self): + # return a new number of branches = branches_count + n + return self.current_branch_count + self.branching_factor + + def time_split(self, data_dict: DatasetDict): + if self.timestep % (self.max_traj_length//self.max_depth) or self.current_depth >= self.max_depth: + return False + self.current_depth += 1 + return True + + def constant_split(self, data_dict: DatasetDict): + self.current_depth += 1 + return True + + def never_split(self, data_dict: DatasetDict): + return False + + def insert_images(self, observation: dict): + for k in self.img_keys: + if self._num_stack: + self.img_buffer[k][self._img_insert_index_] = observation[k][0, ...] + else: + self.img_buffer[k][self._img_insert_index_] = observation[k] + self._img_insert_index_ = (self._img_insert_index_ + 1) % self.img_buffer_size + + def insert(self, data: DatasetDict): + + data_dict = copy.deepcopy(data) + + if self.img_keys: + obs = data_dict["observations"]["state"] + n_obs = data_dict["next_observations"]["state"] + else: + obs = data_dict["observations"] + n_obs = data_dict["next_observations"] + + actions = data_dict["actions"] + rewards = data_dict["rewards"] + masks = data_dict["masks"] + dones = data_dict["dones"] + + # Update number of branches if needed + if self.split(data_dict): + temp = self.current_branch_count + self.current_branch_count = self.branch() + # Update transform_deltas if needed + if temp != self.current_branch_count: + self.generate_transform_deltas() + + # Initialize to extreme x and y + base_diff = -self.workspace_width/2 + obs[..., self.x_obs_idx] += base_diff + obs[..., self.y_obs_idx] += base_diff + n_obs[..., self.x_obs_idx] += base_diff + n_obs[..., self.y_obs_idx] += base_diff + + # Transform transitions + num_transforms = self.current_branch_count ** 2 + + obs_shape = np.ones(len(obs.shape) + 1, dtype=int) + obs_shape[0] = num_transforms + obs = np.tile(obs, obs_shape) + n_obs = np.tile(n_obs, obs_shape) + actions = np.tile(actions, (num_transforms, 1)) + rewards = np.tile(rewards, num_transforms) + masks = np.tile(masks, num_transforms) + dones = np.tile(dones, num_transforms) + + obs += self.transform_deltas + n_obs += self.transform_deltas + + # Insert images + if self.img_keys: + if self.timestep == 0: + for i in range(self._num_stack): + self.insert_images(data_dict["observations"]) + self.insert_images(data_dict["next_observations"]) + + for k in self.img_keys: + data_dict["observations"][k] = (self._img_insert_index_ - 1) % len(self.img_buffer[k]) + data_dict["observations"][k] = np.tile(data_dict["observations"][k], num_transforms) + data_dict["next_observations"].pop(k) + + # Pack back into dictionary and insert + if self.img_keys: + data_dict["observations"]["state"] = obs + data_dict["next_observations"]["state"] = n_obs + else: + data_dict["observations"] = obs + data_dict["next_observations"] = n_obs + + data_dict["actions"] = actions + data_dict["rewards"] = rewards + data_dict["masks"] = masks + data_dict["dones"] = dones + + super().insert(data_dict, batch_size=num_transforms) + + # Reset current_depth, timestep, and max_traj_length + self.timestep += 1 + if data_dict["dones"][0]: + self.current_depth = 0 + if self.update_max_traj_length: + self.max_traj_length = int(self.timestep * self.alpha + self.max_traj_length * (1 - self.alpha)) + self.timestep = 0 + + def sample( + self, batch_size: int, keys: Optional[Iterable[str]] = None, indx: Optional[np.ndarray] = None, pack_obs_and_next_obs: bool = False, + ) -> frozen_dict.FrozenDict: + """Samples from the replay buffer. + + Args: + batch_size: Minibatch size. + keys: Keys to sample. + indx: Take indices instead of sampling. + pack_obs_and_next_obs: whether to pack img and next_img into one image. + It's useful when they have overlapping frames. + + Returns: + A frozen dictionary. + """ + # If no images, sample normally + if not self.img_keys: + return super().sample(batch_size, keys, indx) + + # Generate random indexes for sampling + if indx is None: + if hasattr(self.np_random, "integers"): + indx = self.np_random.integers(len(self), size=batch_size) + else: + indx = self.np_random.randint(len(self), size=batch_size) + + for i in range(batch_size): + while indx[i] >= self._size: + if hasattr(self.np_random, "integers"): + indx[i] = self.np_random.integers(len(self)) + else: + indx[i] = self.np_random.randint(len(self)) + else: + raise NotImplementedError() + + # Sample w/o images + if keys is None: + keys = self.dataset_dict.keys() + else: + assert "observations" in keys + + keys = list(keys) + keys.remove("observations") + + batch = super().sample(batch_size, keys, indx) + batch = batch.unfreeze() + + obs_keys = self.dataset_dict["observations"].keys() + obs_keys = list(obs_keys) + for k in self.img_keys: + obs_keys.remove(k) + + batch["observations"] = {} + for k in obs_keys: + batch["observations"][k] = _sample( + self.dataset_dict["observations"][k], indx + ) + + # Sample images + for k in self.img_keys: + obs_imgs = self.img_buffer[k] + obs_imgs = np.lib.stride_tricks.sliding_window_view( + obs_imgs, self._num_stack + 1, axis=0 + ) + obs_imgs = obs_imgs[self.dataset_dict["observations"][k][indx] - self._num_stack] + # transpose from (B, H, W, C, T) to (B, T, H, W, C) to follow jaxrl_m convention + obs_imgs = obs_imgs.transpose((0, 4, 1, 2, 3)) + + if pack_obs_and_next_obs: + batch["observations"][k] = obs_imgs + else: + batch["observations"][k] = obs_imgs[:, :-1, ...] + if "next_observations" in keys: + batch["next_observations"][k] = obs_imgs[:, 1:, ...] + + return frozen_dict.freeze(batch) diff --git a/serl_launcher/serl_launcher/data/fsrb_test.py b/serl_launcher/serl_launcher/data/fsrb_test.py new file mode 100644 index 00000000..f6b9e256 --- /dev/null +++ b/serl_launcher/serl_launcher/data/fsrb_test.py @@ -0,0 +1,306 @@ +import gym.wrappers +import numpy as np +import gym +from serl_launcher.utils.launcher import make_replay_buffer +from serl_launcher.data.fractal_symmetry_replay_buffer import FractalSymmetryReplayBuffer +from serl_launcher.wrappers.serl_obs_wrappers import SERLObsWrapper +from serl_launcher.wrappers.chunking import ChunkingWrapper +from absl import app, flags +import franka_sim +# import pandas as pd + +FLAGS = flags.FLAGS + +flags.DEFINE_integer("capacity", 10, "Replay buffer capacity.") +flags.DEFINE_string("branch_method", "constant", "Method for determining the number of transforms per dimension (x,y)") +flags.DEFINE_string("split_method", "never", "Method for determining whether to change the number of transforms per dimension (x,y)") +flags.DEFINE_float("workspace_width", 0.5, "workspace width in meters") +flags.DEFINE_integer("max_depth", 4, "Maximum level of depth") # For fractal_branch only +flags.DEFINE_integer("max_steps",100,"Maximum steps") +flags.DEFINE_integer("branching_factor", 3, "Rate of change of number of transforms per dimension (x,y)") # For fractal_branch only +flags.DEFINE_integer("starting_branch_count", 1, "Initial number of transforms per dimension (x,y)") # For constant_branch only +flags.DEFINE_integer("alpha",1,"alpha value") +# Density Workspace width +flags.DEFINE_string("workspace_width_method",'increase', 'Controls workspace width dimensions configurations') + +def main(_): + + x_obs_idx = np.array([0, 4]) + y_obs_idx = np.array([1, 5]) + + # Initialize replay buffer + env = gym.make("PandaPickCubeVision-v0") + env = SERLObsWrapper(env) + env = ChunkingWrapper(env, obs_horizon=3, act_exec_horizon=None) + + # env = gym.make("PandaReachCube-v0") + # env = gym.wrappers.FlattenObservation(env) + + image_keys = [key for key in env.observation_space.keys() if key != "state"] + + their_buffer = make_replay_buffer( + env, + type="memory_efficient_replay_buffer", + capacity=FLAGS.capacity, + image_keys=image_keys, + + ) + + replay_buffer = make_replay_buffer( + env, + type="fractal_symmetry_replay_buffer", + capacity=FLAGS.capacity, + split_method=FLAGS.split_method, + branch_method=FLAGS.branch_method, + workspace_width=FLAGS.workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx= y_obs_idx, + image_keys=image_keys, + # max_depth=FLAGS.max_depth, + max_traj_length = 100, + # branching_factor=FLAGS.branching_factor, + # alpha = FLAGS.alpha, + starting_branch_count = FLAGS.starting_branch_count, + ) + + observation, info = env.reset() + action = env.action_space.sample() + next_observation, reward, terminated, truncated, info = env.step(action) + + # observation = np.zeros_like(observation) + for k in observation.keys(): + observation[k] = np.zeros_like(observation[k]) + next_observation[k] = np.ones_like(next_observation[k]) + + action = np.ones_like(action) + # next_observation = np.ones_like(next_observation) + reward = 1 + + data_dict = dict( + observations=observation, + next_observations=next_observation, + actions=action, + rewards=reward, + masks=not truncated and not terminated, + dones=truncated or terminated, + ) + + del env, observation, next_observation, action, reward, truncated, terminated, info, y_obs_idx, x_obs_idx, _ + + for i in range(6): + + replay_buffer.insert(data_dict) + their_buffer.insert(data_dict) + assert(replay_buffer.dataset_dict["observations"]["state"][i % 10].all() == their_buffer.dataset_dict["observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.dataset_dict["next_observations"]["state"][i % 10].all() == their_buffer.dataset_dict["next_observations"]["state"][(i + 3) % 10].all()) + assert(replay_buffer.img_buffer["front"][replay_buffer.dataset_dict["observations"]["front"][i % 10]].all() == their_buffer.dataset_dict["observations"]["front"][(i + 3) % 10].all()) + + data_dict["observations"]["state"] += 1 + data_dict["next_observations"]["state"] += 1 + for k in image_keys: + data_dict["observations"][k] += 1 + data_dict["next_observations"][k] += 1 + + replay_buffer.sample(batch_size=3, indx=np.array([2,3,4])) + their_buffer.sample(batch_size=3, indx=np.array([5,6,7])) + + + + # branch() tests + + #------------------------------------------------------------------- + # Fractal Associative Expansions + #------------------------------------------------------------------- + replay_buffer.branching_factor = 3 + + replay_buffer.current_depth = 1 + result = replay_buffer.fractal_branch() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 + result = replay_buffer.fractal_branch() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 + result = replay_buffer.fractal_branch() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 + result = replay_buffer.fractal_branch() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 + + del result, expected + + print("\033[32mTEST PASSED \033[0m fractal_branch() tests passed") + + #------------------------------------------------------------------- + # Fractal Associative Contractions + #------------------------------------------------------------------- + replay_buffer.branching_factor = 3 + + replay_buffer.current_depth = 1 + result = replay_buffer.fractal_contraction() + expected = 81 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 2 + result = replay_buffer.fractal_contraction() + expected = 27 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 3 + result = replay_buffer.fractal_contraction() + expected = 9 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 4 + result = replay_buffer.fractal_contraction() + expected = 3 + assert result == expected, f"\033[31mTEST FAILED\033[0m fractal_branch() test failed (expected {expected} but got {result})" + + replay_buffer.current_depth = 0 + + del result, expected + + print("\033[32mTEST PASSED \033[0m fractal_contraction() tests passed") + + #------------------------------------------------------------------- + # split() tests + #------------------------------------------------------------------- + + ## time + replay_buffer.max_steps = 100 + replay_buffer.max_depth = 4 + + replay_buffer.timestep = 0 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 25 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 50 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + replay_buffer.timestep = 75 + result = replay_buffer.time_split(data_dict) + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + + replay_buffer.timestep = 100 + result = replay_buffer.time_split(data_dict) + expected = False + assert result == expected, f"\033[31mTEST FAILED\033[0m split() test failed (expected {expected} but got {result})" + + del result, expected + + print("\033[32mTEST PASSED \033[0m time_split() test passed") + + # insert() tests + # insert() tests + initial_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + replay_buffer.insert(data_dict) + final_size = len(replay_buffer.dataset_dict['observations'][0]) * replay_buffer._insert_index % len(replay_buffer.dataset_dict['observations']) + + result = final_size > initial_size + expected = True + assert result == expected, f"\033[31mTEST FAILED\033[0m insert() test failed (expected buffer size to increase from {initial_size} to {final_size})" + del result, expected, initial_size, final_size + + + print("\033[32mTEST PASSED \033[0m insert() tests passed") + + #------------------------------------------------------------------- + # Fractal Expansions with workspace_width_modification + #------------------------------------------------------------------- + print('\nWorkspace width tests....') + + replay_buffer.branching_factor = 3 + replay_buffer.current_depth = 1 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 2 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + + replay_buffer.current_depth = 3 + + if FLAGS.workspace_width_method == 'constant': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'decrease': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width - 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + elif FLAGS.workspace_width_method == 'increase': + result = replay_buffer.get_workspace_width() + expected = FLAGS.workspace_width + 0.05*replay_buffer.current_depth + assert result == expected, f"\033[31mTEST FAILED\033[0m get_workspace_width() test failed (expected {expected} but got {result})" + + else: + raise NameError('There is no workspace width method with that name.') + + replay_buffer.current_depth = 0 + + del result, expected + + print("\n\033[32mTEST PASSED \033[0m workspace_width_method() test passed") + + + print("\nfinished!\n") + + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py index d94f1143..c37fcf5f 100644 --- a/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py +++ b/serl_launcher/serl_launcher/data/memory_efficient_replay_buffer.py @@ -18,7 +18,6 @@ def __init__( pixel_keys: Tuple[str, ...] = ("pixels",), ): self.pixel_keys = pixel_keys - observation_space = copy.deepcopy(observation_space) self._num_stack = None for pixel_key in self.pixel_keys: diff --git a/serl_launcher/serl_launcher/data/replay_buffer.py b/serl_launcher/serl_launcher/data/replay_buffer.py index f7d798a2..0a1ab414 100644 --- a/serl_launcher/serl_launcher/data/replay_buffer.py +++ b/serl_launcher/serl_launcher/data/replay_buffer.py @@ -22,17 +22,24 @@ def _init_replay_dict( def _insert_recursively( - dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int + dataset_dict: DatasetDict, data_dict: DatasetDict, insert_index: int, capacity: int, batch_size: int = None, ): if isinstance(dataset_dict, np.ndarray): - dataset_dict[insert_index] = data_dict + if batch_size: + if insert_index + batch_size > capacity: + dataset_dict[insert_index:capacity] = data_dict[0:(capacity - insert_index)] + dataset_dict[0:(insert_index + batch_size - capacity)] = data_dict[(capacity - insert_index):batch_size] + else: + dataset_dict[insert_index:(insert_index + batch_size)] = data_dict + else: + dataset_dict[insert_index] = data_dict elif isinstance(dataset_dict, dict): assert dataset_dict.keys() == data_dict.keys(), ( dataset_dict.keys(), data_dict.keys(), ) for k in dataset_dict.keys(): - _insert_recursively(dataset_dict[k], data_dict[k], insert_index) + _insert_recursively(dataset_dict[k], data_dict[k], insert_index, capacity, batch_size) else: raise TypeError() @@ -68,15 +75,20 @@ def __init__( def __len__(self) -> int: return self._size - def insert(self, data_dict: DatasetDict): - _insert_recursively(self.dataset_dict, data_dict, self._insert_index) + def insert(self, data_dict: DatasetDict, batch_size : int = None): + _insert_recursively(self.dataset_dict, data_dict, self._insert_index, self._capacity, batch_size) - self._insert_index = (self._insert_index + 1) % self._capacity - self._size = min(self._size + 1, self._capacity) + if batch_size: + self._insert_index = (self._insert_index + batch_size) % self._capacity + self._size = min(self._size + batch_size, self._capacity) + else: + self._insert_index = (self._insert_index + 1) % self._capacity + self._size = min(self._size + 1, self._capacity) def get_iterator(self, queue_size: int = 2, sample_args: dict = {}, device=None): # See https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device - # queue_size = 2 should be ok for one GPU. + + # queue_size = 2 should be ok for one GPU. See more at https://chatgpt.com/share/687af063-d6b0-8004-92b6-0e88b9c5f1e8 queue = collections.deque() def enqueue(n): diff --git a/serl_launcher/serl_launcher/networks/actor_critic_nets.py b/serl_launcher/serl_launcher/networks/actor_critic_nets.py index 189eef82..d0bb07ca 100644 --- a/serl_launcher/serl_launcher/networks/actor_critic_nets.py +++ b/serl_launcher/serl_launcher/networks/actor_critic_nets.py @@ -13,24 +13,68 @@ class ValueCritic(nn.Module): + """State-value network that predicts ``V(s)`` from observations. + + In actor-critic RL, a value critic estimates how good a state is on average, + independent of a specific action. This module maps observations to a single + scalar per sample: + - input: observation tensor(s) + - output: one value estimate per observation (shape ``[batch]``) + + Architecture used here: + 1. ``encoder`` transforms raw observations into learned features. + 2. ``network`` processes those features into a hidden representation. + 3. A final linear layer (``Dense(1)``) produces the scalar value. + + Initialization behavior: + - If ``init_final`` is provided, the final layer kernel is initialized + uniformly in ``[-init_final, init_final]`` for tighter initial output scale. + - Otherwise, the project default initializer is used. + + Notes for students: + - ``train`` is passed through so submodules can switch behavior (for example, + dropout/batch-norm if those are used in ``encoder``/``network``). + - The returned tensor is ``squeeze``d on the last axis, converting + ``[..., 1]`` to ``[...]`` for convenience in losses. + """ + # Encodes raw observations into feature vectors. encoder: nn.Module + # Backbone network that maps encoded features to a critic hidden representation. network: nn.Module + # Optional bound for uniform init of the final value head weights. init_final: Optional[float] = None + # Define submodules inline in __call__; Flax registers params automatically. @nn.compact def __call__(self, observations: jnp.ndarray, train: bool = False) -> jnp.ndarray: outputs = self.network(self.encoder(observations), train=train) if self.init_final is not None: + # Use a narrow, explicit final-layer init when requested. value = nn.Dense( 1, kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), )(outputs) else: + # Fall back to the default project initializer. value = nn.Dense(1, kernel_init=default_init())(outputs) return jnp.squeeze(value, -1) def multiple_action_q_function(forward): + """Decorator that makes a critic forward pass support multiple actions/state. + + Args: + forward: The original critic method (typically ``__call__``) that maps + ``(self, observations, actions, **kwargs)`` to Q-values for one + action per state. + + Returns: + A wrapped function that: + - calls ``forward`` directly when ``actions`` is 2D (single action/state) + - vmaps ``forward`` over the action axis when ``actions`` is 3D + (multiple candidate actions/state), returning stacked Q-values. + vmap ensures that + """ # Forward the q function with multiple actions on each state, to be used as a decorator def wrapped(self, observations, actions, **kwargs): if jnp.ndim(actions) == 3: @@ -51,6 +95,7 @@ class Critic(nn.Module): network: nn.Module init_final: Optional[float] = None + # Define submodules inline in __call__; Flax registers params automatically. @nn.compact @multiple_action_q_function def __call__( @@ -74,6 +119,16 @@ def __call__( class DistributionalCritic(nn.Module): + """Distributional variant of ``Critic`` that predicts a value distribution. + + Unlike ``Critic`` (which outputs one scalar Q-value per state-action pair), + this module outputs: + - ``logits`` over ``num_atoms`` support points + - the corresponding ``atoms`` linearly spaced in ``[q_low, q_high]`` + + The expected Q-value can be recovered downstream from these distributional + outputs, while training can use distributional RL losses. + """ encoder: Optional[nn.Module] network: nn.Module q_low: float @@ -81,6 +136,7 @@ class DistributionalCritic(nn.Module): num_atoms: int = 51 init_final: Optional[float] = None + # Define submodules inline in __call__; Flax registers params automatically. @nn.compact def __call__( self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False @@ -95,7 +151,7 @@ def __call__( if self.init_final is not None: logits = nn.Dense( self.num_atoms, - kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), + kernel_init=nn.initializers.uniform(-self.init_final, self.init_final), # uniform weight initialization. gives a strict bound on weights vs gaussian. early atom logits/Q outputs small for stable training. )(outputs) else: logits = nn.Dense(self.num_atoms, kernel_init=default_init())(outputs) @@ -107,6 +163,18 @@ def __call__( class ContrastiveCritic(nn.Module): + """Contrastive value model using state-action and goal embeddings. + + Unlike ``Critic`` (scalar Q per state-action) and ``DistributionalCritic`` + (distribution over fixed value atoms), this module does not predict Q + directly. It learns embeddings for: + - ``(state, action)`` via ``sa_net`` + - ``goal`` via ``g_net`` + + It returns their pairwise similarity matrix (dot products). With + ``twin_q=True``, a second independent similarity head is stacked on the + last axis. + """ encoder: nn.Module sa_net: nn.Module g_net: nn.Module @@ -116,6 +184,7 @@ class ContrastiveCritic(nn.Module): g_net2: Optional[nn.Module] = None init_final: Optional[float] = None + # Define submodules inline in __call__; Flax registers params automatically. @nn.compact def __call__( self, observations: jnp.ndarray, actions: jnp.ndarray, train: bool = False @@ -165,6 +234,50 @@ def ensemblize(cls, num_qs, out_axes=0): class Policy(nn.Module): + """Gaussian actor that outputs an action distribution from observations. + + This is the stochastic policy used by actor-critic methods (for example, + SAC-style training). Given an observation, the module predicts: + - ``means``: the center of the action distribution for each action dimension + - ``stds``: the spread (uncertainty / exploration scale) per action dimension + + The policy returns a distribution object, not a sampled action. Downstream + code can then sample actions, compute log-probabilities, or evaluate entropy. + + High-level flow: + 1. Optionally encode observations with ``encoder``. + 2. Pass features through ``network``. + 3. Predict action means with a linear layer. + 4. Compute standard deviations using one of ``std_parameterization`` modes. + 5. Clip stds to ``[std_min, std_max]`` and scale by ``sqrt(temperature)``. + 6. Return either a Gaussian distribution or a tanh-squashed Gaussian. + + Std parameterization modes: + - ``"exp"``: + Predict ``log_stds`` from the network and set ``stds = exp(log_stds)``. + Common choice; ensures positivity while letting the network adapt stds by + state. + - ``"softplus"``: + Predict unconstrained values and map with ``softplus`` for positive stds. + Similar goal to ``exp`` but with different gradients/saturation behavior. + - ``"uniform"``: + Use one learned vector ``log_stds`` shared across all states (state- + independent std), then exponentiate. + - ``"fixed"``: + Do not learn std from this module; use ``fixed_std`` provided externally. + + Action bounds: + - If ``tanh_squash_distribution=False``, output is an unconstrained Gaussian + (``distrax.MultivariateNormalDiag``). + - If ``True``, output is transformed through tanh + (``TanhMultivariateNormalDiag``), commonly used for bounded actions. + + Notes for students: + - ``init_final`` is currently a config field but is not used in this class's + final layers. + - ``temperature`` controls exploration scale at runtime by multiplying std by + ``sqrt(temperature)``. Larger temperature => broader sampling. + """ encoder: Optional[nn.Module] network: nn.Module action_dim: int @@ -175,6 +288,7 @@ class Policy(nn.Module): tanh_squash_distribution: bool = False fixed_std: Optional[jnp.ndarray] = None + # Define submodules inline in __call__; Flax registers params automatically. @nn.compact def __call__( self, observations: jnp.ndarray, temperature: float = 1.0, train: bool = False @@ -228,6 +342,33 @@ def __call__( class TanhMultivariateNormalDiag(distrax.Transformed): + """Diagonal Gaussian policy transformed by tanh (and optional rescaling). + + This class wraps a base multivariate normal distribution and applies a + bijective transform to produce bounded actions. It is commonly used in + continuous-control RL because: + - a Gaussian is easy to optimize in latent action space + - tanh squashing keeps sampled actions bounded + - log-probabilities remain correct through change-of-variables + + Conceptually, it does: + 1. Sample latent action ``u ~ Normal(loc, scale_diag)``. + 2. Squash with ``tanh`` to get values in ``(-1, 1)``. + 3. If ``low`` and ``high`` are provided, map from ``(-1, 1)`` into + ``[low, high]`` elementwise. + + Because this is a ``distrax.Transformed`` distribution, ``log_prob`` and + sampling automatically account for the bijector Jacobian (including the + custom affine rescale Jacobian when bounds are provided). + + Args: + loc: Mean vector of the base diagonal Gaussian. + scale_diag: Per-dimension standard deviations of the base Gaussian. + low: Optional lower action bounds. Must broadcast with action shape. + high: Optional upper action bounds. Must broadcast with action shape. + If either ``low`` or ``high`` is missing, no extra rescaling is + applied and outputs stay in ``(-1, 1)`` after tanh. + """ def __init__( self, loc: jnp.ndarray, diff --git a/serl_launcher/serl_launcher/networks/reward_classifier.py b/serl_launcher/serl_launcher/networks/reward_classifier.py index b05e196c..6b60f17c 100644 --- a/serl_launcher/serl_launcher/networks/reward_classifier.py +++ b/serl_launcher/serl_launcher/networks/reward_classifier.py @@ -67,7 +67,7 @@ def create_classifier( with open(pretrained_encoder_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/utils/launcher.py b/serl_launcher/serl_launcher/utils/launcher.py index 782221eb..d9cdf5f4 100644 --- a/serl_launcher/serl_launcher/utils/launcher.py +++ b/serl_launcher/serl_launcher/utils/launcher.py @@ -18,6 +18,7 @@ from serl_launcher.data.data_store import ( MemoryEfficientReplayBufferDataStore, ReplayBufferDataStore, + FractalSymmetryReplayBufferDataStore, ) ############################################################################## @@ -179,13 +180,17 @@ def make_trainer_config(port_number: int = 5488, broadcast_port: int = 5489): def make_wandb_logger( project: str = "agentlace", + name: str = "placeholder_run_name", description: str = "serl_launcher", + wandb_output_dir: str = None, debug: bool = False, + offline: bool = False, ): wandb_config = WandBLogger.get_default_config() wandb_config.update( { "project": project, + "name": name, "exp_descriptor": description, "tag": description, } @@ -193,7 +198,9 @@ def make_wandb_logger( wandb_logger = WandBLogger( wandb_config=wandb_config, variant={}, + wandb_output_dir=wandb_output_dir, debug=debug, + offline=offline ) return wandb_logger @@ -206,6 +213,13 @@ def make_replay_buffer( image_keys: list = [], # used only type=="memory_efficient_replay_buffer" preload_rlds_path: Optional[str] = None, preload_data_transform: Optional[callable] = None, + branch_method: str = None, # used only type=="fractal_symmetry_replay_buffer" + split_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + workspace_width : float = None, # used only type=="fractal_symmetry_replay_buffer" + workspace_width_method : str = None, # used only type=="fractal_symmetry_replay_buffer" + x_obs_idx = None, + y_obs_idx = None, + **kwargs: dict # used only type=="fractal_symmetry_replay_buffer" ): """ This is the high-level helper function to @@ -215,7 +229,7 @@ def make_replay_buffer( - env: gym or gymasium environment - capacity: capacity of the replay buffer - rlds_logger_path: path to save RLDS logs - - type: support only for "replay_buffer" and "memory_efficient_replay_buffer" + - type: support only for "replay_buffer", "memory_efficient_replay_buffer", and "fractal_symmetry_replay_buffer" - image_keys: list of image keys, used only "memory_efficient_replay_buffer" - preload_rlds_path: path to preloaded RLDS trajectories - preload_data_transform: data transformation function for preloaded RLDS data @@ -254,9 +268,31 @@ def make_replay_buffer( rlds_logger=rlds_logger, image_keys=image_keys, ) + elif type == "fractal_symmetry_replay_buffer": + replay_buffer = FractalSymmetryReplayBufferDataStore( + env.observation_space, + env.action_space, + capacity=capacity, + branch_method=branch_method, + split_method=split_method, + workspace_width=workspace_width, + x_obs_idx=x_obs_idx, + y_obs_idx=y_obs_idx, + rlds_logger=rlds_logger, + image_keys=image_keys, + kwargs=kwargs, + ) + else: raise ValueError(f"Unsupported replay_buffer_type: {type}") + # Load RLDS or oxe_envlogger recroded data with tfds.builder_from_directory. + # Choose number of episodes by passing split="train[:N%]" or split="test[:N%]" + # or ds = tfds.builder_from_directory(builder_dir).as_dataset(split="train").take(5) + # See more details: https://www.tensorflow.org/datasets/splits + # + # It's also possible to filter specirfic episodes, i.e. by time: + # ds = ds.filter(lambda ep: tf.strings.regex_full_match(ep['some/session_id'], "20250821_222412")) if preload_rlds_path: print(f" - Preloaded {preload_rlds_path} to replay buffer") dataset = tfds.builder_from_directory(preload_rlds_path).as_dataset(split="all") diff --git a/serl_launcher/serl_launcher/utils/nAUC_computation.py b/serl_launcher/serl_launcher/utils/nAUC_computation.py new file mode 100644 index 00000000..0d632cdc --- /dev/null +++ b/serl_launcher/serl_launcher/utils/nAUC_computation.py @@ -0,0 +1,77 @@ +import wandb +import numpy as np +import pandas as pd +from scipy.integrate import trapezoid as trap +import matplotlib.pyplot as plt + +# --- Pulling run data from wandb --- + +# Accessing API +api = wandb.Api() + +# Pulling data for specified project and group +good_runs = api.runs("lipscomb-robotics/PegInsert_results", {"group":"ours"}) +baseline_runs = api.runs("lipscomb-robotics/PegInsert_results", {"group":"theirs"}) +# Specifying data to consider - wall_time instead of _step is likely what we need for this +good_hist = [r.history(keys=["_step", "success_rate", "_runtime"]) for r in good_runs] +#good_hist["replay_buffer_type"] = r.config.get("replay_buffer_type") + +baseline_hist = [r.history(keys=["_step", "success_rate", "_runtime"]) for r in baseline_runs] + +# Labeling data by run number +for i,df in enumerate(good_hist): df["run"]=i +for i,df in enumerate(baseline_hist): df["run"]=i + +# Building dataframe of all runs +good_df = pd.concat(good_hist) +baseline_df = pd.concat(baseline_hist) + +# Merging all runs +good_time_sorted = good_df.sort_values("_step").reset_index(drop=True) +baseline_time_sorted = baseline_df.sort_values("_step").reset_index(drop=True) + +# Smoothing runs +good_time_avg = good_df.groupby('_step', as_index=False)['success_rate'].mean() +baseline_time_avg = baseline_df.groupby('_step', as_index=False)['success_rate'].mean() +good_time_avg['group']='ours' +baseline_time_avg['group']='theirs' + +# Checking work +good_time_avg.plot(x='_step', y='success_rate') +plt.savefig('Ours - PegInsert Results') +baseline_time_avg.plot(x='_step', y='success_rate') +plt.savefig('Baseline - PegInsert Results') + +# --- Calculating integrals from run data --- + +# Combining run dfs +combined_df = pd.concat([good_time_avg, baseline_time_avg]) + +# Creating blank dictionaries to group run data and integrals +combined_rundata_arrays = {} +combined_integrals = {} +combined_runlengths = {} +combined_maxrewards = {} +combined_normalizingfactor = {} +combined_integrals_normalized = {} + +# Computing run integrals + +for run_label, group in combined_df.groupby("group"): + x = group["_step"].to_numpy() # Uncomment to run by steps. Comment to run by time + #x = group["_runtime"].to_numpy() # Comment line to run by steps. Uncomment to run by time + y = group["success_rate"].to_numpy() + combined_rundata_arrays[run_label] = (x,y) + combined_integrals[run_label] = trap(y,x) + combined_runlengths[run_label] = float(x[-1]) + combined_maxrewards[run_label] = float(max(y)) + combined_normalizingfactor[run_label] = combined_runlengths[run_label] * combined_maxrewards[run_label] + combined_integrals_normalized[run_label] = float(combined_integrals[run_label] / combined_normalizingfactor[run_label]) + +pd.DataFrame( + combined_integrals_normalized.items(), + columns=["group", "normalized_integral"] +).to_csv( + "PegInsert Results - Normalized Integrals.csv", + index=False +) \ No newline at end of file diff --git a/serl_launcher/serl_launcher/utils/train_utils.py b/serl_launcher/serl_launcher/utils/train_utils.py index 31037317..dd8816a7 100644 --- a/serl_launcher/serl_launcher/utils/train_utils.py +++ b/serl_launcher/serl_launcher/utils/train_utils.py @@ -108,7 +108,7 @@ def load_resnet10_params(agent, image_keys=("image",), public=True): with open(file_path, "rb") as f: encoder_params = pkl.load(f) - param_count = sum(x.size for x in jax.tree_leaves(encoder_params)) + param_count = sum(x.size for x in jax.tree.leaves(encoder_params)) print( f"Loaded {param_count/1e6}M parameters from ResNet-10 pretrained on ImageNet-1K" ) diff --git a/serl_launcher/serl_launcher/utils/upload_eval_from_csv.py b/serl_launcher/serl_launcher/utils/upload_eval_from_csv.py new file mode 100644 index 00000000..f4c6df78 --- /dev/null +++ b/serl_launcher/serl_launcher/utils/upload_eval_from_csv.py @@ -0,0 +1,49 @@ +import wandb +import pandas as pd +import sys + + +def upload_runs_from_csv(filename, project_name): + df = pd.read_csv(filename) + + # Group by Name + grouped = df.groupby("Name") + + for run_name, group in grouped: + # Optional: sort if order matters (e.g., by step or time) + group = group.reset_index(drop=True) + + with wandb.init( + project=project_name, + name=str(run_name), + reinit=True + ) as run: + + for step_idx, (_, row) in enumerate(group.iterrows()): + row_dict = row.to_dict() + + log_data = { + "success_rate": row_dict.get("success_rate"), + "Steps": row_dict.get("Steps"), + "Time (min)": row_dict.get("Time (min)"), + "Time (sec)": row_dict.get("Time (sec)") + } + + # Log each row as a step in the same run + wandb.log(log_data, step=step_idx) + + + +def main(): + if len(sys.argv) != 3: + print("Usage: python upload_eval_from_csv.py ") + return + + filename = sys.argv[1] + project_name = sys.argv[2] + + + upload_runs_from_csv(filename, project_name) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/serl_launcher/serl_launcher/vision/data_augmentations.py b/serl_launcher/serl_launcher/vision/data_augmentations.py index 2c2440fa..f455bbf1 100644 --- a/serl_launcher/serl_launcher/vision/data_augmentations.py +++ b/serl_launcher/serl_launcher/vision/data_augmentations.py @@ -169,7 +169,7 @@ def hsv_to_rgb(h, s, v): def adjust_brightness(rgb_tuple, delta): - return jax.tree_map(lambda x: x + delta, rgb_tuple) + return jax.tree.map(lambda x: x + delta, rgb_tuple) def adjust_contrast(image, factor): @@ -177,7 +177,7 @@ def _adjust_contrast_channel(channel): mean = jnp.mean(channel, axis=(-2, -1), keepdims=True) return factor * (channel - mean) + mean - return jax.tree_map(_adjust_contrast_channel, image) + return jax.tree.map(_adjust_contrast_channel, image) def adjust_saturation(h, s, v, factor): @@ -256,7 +256,7 @@ def identity_fn(x, unused_rng, unused_param): def cond_fn(args, i): def clip(args): - return jax.tree_map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) + return jax.tree.map(lambda arg: jnp.clip(arg, 0.0, 1.0), args) out = jax.lax.cond( should_apply & should_apply_color & (i == idx), @@ -275,7 +275,7 @@ def clip(args): random_hue_cond = _make_cond(_random_hue, idx=3) def _color_jitter(x): - rgb_tuple = tuple(jax.tree_map(jnp.squeeze, jnp.split(x, 3, axis=-1))) + rgb_tuple = tuple(jax.tree.map(jnp.squeeze, jnp.split(x, 3, axis=-1))) if shuffle: order = jax.random.permutation(perm_rng, jnp.arange(4, dtype=jnp.int32)) else: diff --git a/serl_launcher/serl_launcher/wrappers/chunking.py b/serl_launcher/serl_launcher/wrappers/chunking.py index 175c9a5f..404a31c5 100644 --- a/serl_launcher/serl_launcher/wrappers/chunking.py +++ b/serl_launcher/serl_launcher/wrappers/chunking.py @@ -9,7 +9,7 @@ def stack_obs(obs): dict_list = {k: [dic[k] for dic in obs] for k in obs[0]} - return jax.tree_map( + return jax.tree.map( lambda x: np.stack(x), dict_list, is_leaf=lambda x: isinstance(x, list) ) diff --git a/serl_launcher/serl_launcher/wrappers/remap.py b/serl_launcher/serl_launcher/wrappers/remap.py index 7acb2d93..1d724dc1 100644 --- a/serl_launcher/serl_launcher/wrappers/remap.py +++ b/serl_launcher/serl_launcher/wrappers/remap.py @@ -31,4 +31,4 @@ def __init__(self, env: gym.Env, new_structure: Any): raise TypeError(f"Unsupported type {type(new_structure)}") def observation(self, observation): - return jax.tree_map(lambda x: observation[x], self.new_structure) + return jax.tree.map(lambda x: observation[x], self.new_structure) diff --git a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py index 41c169f9..0c789762 100644 --- a/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py +++ b/serl_launcher/serl_launcher/wrappers/serl_obs_wrappers.py @@ -2,24 +2,184 @@ from gym.spaces import flatten_space, flatten +# Optional resizers +import numpy as np +try: + import cv2 + _HAS_CV2 = True +except Exception: + _HAS_CV2 = False + +try: + from PIL import Image + _HAS_PIL = True +except Exception: + _HAS_PIL = False + +def _resize_hwc(img: np.ndarray, hw: tuple[int, int]) -> np.ndarray: + """ + Resize HxWxC image to (H', W', C) without changing dtype/range. + Supports cv2, PIL, or pure NumPy resizing. + If img is float32, it will be scaled to [0,1] if not already in that range. + If img is uint8, it will be resized as-is. + + Args: + img (np.ndarray): Input image in HxWxC format. + hw (tuple[int, int]): Target height and width (H', W'). + Returns: + np.ndarray: Resized image in H'xW'xC format. + """ + H, W = hw + if _HAS_CV2: + # cv2 wants (W, H) + return cv2.resize(img, (W, H), interpolation=cv2.INTER_AREA) + + if _HAS_PIL: + pil = Image.fromarray(img if img.dtype == np.uint8 else np.clip(img, 0, 255).astype(np.uint8)) + pil = pil.resize((W, H), resample=Image.Resampling.BILINEAR) + out = np.asarray(pil) + if img.dtype != np.uint8: + # If original was float, map back to float [0,1] + out = out.astype(np.float32) / 255.0 + return out + + # Pure NumPy (simple nearest neighbor) + y_idx = (np.linspace(0, img.shape[0] - 1, H)).astype(np.int32) + x_idx = (np.linspace(0, img.shape[1] - 1, W)).astype(np.int32) + return img[y_idx][:, x_idx] + class SERLObsWrapper(gym.ObservationWrapper): """ - This observation wrapper treat the observation space as a dictionary - of a flattened state space and the images. + Observation wrapper for SERL environments. + Flattens the 'state' space and resizes images to a target height and width. + Supports both uint8 and float32 images, with optional normalization. + The observation space is a Dict with 'state' and resized image spaces. + + Args: + env (gym.Env): The environment to wrap. + target_hw (tuple[int, int]): Target height and width for resized images. + img_dtype (np.dtype): Data type for images, either np.uint8 or np.float32. + normalize (bool): If True, scales float32 images to [0,1]. + image_parent_key (str): Key in the observation dict where images are stored. + + Defaults to "images". + Returns: + gym.spaces.Dict: The new observation space with flattened state and resized images. """ - def __init__(self, env): + def __init__( + self, + env, + target_hw=(128, 128), # (H, W) for resized images + img_dtype=np.uint8, # np.uint8 for [0..255], or np.float32 for [0..1] + normalize=False, # if True and img_dtype=float32, scale to [0,1] + image_parent_key="images", # where images live in the original obs dict + ): super().__init__(env) - self.observation_space = gym.spaces.Dict( - { - "state": flatten_space(self.env.observation_space["state"]), - **(self.env.observation_space["images"]), - } - ) + assert isinstance(self.env.observation_space, gym.spaces.Dict), \ + "Expected Dict observation_space with keys {'state', 'images'}" + + # ---- Build new observation_space ---- + base_space = self.env.observation_space + assert "state" in base_space.spaces, "Missing 'state' in observation_space" + assert image_parent_key in base_space.spaces, f"Missing '{image_parent_key}' in observation_space" + img_space_dict = base_space.spaces[image_parent_key] + assert isinstance(img_space_dict, gym.spaces.Dict), \ + f"'{image_parent_key}' must be a Dict of image spaces" + + # Flattened state space + state_space = flatten_space(base_space.spaces["state"]) + + + # Image spaces (resized) + H, W = target_hw + image_spaces = {} + for k, sp in img_space_dict.spaces.items(): + # Assume HWC input; preserve channel count + if hasattr(sp, "shape") and sp.shape is not None: + if len(sp.shape) != 3: + raise ValueError(f"Image space '{k}' must be HxWxC; got shape {sp.shape}") + C = sp.shape[-1] + else: + raise ValueError(f"Image space '{k}' missing shape") + + if img_dtype == np.uint8: + low, high = 0, 255 + elif img_dtype == np.float32: + low, high = 0.0, 1.0 if normalize else float(getattr(sp, "high", 1.0)) + else: + raise ValueError("img_dtype must be np.uint8 or np.float32") + + image_spaces[k] = gym.spaces.Box( + low=low, + high=high, + shape=(H, W, C), + dtype=img_dtype, + ) + + # Final Dict space: {'state': ..., 'front': Box(...), 'wrist': Box(...), ...} + self.observation_space = gym.spaces.Dict({ + "state": state_space, + **image_spaces + }) + # self.observation_space = gym.spaces.Dict( + # { + # "state": flatten_space(self.env.observation_space["state"]), + # **(self.env.observation_space["images"]), + # } + # ) + + # Store config + self._target_hw = target_hw + self._img_dtype = img_dtype + self._normalize = normalize + self._image_parent_key = image_parent_key + + # def observation(self, obs): + # obs = { + # "state": flatten(self.env.observation_space["state"], obs["state"]), + # **(obs["images"]), + # } + # return obs def observation(self, obs): - obs = { - "state": flatten(self.env.observation_space["state"], obs["state"]), - **(obs["images"]), - } - return obs + # Flatten state using original (pre-flatten) state space definition + flat_state = flatten(self.env.observation_space.spaces["state"], obs["state"]) + + # Pull original images dict + imgs = obs[self._image_parent_key] + + # Resize & cast each image to match observation_space spec + out = {"state": flat_state} + for k, sp in self.observation_space.spaces.items(): + if k == "state": + continue + img = imgs[k] + # Ensure HWC + if img.ndim != 3: + raise ValueError(f"Image '{k}' must be HxWxC; got shape {img.shape}") + + # If float32 images in [0,1] but we want uint8, scale up before resize for best quality + want_uint8 = (self._img_dtype == np.uint8) + if want_uint8: + if img.dtype != np.uint8: + # Assume 0..1 range; if 0..255 float, clip and cast + img = np.clip(img, 0.0, 1.0) if img.max() <= 1.0 else np.clip(img/255.0, 0.0, 1.0) + img = (img * 255.0 + 0.5).astype(np.uint8) + resized = _resize_hwc(img, self._target_hw).astype(np.uint8) + else: + # float32 output + if img.dtype == np.uint8: + if self._normalize: + img = img.astype(np.float32) / 255.0 + else: + img = img.astype(np.float32) # keep 0..255 range if you really want that + else: + img = img.astype(np.float32) + if self._normalize and img.max() > 1.0: + img = img / 255.0 + resized = _resize_hwc(img, self._target_hw).astype(np.float32) + + out[k] = resized + + return out \ No newline at end of file diff --git a/serl_robot_infra/README.md b/serl_robot_infra/README.md index 1abcd2ee..fd48b9ff 100644 --- a/serl_robot_infra/README.md +++ b/serl_robot_infra/README.md @@ -100,3 +100,11 @@ env = gym.make("FrankaEnv-Vision-v0") 4. bin relocation Please refer to their respective examples in `serl/examples/` directory. + +Navigation +---------- +- [Home](../README.md) +- [Overview](../docs/overview.md) +- [Installation guide](../docs/installation.md) +- [Run in simulation](../docs/run_sim.md) +- [Run on the real robot](../docs/run_realrobot.md) diff --git a/serl_robot_infra/franka_env/camera/rs_capture.py b/serl_robot_infra/franka_env/camera/rs_capture.py index 249135f8..acb7f464 100644 --- a/serl_robot_infra/franka_env/camera/rs_capture.py +++ b/serl_robot_infra/franka_env/camera/rs_capture.py @@ -1,3 +1,4 @@ +import time import numpy as np import pyrealsense2 as rs # Intel RealSense cross-platform open-source API @@ -9,16 +10,39 @@ def get_device_serial_numbers(self): def __init__(self, name, serial_number, dim=(640, 480), fps=15, depth=False): self.name = name - assert serial_number in self.get_device_serial_numbers() self.serial_number = serial_number self.depth = depth + + # Wait for camera to enumerate + for _ in range(5): + if serial_number in self.get_device_serial_numbers(): + break + time.sleep(1) + else: + raise RuntimeError( + f"Camera {serial_number} ({name}) not found after 5s" + ) + self.pipe = rs.pipeline() self.cfg = rs.config() self.cfg.enable_device(self.serial_number) self.cfg.enable_stream(rs.stream.color, dim[0], dim[1], rs.format.bgr8, fps) if self.depth: self.cfg.enable_stream(rs.stream.depth, dim[0], dim[1], rs.format.z16, fps) - self.profile = self.pipe.start(self.cfg) + + # Retry a few times before giving up. + last_err = None + for _ in range(3): + try: + self.profile = self.pipe.start(self.cfg) + break + except RuntimeError as e: + last_err = e + time.sleep(1.0) + else: + raise RuntimeError( + f"pipe.start failed for {name} ({serial_number}) after 3 attempts: {last_err}" + ) # Create an align object # rs.align allows us to perform alignment of depth frames to others frames diff --git a/serl_robot_infra/franka_env/camera/video_capture.py b/serl_robot_infra/franka_env/camera/video_capture.py index 842010ff..c5926ccc 100644 --- a/serl_robot_infra/franka_env/camera/video_capture.py +++ b/serl_robot_infra/franka_env/camera/video_capture.py @@ -19,8 +19,11 @@ def __init__(self, cap, name=None): def _reader(self): while self.enable: - time.sleep(0.01) - ret, frame = self.cap.read() + try: + time.sleep(0.01) + ret, frame = self.cap.read() + except Exception: + continue if not ret: break if not self.q.empty(): diff --git a/serl_robot_infra/franka_env/envs/bin_relocation_env/config.py b/serl_robot_infra/franka_env/envs/bin_relocation_env/config.py index f4887a03..bda1c10b 100644 --- a/serl_robot_infra/franka_env/envs/bin_relocation_env/config.py +++ b/serl_robot_infra/franka_env/envs/bin_relocation_env/config.py @@ -4,33 +4,35 @@ class BinEnvConfig(DefaultEnvConfig): """Set the configuration for FrankaEnv.""" - + WAIT_FOR_GRIPPER_SETTLED: bool = True SERVER_URL: str = "http://127.0.0.1:5000/" REALSENSE_CAMERAS = { - "wrist_1": "130322274175", - "front": "128422272758", + "wrist_1": "218622274083", + "front": "218622276001", } TARGET_POSE = np.array( [ - 0.485, - 0.025, - 0.047555915476419935, - 3.1331234, - 0.0182487, - 1.5824805, + 0.575, + 0.0, + 0.0, + 3.14, + 0.0, + 0.015 ] ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.zeros(6) - ACTION_SCALE = np.array([0.05, 0.1, 1]) + APPLY_GRIPPER_PENALTY = True + ACTION_SCALE = np.array([0.1, 0.2, 1]) RANDOM_RESET = False RANDOM_XY_RANGE = 0.1 RANDOM_RZ_RANGE = np.pi / 6 + # All the upper and lower adjustments happen in franka_bin_relocation.py:FrankBinRelocation:30 ABS_POSE_LIMIT_LOW = np.array( [ - TARGET_POSE[0] - 0.07, - TARGET_POSE[1] - 0.15, - TARGET_POSE[2] - 0.001, + TARGET_POSE[0] - 0.13, # -x axis + TARGET_POSE[1] - 0.17, # -y axis + TARGET_POSE[2] - (-0.0065), # -z axis TARGET_POSE[3] - 0.01, TARGET_POSE[4] - 0.01, TARGET_POSE[5] - RANDOM_RZ_RANGE, @@ -38,9 +40,9 @@ class BinEnvConfig(DefaultEnvConfig): ) ABS_POSE_LIMIT_HIGH = np.array( [ - TARGET_POSE[0] + 0.07, - TARGET_POSE[1] + 0.15, - TARGET_POSE[2] + 0.1, + TARGET_POSE[0] + 0.13, # +x axis + TARGET_POSE[1] + 0.17, # +y axis + TARGET_POSE[2] + 0.1, # +z axis TARGET_POSE[3] + 0.01, TARGET_POSE[4] + 0.01, TARGET_POSE[5] + RANDOM_RZ_RANGE, @@ -54,7 +56,7 @@ class BinEnvConfig(DefaultEnvConfig): "translational_Ki": 0, "translational_clip_x": 0.006, "translational_clip_y": 0.006, - "translational_clip_z": 0.005, + "translational_clip_z": 0.008, "translational_clip_neg_x": 0.006, "translational_clip_neg_y": 0.006, "translational_clip_neg_z": 0.005, @@ -85,4 +87,4 @@ class BinEnvConfig(DefaultEnvConfig): "rotational_clip_neg_y": 0.05, "rotational_clip_neg_z": 0.05, "rotational_Ki": 0.1, - } + } \ No newline at end of file diff --git a/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py b/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py index cc757058..ca536a40 100644 --- a/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py +++ b/serl_robot_infra/franka_env/envs/bin_relocation_env/franka_bin_relocation.py @@ -28,8 +28,8 @@ def __init__(self, **kwargs): and clips actions that will lead to collision. """ self.inner_safety_box = gym.spaces.Box( - self._TARGET_POSE[:3] - np.array([0.07, 0.03, 0.001]), - self._TARGET_POSE[:3] + np.array([0.07, 0.03, 0.04]), + self._TARGET_POSE[:3] - np.array([0, 0, 0]), + self._TARGET_POSE[:3] + np.array([0, 0, 0]), dtype=np.float64, ) @@ -126,10 +126,26 @@ def set_task_id(self, task_id): self.task_id = task_id def reset(self, joint_reset=False, **kwargs): + ''' + Set resest position for end-effector based on TARGET_POSE. + Select values for forward and backward policy that maximize + your viewing range from global camera. + + This is experiment-setup specific and will depend on where + your global camera is positioned and the size of your trays. + ''' + # Forward policy offset if self.task_id == 0: - self.resetpos[1] = self._TARGET_POSE[1] + 0.1 + X_OFFSET_FW = -0.025 + Y_OFFSET_FW = 0.025 + self.resetpos[0] = self._TARGET_POSE[0] + X_OFFSET_FW + self.resetpos[1] = self._TARGET_POSE[1] + Y_OFFSET_FW + # Backward policy offset elif self.task_id == 1: - self.resetpos[1] = self._TARGET_POSE[1] - 0.1 + X_OFFSET_BW = 0.2 + Y_OFFSET_BW = 0.05 + self.resetpos[0] = self._TARGET_POSE[0] - Y_OFFSET_BW + self.resetpos[1] = self._TARGET_POSE[1] - X_OFFSET_BW else: raise ValueError(f"Task id {self.task_id} should be 0 or 1") @@ -140,16 +156,18 @@ def go_to_rest(self, joint_reset=False): Move to the rest position defined in base class. Add a small z offset before going to rest to avoid collision with object. """ - self._send_gripper_command(1) + # Open gripper + self._send_gripper_command(1, reset=True) + + # Get current position self._update_currpos() self._send_pos_command(self.currpos) time.sleep(0.5) - # Move up to clear the slot - self._update_currpos() + # Move up 0.05m in Z-axis to clear any objects in the slot reset_pose = copy.deepcopy(self.currpos) reset_pose[2] += 0.05 self.interpolate_move(reset_pose, timeout=1) - # execute the go_to_rest method from the parent class - super().go_to_rest(joint_reset) + # Call parent to finish reset + super().go_to_rest(joint_reset) \ No newline at end of file diff --git a/serl_robot_infra/franka_env/envs/cable_env/config.py b/serl_robot_infra/franka_env/envs/cable_env/config.py index 8cf4603f..8e72de1f 100644 --- a/serl_robot_infra/franka_env/envs/cable_env/config.py +++ b/serl_robot_infra/franka_env/envs/cable_env/config.py @@ -7,19 +7,17 @@ class CableEnvConfig(DefaultEnvConfig): SERVER_URL: str = "http://127.0.0.1:5000/" REALSENSE_CAMERAS = { - "wrist_1": "130322274175", - "wrist_2": "127122270572", + "wrist_1": "218622274083", + "wrist_2": "218622271526", } TARGET_POSE = np.array( - [ - 0.460639895728905, - -0.02439473272513422, - 0.026321125814908725, - 3.1331234, - 0.0182487, - 1.5824805, - ] - ) + [0.5712090556314777, + 0.08195494837073188, + -0.027075318869685436, + 3.1256426554697843, + -0.0539225368305436, + 1.6104289490096944] + ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.zeros(6) APPLY_GRIPPER_PENALTY = False diff --git a/serl_robot_infra/franka_env/envs/cable_env/franka_cable_route.py b/serl_robot_infra/franka_env/envs/cable_env/franka_cable_route.py index d6d3b633..5e615275 100644 --- a/serl_robot_infra/franka_env/envs/cable_env/franka_cable_route.py +++ b/serl_robot_infra/franka_env/envs/cable_env/franka_cable_route.py @@ -14,12 +14,12 @@ class FrankaCableRoute(FrankaEnv): def __init__(self, **kwargs): super().__init__(**kwargs, config=CableEnvConfig) - def go_to_rest(self, joint_reset=False): + def go_to_rest(self, joint_reset=False, ): """ Move to the rest position defined in base class. Add a small z offset before going to rest to avoid collision with object. """ - self._send_gripper_command(-1) + # self._send_gripper_command(1) <----- uncomment if picking up cable is part of experiment self._update_currpos() self._send_pos_command(self.currpos) time.sleep(0.5) diff --git a/serl_robot_infra/franka_env/envs/franka_env.py b/serl_robot_infra/franka_env/envs/franka_env.py index 66d3f307..5d99cec6 100644 --- a/serl_robot_infra/franka_env/envs/franka_env.py +++ b/serl_robot_infra/franka_env/envs/franka_env.py @@ -61,7 +61,8 @@ class DefaultEnvConfig: PRECISION_PARAM: Dict[str, float] = {} BINARY_GRIPPER_THREASHOLD: float = 0.5 APPLY_GRIPPER_PENALTY: bool = True - GRIPPER_PENALTY: float = 0.1 + GRIPPER_PENALTY: float = 0.05 + WAIT_FOR_GRIPPER_SETTLED: bool = False ############################################################################## @@ -76,6 +77,8 @@ def __init__( config: DefaultEnvConfig = None, max_episode_length=100, ): + self.last_gripper_cmd_time = 0.0 + self.min_gripper_cmd_interval = 1.5 self.action_scale = config.ACTION_SCALE self._TARGET_POSE = config.TARGET_POSE self._REWARD_THRESHOLD = config.REWARD_THRESHOLD @@ -219,14 +222,47 @@ def step(self, action: np.ndarray) -> tuple: return ob, reward, done, False, {} def compute_reward(self, obs, gripper_action_effective) -> bool: - """We are using a sparse reward function.""" + """ + compute_reward() computes a sparse reward. A reward of 1 is given when the current pose and the target pose are within a threshold. + + Modification: originallly this code compared the pose and had a significant weakness: the difference between y-angles could sometimes yield values close to 2pi given that we have +pi and -pi discontinuity relative to the peg-insertion orientation. This led to not picking rewards when they should have been produced. + + We will not separate position and orientation and compute the angle difference with utility method angle_diff. + """ + + # Get cartesian and quaternion pose information current_pose = obs["state"]["tcp_pose"] + # convert from quat to euler first - euler_angles = quat_2_euler(current_pose[3:]) - euler_angles = np.abs(euler_angles) - current_pose = np.hstack([current_pose[:3], euler_angles]) - delta = np.abs(current_pose - self._TARGET_POSE) - if np.all(delta < self._REWARD_THRESHOLD): + # euler_angles = quat_2_euler(current_pose[3:]) + # euler_angles = np.abs(euler_angles) + + # Get current orientation in euler angles + current_euler = quat_2_euler(current_pose[3:]) + + # Get target euler + target_euler = self._TARGET_POSE[3:] + + # Compute position and orientation deltas + pos_delta = np.abs(current_pose[:3] - self._TARGET_POSE[:3]) + + rot_delta = self.angle_diff(current_euler,target_euler) + + # Stack position and orientation differences + delta = np.hstack([pos_delta,rot_delta]) + + # current_pose = np.hstack([current_pose[:3], euler_angles]) + # delta = np.abs(current_pose - self._TARGET_POSE) + + # An all-zero REWARD_THRESHOLD is the explicit "disabled" sentinel: delta + # is non-negative by construction, so no pose could ever satisfy it. + # Tasks that source reward elsewhere (e.g. bin relocation, which uses the + # FWBW classifier wrapper) leave it zero. Return 0 directly so this path + # reads as intentionally inert rather than an accidentally-false compare. + if not np.any(self._REWARD_THRESHOLD): + reward = 0 + # If difference meets threshold produce reward + elif np.all(delta < self._REWARD_THRESHOLD): reward = 1 else: # print(f'Goal not reached, the difference is {delta}, the desired threshold is {_REWARD_THRESHOLD}') @@ -319,7 +355,7 @@ def go_to_rest(self, joint_reset=False): # Change to compliance mode requests.post(self.url + "update_param", json=self.config.COMPLIANCE_PARAM) - def reset(self, joint_reset=False, **kwargs): + def reset(self, joint_reset=False, pos_reset=True, **kwargs): requests.post(self.url + "update_param", json=self.config.COMPLIANCE_PARAM) if self.save_video: self.save_video_recording() @@ -328,8 +364,9 @@ def reset(self, joint_reset=False, **kwargs): if self.cycle_count % self.joint_reset_cycle == 0: self.cycle_count = 0 joint_reset = True - - self.go_to_rest(joint_reset=joint_reset) + + if pos_reset: + self.go_to_rest(joint_reset=joint_reset) self._recover() self.curr_path_length = 0 @@ -385,29 +422,43 @@ def _send_pos_command(self, pos: np.ndarray): data = {"arr": arr.tolist()} requests.post(self.url + "pose", json=data) - def _send_gripper_command(self, pos: float, mode="binary"): - """Internal function to send gripper command to the robot.""" - if mode == "binary": + def _send_gripper_command(self, pos: float, mode="binary", reset=False): + """Send binary gripper command, but rate-limit physical commands.""" + if mode != "binary": + raise NotImplementedError("Continuous gripper control is optional") + + now = time.time() + + # Rate limit physical gripper commands. + if now - self.last_gripper_cmd_time < self.min_gripper_cmd_interval and reset == False: + return False + + try: if ( pos <= -self.config.BINARY_GRIPPER_THREASHOLD and self.gripper_binary_state == 0 - ): # close gripper - requests.post(self.url + "close_gripper") - time.sleep(0.6) + ): + requests.post(self.url + "close_gripper", timeout=1.0) self.gripper_binary_state = 1 + self.last_gripper_cmd_time = time.time() + # time.sleep(0.8) return True + elif ( pos >= self.config.BINARY_GRIPPER_THREASHOLD and self.gripper_binary_state == 1 - ): # open gripper - requests.post(self.url + "open_gripper") - time.sleep(0.6) + ): + requests.post(self.url + "open_gripper", timeout=1.0) self.gripper_binary_state = 0 + self.last_gripper_cmd_time = time.time() + # time.sleep(0.8) return True - else: # do nothing to the gripper - return False - elif mode == "continuous": - raise NotImplementedError("Continuous gripper control is optional") + + return False + + except requests.exceptions.RequestException as e: + print(f"Gripper command failed: {e}") + return False def _update_currpos(self): """ @@ -436,3 +487,6 @@ def _get_obs(self) -> dict: "tcp_torque": self.currtorque, } return copy.deepcopy(dict(images=images, state=state_observation)) + + def angle_diff(self,a,b): + return np.abs((a-b+np.pi)%(2*np.pi) - np.pi) \ No newline at end of file diff --git a/serl_robot_infra/franka_env/envs/peg_env/config.py b/serl_robot_infra/franka_env/envs/peg_env/config.py index d2bcae9b..d6dc8e01 100644 --- a/serl_robot_infra/franka_env/envs/peg_env/config.py +++ b/serl_robot_infra/franka_env/envs/peg_env/config.py @@ -7,24 +7,17 @@ class PegEnvConfig(DefaultEnvConfig): SERVER_URL: str = "http://127.0.0.1:5000/" REALSENSE_CAMERAS = { - "wrist_1": "130322274175", - "wrist_2": "127122270572", + "wrist_1": "218622274083", + "wrist_2": "218622271526", } TARGET_POSE = np.array( - [ - 0.5906439143742067, - 0.07771711953459341, - 0.0937835826958042, - 3.1099675, - 0.0146619, - -0.0078615, - ] - ) + [0.6140227114814025,-0.08283620130379565,0.07813127975429472,-3.128427831028935,-0.03985221222452284,1.5944114064687414] + ) RESET_POSE = TARGET_POSE + np.array([0.0, 0.0, 0.1, 0.0, 0.0, 0.0]) REWARD_THRESHOLD: np.ndarray = np.array([0.01, 0.01, 0.01, 0.2, 0.2, 0.2]) APPLY_GRIPPER_PENALTY = False ACTION_SCALE = np.array([0.02, 0.1, 1]) - RANDOM_RESET = True + RANDOM_RESET = True #Turn to true after basic task is finished RANDOM_XY_RANGE = 0.05 RANDOM_RZ_RANGE = np.pi / 6 ABS_POSE_LIMIT_LOW = np.array( @@ -53,10 +46,10 @@ class PegEnvConfig(DefaultEnvConfig): "rotational_stiffness": 150, "rotational_damping": 7, "translational_Ki": 0, - "translational_clip_x": 0.003, + "translational_clip_x": 0.008, # JR: mod from 0.003 to 0.008 to improve negative motion. But not consistent. "translational_clip_y": 0.003, "translational_clip_z": 0.01, - "translational_clip_neg_x": 0.003, + "translational_clip_neg_x": 0.008, # JR: mod from 0.003 to 0.008. improved forward motion. "translational_clip_neg_y": 0.003, "translational_clip_neg_z": 0.01, "rotational_clip_x": 0.02, diff --git a/serl_robot_infra/franka_env/envs/relative_env.py b/serl_robot_infra/franka_env/envs/relative_env.py index 7d48f526..ab5aceff 100644 --- a/serl_robot_infra/franka_env/envs/relative_env.py +++ b/serl_robot_infra/franka_env/envs/relative_env.py @@ -10,8 +10,18 @@ class RelativeFrame(gym.Wrapper): """ - This wrapper transforms the observation and action to be expressed in the end-effector frame. - Optionally, it can transform the tcp_pose into a relative frame defined as the reset pose. + This wrapper transforms the observation and action to be expressed in the end-effector frame at reset. + All measurements are thus relative to the starting reset frame. + + Consider the following frames and nomenclatures: + o: base frame + r: the frozen tcp frame that happens only at reset time. + b: end-effector frame + + Notation: T_a_b would represent a transformation from b to a. + Transformation: reset_tcp_frame -> base_frame: T_r_o (o to r) + Transformation: end_effector_frame -> base_frame: T_b_o (o to b) + Transformation: end_effector_frame -> reset_tcp_frame: T_b_r = T_r_o_inv * T_b_o (r to b) This wrapper is expected to be used on top of the base Franka environment, which has the following observation space: @@ -30,10 +40,16 @@ class RelativeFrame(gym.Wrapper): def __init__(self, env: Env, include_relative_pose=True): super().__init__(env) + + # Adjoint matrix used to convert tcp_vel or actions from base frame to end-effector frame via Adj(T)^(-1)*tcp_vel self.adjoint_matrix = np.zeros((6, 6)) self.include_relative_pose = include_relative_pose if self.include_relative_pose: + # o: base frame + # r: the frozen tcp frame that happens only at reset time. + # b: end-effector frame + # Transformation from base to tcp: T_r_o # Homogeneous transformation matrix from reset pose's relative frame to base frame self.T_r_o_inv = np.zeros((4, 4)) @@ -73,8 +89,15 @@ def reset(self, **kwargs): def transform_observation(self, obs): """ - Transform observations from spatial(base) frame into body(end-effector) frame - using the adjoint matrix + Convert the environment's observation into the frames expected by the policy. + + * Linear/angular velocities are provided by the wrapped env in the spatial (base) + frame; we left-multiply them by ``Adj(T)^{-1}`` so they are expressed in the + instantaneous body (end-effector) frame. + * When ``include_relative_pose`` is enabled, tcp poses are re-expressed relative + to the pose at reset. That is, we compute ``T_r^b = T_r^o @ T_o^b`` and return + the position/quaternion extracted from ``T_r^b``. + * Image observations pass through untouched. """ adjoint_inv = np.linalg.inv(self.adjoint_matrix) obs["state"]["tcp_vel"] = adjoint_inv @ obs["state"]["tcp_vel"] diff --git a/serl_robot_infra/franka_env/envs/wrappers.py b/serl_robot_infra/franka_env/envs/wrappers.py index af69605e..28a53874 100644 --- a/serl_robot_infra/franka_env/envs/wrappers.py +++ b/serl_robot_infra/franka_env/envs/wrappers.py @@ -90,11 +90,18 @@ def __init__(self, env: Env, reward_classifier_func): def compute_reward(self, obs): if self.reward_classifier_func is not None: + + # Extract the reward as a scalar via .item() and pass to logit logit = self.reward_classifier_func(obs).item() + + # Sigmoid will convert logit to a prob. If prob is greater than 0.5, True, but cast to int by * 1. return (sigmoid(logit) >= 0.5) * 1 return 0 def step(self, action): + """ + Step the wrapped env, add classifier success bonus, and terminate on success. + """ obs, rew, done, truncated, info = self.env.step(action) success = self.compute_reward(obs) rew += success diff --git a/serl_robot_infra/robot_servers/franka_gripper_server.py b/serl_robot_infra/robot_servers/franka_gripper_server.py index 48b285ce..03a05074 100644 --- a/serl_robot_infra/robot_servers/franka_gripper_server.py +++ b/serl_robot_infra/robot_servers/franka_gripper_server.py @@ -28,10 +28,10 @@ def open(self): def close(self): msg = GraspActionGoal() msg.goal.width = 0.01 - msg.goal.speed = 0.3 - msg.goal.epsilon.inner = 1 - msg.goal.epsilon.outer = 1 - msg.goal.force = 130 + msg.goal.speed = 0.5 + msg.goal.epsilon.inner = 0.001 + msg.goal.epsilon.outer = 1.0 + msg.goal.force = 30.0 self.grippergrasppub.publish(msg) def move(self, position: int): diff --git a/serl_robot_infra/robot_servers/franka_server.py b/serl_robot_infra/robot_servers/franka_server.py index 0c582257..5d45378e 100644 --- a/serl_robot_infra/robot_servers/franka_server.py +++ b/serl_robot_infra/robot_servers/franka_server.py @@ -17,7 +17,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string( - "robot_ip", "172.16.0.2", "IP address of the franka robot's controller box" + "robot_ip", "10.200.110.10", "IP address of the franka robot's controller box" ) flags.DEFINE_string( "gripper_ip", "192.168.1.114", "IP address of the robotiq gripper if being used"