|
| 1 | +import unittest |
1 | 2 | import torch
|
2 | 3 | import torch.nn as nn
|
3 | 4 | from typing import Dict, List
|
4 | 5 |
|
5 |
| -import os |
6 |
| -import sys |
7 |
| -current_directory = os.getcwd() |
8 |
| -if current_directory not in sys.path: |
9 |
| - sys.path.append(current_directory) |
10 |
| - |
11 |
| -from src.core import jit_class, ModuleBase, trace_impl, use_state |
12 |
| - |
13 |
| - |
14 |
| -if __name__ == "__main__": |
15 |
| - |
16 |
| - @jit_class |
17 |
| - class Test(ModuleBase): |
18 |
| - |
19 |
| - def __init__(self, threshold=0.5): |
20 |
| - super().__init__() |
21 |
| - self.threshold = threshold |
22 |
| - self.sub_mod = nn.Module() |
23 |
| - self.sub_mod.buf = nn.Buffer(torch.zeros(())) |
24 |
| - |
25 |
| - def h(self, q: torch.Tensor) -> torch.Tensor: |
26 |
| - if q.flatten()[0] > self.threshold: |
27 |
| - x = torch.sin(q) |
28 |
| - else: |
29 |
| - x = torch.tan(q) |
30 |
| - return x * x.shape[1] |
31 |
| - |
32 |
| - @trace_impl(h) |
33 |
| - def th(self, q: torch.Tensor) -> torch.Tensor: |
34 |
| - x = torch.where(q.flatten()[0] > self.threshold, q + 2, q + 5) |
35 |
| - x += self.g(x).abs() |
36 |
| - x *= x.shape[1] |
37 |
| - self.sub_mod.buf = x.sum() |
38 |
| - return x |
39 |
| - |
40 |
| - def g(self, p: torch.Tensor) -> torch.Tensor: |
41 |
| - x = torch.cos(p) |
42 |
| - return x * p.shape[0] |
43 |
| - |
44 |
| - t = Test() |
45 |
| - print(t.h.inlined_graph) |
46 |
| - result = t.g(torch.rand(100, 4)) |
47 |
| - print(result) |
48 |
| - t.add_mutable("mut_list", [torch.zeros(10), torch.ones(10)]) |
49 |
| - t.add_mutable("mut_dict", {"a": torch.zeros(20), "b": torch.ones(20)}) |
50 |
| - print(t.mut_list[0]) |
51 |
| - print(t.mut_dict["b"]) |
52 |
| - |
53 |
| - t = Test() |
54 |
| - fn = use_state(lambda: t.h, is_generator=True) |
55 |
| - trace_fn = torch.jit.trace(fn, (fn.init_state(), torch.ones(10, 1)), strict=False) |
56 |
| - |
57 |
| - def loop(init_state: Dict[str, torch.Tensor], init_x: torch.Tensor, n: int = 10): |
58 |
| - state = init_state |
59 |
| - ret = init_x |
60 |
| - rets: List[torch.Tensor] = [] |
61 |
| - for _ in range(n): |
62 |
| - state, ret = trace_fn(state, ret) |
63 |
| - rets.append(state["self.sub_mod.buf"]) |
64 |
| - return rets |
65 |
| - |
66 |
| - print(trace_fn.code) |
67 |
| - loop = torch.jit.trace(loop, (fn.init_state(), torch.rand(10, 2)), strict=False) |
68 |
| - print(loop.code) |
69 |
| - print(loop(fn.init_state(), torch.rand(10, 2))) |
| 6 | +from evox.core import jit_class, ModuleBase, trace_impl, use_state |
| 7 | + |
| 8 | + |
| 9 | +@jit_class |
| 10 | +class DummyModule(ModuleBase): |
| 11 | + |
| 12 | + def __init__(self, threshold=0.5): |
| 13 | + super().__init__() |
| 14 | + self.threshold = threshold |
| 15 | + self.sub_mod = nn.Module() |
| 16 | + self.sub_mod.buf = nn.Buffer(torch.zeros(())) |
| 17 | + |
| 18 | + def h(self, q: torch.Tensor) -> torch.Tensor: |
| 19 | + if q.flatten()[0] > self.threshold: |
| 20 | + x = torch.sin(q) |
| 21 | + else: |
| 22 | + x = torch.tan(q) |
| 23 | + return x * x.shape[1] |
| 24 | + |
| 25 | + @trace_impl(h) |
| 26 | + def th(self, q: torch.Tensor) -> torch.Tensor: |
| 27 | + x = torch.where(q.flatten()[0] > self.threshold, q + 2, q + 5) |
| 28 | + x += self.g(x).abs() |
| 29 | + x *= x.shape[1] |
| 30 | + self.sub_mod.buf = x.sum() |
| 31 | + return x |
| 32 | + |
| 33 | + def g(self, p: torch.Tensor) -> torch.Tensor: |
| 34 | + x = torch.cos(p) |
| 35 | + return x * p.shape[0] |
| 36 | + |
| 37 | + |
| 38 | +class TestModule(unittest.TestCase): |
| 39 | + |
| 40 | + def setUp(self): |
| 41 | + self.test_instance = DummyModule() |
| 42 | + |
| 43 | + def test_h_function(self): |
| 44 | + q = torch.rand(100, 4) |
| 45 | + result = self.test_instance.h(q) |
| 46 | + self.assertIsInstance(result, torch.Tensor) |
| 47 | + |
| 48 | + def test_g_function(self): |
| 49 | + p = torch.rand(100, 4) |
| 50 | + result = self.test_instance.g(p) |
| 51 | + self.assertIsInstance(result, torch.Tensor) |
| 52 | + |
| 53 | + def test_th_function(self): |
| 54 | + q = torch.rand(100, 4) |
| 55 | + result = self.test_instance.th(q) |
| 56 | + self.assertIsInstance(result, torch.Tensor) |
| 57 | + |
| 58 | + def test_add_mutable_list(self): |
| 59 | + self.test_instance.add_mutable("mut_list", [torch.zeros(10), torch.ones(10)]) |
| 60 | + self.assertTrue(torch.equal(self.test_instance.mut_list[0], torch.zeros(10))) |
| 61 | + self.assertTrue(torch.equal(self.test_instance.mut_list[1], torch.ones(10))) |
| 62 | + |
| 63 | + def test_add_mutable_dict(self): |
| 64 | + self.test_instance.add_mutable( |
| 65 | + "mut_dict", {"a": torch.zeros(20), "b": torch.ones(20)} |
| 66 | + ) |
| 67 | + self.assertTrue(torch.equal(self.test_instance.mut_dict["a"], torch.zeros(20))) |
| 68 | + self.assertTrue(torch.equal(self.test_instance.mut_dict["b"], torch.ones(20))) |
| 69 | + |
| 70 | + def test_trace_fn(self): |
| 71 | + fn = use_state(lambda: self.test_instance.h, is_generator=True) |
| 72 | + trace_fn = torch.jit.trace( |
| 73 | + fn, (fn.init_state(), torch.ones(10, 1)), strict=False |
| 74 | + ) |
| 75 | + self.assertIsNotNone(trace_fn) |
| 76 | + |
| 77 | + def test_loop_function(self): |
| 78 | + fn = use_state(lambda: self.test_instance.h, is_generator=True) |
| 79 | + trace_fn = torch.jit.trace( |
| 80 | + fn, (fn.init_state(), torch.ones(10, 1)), strict=False |
| 81 | + ) |
| 82 | + |
| 83 | + def loop( |
| 84 | + init_state: Dict[str, torch.Tensor], init_x: torch.Tensor, n: int = 10 |
| 85 | + ): |
| 86 | + state = init_state |
| 87 | + ret = init_x |
| 88 | + rets: List[torch.Tensor] = [] |
| 89 | + for _ in range(n): |
| 90 | + state, ret = trace_fn(state, ret) |
| 91 | + rets.append(state["self.sub_mod.buf"]) |
| 92 | + return rets |
| 93 | + |
| 94 | + loop_traced = torch.jit.trace( |
| 95 | + loop, (fn.init_state(), torch.rand(10, 2)), strict=False |
| 96 | + ) |
| 97 | + self.assertIsNotNone(loop_traced) |
0 commit comments