Skip to content

Commit ca4043c

Browse files
authored
Adds wandb native support in rl_games (#2650)
# Description This PR creates support wandb logging in rl_games training. rl_games has been supporting wandb logging, and the examples of how to configure it can be seen from [rl_games-wandb_support](https://github.com/Denys88/rl_games/blob/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/runner.py ) we could follow the same style and enable current rl_games pipeline to use wandb as well. ## Type of change <!-- As you go through the list, delete the ones that are not applicable. --> - New feature (non-breaking change which adds functionality) ## Checklist - [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with `./isaaclab.sh --format` - [ ] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] I have updated the changelog and the corresponding version in the extension's `config/extension.toml` file - [x] I have added my name to the `CONTRIBUTORS.md` or my name already exists there <!-- As you go through the checklist above, you can mark something as done by putting an x character in it For example, - [x] I have done this task - [ ] I have not done this task -->
1 parent 7a489ad commit ca4043c

File tree

1 file changed

+33
-2
lines changed
  • scripts/reinforcement_learning/rl_games

1 file changed

+33
-2
lines changed

scripts/reinforcement_learning/rl_games/train.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import sys
12+
from distutils.util import strtobool
1213

1314
from isaaclab.app import AppLauncher
1415

@@ -26,7 +27,17 @@
2627
parser.add_argument("--checkpoint", type=str, default=None, help="Path to model checkpoint.")
2728
parser.add_argument("--sigma", type=str, default=None, help="The policy's initial standard deviation.")
2829
parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
29-
30+
parser.add_argument("--wandb-project-name", type=str, default=None, help="the wandb's project name")
31+
parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project")
32+
parser.add_argument("--wandb-name", type=str, default=None, help="the name of wandb's run")
33+
parser.add_argument(
34+
"--track",
35+
type=lambda x: bool(strtobool(x)),
36+
default=False,
37+
nargs="?",
38+
const=True,
39+
help="if toggled, this experiment will be tracked with Weights and Biases",
40+
)
3041
# append AppLauncher cli args
3142
AppLauncher.add_app_launcher_args(parser)
3243
# parse the arguments
@@ -109,7 +120,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
109120
env_cfg.seed = agent_cfg["params"]["seed"]
110121

111122
# specify directory for logging experiments
112-
log_root_path = os.path.join("logs", "rl_games", agent_cfg["params"]["config"]["name"])
123+
config_name = agent_cfg["params"]["config"]["name"]
124+
log_root_path = os.path.join("logs", "rl_games", config_name)
113125
log_root_path = os.path.abspath(log_root_path)
114126
print(f"[INFO] Logging experiment in directory: {log_root_path}")
115127
# specify directory for logging runs
@@ -118,6 +130,8 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
118130
# logging directory path: <train_dir>/<full_experiment_name>
119131
agent_cfg["params"]["config"]["train_dir"] = log_root_path
120132
agent_cfg["params"]["config"]["full_experiment_name"] = log_dir
133+
wandb_project = config_name if args_cli.wandb_project_name is None else args_cli.wandb_project_name
134+
experiment_name = log_dir if args_cli.wandb_name is None else args_cli.wandb_name
121135

122136
# dump the configuration into log-directory
123137
dump_yaml(os.path.join(log_root_path, log_dir, "params", "env.yaml"), env_cfg)
@@ -168,6 +182,23 @@ def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agen
168182
# reset the agent and env
169183
runner.reset()
170184
# train the agent
185+
186+
global_rank = int(os.getenv("RANK", "0"))
187+
if args_cli.track and global_rank == 0:
188+
if args_cli.wandb_entity is None:
189+
raise ValueError("Weights and Biases entity must be specified for tracking.")
190+
import wandb
191+
192+
wandb.init(
193+
project=wandb_project,
194+
entity=args_cli.wandb_entity,
195+
name=experiment_name,
196+
sync_tensorboard=True,
197+
config=agent_cfg,
198+
monitor_gym=True,
199+
save_code=True,
200+
)
201+
171202
if args_cli.checkpoint is not None:
172203
runner.run({"train": True, "play": False, "sigma": train_sigma, "checkpoint": resume_path})
173204
else:

0 commit comments

Comments
 (0)