Skip to content

Commit 6e4d32f

Browse files
iProzdpre-commit-ci[bot]njzjz
authored
Fix model-devi with mixed_type format (#2433)
Fix model-devi with mixed_type format; Add UTs; Add detection of whether models support mixed_type inference. --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 0510ab2 commit 6e4d32f

File tree

6 files changed

+26211
-2
lines changed

6 files changed

+26211
-2
lines changed

deepmd/infer/deep_pot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ def __init__(
164164
except (ValueError, KeyError):
165165
self.modifier_type = None
166166

167+
try:
168+
t_jdata = self._get_tensor("train_attr/training_script:0")
169+
jdata = run_sess(self.sess, t_jdata).decode("UTF-8")
170+
import json
171+
172+
jdata = json.loads(jdata)
173+
self.descriptor_type = jdata["model"]["descriptor"]["type"]
174+
except (ValueError, KeyError):
175+
self.descriptor_type = None
176+
167177
if self.modifier_type == "dipole_charge":
168178
t_mdl_name = self._get_tensor("modifier_attr/mdl_name:0")
169179
t_mdl_charge_map = self._get_tensor("modifier_attr/mdl_charge_map:0")
@@ -243,6 +253,10 @@ def get_sel_type(self) -> List[int]:
243253
"""Unsupported in this model."""
244254
raise NotImplementedError("This model type does not support this attribute")
245255

256+
def get_descriptor_type(self) -> List[int]:
257+
"""Get the descriptor type of this model."""
258+
return self.descriptor_type
259+
246260
def get_dim_fparam(self) -> int:
247261
"""Get the number (dimension) of frame parameters of this DP."""
248262
return self.dfparam

deepmd/infer/model_devi.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def calc_model_devi(
143143
models,
144144
fname=None,
145145
frequency=1,
146+
mixed_type=False,
146147
):
147148
"""Python interface to calculate model deviation.
148149
@@ -160,6 +161,8 @@ def calc_model_devi(
160161
File to dump results, default None
161162
frequency : int
162163
Steps between frames (if the system is given by molecular dynamics engine), default 1
164+
mixed_type : bool
165+
Whether the input atype is in mixed_type format or not
163166
164167
Returns
165168
-------
@@ -187,6 +190,7 @@ def calc_model_devi(
187190
coord,
188191
box,
189192
atype,
193+
mixed_type=mixed_type,
190194
)
191195
energies.append(ret[0] / natom)
192196
forces.append(ret[1])
@@ -246,17 +250,21 @@ def make_model_devi(
246250
for system in all_sys:
247251
# create data-system
248252
dp_data = DeepmdData(system, set_prefix, shuffle_test=False, type_map=tmap)
253+
mixed_type = dp_data.mixed_type
249254

250255
data_sets = [dp_data._load_set(set_name) for set_name in dp_data.dirs]
251256
nframes_tot = 0
252257
devis = []
253258
for data in data_sets:
254259
coord = data["coord"]
255260
box = data["box"]
256-
atype = data["type"][0]
261+
if mixed_type:
262+
atype = data["type"]
263+
else:
264+
atype = data["type"][0]
257265
if not dp_data.pbc:
258266
box = None
259-
devi = calc_model_devi(coord, box, atype, dp_models)
267+
devi = calc_model_devi(coord, box, atype, dp_models, mixed_type=mixed_type)
260268
nframes_tot += coord.shape[0]
261269
devis.append(devi)
262270
devis = np.vstack(devis)

0 commit comments

Comments
 (0)