Skip to content

Commit f95e9ec

Browse files
committed
add compare_tensors_as_sets_float, remove the comparasion of order dependence in band index related tensors
1 parent 53654d5 commit f95e9ec

8 files changed

+232
-36
lines changed

dptb/tests/test_SKHamiltonian.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from dptb.nn.hamiltonian import SKHamiltonian
1111
from dptb.utils.constants import anglrMId, orbitalId
1212
from e3nn.o3 import wigner_3j, Irrep, xyz_to_angles, Irrep
13+
from dptb.tests.tstools import compare_tensors_as_sets_float
1314

1415
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1516

@@ -146,7 +147,34 @@ def test_hoppingblocks(self):
146147
data = nnsk(self.batch)
147148
data = hamiltonian(data)
148149
assert data[AtomicDataDict.EDGE_FEATURES_KEY].shape == torch.Size([18, 13])
149-
150+
expected_edge_index = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
151+
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
152+
expected_edge_cell_shift = torch.tensor([[-1., 0., 0.],
153+
[-1., 0., 0.],
154+
[ 0., 1., 0.],
155+
[ 0., 1., 0.],
156+
[ 1., 1., 0.],
157+
[ 0., 0., 0.],
158+
[ 1., 1., 0.],
159+
[ 0., -1., 0.],
160+
[-1., 0., 0.],
161+
[ 1., -0., -0.],
162+
[ 1., -0., -0.],
163+
[-0., -1., -0.],
164+
[-0., -1., -0.],
165+
[-1., -1., -0.],
166+
[-0., -0., -0.],
167+
[-1., -1., -0.],
168+
[-0., 1., -0.],
169+
[ 1., -0., -0.]])
170+
171+
exp_val = torch.cat((expected_edge_index.T, expected_edge_cell_shift), axis=1)
172+
tar_val = torch.cat((data[AtomicDataDict.EDGE_INDEX_KEY].T, data[AtomicDataDict.EDGE_CELL_SHIFT_KEY]), axis=1)
173+
exp_val = exp_val.int()
174+
tar_val = tar_val.int()
175+
176+
assert compare_tensors_as_sets_float(exp_val, tar_val)
177+
150178

151179
expected_selected_hopblock = torch.tensor([[ 5.3185172379e-02, -4.6635824091e-09, 1.3500485174e-09,
152180
3.0885510147e-02, 8.2756355405e-02, 4.3990724937e-16,
@@ -179,7 +207,16 @@ def test_hoppingblocks(self):
179207
-1.8777785993e-09, -3.7203207612e-02, -1.8777785993e-09,
180208
-1.4753011055e-02]])
181209

182-
assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]] - expected_selected_hopblock) < 1e-6)
210+
# assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]] - expected_selected_hopblock) < 1e-6)
211+
testind = [0,3,9,5,12,15]
212+
tarind_list = []
213+
for i in range(len(testind)):
214+
ind = testind[i]
215+
bond = exp_val.tolist()[ind]
216+
assert bond in tar_val.tolist()
217+
tarind = tar_val.tolist().index(bond)
218+
tarind_list.append(tarind)
219+
assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][tarind_list] - expected_selected_hopblock) < 1e-4)
183220

184221
def test_onsite_stain(self):
185222
model_options = self.model_options

dptb/tests/test_atomicdata_rmaxdict.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dptb.utils.constants import atomic_num_dict, atomic_num_dict_r
66
import os
77
from pathlib import Path
8+
from dptb.tests.tstools import compare_tensors_as_sets
89

910
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1011

@@ -17,10 +18,9 @@ def test_rmax_float():
1718
atomic_options['r_max'] = 2.6
1819

1920
data = AtomicData.from_ase(atoms, **atomic_options)
20-
assert (data.edge_index == torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
21-
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])).all()
22-
23-
assert (data.edge_cell_shift == torch.tensor([[-1., 0., 0.],
21+
expected_edge_index = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
22+
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
23+
expected_edge_cell_shift = torch.tensor([[-1., 0., 0.],
2424
[-1., 0., 0.],
2525
[ 0., 1., 0.],
2626
[ 0., 1., 0.],
@@ -37,7 +37,10 @@ def test_rmax_float():
3737
[-0., -0., -0.],
3838
[-1., -1., -0.],
3939
[-0., 1., -0.],
40-
[ 1., -0., -0.]])).all()
40+
[ 1., -0., -0.]])
41+
exp_val = torch.cat((expected_edge_index.T, expected_edge_cell_shift), axis=1)
42+
tar_val = torch.cat((data.edge_index.T, data.edge_cell_shift), axis=1)
43+
assert compare_tensors_as_sets(exp_val, tar_val)
4144

