@@ -172,7 +172,7 @@ def visualize(
172
172
weights ,
173
173
output_type : str = "HTML" ,
174
174
respect_done : bool = False ,
175
- num_episodes : Optional [int ] = None ,
175
+ max_episode_length : Optional [int ] = None ,
176
176
* args ,
177
177
** kwargs ,
178
178
):
@@ -188,9 +188,9 @@ def visualize(
188
188
The output type, either "HTML" or "rgb_array".
189
189
respect_done
190
190
Whether to respect the done signal.
191
- num_episodes
192
- The number of episodes to visualize, used to override the num_episodes in the constructor.
193
- If None, use the num_episodes in the constructor.
191
+ max_episode_length
192
+ Used to override the max_episode_length in the constructor.
193
+ If None, use the max_episode_length in the constructor.
194
194
"""
195
195
assert output_type in [
196
196
"HTML" ,
@@ -208,8 +208,8 @@ def visualize(
208
208
else :
209
209
rollout_state = (brax_state ,)
210
210
211
- num_episodes = num_episodes or self .num_episodes
212
- for _ in range (num_episodes ):
211
+ max_episode_length = max_episode_length or self .max_episode_length
212
+ for _ in range (max_episode_length ):
213
213
if self .stateful_policy :
214
214
state , brax_state = rollout_state
215
215
action , state = self .policy (state , weights , brax_state .obs )
0 commit comments