Skip to content

Commit b332a62

Browse files
committed
Rewrite tests in unittest framework
1 parent 6452dfb commit b332a62

File tree

7 files changed

+573
-436
lines changed

7 files changed

+573
-436
lines changed

unit_test/algorithms/test_pso.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,55 @@
1+
from unittest import TestCase
12
import time
23
import torch
34
from torch.profiler import profile, ProfilerActivity
5+
from evox.core import vmap, Problem, use_state, jit
6+
from evox.workflows import StdWorkflow
7+
from evox.algorithms import PSO
48

5-
import os
6-
import sys
7-
current_directory = os.getcwd()
8-
if current_directory not in sys.path:
9-
sys.path.append(current_directory)
109

11-
from src.core import vmap, Problem, use_state, jit
12-
from src.workflows import StdWorkflow
13-
from src.algorithms import PSO
10+
class TestPSO(TestCase):
11+
def setUp(self):
12+
class Sphere(Problem):
13+
def __init__(self):
14+
super().__init__()
1415

16+
def evaluate(self, pop: torch.Tensor):
17+
return (pop**2).sum(-1)
1518

16-
if __name__ == "__main__":
17-
class Sphere(Problem):
18-
def __init__(self):
19-
super().__init__()
19+
self.algo = PSO(pop_size=10, lb=-10 * torch.ones(3), ub=10 * torch.ones(3))
20+
self.prob = Sphere()
2021

21-
def evaluate(self, pop: torch.Tensor):
22-
return (pop**2).sum(-1)
23-
24-
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
25-
print(torch.get_default_device())
26-
algo = PSO(pop_size=10, lb=-10 * torch.ones(3), ub=10 * torch.ones(3))
27-
prob = Sphere()
28-
workflow = StdWorkflow()
29-
workflow.setup(algo, prob)
30-
workflow.init_step()
31-
workflow.step()
32-
# with open("tests/a.md", "w") as ff:
33-
# ff.write(workflow.step.inlined_graph.__str__())
34-
state_step = use_state(lambda: workflow.step)
35-
vmap_state_step = vmap(state_step)
36-
print(vmap_state_step.init_state(2))
37-
state = state_step.init_state()
38-
jit_state_step = jit(state_step, trace=True, example_inputs=(state,))
39-
state = state_step.init_state()
40-
# with open("tests/b.md", "w") as ff:
41-
# ff.write(jit_state_step.inlined_graph.__str__())
42-
t = time.time()
43-
with profile(
44-
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
45-
) as prof:
46-
for _ in range(1000):
47-
workflow.step()
48-
print(prof.key_averages().table())
49-
torch.cuda.synchronize()
50-
t = time.time()
51-
with profile(
52-
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True
53-
) as prof:
54-
for _ in range(1000):
55-
state = jit_state_step(state)
56-
print(prof.key_averages().table())
57-
torch.cuda.synchronize()
58-
print(time.time() - t)
22+
def test_pso(self):
23+
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")
24+
print(torch.get_default_device())
25+
workflow = StdWorkflow()
26+
workflow.setup(self.algo, self.prob)
27+
workflow.init_step()
28+
workflow.step()
29+
state_step = use_state(lambda: workflow.step)
30+
vmap_state_step = vmap(state_step)
31+
print(vmap_state_step.init_state(2))
32+
state = state_step.init_state()
33+
jit_state_step = jit(state_step, trace=True, example_inputs=(state,))
34+
state = state_step.init_state()
35+
t = time.time()
36+
with profile(
37+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
38+
record_shapes=True,
39+
profile_memory=True,
40+
) as prof:
41+
for _ in range(1000):
42+
workflow.step()
43+
print(prof.key_averages().table())
44+
torch.cuda.synchronize()
45+
t = time.time()
46+
with profile(
47+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
48+
record_shapes=True,
49+
profile_memory=True,
50+
) as prof:
51+
for _ in range(1000):
52+
state = jit_state_step(state)
53+
print(prof.key_averages().table())
54+
torch.cuda.synchronize()
55+
print(time.time() - t)

unit_test/core/test_jit_util.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,28 @@
1+
import unittest
12
from functools import partial
23
import torch
4+
from evox.core import vmap, jit
35