4245
def test_rmax_dict_eq():
4346
strfile = os.path.join(rootdir, "hBN", "hBN.vasp")
@@ -46,10 +49,10 @@ def test_rmax_dict_eq():
4649
atomic_options['pbc'] = True
4750
atomic_options['r_max'] = {'B': 2.6, 'N': 2.6}
4851
data = AtomicData.from_ase(atoms, **atomic_options)
49-
assert (data.edge_index == torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
50-
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])).all()
51-
52-
assert (data.edge_cell_shift == torch.tensor([[-1., 0., 0.],
52+
53+
expected_edge_index = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
54+
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
55+
expected_edge_cell_shift = torch.tensor([[-1., 0., 0.],
5356
[-1., 0., 0.],
5457
[ 0., 1., 0.],
5558
[ 0., 1., 0.],
@@ -66,19 +69,21 @@ def test_rmax_dict_eq():
6669
[-0., -0., -0.],
6770
[-1., -1., -0.],
6871
[-0., 1., -0.],
69-
[ 1., -0., -0.]])).all()
70-
72+
[ 1., -0., -0.]])
73+
exp_val = torch.cat((expected_edge_index.T, expected_edge_cell_shift), axis=1)
74+
tar_val = torch.cat((data.edge_index.T, data.edge_cell_shift), axis=1)
75+
assert compare_tensors_as_sets(exp_val, tar_val)
76+
7177
def test_rmax_dict_neq():
7278
strfile = os.path.join(rootdir, "hBN", "hBN.vasp")
7379
atoms = read(strfile)
7480
atomic_options = {}
7581
atomic_options['pbc'] = True
7682
atomic_options['r_max'] = {'B':1.5,'N':2.6}
7783
data = AtomicData.from_ase(atoms, **atomic_options)
78-
assert (data.edge_index == torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
79-
[0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0]])).all()
80-
81-
assert (data.edge_cell_shift == torch.tensor([[-1., 0., 0.],
84+
expected_edge_index = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1],
85+
[0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0]])
86+
expected_edge_cell_shift = torch.tensor([[-1., 0., 0.],
8287
[-1., 0., 0.],
8388
[ 0., 1., 0.],
8489
[ 0., 1., 0.],
@@ -89,4 +94,8 @@ def test_rmax_dict_neq():
8994
[-0., -1., -0.],
9095
[-0., -1., -0.],
9196
[-1., -1., -0.],
92-
[-0., -0., -0.]])).all()
97+
[-0., -0., -0.]])
98+
exp_val = torch.cat((expected_edge_index.T, expected_edge_cell_shift), axis=1)
99+
tar_val = torch.cat((data.edge_index.T, data.edge_cell_shift), axis=1)
100+
assert compare_tensors_as_sets(exp_val, tar_val)
101+

dptb/tests/test_block_to_feature.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from dptb.data.interfaces.ham_to_feature import block_to_feature, feature_to_block
1212
from dptb.utils.constants import anglrMId
1313
from e3nn.o3 import wigner_3j, Irrep, xyz_to_angles, Irrep
14+
from dptb.tests.tstools import compare_tensors_as_sets_float, compare_tensors_as_sets
1415

1516
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1617

@@ -127,6 +128,34 @@ def test_transform_hoppingblocks(self):
127128
data = nnsk(self.batch)
128129
data = hamiltonian(data)
129130

131+
expected_edge_index = torch.tensor([[0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1],
132+
[0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]])
133+
expected_edge_cell_shift = torch.tensor([[-1., 0., 0.],
134+
[-1., 0., 0.],
135+
[ 0., 1., 0.],
136+
[ 0., 1., 0.],
137+
[ 1., 1., 0.],
138+
[ 0., 0., 0.],
139+
[ 1., 1., 0.],
140+
[ 0., -1., 0.],
141+
[-1., 0., 0.],
142+
[ 1., -0., -0.],
143+
[ 1., -0., -0.],
144+
[-0., -1., -0.],
145+
[-0., -1., -0.],
146+
[-1., -1., -0.],
147+
[-0., -0., -0.],
148+
[-1., -1., -0.],
149+
[-0., 1., -0.],
150+
[ 1., -0., -0.]])
151+
152+
exp_val = torch.cat((expected_edge_index.T, expected_edge_cell_shift), axis=1)
153+
tar_val = torch.cat((data[AtomicDataDict.EDGE_INDEX_KEY].T, data[AtomicDataDict.EDGE_CELL_SHIFT_KEY]), axis=1)
154+
exp_val = exp_val.int()
155+
tar_val = tar_val.int()
156+
157+
assert compare_tensors_as_sets(exp_val, tar_val)
158+
130159
with torch.no_grad():
131160
block = feature_to_block(data, nnsk.idp)
132161
block_to_feature(data, nnsk.idp, blocks=block)
@@ -165,6 +194,15 @@ def test_transform_hoppingblocks(self):
165194
-1.8777785993e-09, -3.7203207612e-02, -1.8777785993e-09,
166195
-1.4753011055e-02]])
167196

