8
8
from evox import Problem , State , jit_method
9
9
10
10
11
- def vmap_rng_split (key : jax .Array , num : int = 2 ) -> jax .Array :
12
- # batched_key [B, 2] -> batched_keys [num, B, 2]
13
- return jax .vmap (jax .random .split , in_axes = (0 , None ), out_axes = 1 )(key , num )
14
-
15
-
16
11
class Brax (Problem ):
17
12
def __init__ (
18
13
self ,
19
14
policy : Callable ,
20
15
env_name : str ,
21
16
max_episode_length : int ,
22
17
num_episodes : int ,
18
+ rotate_key : bool = True ,
23
19
stateful_policy : bool = False ,
24
20
initial_state : Any = None ,
25
21
reduce_fn : Callable [[jax .Array , int ], jax .Array ] = jnp .mean ,
@@ -34,7 +30,9 @@ def __init__(
34
30
Then you need to set the `environment name <https://github.com/google/brax/tree/main/brax/envs>`_,
35
31
the maximum episode length, the number of episodes to evaluate for each individual.
36
32
For each individual,
37
- it will run the policy with the environment for num_episodes times and use the reduce_fn to reduce the rewards (default to average).
33
+ it will run the policy with the environment for num_episodes times with different seed,
34
+ and use the reduce_fn to reduce the rewards (default to average).
35
+ Different individuals will share the same set of random keys in each iteration.
38
36
39
37
Parameters
40
38
----------
@@ -46,6 +44,16 @@ def __init__(
46
44
The maximum number of timesteps of each episode.
47
45
num_episodes
48
46
The number of episodes to evaluate for each individual.
47
+ rotate_key
48
+ Indicates whether to rotate the random key for each iteration (default is True).
49
+
50
+ If True, the random key will rotate after each iteration,
51
+ resulting in non-deterministic and potentially noisy fitness evaluations.
52
+ This means that identical policy weights may yield different fitness values across iterations.
53
+
54
+ If False, the random key remains the same for all iterations,
55
+ ensuring consistent fitness evaluations.
56
+
49
57
stateful_policy
50
58
Whether the policy is stateful (for example, RNN).
51
59
Default to False.
@@ -61,6 +69,21 @@ def __init__(
61
69
backend
62
70
Brax's backend, one of "generalized", "positional", "spring".
63
71
Default to "generalized".
72
+
73
+ Notes
74
+ -----
75
+ When rotating keys, fitness evaluation is non-deterministic and may introduce noise.
76
+
77
+ Examples
78
+ --------
79
+ >>> from evox import problems
80
+ >>> problem = problems.neuroevolution.Brax(
81
+ ... env_name="swimmer",
82
+ ... policy=jit(model.apply),
83
+ ... max_episode_length=1000,
84
+ ... num_episodes=3,
85
+ ... rotate_key=False,
86
+ ...)
64
87
"""
65
88
if stateful_policy :
66
89
self .batched_policy = jit (vmap (vmap (policy , in_axes = (0 , None , 0 ))))
@@ -76,6 +99,7 @@ def __init__(
76
99
self .initial_state = initial_state
77
100
self .max_episode_length = max_episode_length
78
101
self .num_episodes = num_episodes
102
+ self .rotate_key = rotate_key
79
103
self .reduce_fn = reduce_fn
80
104
81
105
self .jit_reset = jit (vmap (self .env .reset ))
@@ -87,7 +111,10 @@ def setup(self, key):
87
111
@jit_method
88
112
def evaluate (self , state , weights ):
89
113
pop_size = jtu .tree_leaves (weights )[0 ].shape [0 ]
90
- key , eval_key = jax .random .split (state .key )
114
+ if self .rotate_key :
115
+ key , eval_key = jax .random .split (state .key )
116
+ else :
117
+ key , eval_key = state .key , state .key
91
118
92
119
def _cond_func (carry ):
93
120
counter , _state , done , _total_reward = carry
@@ -108,9 +135,11 @@ def _body_func(carry):
108
135
total_reward += (1 - done ) * brax_state .reward
109
136
return counter + 1 , rollout_state , done , total_reward
110
137
111
- brax_state = self .jit_reset (
112
- vmap_rng_split (jax .random .split (eval_key , self .num_episodes ), pop_size )
113
- )
138
+ # For each episode, we need a different random key.
139
+ keys = jax .random .split (eval_key , self .num_episodes )
140
+ # For each individual in the population, we need the same set of keys.
141
+ keys = jnp .broadcast_to (keys , (pop_size , * keys .shape ))
142
+ brax_state = self .jit_reset (keys )
114
143
115
144
if self .stateful_policy :
116
145
initial_state = jax .tree .map (
0 commit comments