From d9767508b197f84c93f55d9a17b92b34bfdde3bc Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 24 Sep 2024 08:49:41 +0800 Subject: [PATCH 1/2] fix: compare pwscf energy by relative error --- tests/generator/comp_sys.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/generator/comp_sys.py b/tests/generator/comp_sys.py index 8806ddb5e..eda0762c2 100644 --- a/tests/generator/comp_sys.py +++ b/tests/generator/comp_sys.py @@ -86,6 +86,9 @@ def test_coord(self): tmp_cell = self.system_1.data["cells"] tmp_cell = np.reshape(tmp_cell, [-1, 3]) tmp_cell_norm = np.reshape(np.linalg.norm(tmp_cell, axis=1), [-1, 3]) + if np.max(np.abs(tmp_cell_norm)) < 1e-12: + # zero cell, no pbc case, set to [1., 1., 1.] + tmp_cell_norm = np.ones(tmp_cell_norm.shape) for ff in range(self.system_1.get_nframes()): for ii in range(sum(self.system_1.data["atom_numbs"])): for jj in range(3): @@ -103,12 +106,21 @@ class CompLabeledSys(CompSys): def test_energy(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes()) for ff in range(self.system_1.get_nframes()): - self.assertAlmostEqual( + if abs(self.system_2.data["energies"][ff]) < 1e-12: + self.assertAlmostEqual( self.system_1.data["energies"][ff], self.system_2.data["energies"][ff], places=self.e_places, msg="energies[%d] failed" % (ff), - ) + ) + else: + self.assertAlmostEqual( + self.system_1.data["energies"][ff]/ + self.system_2.data["energies"][ff], + 1.0, + places=self.e_places, + msg="energies[%d] failed" % (ff), + ) def test_force(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes()) From 2757fbf5d21d15224c8e75a2dbb515591116a45d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 00:50:45 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/generator/comp_sys.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/generator/comp_sys.py b/tests/generator/comp_sys.py index eda0762c2..db37ad843 100644 --- a/tests/generator/comp_sys.py +++ b/tests/generator/comp_sys.py @@ -107,20 +107,20 @@ def test_energy(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes()) for ff in range(self.system_1.get_nframes()): if abs(self.system_2.data["energies"][ff]) < 1e-12: - self.assertAlmostEqual( - self.system_1.data["energies"][ff], - self.system_2.data["energies"][ff], - places=self.e_places, - msg="energies[%d] failed" % (ff), - ) + self.assertAlmostEqual( + self.system_1.data["energies"][ff], + self.system_2.data["energies"][ff], + places=self.e_places, + msg="energies[%d] failed" % (ff), + ) else: - self.assertAlmostEqual( - self.system_1.data["energies"][ff]/ - self.system_2.data["energies"][ff], - 1.0, - places=self.e_places, - msg="energies[%d] failed" % (ff), - ) + self.assertAlmostEqual( + self.system_1.data["energies"][ff] + / self.system_2.data["energies"][ff], + 1.0, + places=self.e_places, + msg="energies[%d] failed" % (ff), + ) def test_force(self): self.assertEqual(self.system_1.get_nframes(), self.system_2.get_nframes())