168-
assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]] - expected_selected_hopblock) < 1e-6)
169-
170-
197+
# assert compare_tensors_as_sets_float(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]], expected_selected_hopblock)
198+
# assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][[0,3,9,5,12,15]] - expected_selected_hopblock) < 1e-6)
199+
200+
testind = [0,3,9,5,12,15]
201+
tarind_list = []
202+
for i in range(len(testind)):
203+
ind = testind[i]
204+
bond = exp_val.tolist()[ind]
205+
assert bond in tar_val.tolist()
206+
tarind = tar_val.tolist().index(bond)
207+
tarind_list.append(tarind)
208+
assert torch.all(torch.abs(data[AtomicDataDict.EDGE_FEATURES_KEY][tarind_list] - expected_selected_hopblock) < 1e-4)

dptb/tests/test_dataloader_batch.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dptb.utils.torch_geometric.batch import Batch
99
from dptb.utils.torch_geometric.data import Data
1010
from collections.abc import Mapping
11+
from dptb.tests.tstools import compare_tensors_as_sets_float
1112

1213
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1314

@@ -88,7 +89,8 @@ def test_batch(self):
8889
4.5023179054, 4.5023179054, 2.3512587547, 4.5023179054, 4.5023179054,
8990
3.8395895958, 3.8395895958, 3.8395895958, 3.8395895958, 3.8395895958,
9091
3.8395895958])
91-
assert torch.all(torch.abs(batch[AtomicDataDict.EDGE_LENGTH_KEY] - expected_length) < 1e-8)
92+
assert compare_tensors_as_sets_float(batch[AtomicDataDict.EDGE_LENGTH_KEY], expected_length, precision=7)
93+
# assert torch.all(torch.abs(batch[AtomicDataDict.EDGE_LENGTH_KEY] - expected_length) < 1e-8)
9294

9395
assert batch[AtomicDataDict.EDGE_VECTORS_KEY].shape == torch.Size([56, 3])
9496
expected_edgevectors = torch.tensor([[ 1.9197947979, 3.3251821995, 0.0000000000],
@@ -147,7 +149,8 @@ def test_batch(self):
147149
[-1.9197947979, 1.1083940268, 3.1350116730],
148150
[ 0.0000000000, -2.2167882919, 3.1350116730],
149151
[ 1.9197947979, 1.1083940268, 3.1350116730]])
150-
assert torch.all(torch.abs(batch[AtomicDataDict.EDGE_VECTORS_KEY] - expected_edgevectors) < 1e-8)
152+
assert compare_tensors_as_sets_float(batch[AtomicDataDict.EDGE_VECTORS_KEY], expected_edgevectors, precision=7)
153+
# assert torch.all(torch.abs(batch[AtomicDataDict.EDGE_VECTORS_KEY] - expected_edgevectors) < 1e-8)
151154

152155

153156
batch = AtomicDataDict.with_env_vectors(batch, with_lengths=True)
@@ -159,7 +162,7 @@ def test_batch(self):
159162
[0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1,
160163
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
161164
0, 0, 1, 1, 1, 1, 1, 1]])
162-
assert torch.all(batch[AtomicDataDict.ENV_INDEX_KEY] == expected_env_index)
165+
163166

164167
expected_env_length = torch.tensor([3.8395895958, 4.5023179054, 3.8395895958, 4.5023179054, 3.8395895958,
165168
4.5023179054, 3.8395895958, 2.3512587547, 4.5023179054, 2.3512589931,
@@ -173,7 +176,8 @@ def test_batch(self):
173176
4.5023179054, 4.5023179054, 2.3512587547, 4.5023179054, 4.5023179054,
174177
3.8395895958, 3.8395895958, 3.8395895958, 3.8395895958, 3.8395895958,
175178
3.8395895958])
176-
assert torch.all(torch.abs(batch[AtomicDataDict.ENV_LENGTH_KEY] - expected_env_length) < 1e-8)
179+
180+
177181

