Skip to content

Commit 1ea5402

Browse files
committed
Rewrite algorithm test cases and simplify code
1 parent 4b779d9 commit 1ea5402

File tree

14 files changed

+135
-190
lines changed

14 files changed

+135
-190
lines changed

unit_test/algorithms/de_variants/test_de.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

unit_test/algorithms/es_variants/test_openes.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_clpso.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_cso.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_dms_pso_el.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_fs_pso.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_pso.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_sl_pso_gs.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/pso_variants/test_sl_pso_us.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

unit_test/algorithms/test_base.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from unittest import TestCase
2+
3+
import torch
4+
5+
from evox.core import Algorithm, Problem, jit, use_state, vmap
6+
from evox.workflows import EvalMonitor, StdWorkflow
7+
8+
9+
class Sphere(Problem):
10+
def __init__(self):
11+
super().__init__()
12+
13+
def evaluate(self, pop: torch.Tensor):
14+
return (pop**2).sum(-1)
15+
16+
17+
class TestBase(TestCase):
18+
def run_algorithm(self, algo: Algorithm):
19+
monitor = EvalMonitor(full_fit_history=False, full_sol_history=False)
20+
prob = Sphere()
21+
workflow = StdWorkflow()
22+
workflow.setup(algo, prob, monitor)
23+
workflow.init_step()
24+
state_step = use_state(lambda: workflow.step)
25+
state = state_step.init_state()
26+
self.assertIsNotNone(state["self.algorithm._monitor_.topk_fitness"])
27+
jit_state_step = jit(state_step, trace=True, example_inputs=(state,))
28+
for _ in range(3):
29+
state = jit_state_step(state)
30+
31+
def run_vmap_algorithm(self, algo: Algorithm):
32+
prob = Sphere()
33+
workflow = StdWorkflow()
34+
workflow.setup(algo, prob)
35+
state_step = use_state(lambda: workflow.step)
36+
vmap_state_step = vmap(state_step)
37+
state = vmap_state_step.init_state(3)
38+
vmap_state_step = jit(
39+
vmap_state_step,
40+
trace=True,
41+
lazy=False,
42+
example_inputs=(state,),
43+
)
44+
for _ in range(3):
45+
state = vmap_state_step(state)

0 commit comments

Comments
 (0)