Skip to content

Commit e9a25fb

Browse files
chore: add the mlp_engine option (#1576)
I am going to use DP-GEN to develop models trained by other MLP software. This may or may not be merged into the main branch, but I think a general `mlp_engine` option can be added anyway. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced handling for multiple ML potential engines with specialized training argument functions. - **Improvements** - Enhanced training initialization by splitting into common and engine-specific functions. - Improved error handling for unsupported ML potential engines. - **Bug Fixes** - Corrected logic to differentiate between `dp` and other engine values during training and model initialization. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8782483 commit e9a25fb

File tree

4 files changed

+74
-14
lines changed

4 files changed

+74
-14
lines changed

dpgen/generator/arginfo.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@ def data_args() -> list[Argument]:
7979
# Training
8080

8181

82-
def training_args() -> list[Argument]:
82+
def training_args_common() -> list[Argument]:
83+
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
84+
return [
85+
Argument("numb_models", int, optional=False, doc=doc_numb_models),
86+
]
87+
88+
89+
def training_args_dp() -> list[Argument]:
8390
"""Traning arguments.
8491
8592
Returns
@@ -90,7 +97,6 @@ def training_args() -> list[Argument]:
9097
doc_train_backend = (
9198
"The backend of the training. Currently only support tensorflow and pytorch."
9299
)
93-
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
94100
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
95101
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
96102
doc_default_training_param = "Training parameters for deepmd-kit in 00.train. You can find instructions from `DeePMD-kit documentation <https://docs.deepmodeling.org/projects/deepmd/>`_."
@@ -133,7 +139,6 @@ def training_args() -> list[Argument]:
133139
default="tensorflow",
134140
doc=doc_train_backend,
135141
),
136-
Argument("numb_models", int, optional=False, doc=doc_numb_models),
137142
Argument(
138143
"training_iter0_model_path",
139144
list[str],
@@ -224,6 +229,19 @@ def training_args() -> list[Argument]:
224229
]
225230

226231

232+
def training_args() -> Variant:
233+
doc_mlp_engine = "Machine learning potential engine. Currently, only DeePMD-kit (defualt) is supported."
234+
doc_dp = "DeePMD-kit."
235+
return Variant(
236+
"mlp_engine",
237+
[
238+
Argument("dp", dict, training_args_dp(), doc=doc_dp),
239+
],
240+
default_tag="dp",
241+
doc=doc_mlp_engine,
242+
)
243+
244+
227245
# Exploration
228246
def model_devi_jobs_template_args() -> Argument:
229247
doc_template = (
@@ -987,7 +1005,11 @@ def run_jdata_arginfo() -> Argument:
9871005
return Argument(
9881006
"run_jdata",
9891007
dict,
990-
sub_fields=basic_args() + data_args() + training_args() + fp_args(),
991-
sub_variants=model_devi_args() + [fp_style_variant_type_args()],
1008+
sub_fields=basic_args() + data_args() + training_args_common() + fp_args(),
1009+
sub_variants=[
1010+
training_args(),
1011+
*model_devi_args(),
1012+
fp_style_variant_type_args(),
1013+
],
9921014
doc=doc_run_jdata,
9931015
)

dpgen/generator/run.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,19 @@
128128

129129
def _get_model_suffix(jdata) -> str:
130130
"""Return the model suffix based on the backend."""
131-
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
132-
backend = jdata.get("train_backend", "tensorflow")
133-
if backend in suffix_map:
134-
suffix = suffix_map[backend]
131+
mlp_engine = jdata.get("mlp_engine", "dp")
132+
if mlp_engine == "dp":
133+
suffix_map = {"tensorflow": ".pb", "pytorch": ".pth"}
134+
backend = jdata.get("train_backend", "tensorflow")
135+
if backend in suffix_map:
136+
suffix = suffix_map[backend]
137+
else:
138+
raise ValueError(
139+
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
140+
)
141+
return suffix
135142
else:
136-
raise ValueError(
137-
f"The backend {backend} is not available. Supported backends are: 'tensorflow', 'pytorch'."
138-
)
139-
return suffix
143+
raise ValueError(f"Unsupported engine: {mlp_engine}")
140144

141145

142146
def get_job_names(jdata):
@@ -270,6 +274,14 @@ def dump_to_deepmd_raw(dump, deepmd_raw, type_map, fmt="gromacs/gro", charge=Non
270274

271275

272276
def make_train(iter_index, jdata, mdata):
277+
mlp_engine = jdata.get("mlp_engine", "dp")
278+
if mlp_engine == "dp":
279+
return make_train_dp(iter_index, jdata, mdata)
280+
else:
281+
raise ValueError(f"Unsupported engine: {mlp_engine}")
282+
283+
284+
def make_train_dp(iter_index, jdata, mdata):
273285
# load json param
274286
# train_param = jdata['train_param']
275287
train_input_file = default_train_input_file
@@ -714,6 +726,14 @@ def get_nframes(system):
714726

715727

716728
def run_train(iter_index, jdata, mdata):
729+
mlp_engine = jdata.get("mlp_engine", "dp")
730+
if mlp_engine == "dp":
731+
return make_train_dp(iter_index, jdata, mdata)
732+
else:
733+
raise ValueError(f"Unsupported engine: {mlp_engine}")
734+
735+
736+
def run_train_dp(iter_index, jdata, mdata):
717737
# print("debug:run_train:mdata", mdata)
718738
# load json param
719739
numb_models = jdata["numb_models"]
@@ -899,6 +919,14 @@ def run_train(iter_index, jdata, mdata):
899919

900920

901921
def post_train(iter_index, jdata, mdata):
922+
mlp_engine = jdata.get("mlp_engine", "dp")
923+
if mlp_engine == "dp":
924+
return post_train_dp(iter_index, jdata, mdata)
925+
else:
926+
raise ValueError(f"Unsupported engine: {mlp_engine}")
927+
928+
929+
def post_train_dp(iter_index, jdata, mdata):
902930
# load json param
903931
numb_models = jdata["numb_models"]
904932
# paths

dpgen/simplify/arginfo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fp_style_siesta_args,
1313
fp_style_vasp_args,
1414
training_args,
15+
training_args_common,
1516
)
1617

1718

@@ -201,10 +202,11 @@ def simplify_jdata_arginfo() -> Argument:
201202
*data_args(),
202203
*general_simplify_arginfo(),
203204
# simplify use the same training method as run
204-
*training_args(),
205+
*training_args_common(),
205206
*fp_args(),
206207
],
207208
sub_variants=[
209+
training_args(),
208210
fp_style_variant_type_args(),
209211
],
210212
doc=doc_run_jdata,

dpgen/simplify/simplify.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@ def get_multi_system(path: Union[str, list[str]], jdata: dict) -> dpdata.MultiSy
103103

104104

105105
def init_model(iter_index, jdata, mdata):
106+
mlp_engine = jdata.get("mlp_engine", "dp")
107+
if mlp_engine == "dp":
108+
init_model_dp(iter_index, jdata, mdata)
109+
else:
110+
raise TypeError(f"unsupported engine {mlp_engine}")
111+
112+
113+
def init_model_dp(iter_index, jdata, mdata):
106114
training_init_model = jdata.get("training_init_model", False)
107115
if not training_init_model:
108116
return

0 commit comments

Comments
 (0)