178182
expected_env_vectors = torch.tensor([[ 1.9197947979, 3.3251821995, 0.0000000000],
179183
[ 0.0000000000, 4.4335761070, 0.7837529182],
@@ -232,19 +236,22 @@ def test_batch(self):
232236
[ 0.0000000000, -2.2167882919, 3.1350116730],
233237
[ 1.9197947979, 1.1083940268, 3.1350116730]])
234238

235-
assert torch.all(torch.abs(batch[AtomicDataDict.ENV_VECTORS_KEY] - expected_env_vectors) < 1e-8)
236-
239+
expect_envs = torch.cat([expected_env_index.T, expected_env_length.unsqueeze(1), expected_env_vectors], dim=1)
240+
target_envs = torch.cat([batch[AtomicDataDict.ENV_INDEX_KEY].T, batch[AtomicDataDict.ENV_LENGTH_KEY].unsqueeze(1), batch[AtomicDataDict.ENV_VECTORS_KEY]], dim=1)
241+
assert compare_tensors_as_sets_float(target_envs, expect_envs, precision=7)
242+
243+
#assert torch.all(torch.abs(batch[AtomicDataDict.ENV_VECTORS_KEY] - expected_env_vectors) < 1e-8)
244+
#assert torch.all(batch[AtomicDataDict.ENV_INDEX_KEY] == expected_env_index)
245+
#assert torch.all(torch.abs(batch[AtomicDataDict.ENV_LENGTH_KEY] - expected_env_length) < 1e-8)
237246

238247
batch = AtomicDataDict.with_onsitenv_vectors(batch, with_lengths=True)
239248
assert batch[AtomicDataDict.ONSITENV_INDEX_KEY].shape == torch.Size([2, 8])
240249
assert batch[AtomicDataDict.ONSITENV_LENGTH_KEY].shape == torch.Size([8])
241250

242251
expected_onsiteenv_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1],
243252
[1, 1, 1, 1, 0, 0, 0, 0]])
244-
assert torch.all(batch[AtomicDataDict.ONSITENV_INDEX_KEY] == expected_onsiteenv_index)
245253
expected_onsiteenv_length = torch.tensor([2.3512587547, 2.3512589931, 2.3512587547, 2.3512587547, 2.3512587547,
246254
2.3512589931, 2.3512587547, 2.3512587547])
247-
assert torch.all(torch.abs(batch[AtomicDataDict.ONSITENV_LENGTH_KEY] - expected_onsiteenv_length) < 1e-8)
248255
expected_onsiteenv_vectors = torch.tensor([[-1.9197947979, 1.1083940268, 0.7837529182],
249256
[ 0.0000000000, -2.2167882919, 0.7837529182],
250257
[ 1.9197947979, 1.1083940268, 0.7837529182],
@@ -253,4 +260,11 @@ def test_batch(self):
253260
[ 0.0000000000, 2.2167882919, -0.7837529182],
254261
[-1.9197947979, -1.1083940268, -0.7837529182],
255262
[ 0.0000000000, 0.0000000000, 2.3512587547]])
256-
assert torch.all(torch.abs(batch[AtomicDataDict.ONSITENV_VECTORS_KEY] - expected_onsiteenv_vectors) < 1e-8)
263+
264+
expected_onsiteenvs = torch.cat([expected_onsiteenv_index.T, expected_onsiteenv_length.unsqueeze(1), expected_onsiteenv_vectors], dim=1)
265+
target_onsiteenvs = torch.cat([batch[AtomicDataDict.ONSITENV_INDEX_KEY].T, batch[AtomicDataDict.ONSITENV_LENGTH_KEY].unsqueeze(1), batch[AtomicDataDict.ONSITENV_VECTORS_KEY]], dim=1)
266+
assert compare_tensors_as_sets_float(target_onsiteenvs, expected_onsiteenvs, precision=7)
267+
268+
#assert torch.all(batch[AtomicDataDict.ONSITENV_INDEX_KEY] == expected_onsiteenv_index)
269+
#assert torch.all(torch.abs(batch[AtomicDataDict.ONSITENV_LENGTH_KEY] - expected_onsiteenv_length) < 1e-8)
270+
#assert torch.all(torch.abs(batch[AtomicDataDict.ONSITENV_VECTORS_KEY] - expected_onsiteenv_vectors) < 1e-8)

