Skip to content

Commit 4185007

Browse files
committed
test: re-enable brax test
1 parent a32503c commit 4185007

File tree

2 files changed

+25
-10
lines changed

2 files changed

+25
-10
lines changed

src/evox/utils/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Iterable
44

55
import jax
6+
import numpy as np
67
import jax.numpy as jnp
78
from jax import jit, vmap
89
from jax.tree_util import tree_flatten, tree_leaves, tree_unflatten

tests/visualize/test_brax.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,36 @@ def random_policy(rand_seed, x): # weights, observation
1616
)
1717

1818

19-
@pytest.mark.skip(
20-
reason="cost too much time"
21-
)
22-
def test():
19+
def random_stateful_policy(state, rand_seed, x): # state, weights, observation
20+
return jnp.tanh(
21+
jax.random.normal(
22+
jax.random.PRNGKey(jnp.array(x[0] * 1e7, dtype=jnp.int32) + rand_seed),
23+
shape=(8,),
24+
)
25+
), state
26+
27+
28+
@pytest.mark.parametrize("stateful_policy", [False, True])
29+
def test_brax(stateful_policy):
2330
seed = 41
2431
key = jax.random.PRNGKey(seed)
2532

26-
# It takes too much time to render 500 frames (474s on Nvidia RTX 3090)
27-
# I think it is good to add a progress bar to shrink waiting experience.
33+
if stateful_policy:
34+
policy = random_stateful_policy
35+
else:
36+
policy = random_policy
37+
2838
problem = problems.neuroevolution.Brax(
2939
env_name=gym_name,
30-
policy=jax.jit(random_policy),
31-
cap_episode=500,
40+
policy=policy,
41+
num_episodes=1,
42+
max_episode_length=3,
43+
stateful_policy=stateful_policy,
44+
initial_state=jnp.zeros(10) if stateful_policy else None,
3245
)
3346

3447
state = problem.init(key)
35-
frames = problem.visualize(key, seed, output_type="rgb_array", width=250, height=250)
36-
frames2gif(frames, f"{gym_name}_{seed}.gif")
48+
problem.evaluate(state, jnp.arange(3))
49+
50+
problem.visualize(key, seed, output_type="HTML")
3751
assert True

0 commit comments

Comments
 (0)