Skip to content

Commit eddc8ac

Browse files
committed
dev: improve key management in brax
1 parent 656025f commit eddc8ac

File tree

1 file changed

+39
-10
lines changed
  • src/evox/problems/neuroevolution/reinforcement_learning

1 file changed

+39
-10
lines changed

src/evox/problems/neuroevolution/reinforcement_learning/brax.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,14 @@
88
from evox import Problem, State, jit_method
99

1010

11-
def vmap_rng_split(key: jax.Array, num: int = 2) -> jax.Array:
12-
# batched_key [B, 2] -> batched_keys [num, B, 2]
13-
return jax.vmap(jax.random.split, in_axes=(0, None), out_axes=1)(key, num)
14-
15-
1611
class Brax(Problem):
1712
def __init__(
1813
self,
1914
policy: Callable,
2015
env_name: str,
2116
max_episode_length: int,
2217
num_episodes: int,
18+
rotate_key: bool = True,
2319
stateful_policy: bool = False,
2420
initial_state: Any = None,
2521
reduce_fn: Callable[[jax.Array, int], jax.Array] = jnp.mean,
@@ -34,7 +30,9 @@ def __init__(
3430
Then you need to set the `environment name <https://github.com/google/brax/tree/main/brax/envs>`_,
3531
the maximum episode length, the number of episodes to evaluate for each individual.
3632
For each individual,
37-
it will run the policy with the environment for num_episodes times and use the reduce_fn to reduce the rewards (default to average).
33+
it will run the policy with the environment for num_episodes times with different seed,
34+
and use the reduce_fn to reduce the rewards (default to average).
35+
Different individuals will share the same set of random keys in each iteration.
3836
3937
Parameters
4038
----------
@@ -46,6 +44,16 @@ def __init__(
4644
The maximum number of timesteps of each episode.
4745
num_episodes
4846
The number of episodes to evaluate for each individual.
47+
rotate_key
48+
Indicates whether to rotate the random key for each iteration (default is True).
49+
50+
If True, the random key will rotate after each iteration,
51+
resulting in non-deterministic and potentially noisy fitness evaluations.
52+
This means that identical policy weights may yield different fitness values across iterations.
53+
54+
If False, the random key remains the same for all iterations,
55+
ensuring consistent fitness evaluations.
56+
4957
stateful_policy
5058
Whether the policy is stateful (for example, RNN).
5159
Default to False.
@@ -61,6 +69,21 @@ def __init__(
6169
backend
6270
Brax's backend, one of "generalized", "positional", "spring".
6371
Default to "generalized".
72+
73+
Notes
74+
-----
75+
When rotating keys, fitness evaluation is non-deterministic and may introduce noise.
76+
77+
Examples
78+
--------
79+
>>> from evox import problems
80+
>>> problem = problems.neuroevolution.Brax(
81+
... env_name="swimmer",
82+
... policy=jit(model.apply),
83+
... max_episode_length=1000,
84+
... num_episodes=3,
85+
... rotate_key=False,
86+
...)
6487
"""
6588
if stateful_policy:
6689
self.batched_policy = jit(vmap(vmap(policy, in_axes=(0, None, 0))))
@@ -76,6 +99,7 @@ def __init__(
7699
self.initial_state = initial_state
77100
self.max_episode_length = max_episode_length
78101
self.num_episodes = num_episodes
102+
self.rotate_key = rotate_key
79103
self.reduce_fn = reduce_fn
80104

81105
self.jit_reset = jit(vmap(self.env.reset))
@@ -87,7 +111,10 @@ def setup(self, key):
87111
@jit_method
88112
def evaluate(self, state, weights):
89113
pop_size = jtu.tree_leaves(weights)[0].shape[0]
90-
key, eval_key = jax.random.split(state.key)
114+
if self.rotate_key:
115+
key, eval_key = jax.random.split(state.key)
116+
else:
117+
key, eval_key = state.key, state.key
91118

92119
def _cond_func(carry):
93120
counter, _state, done, _total_reward = carry
@@ -108,9 +135,11 @@ def _body_func(carry):
108135
total_reward += (1 - done) * brax_state.reward
109136
return counter + 1, rollout_state, done, total_reward
110137

111-
brax_state = self.jit_reset(
112-
vmap_rng_split(jax.random.split(eval_key, self.num_episodes), pop_size)
113-
)
138+
# For each episode, we need a different random key.
139+
keys = jax.random.split(eval_key, self.num_episodes)
140+
# For each individual in the population, we need the same set of keys.
141+
keys = jnp.broadcast_to(keys, (pop_size, *keys.shape))
142+
brax_state = self.jit_reset(keys)
114143

115144
if self.stateful_policy:
116145
initial_state = jax.tree.map(

0 commit comments

Comments
 (0)