dptb/tests/test_default_dataset.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path
99
from ase.io.trajectory import Trajectory
1010
import torch as th
11+
from dptb.tests.tstools import compare_tensors_as_sets
1112

1213
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data/test_sktb/dataset")
1314

@@ -88,7 +89,8 @@ def test_get_data(self):
8889
assert (np.abs(atomic_data.pos.numpy() - self.strase[0].positions) < 1e-6).all()
8990
assert (np.abs(atomic_data.cell.numpy() - self.strase[0].cell) < 1e-6).all()
9091

91-
assert th.abs(atomic_data.edge_index - expected_edge_index).sum() < 1e-8
92+
assert compare_tensors_as_sets(atomic_data.edge_index.T, expected_edge_index.T)
93+
# assert th.abs(atomic_data.edge_index - expected_edge_index).sum() < 1e-8
9294
assert atomic_data.node_features.shape == (2, 1)
9395
assert not "node_attrs" in data[0]
9496
assert not "batch" in data[0]

dptb/tests/test_dftbsk.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pathlib import Path
99
from dptb.data import AtomicDataset, DataLoader, AtomicDataDict, AtomicData
1010
import numpy as np
11-
11+
from dptb.tests.tstools import compare_tensors_as_sets_float
1212

1313
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1414

@@ -86,7 +86,8 @@ def test_forward_dftbsk(self):
8686
[-1.1053317, -1.4127309, 1.7213905, -0.3220515],
8787
[-1.1053317, -1.4127309, 1.7213905, -0.3220515]])
8888

89-
assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_edge_feature)
89+
assert compare_tensors_as_sets_float(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_edge_feature, precision=5)
90+
# assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_edge_feature)
9091

9192
expected_edge_overlap = torch.tensor([[ 0.0115951, -0.0208762, -0.0355638, 0.0046229],
9293
[ 0.2665880, 0.3365866, 0.3249941, -0.1442579],
@@ -107,7 +108,8 @@ def test_forward_dftbsk(self):
107108
[ 0.0399277, 0.0585683, -0.0838758, 0.0126881],
108109
[ 0.0399277, 0.0585683, -0.0838758, 0.0126881]])
109110

110-
assert torch.allclose(data[AtomicDataDict.EDGE_OVERLAP_KEY], expected_edge_overlap)
111+
assert compare_tensors_as_sets_float(data[AtomicDataDict.EDGE_OVERLAP_KEY], expected_edge_overlap, precision=5)
112+
# assert torch.allclose(data[AtomicDataDict.EDGE_OVERLAP_KEY], expected_edge_overlap)
111113

112114
assert AtomicDataDict.NODE_SOC_SWITCH_KEY in data
113115
assert not data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all()

dptb/tests/test_nnsk.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from dptb.data import AtomicDataset, DataLoader, AtomicDataDict, AtomicData
99
import numpy as np
1010
from dptb.utils.constants import atomic_num_dict_r
11+
from dptb.tests.tstools import compare_tensors_as_sets_float
1112

1213
rootdir = os.path.join(Path(os.path.abspath(__file__)).parent, "data")
1314

@@ -121,7 +122,8 @@ def test_nnsk_none_powerlaw(self):
121122
[ 0.0109460149, -0.0026458376, -0.0233188029, 0.0033660505],
122123
[ 0.0109460149, -0.0026458376, -0.0233188029, 0.0033660505]])
123124

124-
assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, atol=1e-10)
125+
assert compare_tensors_as_sets_float(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, precision=6)
126+
# assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, atol=1e-10)
125127

126128
def test_nnsk_uniform_varTang96(self):
127129
model_options = self.model_options
@@ -195,7 +197,8 @@ def test_nnsk_uniform_varTang96(self):
195197
[-0.0043444759, 0.0260002706, -0.0492796339, 0.0716556087],
196198
[-0.0043444759, 0.0260002706, -0.0492796339, 0.0716556087]])
197199

198-
assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, atol=1e-10)
200+
assert compare_tensors_as_sets_float(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, precision=6)
201+
# assert torch.allclose(data[AtomicDataDict.EDGE_FEATURES_KEY], expected_hopskint, atol=1e-10)
199202

200203
def test_nnsk_onsite_strain(self):
201204
model_options = self.model_options

0 commit comments

Comments
 (0)