4-
import os
5-
import sys
6-
current_directory = os.getcwd()
7-
if current_directory not in sys.path:
8-
sys.path.append(current_directory)
96

10-
from src.core import vmap, jit
7+
@partial(vmap, example_ndim=2)
8+
def _single_eval(
9+
x: torch.Tensor, p: float = 2.0, q: torch.Tensor = torch.as_tensor(range(2))
10+
):
11+
return (x**p).sum() * q.sum()
1112

1213

13-
if __name__ == "__main__":
14+
class TestJitUtil(unittest.TestCase):
15+
def setUp(self):
16+
self.expected = torch.tensor([8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0])
1417

15-
@partial(vmap, example_ndim=2)
16-
def _single_eval(x: torch.Tensor, p: float = 2.0, q: torch.Tensor = torch.as_tensor(range(2))):
17-
return (x**p).sum() * q.sum()
18+
def test_single_eval(self):
19+
result = _single_eval(2 * torch.ones(10, 2))
20+
self.assertTrue(torch.equal(result, self.expected))
1821

19-
print(_single_eval(2 * torch.ones(10, 2)))
20-
print(jit(_single_eval)(2 * torch.ones(10, 2)))
21-
print(jit(_single_eval, trace=True, lazy=True)(2 * torch.ones(10, 2)))
22+
def test_jit_single_eval(self):
23+
result = jit(_single_eval)(2 * torch.ones(10, 2))
24+
self.assertTrue(torch.equal(result, self.expected))
25+
26+
def test_jit_single_eval_trace_lazy(self):
27+
result = jit(_single_eval, trace=True, lazy=True)(2 * torch.ones(10, 2))
28+
self.assertTrue(torch.equal(result, self.expected))

unit_test/core/test_module.py

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,97 @@
1+
import unittest
12
import torch
23
import torch.nn as nn
34
from typing import Dict, List
45

