|
2 | 2 | Tests the ect functions.
|
3 | 3 | """
|
4 | 4 |
|
5 |
| -import torch |
| 5 | +# import torch |
6 | 6 |
|
7 |
| -from dect.ect import compute_ecc, normalize |
| 7 | +# from dect.ect import compute_ecc, normalize |
8 | 8 |
|
9 | 9 |
|
10 |
| -class TestECT: |
11 |
| - """ |
12 |
| - 1. When normalized, the ect needs to be normalized. |
13 |
| - 2. The dimensions need to correspond. e.g. the batches need not to |
14 |
| - be mixed up. |
15 |
| - 3. Test that when one of the inputs has a gradient the out has one too. |
16 |
| - """ |
| 10 | +# class TestECT: |
| 11 | +# """ |
| 12 | +# 1. When normalized, the ect needs to be normalized. |
| 13 | +# 2. The dimensions need to correspond. e.g. the batches need not to |
| 14 | +# be mixed up. |
| 15 | +# 3. Test that when one of the inputs has a gradient the out has one too. |
| 16 | +# """ |
17 | 17 |
|
18 |
| - def test_ecc_single(self): |
19 |
| - """ |
20 |
| - Check that the dimensions are correct. |
21 |
| - lin of size [bump_steps, 1, 1, 1] |
22 |
| - """ |
23 |
| - lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
24 |
| - index = torch.tensor([0, 0, 0], dtype=torch.long) |
25 |
| - nh = torch.tensor([[0.0], [0.5], [0.5]]) |
26 |
| - scale = 100 |
27 |
| - ecc = compute_ecc(nh, index, lin, scale) |
28 |
| - assert ecc.shape == (1, 1, 13, 1) |
| 18 | +# def test_ecc_single(self): |
| 19 | +# """ |
| 20 | +# Check that the dimensions are correct. |
| 21 | +# lin of size [bump_steps, 1, 1, 1] |
| 22 | +# """ |
| 23 | +# lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
| 24 | +# index = torch.tensor([0, 0, 0], dtype=torch.long) |
| 25 | +# nh = torch.tensor([[0.0], [0.5], [0.5]]) |
| 26 | +# scale = 100 |
| 27 | +# ecc = compute_ecc(nh, index, lin, scale) |
| 28 | +# assert ecc.shape == (1, 1, 13, 1) |
29 | 29 |
|
30 |
| - # Check that min and max are 0 and 3 |
31 |
| - torch.testing.assert_close(ecc.max(), torch.tensor(3.0)) |
32 |
| - torch.testing.assert_close(ecc.min(), torch.tensor(0.0)) |
| 30 | +# # Check that min and max are 0 and 3 |
| 31 | +# torch.testing.assert_close(ecc.max(), torch.tensor(3.0)) |
| 32 | +# torch.testing.assert_close(ecc.min(), torch.tensor(0.0)) |
33 | 33 |
|
34 |
| - def test_ecc_multi_set_directions(self): |
35 |
| - """ |
36 |
| - Check that the dimensions are correct. |
37 |
| - lin of size [bump_steps, 1, 1, 1] |
38 |
| - """ |
39 |
| - lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
40 |
| - index = torch.tensor([0, 0, 0], dtype=torch.long) |
41 |
| - nh = torch.tensor([[0.0, 0.0], [0.5, 0.5], [0.5, 0.5]]) |
42 |
| - scale = 100 |
43 |
| - ecc = compute_ecc(nh, index, lin, scale) |
44 |
| - assert ecc.shape == (1, 1, 13, 2) |
| 34 | +# def test_ecc_multi_set_directions(self): |
| 35 | +# """ |
| 36 | +# Check that the dimensions are correct. |
| 37 | +# lin of size [bump_steps, 1, 1, 1] |
| 38 | +# """ |
| 39 | +# lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
| 40 | +# index = torch.tensor([0, 0, 0], dtype=torch.long) |
| 41 | +# nh = torch.tensor([[0.0, 0.0], [0.5, 0.5], [0.5, 0.5]]) |
| 42 | +# scale = 100 |
| 43 | +# ecc = compute_ecc(nh, index, lin, scale) |
| 44 | +# assert ecc.shape == (1, 1, 13, 2) |
45 | 45 |
|
46 |
| - def test_ecc_multi_batch(self): |
47 |
| - """ |
48 |
| - Check that the dimensions are correct. |
49 |
| - lin of size [bump_steps, 1, 1, 1] |
50 |
| - """ |
51 |
| - lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
52 |
| - index = torch.tensor([0, 0, 1, 1, 1], dtype=torch.long) |
53 |
| - nh = torch.tensor([[0.0], [0.5], [0.5], [0.7], [0.7]]) |
54 |
| - scale = 100 |
55 |
| - ecc = compute_ecc(nh, index, lin, scale) |
56 |
| - assert ecc.shape == (2, 1, 13, 1) |
| 46 | +# def test_ecc_multi_batch(self): |
| 47 | +# """ |
| 48 | +# Check that the dimensions are correct. |
| 49 | +# lin of size [bump_steps, 1, 1, 1] |
| 50 | +# """ |
| 51 | +# lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
| 52 | +# index = torch.tensor([0, 0, 1, 1, 1], dtype=torch.long) |
| 53 | +# nh = torch.tensor([[0.0], [0.5], [0.5], [0.7], [0.7]]) |
| 54 | +# scale = 100 |
| 55 | +# ecc = compute_ecc(nh, index, lin, scale) |
| 56 | +# assert ecc.shape == (2, 1, 13, 1) |
57 | 57 |
|
58 |
| - # Check that min and max are 0 and 1 |
59 |
| - torch.testing.assert_close(ecc[0].max(), torch.tensor(2.0)) |
60 |
| - torch.testing.assert_close(ecc[0].min(), torch.tensor(0.0)) |
| 58 | +# # Check that min and max are 0 and 1 |
| 59 | +# torch.testing.assert_close(ecc[0].max(), torch.tensor(2.0)) |
| 60 | +# torch.testing.assert_close(ecc[0].min(), torch.tensor(0.0)) |
61 | 61 |
|
62 |
| - torch.testing.assert_close(ecc[1].max(), torch.tensor(3.0)) |
63 |
| - torch.testing.assert_close(ecc[1].min(), torch.tensor(0.0)) |
| 62 | +# torch.testing.assert_close(ecc[1].max(), torch.tensor(3.0)) |
| 63 | +# torch.testing.assert_close(ecc[1].min(), torch.tensor(0.0)) |
64 | 64 |
|
65 |
| - def test_ecc_normalized(self): |
66 |
| - """ |
67 |
| - Check that the dimensions are correct. |
68 |
| - lin of size [bump_steps, 1, 1, 1] |
69 |
| - """ |
70 |
| - lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
71 |
| - index = torch.tensor([0, 0, 1, 1, 1], dtype=torch.long) |
72 |
| - nh = torch.tensor([[0.0], [0.5], [0.5], [0.7], [0.7]]) |
73 |
| - scale = 100 |
74 |
| - ecc = compute_ecc(nh, index, lin, scale) |
75 |
| - assert ecc.shape == (2, 1, 13, 1) |
76 |
| - ecc_normalized = normalize(ecc) |
| 65 | +# def test_ecc_normalized(self): |
| 66 | +# """ |
| 67 | +# Check that the dimensions are correct. |
| 68 | +# lin of size [bump_steps, 1, 1, 1] |
| 69 | +# """ |
| 70 | +# lin = torch.linspace(-1, 1, 13).view(-1, 1, 1, 1) |
| 71 | +# index = torch.tensor([0, 0, 1, 1, 1], dtype=torch.long) |
| 72 | +# nh = torch.tensor([[0.0], [0.5], [0.5], [0.7], [0.7]]) |
| 73 | +# scale = 100 |
| 74 | +# ecc = compute_ecc(nh, index, lin, scale) |
| 75 | +# assert ecc.shape == (2, 1, 13, 1) |
| 76 | +# ecc_normalized = normalize(ecc) |
77 | 77 |
|
78 |
| - # Check that min and max are 0 and 1 |
79 |
| - torch.testing.assert_close(ecc_normalized[0].max(), torch.tensor(1.0)) |
80 |
| - torch.testing.assert_close(ecc_normalized[0].min(), torch.tensor(0.0)) |
| 78 | +# # Check that min and max are 0 and 1 |
| 79 | +# torch.testing.assert_close(ecc_normalized[0].max(), torch.tensor(1.0)) |
| 80 | +# torch.testing.assert_close(ecc_normalized[0].min(), torch.tensor(0.0)) |
81 | 81 |
|
82 |
| - torch.testing.assert_close(ecc_normalized[1].max(), torch.tensor(1.0)) |
83 |
| - torch.testing.assert_close(ecc_normalized[1].min(), torch.tensor(0.0)) |
| 82 | +# torch.testing.assert_close(ecc_normalized[1].max(), torch.tensor(1.0)) |
| 83 | +# torch.testing.assert_close(ecc_normalized[1].min(), torch.tensor(0.0)) |
0 commit comments