@@ -86,6 +86,9 @@ def test_coord(self):
86
86
tmp_cell = self .system_1 .data ["cells" ]
87
87
tmp_cell = np .reshape (tmp_cell , [- 1 , 3 ])
88
88
tmp_cell_norm = np .reshape (np .linalg .norm (tmp_cell , axis = 1 ), [- 1 , 3 ])
89
+ if np .max (np .abs (tmp_cell_norm )) < 1e-12 :
90
+ # zero cell, no pbc case, set to [1., 1., 1.]
91
+ tmp_cell_norm = np .ones (tmp_cell_norm .shape )
89
92
for ff in range (self .system_1 .get_nframes ()):
90
93
for ii in range (sum (self .system_1 .data ["atom_numbs" ])):
91
94
for jj in range (3 ):
@@ -103,12 +106,21 @@ class CompLabeledSys(CompSys):
103
106
def test_energy (self ):
104
107
self .assertEqual (self .system_1 .get_nframes (), self .system_2 .get_nframes ())
105
108
for ff in range (self .system_1 .get_nframes ()):
106
- self .assertAlmostEqual (
107
- self .system_1 .data ["energies" ][ff ],
108
- self .system_2 .data ["energies" ][ff ],
109
- places = self .e_places ,
110
- msg = "energies[%d] failed" % (ff ),
111
- )
109
+ if abs (self .system_2 .data ["energies" ][ff ]) < 1e-12 :
110
+ self .assertAlmostEqual (
111
+ self .system_1 .data ["energies" ][ff ],
112
+ self .system_2 .data ["energies" ][ff ],
113
+ places = self .e_places ,
114
+ msg = "energies[%d] failed" % (ff ),
115
+ )
116
+ else :
117
+ self .assertAlmostEqual (
118
+ self .system_1 .data ["energies" ][ff ]
119
+ / self .system_2 .data ["energies" ][ff ],
120
+ 1.0 ,
121
+ places = self .e_places ,
122
+ msg = "energies[%d] failed" % (ff ),
123
+ )
112
124
113
125
def test_force (self ):
114
126
self .assertEqual (self .system_1 .get_nframes (), self .system_2 .get_nframes ())
0 commit comments