6
6
import jax .numpy as jnp
7
7
import ray
8
8
9
- from evox import Algorithm , Problem , State , Workflow , use_state
10
- from evox .utils import algorithm_has_init_ask , parse_opt_direction
9
+ from evox import (
10
+ Algorithm ,
11
+ Problem ,
12
+ State ,
13
+ Workflow ,
14
+ use_state ,
15
+ has_init_ask ,
16
+ has_init_tell ,
17
+ )
18
+ from evox .utils import parse_opt_direction
11
19
12
20
13
21
class WorkerWorkflow (Workflow ):
@@ -36,7 +44,7 @@ def __init__(
36
44
self .fit_transforms = fit_transforms
37
45
38
46
def setup (self , key ):
39
- return State (generation = 0 )
47
+ return State (generation = 0 , first_step = True )
40
48
41
49
def _get_slice (self , pop_size ):
42
50
slice_per_worker = pop_size // self .num_workers
@@ -45,20 +53,23 @@ def _get_slice(self, pop_size):
45
53
end = start + slice_per_worker + (self .worker_index < remainder )
46
54
return start , end
47
55
56
+ def _ask (self , state ):
57
+ if has_init_ask (self .algorithm ) and state .first_step :
58
+ ask = self .algorithm .init_ask
59
+ else :
60
+ ask = self .algorithm .ask
61
+
62
+ # candidate: individuals that need to be evaluated (may differ from population)
63
+ # Note: num_cands can be different from init_ask() and ask()
64
+ cands , state = use_state (ask )(state )
65
+
66
+ return cands , state
67
+
48
68
def step1 (self , state : State ):
49
69
if "pre_ask" in self .non_empty_hooks :
50
70
ray .get (self .monitor_actor .push .remote ("pre_ask" , state ))
51
71
52
- if state .generation == 0 :
53
- is_init = algorithm_has_init_ask (self .algorithm , state )
54
- else :
55
- is_init = False
56
-
57
- if is_init :
58
- cand_sol , state = use_state (self .algorithm .init_ask )(state )
59
- else :
60
- cand_sol , state = use_state (self .algorithm .ask )(state )
61
-
72
+ cand_sol , state = self ._ask (state )
62
73
if "post_ask" in self .non_empty_hooks :
63
74
ray .get (self .monitor_actor .push .remote ("post_ask" , None , cand_sol ))
64
75
@@ -82,12 +93,17 @@ def step1(self, state: State):
82
93
83
94
return partial_fitness , state
84
95
85
- def step2 (self , state : State , fitness : List [ jax . Array ] ):
86
- if state . generation == 0 :
87
- is_init = algorithm_has_init_ask ( self .algorithm , state )
96
+ def _tell (self , state , transformed_fitness ):
97
+ if has_init_tell ( self . algorithm ) and state . first_step :
98
+ tell = self .algorithm . init_tell
88
99
else :
89
- is_init = False
100
+ tell = self . algorithm . tell
90
101
102
+ state = use_state (tell )(state , transformed_fitness )
103
+
104
+ return state
105
+
106
+ def step2 (self , state : State , fitness : List [jax .Array ]):
91
107
fitness = jnp .concatenate (fitness , axis = 0 )
92
108
fitness = fitness * self .opt_direction
93
109
@@ -112,15 +128,19 @@ def step2(self, state: State, fitness: List[jax.Array]):
112
128
)
113
129
)
114
130
115
- if is_init :
116
- state = use_state (self .algorithm .init_tell )(state , fitness )
117
- else :
118
- state = use_state (self .algorithm .tell )(state , fitness )
131
+ state = self ._tell (state , fitness )
119
132
120
133
if "post_tell" in self .non_empty_hooks :
121
134
ray .get (self .monitor_actor .push .remote ("post_tell" , state ))
122
-
123
- return state .update (generation = state .generation + 1 )
135
+
136
+
137
+ if has_init_ask (self .algorithm ) and state .first_step :
138
+ # this ensures that _step() will be re-jitted
139
+ state = state .replace (generation = state .generation + 1 , first_step = False )
140
+ else :
141
+ state = state .replace (generation = state .generation + 1 )
142
+
143
+ return state
124
144
125
145
def valid (self , state : State , metric : str ):
126
146
new_state = use_state (self .problem .valid )(state , metric = metric )
0 commit comments