Skip to content

Commit 9665cdc

Browse files
committed
fix(gen_inputs): handle device type and freeze overlap param in nnsk
Ensure correct device type assignment by checking both string and torch.device instances. Additionally, freeze the overlap parameter in the nnsk model when overlap is detected to prevent unintended modifications.
1 parent 5c8542f commit 9665cdc

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

dptb/utils/gen_inputs.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,12 @@ def gen_inputs(mode, task='train', model=None):
4646
dtype = model.dtype
4747
else:
4848
dtype = model.dtype.__str__().split('.')[-1]
49-
dd = "cpu" if model.device == torch.device("cpu") else "cuda"
49+
50+
if model.device == 'cpu' or model.device == torch.device("cpu"):
51+
dd = "cpu"
52+
else:
53+
dd = "cuda"
54+
5055
common_options = {
5156
"basis": basis,
5257
"dtype": dtype,
@@ -55,6 +60,11 @@ def gen_inputs(mode, task='train', model=None):
5560
}
5661
input_dict["common_options"].update(common_options)
5762
input_dict["model_options"].update(model.model_options)
63+
if is_overlap:
64+
if "nnsk" in input_dict["model_options"]:
65+
# for nnsk if there is overlap param, freeze the overlap param in the nnsk model.
66+
input_dict["model_options"]["nnsk"].update({"freeze": ["overlap"]})
67+
5868
#with open(os.path.join(outdir,'input_template.json'), 'w') as f:
5969
# json.dump(input_dict, f, indent=4)
6070
return input_dict

0 commit comments

Comments
 (0)