Skip to content

Commit 3230cc6

Browse files
add reset method to vecenv wrapper and fix test
1 parent 11911f8 commit 3230cc6

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

source/isaaclab_rl/isaaclab_rl/rsl_rl/vecenv_wrapper.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,11 @@ def episode_length_buf(self, value: torch.Tensor):
136136
def seed(self, seed: int = -1) -> int: # noqa: D102
137137
return self.unwrapped.seed(seed)
138138

139+
def reset(self) -> tuple[TensorDict, dict]: # noqa: D102
140+
# reset the environment
141+
obs_dict, extras = self.env.reset()
142+
return TensorDict(obs_dict, batch_size=[self.num_envs]), extras
143+
139144
def get_observations(self) -> TensorDict:
140145
"""Returns the current observations of the environment."""
141146
if hasattr(self.unwrapped, "observation_manager"):

source/isaaclab_rl/test/test_rsl_rl_wrapper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import gymnasium as gym
1818
import torch
19+
from tensordict import TensorDict
1920

2021
import carb
2122
import omni.usd
@@ -161,6 +162,8 @@ def _check_valid_tensor(data: torch.Tensor | dict) -> bool:
161162
"""
162163
if isinstance(data, torch.Tensor):
163164
return not torch.any(torch.isnan(data))
165+
elif isinstance(data, TensorDict):
166+
return not data.isnan().any()
164167
elif isinstance(data, dict):
165168
valid_tensor = True
166169
for value in data.values():

0 commit comments

Comments
 (0)