Skip to content

Commit 011d8f0

Browse files
committed
fix: init_ask detection and state.first_step
1 parent 02f00e7 commit 011d8f0

File tree

1 file changed

+43
-23
lines changed

1 file changed

+43
-23
lines changed

src/evox/workflows/distributed.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,16 @@
66
import jax.numpy as jnp
77
import ray
88

9-
from evox import Algorithm, Problem, State, Workflow, use_state
10-
from evox.utils import algorithm_has_init_ask, parse_opt_direction
9+
from evox import (
10+
Algorithm,
11+
Problem,
12+
State,
13+
Workflow,
14+
use_state,
15+
has_init_ask,
16+
has_init_tell,
17+
)
18+
from evox.utils import parse_opt_direction
1119

1220

1321
class WorkerWorkflow(Workflow):
@@ -36,7 +44,7 @@ def __init__(
3644
self.fit_transforms = fit_transforms
3745

3846
def setup(self, key):
39-
return State(generation=0)
47+
return State(generation=0, first_step=True)
4048

4149
def _get_slice(self, pop_size):
4250
slice_per_worker = pop_size // self.num_workers
@@ -45,20 +53,23 @@ def _get_slice(self, pop_size):
4553
end = start + slice_per_worker + (self.worker_index < remainder)
4654
return start, end
4755

56+
def _ask(self, state):
57+
if has_init_ask(self.algorithm) and state.first_step:
58+
ask = self.algorithm.init_ask
59+
else:
60+
ask = self.algorithm.ask
61+
62+
# candidate: individuals that need to be evaluated (may differ from population)
63+
# Note: num_cands can be different from init_ask() and ask()
64+
cands, state = use_state(ask)(state)
65+
66+
return cands, state
67+
4868
def step1(self, state: State):
4969
if "pre_ask" in self.non_empty_hooks:
5070
ray.get(self.monitor_actor.push.remote("pre_ask", state))
5171

52-
if state.generation == 0:
53-
is_init = algorithm_has_init_ask(self.algorithm, state)
54-
else:
55-
is_init = False
56-
57-
if is_init:
58-
cand_sol, state = use_state(self.algorithm.init_ask)(state)
59-
else:
60-
cand_sol, state = use_state(self.algorithm.ask)(state)
61-
72+
cand_sol, state = self._ask(state)
6273
if "post_ask" in self.non_empty_hooks:
6374
ray.get(self.monitor_actor.push.remote("post_ask", None, cand_sol))
6475

@@ -82,12 +93,17 @@ def step1(self, state: State):
8293

8394
return partial_fitness, state
8495

85-
def step2(self, state: State, fitness: List[jax.Array]):
86-
if state.generation == 0:
87-
is_init = algorithm_has_init_ask(self.algorithm, state)
96+
def _tell(self, state, transformed_fitness):
97+
if has_init_tell(self.algorithm) and state.first_step:
98+
tell = self.algorithm.init_tell
8899
else:
89-
is_init = False
100+
tell = self.algorithm.tell
90101

102+
state = use_state(tell)(state, transformed_fitness)
103+
104+
return state
105+
106+
def step2(self, state: State, fitness: List[jax.Array]):
91107
fitness = jnp.concatenate(fitness, axis=0)
92108
fitness = fitness * self.opt_direction
93109

@@ -112,15 +128,19 @@ def step2(self, state: State, fitness: List[jax.Array]):
112128
)
113129
)
114130

115-
if is_init:
116-
state = use_state(self.algorithm.init_tell)(state, fitness)
117-
else:
118-
state = use_state(self.algorithm.tell)(state, fitness)
131+
state = self._tell(state, fitness)
119132

120133
if "post_tell" in self.non_empty_hooks:
121134
ray.get(self.monitor_actor.push.remote("post_tell", state))
122-
123-
return state.update(generation=state.generation + 1)
135+
136+
137+
if has_init_ask(self.algorithm) and state.first_step:
138+
# this ensures that _step() will be re-jitted
139+
state = state.replace(generation=state.generation + 1, first_step=False)
140+
else:
141+
state = state.replace(generation=state.generation + 1)
142+
143+
return state
124144

125145
def valid(self, state: State, metric: str):
126146
new_state = use_state(self.problem.valid)(state, metric=metric)

0 commit comments

Comments
 (0)