Skip to content

Adds changes for rsl_rl 3.0.0 #2962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions scripts/reinforcement_learning/rsl_rl/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg


def add_rsl_rl_args(parser: argparse.ArgumentParser):
Expand Down Expand Up @@ -39,7 +39,7 @@ def add_rsl_rl_args(parser: argparse.ArgumentParser):
)


def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlBaseRunnerCfg:
"""Parse configuration for RSL-RL agent based on inputs.

Args:
Expand All @@ -52,12 +52,12 @@ def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPol
from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry

# load the default configuration
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg: RslRlBaseRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
return rslrl_cfg


def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
def update_rsl_rl_cfg(agent_cfg: RslRlBaseRunnerCfg, args_cli: argparse.Namespace):
"""Update configuration for RSL-RL agent based on inputs.

Args:
Expand Down
39 changes: 25 additions & 14 deletions scripts/reinforcement_learning/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
import time
import torch

from rsl_rl.runners import OnPolicyRunner
from rsl_rl.runners import DistillationRunner, OnPolicyRunner

from isaaclab.envs import (
DirectMARLEnv,
Expand All @@ -68,7 +68,7 @@
from isaaclab.utils.dict import print_dict
from isaaclab.utils.pretrained_checkpoint import get_published_pretrained_checkpoint

from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx

import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path
Expand All @@ -78,14 +78,14 @@


@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
"""Play with RSL-RL agent."""
# grab task name for checkpoint path
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")

# override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
agent_cfg: RslRlBaseRunnerCfg = cli_args.parse_rsl_rl_cfg(task_name, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs

# set the environment seed
Expand Down Expand Up @@ -133,32 +133,43 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen

print(f"[INFO]: Loading model checkpoint from: {resume_path}")
# load previously trained model
ppo_runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
ppo_runner.load(resume_path)
if agent_cfg.class_name == "OnPolicyRunner":
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
elif agent_cfg.class_name == "DistillationRunner":
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
else:
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
runner.load(resume_path)

# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)
policy = runner.get_inference_policy(device=env.unwrapped.device)

# extract the neural network module
# we do this in a try-except to maintain backwards compatibility.
try:
# version 2.3 onwards
policy_nn = ppo_runner.alg.policy
policy_nn = runner.alg.policy
except AttributeError:
# version 2.2 and below
policy_nn = ppo_runner.alg.actor_critic
policy_nn = runner.alg.actor_critic

# extract the normalizer
if hasattr(policy_nn, "actor_obs_normalizer"):
normalizer = policy_nn.actor_obs_normalizer
elif hasattr(policy_nn, "student_obs_normalizer"):
normalizer = policy_nn.student_obs_normalizer
else:
normalizer = None

# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_jit(policy_nn, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(
policy_nn, normalizer=ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")

dt = env.unwrapped.step_dt

# reset environment
obs, _ = env.get_observations()
obs = env.get_observations()
timestep = 0
# simulate environment
while simulation_app.is_running():
Expand Down
18 changes: 11 additions & 7 deletions scripts/reinforcement_learning/rsl_rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# local imports
import cli_args # isort: skip


# add argparse arguments
parser = argparse.ArgumentParser(description="Train an RL agent with RSL-RL.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
Expand Down Expand Up @@ -53,9 +52,9 @@
from packaging import version

# for distributed training, check minimum supported rsl-rl version
RSL_RL_VERSION = "2.3.1"
RSL_RL_VERSION = "3.0.0"
installed_version = metadata.version("rsl-rl-lib")
if args_cli.distributed and version.parse(installed_version) < version.parse(RSL_RL_VERSION):
if version.parse(installed_version) < version.parse(RSL_RL_VERSION):
if platform.system() == "Windows":
cmd = [r".\isaaclab.bat", "-p", "-m", "pip", "install", f"rsl-rl-lib=={RSL_RL_VERSION}"]
else:
Expand All @@ -74,7 +73,7 @@
import torch
from datetime import datetime

from rsl_rl.runners import OnPolicyRunner
from rsl_rl.runners import DistillationRunner, OnPolicyRunner

from isaaclab.envs import (
DirectMARLEnv,
Expand All @@ -86,7 +85,7 @@
from isaaclab.utils.dict import print_dict
from isaaclab.utils.io import dump_pickle, dump_yaml

from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper

import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path
Expand All @@ -101,7 +100,7 @@


@hydra_task_config(args_cli.task, "rsl_rl_cfg_entry_point")
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlOnPolicyRunnerCfg):
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
"""Train with RSL-RL agent."""
# override configurations with non-hydra CLI arguments
agent_cfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
Expand Down Expand Up @@ -164,7 +163,12 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)

# create runner from rsl-rl
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
if agent_cfg.class_name == "OnPolicyRunner":
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
elif agent_cfg.class_name == "DistillationRunner":
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=log_dir, device=agent_cfg.device)
else:
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
# write git state to logs
runner.add_git_repo_to_log(__file__)
# load the checkpoint
Expand Down
12 changes: 12 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/distillation_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ class RslRlDistillationStudentTeacherCfg:
noise_std_type: Literal["scalar", "log"] = "scalar"
"""The type of noise standard deviation for the policy. Default is scalar."""

student_obs_normalization: bool = False
"""Whether to normalize the observation for the student network. Default is False."""

teacher_obs_normalization: bool = False
"""Whether to normalize the observation for the teacher network. Default is False."""

student_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the student network."""

Expand Down Expand Up @@ -81,3 +87,9 @@ class RslRlDistillationAlgorithmCfg:

max_grad_norm: None | float = None
"""The maximum norm the gradient is clipped to."""

optimizer: Literal["adam", "adamw", "sgd", "rmsprop"] = "adam"
"""The optimizer to use for the student policy."""

loss_type: Literal["mse", "huber"] = "mse"
"""The loss type to use for the student policy."""
84 changes: 69 additions & 15 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/rl_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class RslRlPpoActorCriticCfg:
noise_std_type: Literal["scalar", "log"] = "scalar"
"""The type of noise standard deviation for the policy. Default is scalar."""

actor_obs_normalization: bool = False
"""Whether to normalize the observation for the actor network. Default is False."""

critic_obs_normalization: bool = False
"""Whether to normalize the observation for the critic network. Default is False."""

actor_hidden_dims: list[int] = MISSING
"""The hidden dimensions of the actor network."""

Expand Down Expand Up @@ -114,23 +120,21 @@ class RslRlPpoAlgorithmCfg:
Otherwise, the advantage is normalized over the entire collected trajectories.
"""

rnd_cfg: RslRlRndCfg | None = None
"""The RND configuration. Default is None, in which case RND is not used."""

symmetry_cfg: RslRlSymmetryCfg | None = None
"""The symmetry configuration. Default is None, in which case symmetry is not used."""

rnd_cfg: RslRlRndCfg | None = None
"""The configuration for the Random Network Distillation (RND) module. Default is None,
in which case RND is not used.
"""


#########################
# Runner configurations #
#########################


@configclass
class RslRlOnPolicyRunnerCfg:
"""Configuration of the runner for on-policy algorithms."""
class RslRlBaseRunnerCfg:
"""Base configuration of the runner."""

seed: int = 42
"""The seed for the experiment. Default is 42."""
Expand All @@ -144,17 +148,36 @@ class RslRlOnPolicyRunnerCfg:
max_iterations: int = MISSING
"""The maximum number of iterations."""

empirical_normalization: bool = MISSING
"""Whether to use empirical normalization."""
empirical_normalization: bool | None = None
"""This parameter is deprecated and will be removed in the future.

policy: RslRlPpoActorCriticCfg | RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""
Use `actor_obs_normalization` and `critic_obs_normalization` instead.
"""

algorithm: RslRlPpoAlgorithmCfg | RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
obs_groups: dict[str, list[str]] = MISSING
"""A mapping from observation groups to observation sets.

The keys of the dictionary are predefined observation sets used by the underlying algorithm
and values are lists of observation groups provided by the environment.

For instance, if the environment provides a dictionary of observations with groups "policy", "images",
and "privileged", these can be mapped to algorithmic observation sets as follows:

.. code-block:: python

obs_groups = {
"policy": ["policy", "images"],
"critic": ["policy", "privileged"],
}

This way, the policy will receive the "policy" and "images" observations, and the critic will
receive the "policy" and "privileged" observations.

For more details, please check ``vec_env.py`` in the rsl_rl library.
"""

clip_actions: float | None = None
"""The clipping value for actions. If ``None``, then no clipping is done.
"""The clipping value for actions. If None, then no clipping is done. Defaults to None.

.. note::
This clipping is performed inside the :class:`RslRlVecEnvWrapper` wrapper.
Expand Down Expand Up @@ -184,7 +207,10 @@ class RslRlOnPolicyRunnerCfg:
"""The wandb project name. Default is "isaaclab"."""

resume: bool = False
"""Whether to resume. Default is False."""
"""Whether to resume a previous training. Default is False.

This flag will be ignored for distillation.
"""

load_run: str = ".*"
"""The run directory to load. Default is ".*" (all).
Expand All @@ -197,3 +223,31 @@ class RslRlOnPolicyRunnerCfg:

If regex expression, the latest (alphabetical order) matching file will be loaded.
"""


@configclass
class RslRlOnPolicyRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for on-policy algorithms."""

class_name: str = "OnPolicyRunner"
"""The runner class name. Default is OnPolicyRunner."""

policy: RslRlPpoActorCriticCfg = MISSING
"""The policy configuration."""

algorithm: RslRlPpoAlgorithmCfg = MISSING
"""The algorithm configuration."""


@configclass
class RslRlDistillationRunnerCfg(RslRlBaseRunnerCfg):
"""Configuration of the runner for distillation algorithms."""

class_name: str = "DistillationRunner"
"""The runner class name. Default is DistillationRunner."""

policy: RslRlDistillationStudentTeacherCfg = MISSING
"""The policy configuration."""

algorithm: RslRlDistillationAlgorithmCfg = MISSING
"""The algorithm configuration."""
6 changes: 2 additions & 4 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/symmetry_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,11 @@ class RslRlSymmetryCfg:
Args:

env (VecEnv): The environment object. This is used to access the environment's properties.
obs (torch.Tensor | None): The observation tensor. If None, the observation is not used.
obs (tensordict.TensorDict | None): The observation tensor dictionary. If None, the observation is not used.
action (torch.Tensor | None): The action tensor. If None, the action is not used.
obs_type (str): The name of the observation type. Defaults to "policy".
This is useful when handling augmentation for different observation groups.

Returns:
A tuple containing the augmented observation and action tensors. The tensors can be None,
A tuple containing the augmented observation dictionary and action tensors. The tensors can be None,
if their respective inputs are None.
"""

Expand Down
Loading
Loading