Skip to content

Commit 3fa0944

Browse files
committed
feat(emp_sk): add command to generate empirical SK parameters
Introduce a new command 'esk' to generate initial empirical SK parameters. This includes the addition of the `emp_sk.py` module, which handles the conversion of model parameters to empirical SK format. The command supports both 'poly2' and 'poly4' base models and saves the output in JSON format. Also, fix a minor bug in `gen_inputs.py` where `model_options` was incorrectly accessed as a method.
1 parent f2dfe1b commit 3fa0944

File tree

3 files changed

+134
-2
lines changed

3 files changed

+134
-2
lines changed

dptb/entrypoints/emp_sk.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
import numpy as np
3+
from dptb.nn.build import build_model
4+
import json
5+
import logging
6+
from dptb.nn.sktb.onsiteDB import onsite_energy_database
7+
import re
8+
import os
9+
from dptb.utils.gen_inputs import gen_inputs
10+
import json
11+
log = logging.getLogger(__name__)
12+
13+
def to_empsk(
14+
INPUT,
15+
output='./',
16+
basemodel='poly2',
17+
**kwargs):
18+
"""
19+
Convert the model to empirical SK parameters.
20+
"""
21+
if INPUT is None:
22+
raise ValueError('INPUT is None.')
23+
with open(INPUT, 'r') as f:
24+
input = json.load(f)
25+
common_options = input['common_options']
26+
EmpSK(common_options, basemodel=basemodel).to_json(outdir=output)
27+
28+
class EmpSK(object):
29+
"""
30+
Empirical SK parameters.
31+
"""
32+
def __init__(self, common_options, basemodel='poly2'):
33+
"""
34+
Args:
35+
common_options: common options for the model. especially contain the basis information.
36+
basemodel: base model type for the empirical SK parameters either 'poly2' or 'poly4'.
37+
"""
38+
self.common_options,self.basisref = self.format_common_options(common_options)
39+
if basemodel == 'poly2':
40+
model_ckpt = os.path.join(os.path.dirname(__file__), '..', 'nn', 'dftb', "base_poly2.pth")
41+
elif basemodel == 'poly4':
42+
model_ckpt = os.path.join(os.path.dirname(__file__), '..', 'nn', 'dftb', "base_poly4.pth")
43+
else:
44+
raise ValueError(f'basemodel {basemodel} is not supported.')
45+
46+
self.model = build_model(model_ckpt, common_options=common_options, no_check=True)
47+
48+
def to_json(self, outdir='./'):
49+
"""
50+
Convert the model to json format.
51+
"""
52+
53+
json_dict = self.model.to_json(basisref=self.basisref)
54+
with open(os.path.join(outdir,'sktb.json'), 'w') as f:
55+
json.dump(json_dict, f, indent=4)
56+
57+
# save input template
58+
# input_template = gen_inputs(model=self.model, task='train', mode=mode)
59+
60+
#with open(os.path.join(outdir,'input_template.json'), 'w') as f:
61+
# json.dump(input_template, f, indent=4)
62+
log.info(f'Empirical SK parameters are saved in {os.path.join(outdir,"sktb.json")}')
63+
log.info('If you want to further train the model, please use `dptb config` command to generate input template.')
64+
return json_dict
65+
66+
def format_common_options(self, common_options):
67+
"""
68+
Format the common options for the model.
69+
""" # check basis in common_options
70+
if 'basis' not in common_options:
71+
raise ValueError('basis information is not given in common_options.')
72+
# check basis type
73+
assert isinstance(common_options['basis'], dict), 'basis information is not a dictionary.'
74+
basis = common_options['basis']
75+
sys_ele = "".join(list(basis.keys()))
76+
log.info(f'Extracting empirical SK parameters for {sys_ele}')
77+
78+
use_basis_ref = False
79+
basisref = {}
80+
for ie in basis.keys():
81+
basisref[ie] = {}
82+
assert isinstance(basis[ie], list), f'basis information for {ie} is not a list.'
83+
for ieorb in basis[ie]:
84+
assert isinstance(ieorb, str), f'basis information for {ie} is not a string.'
85+
if len(ieorb) == 1:
86+
assert use_basis_ref is False, 'wrong basis seting eithor s, p ,d or ns np d*. can not be both s and ns np d*.'
87+
continue
88+
else:
89+
use_basis_ref = True
90+
assert ieorb in onsite_energy_database[ie], f'basis information for {ie} is not in onsite_energy_database : {onsite_energy_database[ie].keys()}.'
91+
orbsymb = re.findall(r'[A-Za-z]', ieorb)[0]
92+
basisref[ie][orbsymb] = ieorb
93+
94+
if use_basis_ref:
95+
std_basis = {}
96+
for ie in basis.keys():
97+
std_basis[ie] = []
98+
for ieorb in basis[ie]:
99+
std_basis[ie].append(re.findall(r'[A-Za-z]', ieorb)[0])
100+
common_options['basis'].update(std_basis)
101+
102+
return common_options, basisref
103+
else:
104+
return common_options, None

dptb/entrypoints/main.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from dptb.utils.loggers import set_log_handles
1414
from dptb.utils.config_check import check_config_train
1515
from dptb.entrypoints.collectskf import skf2pth, skf2nnsk
16+
from dptb.entrypoints.emp_sk import to_empsk
17+
1618
from dptb import __version__
1719

1820

@@ -402,7 +404,30 @@ def main_parser() -> argparse.ArgumentParser:
402404
help="The output files in training.",
403405
)
404406

405-
407+
parser_esk = subparsers.add_parser(
408+
"esk",
409+
parents=[parser_log],
410+
help="Generate initial empirical SK parameters.",
411+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
412+
)
413+
parser_esk.add_argument(
414+
"INPUT", help="the input parameter file in json or yaml format",
415+
type=str,
416+
default=None
417+
)
418+
parser_esk.add_argument(
419+
"-o",
420+
"--output",
421+
type=str,
422+
default="./",
423+
help="The output files in training."
424+
)
425+
parser_esk.add_argument(
426+
"-m",
427+
"--basemodel",
428+
type=str,
429+
default="poly2",
430+
)
406431
return parser
407432

408433
def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
@@ -466,3 +491,6 @@ def main():
466491

467492
elif args.command == 'skf2nn':
468493
skf2nnsk(**dict_args)
494+
495+
elif args.command == 'esk':
496+
to_empsk(**dict_args)

dptb/utils/gen_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def gen_inputs(mode, task='train', model=None):
5454
"overlap": is_overlap,
5555
}
5656
input_dict["common_options"].update(common_options)
57-
input_dict["model_options"].update(model.model_options())
57+
input_dict["model_options"].update(model.model_options)
5858
#with open(os.path.join(outdir,'input_template.json'), 'w') as f:
5959
# json.dump(input_dict, f, indent=4)
6060
return input_dict

0 commit comments

Comments
 (0)