Skip to content

Commit d364a27

Browse files
committed
make numb_models as common argument
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 6957426 commit d364a27

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

dpgen/generator/arginfo.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ def data_args() -> list[Argument]:
7878

7979
# Training
8080

81+
def training_args_common() -> list[Argument]:
82+
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
83+
return [
84+
Argument("numb_models", int, optional=False, doc=doc_numb_models),
85+
]
86+
8187

8288
def training_args_dp() -> list[Argument]:
8389
"""Traning arguments.
@@ -90,7 +96,6 @@ def training_args_dp() -> list[Argument]:
9096
doc_train_backend = (
9197
"The backend of the training. Currently only support tensorflow and pytorch."
9298
)
93-
doc_numb_models = "Number of models to be trained in 00.train. 4 is recommend."
9499
doc_training_iter0_model_path = "The model used to init the first iter training. Number of element should be equal to numb_models."
95100
doc_training_init_model = "Iteration > 0, the model parameters will be initilized from the model trained at the previous iteration. Iteration == 0, the model parameters will be initialized from training_iter0_model_path."
96101
doc_default_training_param = "Training parameters for deepmd-kit in 00.train. You can find instructions from `DeePMD-kit documentation <https://docs.deepmodeling.org/projects/deepmd/>`_."
@@ -133,7 +138,6 @@ def training_args_dp() -> list[Argument]:
133138
default="tensorflow",
134139
doc=doc_train_backend,
135140
),
136-
Argument("numb_models", int, optional=False, doc=doc_numb_models),
137141
Argument(
138142
"training_iter0_model_path",
139143
list[str],
@@ -999,7 +1003,7 @@ def run_jdata_arginfo() -> Argument:
9991003
return Argument(
10001004
"run_jdata",
10011005
dict,
1002-
sub_fields=basic_args() + data_args() + fp_args(),
1006+
sub_fields=basic_args() + data_args() + training_args_common() + fp_args(),
10031007
sub_variants=[training_args(), *model_devi_args(), fp_style_variant_type_args()],
10041008
doc=doc_run_jdata,
10051009
)

dpgen/simplify/arginfo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fp_style_siesta_args,
1313
fp_style_vasp_args,
1414
training_args,
15+
training_args_common,
1516
)
1617

1718

@@ -201,6 +202,7 @@ def simplify_jdata_arginfo() -> Argument:
201202
*data_args(),
202203
*general_simplify_arginfo(),
203204
# simplify use the same training method as run
205+
*training_args_common(),
204206
*fp_args(),
205207
],
206208
sub_variants=[

0 commit comments

Comments
 (0)