Skip to content

Commit 6580c58

Browse files
Toni-SMmohanksriram
authored andcommitted
Updates skrl integration to support training/evaluation using JAX (isaac-sim#592)
# Description This PR updates the skrl integration to support training/evaluation using JAX ML framework ## Type of change - New feature (non-breaking change which adds functionality) - This change requires a documentation update ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have run all the tests with `./isaaclab.sh --test` and they pass - [x] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there
1 parent d6a4ae3 commit 6580c58

File tree

6 files changed

+140
-35
lines changed

6 files changed

+140
-35
lines changed

docs/source/setup/sample.rst

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,31 @@ from the environments into the respective libraries function argument and return
184184
- Training an agent with
185185
`SKRL <https://skrl.readthedocs.io>`__ on ``Isaac-Reach-Franka-v0``:
186186

187-
.. code:: bash
187+
.. tab-set::
188188

189-
# install python module (for skrl)
190-
./isaaclab.sh -i skrl
191-
# run script for training
192-
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
193-
# run script for playing with 32 environments
194-
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
189+
.. tab-item:: PyTorch
190+
191+
.. code:: bash
192+
193+
# install python module (for skrl)
194+
./isaaclab.sh -i skrl
195+
# run script for training
196+
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless
197+
# run script for playing with 32 environments
198+
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --checkpoint /PATH/TO/model.pt
199+
200+
.. tab-item:: JAX
201+
202+
.. code:: bash
203+
204+
# install python module (for skrl)
205+
./isaaclab.sh -i skrl
206+
# install skrl dependencies for JAX. Visit https://skrl.readthedocs.io/en/latest/intro/installation.html for more details
207+
./isaaclab.sh -p -m pip install skrl["jax"]
208+
# run script for training
209+
./isaaclab.sh -p source/standalone/workflows/skrl/train.py --task Isaac-Reach-Franka-v0 --headless --ml_framework jax
210+
# run script for playing with 32 environments
211+
./isaaclab.sh -p source/standalone/workflows/skrl/play.py --task Isaac-Reach-Franka-v0 --num_envs 32 --ml_framework jax --checkpoint /PATH/TO/model.pt
195212
196213
- Training an agent with
197214
`RL-Games <https://github.com/Denys88/rl_games>`__ on ``Isaac-Ant-v0``:

source/extensions/omni.isaac.lab_tasks/config/extension.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[package]
22

33
# Note: Semantic Versioning is used: https://semver.org/
4-
version = "0.7.9"
4+
version = "0.7.10"
55

66
# Description
77
title = "Isaac Lab Environments"

source/extensions/omni.isaac.lab_tasks/docs/CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Changelog
22
---------
33

4+
0.7.10 (2024-07-02)
5+
~~~~~~~~~~~~~~~~~~
6+
7+
Added
8+
^^^^^
9+
10+
* Extended skrl wrapper to support training/evaluation using JAX
11+
12+
413
0.7.9 (2024-07-01)
514
~~~~~~~~~~~~~~~~~~
615

source/extensions/omni.isaac.lab_tasks/omni/isaac/lab_tasks/utils/wrappers/skrl.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
1212
from omni.isaac.lab_tasks.utils.wrappers.skrl import SkrlVecEnvWrapper
1313
14-
env = SkrlVecEnvWrapper(env)
14+
env = SkrlVecEnvWrapper(env, ml_framework="torch") # or ml_framework="jax"
1515
1616
Or, equivalently, by directly calling the skrl library API as follows:
1717
1818
.. code-block:: python
1919
20-
from skrl.envs.torch.wrappers import wrap_env
20+
from skrl.envs.torch.wrappers import wrap_env # for PyTorch, or...
21+
from skrl.envs.jax.wrappers import wrap_env # for JAX
2122
2223
env = wrap_env(env, wrapper="isaaclab")
2324
@@ -26,10 +27,7 @@
2627
# needed to import for type hinting: Agent | list[Agent]
2728
from __future__ import annotations
2829

29-
from skrl.envs.wrappers.torch import wrap_env
30-
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa: F401
31-
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa: F401
32-
from skrl.utils.model_instantiators.torch import Shape # noqa: F401
30+
from typing import Literal
3331

3432
from omni.isaac.lab.envs import DirectRLEnv, ManagerBasedRLEnv
3533

@@ -38,14 +36,18 @@
3836
"""
3937

4038

41-
def process_skrl_cfg(cfg: dict) -> dict:
39+
def process_skrl_cfg(cfg: dict, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch") -> dict:
4240
"""Convert simple YAML types to skrl classes/components.
4341
4442
Args:
4543
cfg: A configuration dictionary.
44+
ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
4645
4746
Returns:
4847
A dictionary containing the converted configuration.
48+
49+
Raises:
50+
ValueError: If the specified ML framework is not valid.
4951
"""
5052
_direct_eval = [
5153
"learning_rate_scheduler",
@@ -62,6 +64,20 @@ def reward_shaper(rewards, timestep, timesteps):
6264
return reward_shaper
6365

6466
def update_dict(d):
67+
# import statements according to the ML framework
68+
if ml_framework.startswith("torch"):
69+
from skrl.resources.preprocessors.torch import RunningStandardScaler # noqa: F401
70+
from skrl.resources.schedulers.torch import KLAdaptiveLR # noqa: F401
71+
from skrl.utils.model_instantiators.torch import Shape # noqa: F401
72+
elif ml_framework.startswith("jax"):
73+
from skrl.resources.preprocessors.jax import RunningStandardScaler # noqa: F401
74+
from skrl.resources.schedulers.jax import KLAdaptiveLR # noqa: F401
75+
from skrl.utils.model_instantiators.jax import Shape # noqa: F401
76+
else:
77+
ValueError(
78+
f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
79+
)
80+
6581
for key, value in d.items():
6682
if isinstance(value, dict):
6783
update_dict(value)
@@ -84,7 +100,7 @@ def update_dict(d):
84100
"""
85101

86102

87-
def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
103+
def SkrlVecEnvWrapper(env: ManagerBasedRLEnv, ml_framework: Literal["torch", "jax", "jax-numpy"] = "torch"):
88104
"""Wraps around Isaac Lab environment for skrl.
89105
90106
This function wraps around the Isaac Lab environment. Since the :class:`ManagerBasedRLEnv` environment
@@ -94,9 +110,11 @@ def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
94110
95111
Args:
96112
env: The environment to wrap around.
113+
ml_framework: The ML framework to use for the wrapper. Defaults to "torch".
97114
98115
Raises:
99116
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv`.
117+
ValueError: If the specified ML framework is not valid.
100118
101119
Reference:
102120
https://skrl.readthedocs.io/en/latest/api/envs/wrapping.html
@@ -106,5 +124,16 @@ def SkrlVecEnvWrapper(env: ManagerBasedRLEnv):
106124
raise ValueError(
107125
f"The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type: {type(env)}"
108126
)
127+
128+
# import statements according to the ML framework
129+
if ml_framework.startswith("torch"):
130+
from skrl.envs.wrappers.torch import wrap_env
131+
elif ml_framework.startswith("jax"):
132+
from skrl.envs.wrappers.jax import wrap_env
133+
else:
134+
ValueError(
135+
f"Invalid ML framework for skrl: {ml_framework}. Available options are: 'torch', 'jax' or 'jax-numpy'"
136+
)
137+
109138
# wrap and return the environment
110139
return wrap_env(env, wrapper="isaaclab")

source/standalone/workflows/skrl/play.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
2727
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
2828
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
29+
parser.add_argument(
30+
"--ml_framework",
31+
type=str,
32+
default="torch",
33+
choices=["torch", "jax", "jax-numpy"],
34+
help="The ML framework used for training the skrl agent.",
35+
)
36+
2937
# append AppLauncher cli args
3038
AppLauncher.add_app_launcher_args(parser)
3139
# parse the arguments
@@ -41,8 +49,14 @@
4149
import os
4250
import torch
4351

44-
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
45-
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
52+
import skrl
53+
54+
if args_cli.ml_framework.startswith("torch"):
55+
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
56+
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
57+
elif args_cli.ml_framework.startswith("jax"):
58+
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
59+
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
4660

4761
import omni.isaac.lab_tasks # noqa: F401
4862
from omni.isaac.lab_tasks.utils import get_checkpoint_path, load_cfg_from_registry, parse_env_cfg
@@ -51,7 +65,10 @@
5165

5266
def main():
5367
"""Play with skrl agent."""
54-
# parse env configuration
68+
# configure the ML framework into the global skrl variable
69+
if args_cli.ml_framework.startswith("jax"):
70+
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
71+
# parse configuration
5572
env_cfg = parse_env_cfg(
5673
args_cli.task, use_gpu=not args_cli.cpu, num_envs=args_cli.num_envs, use_fabric=not args_cli.disable_fabric
5774
)
@@ -60,24 +77,26 @@ def main():
6077
# create isaac environment
6178
env = gym.make(args_cli.task, cfg=env_cfg)
6279
# wrap around environment for skrl
63-
env = SkrlVecEnvWrapper(env) # same as: `wrap_env(env, wrapper="isaaclab")`
80+
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
6481

6582
# instantiate models using skrl model instantiator utility
6683
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
6784
models = {}
85+
if args_cli.ml_framework.startswith("jax"):
86+
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX
6887
# non-shared models
6988
if experiment_cfg["models"]["separate"]:
7089
models["policy"] = gaussian_model(
7190
observation_space=env.observation_space,
7291
action_space=env.action_space,
7392
device=env.device,
74-
**process_skrl_cfg(experiment_cfg["models"]["policy"]),
93+
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
7594
)
7695
models["value"] = deterministic_model(
7796
observation_space=env.observation_space,
7897
action_space=env.action_space,
7998
device=env.device,
80-
**process_skrl_cfg(experiment_cfg["models"]["value"]),
99+
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
81100
)
82101
# shared models
83102
else:
@@ -88,17 +107,21 @@ def main():
88107
structure=None,
89108
roles=["policy", "value"],
90109
parameters=[
91-
process_skrl_cfg(experiment_cfg["models"]["policy"]),
92-
process_skrl_cfg(experiment_cfg["models"]["value"]),
110+
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
111+
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
93112
],
94113
)
95114
models["value"] = models["policy"]
115+
# instantiate models' state dict
116+
if args_cli.ml_framework.startswith("jax"):
117+
for role, model in models.items():
118+
model.init_state_dict(role)
96119

97120
# configure and instantiate PPO agent
98121
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
99122
agent_cfg = PPO_DEFAULT_CONFIG.copy()
100123
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
101-
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"]))
124+
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
102125

103126
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
104127
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})

source/standalone/workflows/skrl/train.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@
3333
"--distributed", action="store_true", default=False, help="Run training with multiple GPUs or nodes."
3434
)
3535
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
36+
parser.add_argument(
37+
"--ml_framework",
38+
type=str,
39+
default="torch",
40+
choices=["torch", "jax", "jax-numpy"],
41+
help="The ML framework used for training the skrl agent.",
42+
)
3643

3744
# append AppLauncher cli args
3845
AppLauncher.add_app_launcher_args(parser)
@@ -52,11 +59,19 @@
5259
import os
5360
from datetime import datetime
5461

55-
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
56-
from skrl.memories.torch import RandomMemory
57-
from skrl.trainers.torch import SequentialTrainer
62+
import skrl
5863
from skrl.utils import set_seed
59-
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
64+
65+
if args_cli.ml_framework.startswith("torch"):
66+
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
67+
from skrl.memories.torch import RandomMemory
68+
from skrl.trainers.torch import SequentialTrainer
69+
from skrl.utils.model_instantiators.torch import deterministic_model, gaussian_model, shared_model
70+
elif args_cli.ml_framework.startswith("jax"):
71+
from skrl.agents.jax.ppo import PPO, PPO_DEFAULT_CONFIG
72+
from skrl.memories.jax import RandomMemory
73+
from skrl.trainers.jax import SequentialTrainer
74+
from skrl.utils.model_instantiators.jax import deterministic_model, gaussian_model
6075

6176
from omni.isaac.lab.utils.dict import print_dict
6277
from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
@@ -68,6 +83,10 @@
6883

6984
def main():
7085
"""Train with skrl agent."""
86+
# configure the ML framework into the global skrl variable
87+
if args_cli.ml_framework.startswith("jax"):
88+
skrl.config.jax.backend = "jax" if args_cli.ml_framework == "jax" else "numpy"
89+
7190
# read the seed from command line
7291
args_cli_seed = args_cli.seed
7392

@@ -93,6 +112,8 @@ def main():
93112

94113
# multi-gpu training config
95114
if args_cli.distributed:
115+
if args_cli.ml_framework.startswith("jax"):
116+
raise ValueError("Multi-GPU distributed training not yet supported in JAX")
96117
# update env config device
97118
env_cfg.sim.device = f"cuda:{app_launcher.local_rank}"
98119

@@ -120,27 +141,29 @@ def main():
120141
print_dict(video_kwargs, nesting=4)
121142
env = gym.wrappers.RecordVideo(env, **video_kwargs)
122143
# wrap around environment for skrl
123-
env = SkrlVecEnvWrapper(env) # same as: `wrap_env(env, wrapper="isaaclab")`
144+
env = SkrlVecEnvWrapper(env, ml_framework=args_cli.ml_framework) # same as: `wrap_env(env, wrapper="isaaclab")`
124145

125146
# set seed for the experiment (override from command line)
126147
set_seed(args_cli_seed if args_cli_seed is not None else experiment_cfg["seed"])
127148

128149
# instantiate models using skrl model instantiator utility
129150
# https://skrl.readthedocs.io/en/latest/api/utils/model_instantiators.html
130151
models = {}
152+
if args_cli.ml_framework.startswith("jax"):
153+
experiment_cfg["models"]["separate"] = True # shared model is not supported in JAX
131154
# non-shared models
132155
if experiment_cfg["models"]["separate"]:
133156
models["policy"] = gaussian_model(
134157
observation_space=env.observation_space,
135158
action_space=env.action_space,
136159
device=env.device,
137-
**process_skrl_cfg(experiment_cfg["models"]["policy"]),
160+
**process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
138161
)
139162
models["value"] = deterministic_model(
140163
observation_space=env.observation_space,
141164
action_space=env.action_space,
142165
device=env.device,
143-
**process_skrl_cfg(experiment_cfg["models"]["value"]),
166+
**process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
144167
)
145168
# shared models
146169
else:
@@ -151,11 +174,15 @@ def main():
151174
structure=None,
152175
roles=["policy", "value"],
153176
parameters=[
154-
process_skrl_cfg(experiment_cfg["models"]["policy"]),
155-
process_skrl_cfg(experiment_cfg["models"]["value"]),
177+
process_skrl_cfg(experiment_cfg["models"]["policy"], ml_framework=args_cli.ml_framework),
178+
process_skrl_cfg(experiment_cfg["models"]["value"], ml_framework=args_cli.ml_framework),
156179
],
157180
)
158181
models["value"] = models["policy"]
182+
# instantiate models' state dict
183+
if args_cli.ml_framework.startswith("jax"):
184+
for role, model in models.items():
185+
model.init_state_dict(role)
159186

160187
# instantiate a RandomMemory as rollout buffer (any memory can be used for this)
161188
# https://skrl.readthedocs.io/en/latest/api/memories/random.html
@@ -166,7 +193,7 @@ def main():
166193
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html
167194
agent_cfg = PPO_DEFAULT_CONFIG.copy()
168195
experiment_cfg["agent"]["rewards_shaper"] = None # avoid 'dictionary changed size during iteration'
169-
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"]))
196+
agent_cfg.update(process_skrl_cfg(experiment_cfg["agent"], ml_framework=args_cli.ml_framework))
170197

171198
agent_cfg["state_preprocessor_kwargs"].update({"size": env.observation_space, "device": env.device})
172199
agent_cfg["value_preprocessor_kwargs"].update({"size": 1, "device": env.device})

0 commit comments

Comments
 (0)