@@ -16,22 +16,36 @@ def random_policy(rand_seed, x): # weights, observation
16
16
)
17
17
18
18
19
- @pytest .mark .skip (
20
- reason = "cost too much time"
21
- )
22
- def test ():
19
+ def random_stateful_policy (state , rand_seed , x ): # state, weights, observation
20
+ return jnp .tanh (
21
+ jax .random .normal (
22
+ jax .random .PRNGKey (jnp .array (x [0 ] * 1e7 , dtype = jnp .int32 ) + rand_seed ),
23
+ shape = (8 ,),
24
+ )
25
+ ), state
26
+
27
+
28
+ @pytest .mark .parametrize ("stateful_policy" , [False , True ])
29
+ def test_brax (stateful_policy ):
23
30
seed = 41
24
31
key = jax .random .PRNGKey (seed )
25
32
26
- # It takes too much time to render 500 frames (474s on Nvidia RTX 3090)
27
- # I think it is good to add a progress bar to shrink waiting experience.
33
+ if stateful_policy :
34
+ policy = random_stateful_policy
35
+ else :
36
+ policy = random_policy
37
+
28
38
problem = problems .neuroevolution .Brax (
29
39
env_name = gym_name ,
30
- policy = jax .jit (random_policy ),
31
- cap_episode = 500 ,
40
+ policy = policy ,
41
+ num_episodes = 1 ,
42
+ max_episode_length = 3 ,
43
+ stateful_policy = stateful_policy ,
44
+ initial_state = jnp .zeros (10 ) if stateful_policy else None ,
32
45
)
33
46
34
47
state = problem .init (key )
35
- frames = problem .visualize (key , seed , output_type = "rgb_array" , width = 250 , height = 250 )
36
- frames2gif (frames , f"{ gym_name } _{ seed } .gif" )
48
+ problem .evaluate (state , jnp .arange (3 ))
49
+
50
+ problem .visualize (key , seed , output_type = "HTML" )
37
51
assert True
0 commit comments