Skip to content

Commit 453b49f

Browse files
wanghan-iapcmHan Wangpre-commit-ci[bot]
authored
fix: compare pwscf energy by relative error (#1643)
ut failure caused by deepmodeling/dpdata#725 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Improved handling of edge cases in coordinate and energy tests, enhancing test robustness for zero values. - Added conditional checks to ensure accurate comparisons in energy calculations. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b5c6ea0 commit 453b49f

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

tests/generator/comp_sys.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ def test_coord(self):
8686
tmp_cell = self.system_1.data["cells"]
8787
tmp_cell = np.reshape(tmp_cell, [-1, 3])
8888
tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis=1), [-1, 3])
89+
if np.max(np.abs(tmp_cell_norm)) < 1e-12:
90+
# zero cell, no pbc case, set to [1., 1., 1.]
91+
tmp_cell_norm = np.ones(tmp_cell_norm.shape)
8992
for ff in range(self.system_1.get_nframes()):
9093
for ii in range(sum(self.system_1.data["atom_numbs"])):
9194
for jj in range(3):
@@ -103,12 +106,21 @@ class CompLabeledSys(CompSys):
103106
def test_energy(self):
104107
self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())
105108
for ff in range(self.system_1.get_nframes()):
106-
self.assertAlmostEqual(
107-
self.system_1.data["energies"][ff],
108-
self.system_2.data["energies"][ff],
109-
places=self.e_places,
110-
msg="energies[%d] failed" % (ff),
111-
)
109+
if abs(self.system_2.data["energies"][ff]) < 1e-12:
110+
self.assertAlmostEqual(
111+
self.system_1.data["energies"][ff],
112+
self.system_2.data["energies"][ff],
113+
places=self.e_places,
114+
msg="energies[%d] failed" % (ff),
115+
)
116+
else:
117+
self.assertAlmostEqual(
118+
self.system_1.data["energies"][ff]
119+
/ self.system_2.data["energies"][ff],
120+
1.0,
121+
places=self.e_places,
122+
msg="energies[%d] failed" % (ff),
123+
)
112124

113125
def test_force(self):
114126
self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())

0 commit comments

Comments
 (0)