5-
import os
6-
import sys
7-
current_directory = os.getcwd()
8-
if current_directory not in sys.path:
9-
sys.path.append(current_directory)
10-
11-
from src.core import jit_class, ModuleBase, trace_impl, use_state
12-
13-
14-
if __name__ == "__main__":
15-
16-
@jit_class
17-
class Test(ModuleBase):
18-
19-
def __init__(self, threshold=0.5):
20-
super().__init__()
21-
self.threshold = threshold
22-
self.sub_mod = nn.Module()
23-
self.sub_mod.buf = nn.Buffer(torch.zeros(()))
24-
25-
def h(self, q: torch.Tensor) -> torch.Tensor:
26-
if q.flatten()[0] > self.threshold:
27-
x = torch.sin(q)
28-
else:
29-
x = torch.tan(q)
30-
return x * x.shape[1]
31-
32-
@trace_impl(h)
33-
def th(self, q: torch.Tensor) -> torch.Tensor:
34-
x = torch.where(q.flatten()[0] > self.threshold, q + 2, q + 5)
35-
x += self.g(x).abs()
36-
x *= x.shape[1]
37-
self.sub_mod.buf = x.sum()
38-
return x
39-
40-
def g(self, p: torch.Tensor) -> torch.Tensor:
41-
x = torch.cos(p)
42-
return x * p.shape[0]
43-
44-
t = Test()
45-
print(t.h.inlined_graph)
46-
result = t.g(torch.rand(100, 4))
47-
print(result)
48-
t.add_mutable("mut_list", [torch.zeros(10), torch.ones(10)])
49-
t.add_mutable("mut_dict", {"a": torch.zeros(20), "b": torch.ones(20)})
50-
print(t.mut_list[0])
51-
print(t.mut_dict["b"])
52-
53-
t = Test()
54-
fn = use_state(lambda: t.h, is_generator=True)
55-
trace_fn = torch.jit.trace(fn, (fn.init_state(), torch.ones(10, 1)), strict=False)
56-
57-
def loop(init_state: Dict[str, torch.Tensor], init_x: torch.Tensor, n: int = 10):
58-
state = init_state
59-
ret = init_x
60-
rets: List[torch.Tensor] = []
61-
for _ in range(n):
62-
state, ret = trace_fn(state, ret)
63-
rets.append(state["self.sub_mod.buf"])
64-
return rets
65-
66-
print(trace_fn.code)
67-
loop = torch.jit.trace(loop, (fn.init_state(), torch.rand(10, 2)), strict=False)
68-
print(loop.code)
69-
print(loop(fn.init_state(), torch.rand(10, 2)))
6+
from evox.core import jit_class, ModuleBase, trace_impl, use_state
7+
8+
9+
@jit_class
10+
class DummyModule(ModuleBase):
11+
12+
def __init__(self, threshold=0.5):
13+
super().__init__()
14+
self.threshold = threshold
15+
self.sub_mod = nn.Module()
16+
self.sub_mod.buf = nn.Buffer(torch.zeros(()))
17+
18+
def h(self, q: torch.Tensor) -> torch.Tensor:
19+
if q.flatten()[0] > self.threshold:
20+
x = torch.sin(q)
21+
else:
22+
x = torch.tan(q)
23+
return x * x.shape[1]
24+
25+
@trace_impl(h)
26+
def th(self, q: torch.Tensor) -> torch.Tensor:
27+
x = torch.where(q.flatten()[0] > self.threshold, q + 2, q + 5)
28+
x += self.g(x).abs()
29+
x *= x.shape[1]
30+
self.sub_mod.buf = x.sum()
31+
return x
32+
33+
def g(self, p: torch.Tensor) -> torch.Tensor:
34+
x = torch.cos(p)
35+
return x * p.shape[0]
36+
37+
38+
class TestModule(unittest.TestCase):
39+
40+
def setUp(self):
41+
self.test_instance = DummyModule()
42+
43+
def test_h_function(self):
44+
q = torch.rand(100, 4)
45+
result = self.test_instance.h(q)
46+
self.assertIsInstance(result, torch.Tensor)
47+
48+
def test_g_function(self):
49+
p = torch.rand(100, 4)
50+
result = self.test_instance.g(p)
51+
self.assertIsInstance(result, torch.Tensor)
52+
53+
def test_th_function(self):
54+
q = torch.rand(100, 4)
55+
result = self.test_instance.th(q)
56+
self.assertIsInstance(result, torch.Tensor)
57+
58+
def test_add_mutable_list(self):
59+
self.test_instance.add_mutable("mut_list", [torch.zeros(10), torch.ones(10)])
60+
self.assertTrue(torch.equal(self.test_instance.mut_list[0], torch.zeros(10)))
61+
self.assertTrue(torch.equal(self.test_instance.mut_list[1], torch.ones(10)))
62+
63+
def test_add_mutable_dict(self):
64+
self.test_instance.add_mutable(
65+
"mut_dict", {"a": torch.zeros(20), "b": torch.ones(20)}
66+
)
67+
self.assertTrue(torch.equal(self.test_instance.mut_dict["a"], torch.zeros(20)))
68+
self.assertTrue(torch.equal(self.test_instance.mut_dict["b"], torch.ones(20)))
69+
70+
def test_trace_fn(self):
71+
fn = use_state(lambda: self.test_instance.h, is_generator=True)
72+
trace_fn = torch.jit.trace(
73+
fn, (fn.init_state(), torch.ones(10, 1)), strict=False
74+
)
75+
self.assertIsNotNone(trace_fn)
76+
77+
def test_loop_function(self):
78+
fn = use_state(lambda: self.test_instance.h, is_generator=True)
79+
trace_fn = torch.jit.trace(
80+
fn, (fn.init_state(), torch.ones(10, 1)), strict=False
81+
)
82+
83+
def loop(
84+
init_state: Dict[str, torch.Tensor], init_x: torch.Tensor, n: int = 10
85+
):
86+
state = init_state
87+
ret = init_x
88+
rets: List[torch.Tensor] = []
89+
for _ in range(n):
90+
state, ret = trace_fn(state, ret)
91+
rets.append(state["self.sub_mod.buf"])
92+
return rets
93+
94+
loop_traced = torch.jit.trace(
95+
loop, (fn.init_state(), torch.rand(10, 2)), strict=False
96+
)
97+
self.assertIsNotNone(loop_traced)

0 commit comments

Comments
 (0)