Skip to content

Commit 1861c4e

Browse files
committed
refactor(data preprocess): remove the cut off options from info.json and collect the values from input.json
1 parent aaa6375 commit 1861c4e

25 files changed

+192
-98
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ dptb/tests/**/*.pth
66
dptb/tests/**/*.npy
77
dptb/tests/**/*.traj
88
dptb/tests/**/out*/*
9+
dptb/tests/**/out*/*
10+
dptb/tests/**/*lmdb
11+
dptb/tests/**/*h5
912
examples/_*
1013
*.dat
1114
*log*

dptb/data/build.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,17 @@ def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
109109
def build_dataset(
110110
# set_options
111111
root: str,
112+
# dataset_options
113+
r_max: float,
114+
er_max: float = None,
115+
oer_max: float = None,
112116
type: str = "DefaultDataset",
113117
prefix: str = None,
114118
separator:str='.',
115119
get_Hamiltonian: bool = False,
116120
get_overlap: bool = False,
117121
get_DM: bool = False,
118122
get_eigenvalues: bool = False,
119-
120123
# common_options
121124
orthogonal: bool = False,
122125
basis: str = None,
@@ -224,7 +227,10 @@ def build_dataset(
224227
# We will sort the info_files here.
225228
# The order itself is not important, but must be consistant for the same list.
226229
info_files = {key: info_files[key] for key in sorted(info_files)}
227-
230+
231+
for ikey in info_files:
232+
info_files[ikey].update({'r_max': r_max, 'er_max': er_max, 'oer_max': oer_max})
233+
228234
if dataset_type == "DeePHDataset":
229235
dataset = DeePHE3Dataset(
230236
root=root,

dptb/data/dataset/_deeph_dataset.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,11 @@ def __init__(
4343
for file in self.info_files.keys():
4444
# get the info here
4545
info = info_files[file]
46-
assert "AtomicData_options" in info
47-
AtomicData_options = info["AtomicData_options"]
48-
assert "r_max" in AtomicData_options
49-
assert "pbc" in AtomicData_options
46+
assert "r_max" in info
47+
assert "pbc" in info
5048
subdata = os.path.join(self.root, file)
5149
self.raw_data.append(subdata)
52-
self.data_options[subdata] = AtomicData_options
50+
self.data_options[subdata] = info
5351

5452
# The AtomicData_options is never used here.
5553
# Because we always return a list of AtomicData object in `get_data()`.
@@ -68,12 +66,15 @@ def get_data(self):
6866
for subpath in tqdm(self.raw_data, desc="Loading data"):
6967
# the type_mapper here is loaded in PyG `dataset` type as `transform` attritube
7068
# so the OrbitalMapper can be accessed by self.transform here
71-
AtomicData_options = self.data_options[subpath]
69+
info = self.data_options[subpath]
7270
atomic_data = AtomicData.from_points(
7371
pos = np.loadtxt(os.path.join(subpath, "site_positions.dat")).T,
7472
cell = np.loadtxt(os.path.join(subpath, "lat.dat")).T,
7573
atomic_numbers = np.loadtxt(os.path.join(subpath, "element.dat")),
76-
**AtomicData_options,
74+
pbc = info["pbc"],
75+
r_max=info["r_max"],
76+
er_max=info.get("er_max", None),
77+
oer_max=info.get("oer_max", None)
7778
)
7879
idp = self.type_mapper
7980
openmx_to_deeptb(atomic_data, idp, os.path.join(subpath, "./hamiltonians.h5"))

dptb/data/dataset/_default_dataset.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class _TrajData(object):
4040

4141
def __init__(self,
4242
root: str,
43-
AtomicData_options: Dict[str, Any] = {},
4443
get_Hamiltonian = False,
4544
get_overlap = False,
4645
get_DM = False,
@@ -50,13 +49,10 @@ def __init__(self,
5049

5150
assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData."
5251
self.root = root
53-
self.AtomicData_options = AtomicData_options
5452
self.info = info
55-
5653
self.data = {}
57-
# load cell
58-
59-
pbc = AtomicData_options["pbc"]
54+
pbc = info["pbc"]
55+
# load cell
6056
if isinstance(pbc, bool):
6157
has_cell = pbc
6258
elif isinstance(pbc, list):
@@ -155,7 +151,6 @@ def __init__(self,
155151
@classmethod
156152
def from_ase_traj(cls,
157153
root: str,
158-
AtomicData_options: Dict[str, Any] = {},
159154
get_Hamiltonian = False,
160155
get_overlap = False,
161156
get_DM = False,
@@ -185,7 +180,6 @@ def from_ase_traj(cls,
185180
np.savetxt(os.path.join(root, "atomic_numbers.dat"), atomic_numbers, fmt='%d')
186181

187182
return cls(root=root,
188-
AtomicData_options=AtomicData_options,
189183
get_Hamiltonian=get_Hamiltonian,
190184
get_overlap=get_overlap,
191185
get_DM=get_DM,
@@ -218,10 +212,11 @@ def toAtomicDataList(self, idp: TypeMapper = None):
218212
dtype=torch.long)
219213

220214
atomic_data = AtomicData.from_points(
215+
r_max = self.info["r_max"],
216+
pbc = self.info["pbc"],
217+
er_max = self.info.get("er_max", None),
218+
oer_max= self.info.get("oer_max", None),
221219
**kwargs,
222-
# pbc is stored in AtomicData_options now.
223-
#pbc = self.info["pbc"],
224-
**self.AtomicData_options
225220
)
226221
if "hamiltonian_blocks" in self.data:
227222
assert idp is not None, "LCAO Basis must be provided in `common_option` for loading Hamiltonian."
@@ -300,21 +295,19 @@ def __init__(
300295
for file in self.info_files.keys():
301296
# get the info here
302297
info = info_files[file]
303-
assert "AtomicData_options" in info
304-
AtomicData_options = info["AtomicData_options"]
305-
assert "r_max" in AtomicData_options
306-
assert "pbc" in AtomicData_options
298+
# assert "AtomicData_options" in info
299+
assert "r_max" in info
300+
assert "pbc" in info
301+
pbc = info["pbc"]
307302
if info["pos_type"] == "ase":
308303
subdata = _TrajData.from_ase_traj(os.path.join(self.root, file),
309-
AtomicData_options,
310304
get_Hamiltonian,
311305
get_overlap,
312306
get_DM,
313307
get_eigenvalues,
314308
info=info)
315309
else:
316310
subdata = _TrajData(os.path.join(self.root, file),
317-
AtomicData_options,
318311
get_Hamiltonian,
319312
get_overlap,
320313
get_DM,

dptb/data/dataset/_hdf5_dataset.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,14 @@ class _HDF5_TrajData(object):
3838

3939
def __init__(self,
4040
root: str,
41-
AtomicData_options: Dict[str, Any] = {},
4241
get_Hamiltonian = False,
4342
get_overlap = False,
4443
get_DM = False,
4544
get_eigenvalues = False,
4645
info = None):
4746
assert not get_Hamiltonian * get_DM, "Cannot get both Hamiltonian and DM"
4847
self.root = root
49-
self.AtomicData_options = AtomicData_options
5048
self.info = info
51-
5249
self.data = {}
5350

5451
assert os.path.exists(os.path.join(root, "structure.pkl")), "structure file not found."
@@ -87,9 +84,11 @@ def toAtomicDataList(self, idp: TypeMapper = None):
8784
pos = self.data['structure'][frame]["positions"][:],
8885
cell = frame_cell,
8986
atomic_numbers = self.data['structure'][frame]["atomic_numbers"][:],
90-
# pbc is stored in AtomicData_options now.
91-
#pbc = self.info["pbc"],
92-
**self.AtomicData_options)
87+
r_max = self.info["r_max"],
88+
er_max = self.info.get("er_max", None),
89+
oer_max = self.info.get("oer_max", None),
90+
pbc = self.info["pbc"],
91+
)
9392

9493
if "hamiltonian_blocks" in self.data:
9594
assert idp is not None, "LCAO Basis must be provided in `common_option` for loading Hamiltonian."
@@ -171,13 +170,10 @@ def __init__(
171170
for file in self.info_files.keys():
172171
# get the info here
173172
info = info_files[file]
174-
assert "AtomicData_options" in info
175-
AtomicData_options = info["AtomicData_options"]
176-
assert "r_max" in AtomicData_options
177-
assert "pbc" in AtomicData_options
173+
assert "r_max" in info
174+
assert "pbc" in info
178175
if info["pos_type"] in ["hdf5", 'pickle']:
179176
subdata = _HDF5_TrajData(os.path.join(self.root, file),
180-
AtomicData_options,
181177
get_Hamiltonian,
182178
get_overlap,
183179
get_DM,

dptb/entrypoints/train.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dptb.data.build import build_dataset
44
from dptb.plugins.monitor import TrainLossMonitor, LearningRateMonitor, Validationer
55
from dptb.plugins.train_logger import Logger
6-
from dptb.utils.argcheck import normalize
6+
from dptb.utils.argcheck import normalize, collect_cutoffs
77
from dptb.plugins.saver import Saver
88
from typing import Dict, List, Optional, Any
99
from dptb.utils.tools import j_loader, setup_seed, j_must_have
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import time
21+
import copy
2122

2223
__all__ = ["train"]
2324

@@ -147,26 +148,33 @@ def train(
147148
jdata["train_options"] = f["config"]["train_options"]
148149
if jdata.get("model_options") is None:
149150
jdata["model_options"] = f["config"]["model_options"]
151+
152+
## add some warning !
153+
for k, v in jdata["model_options"].items():
154+
if k not in f["config"]["model_options"]:
155+
log.warning(f"The model options {k} is not defined in checkpoint, set to {v}.")
156+
else:
157+
deep_dict_difference(k, v, f["config"]["model_options"])
150158
del f
151159
else:
152160
j_must_have(jdata, "model_options")
153161
j_must_have(jdata, "train_options")
154162

155-
163+
cutoff_options =collect_cutoffs(jdata)
156164
# setup seed
157165
setup_seed(seed=jdata["common_options"]["seed"])
158166

159167
# with open(os.path.join(output, "train_config.json"), "w") as fp:
160168
# json.dump(jdata, fp, indent=4)
161169

162170
# build dataset
163-
train_datasets = build_dataset(**jdata["data_options"]["train"], **jdata["common_options"])
171+
train_datasets = build_dataset(**cutoff_options,**jdata["data_options"]["train"], **jdata["common_options"])
164172
if jdata["data_options"].get("validation"):
165-
validation_datasets = build_dataset(**jdata["data_options"]["validation"], **jdata["common_options"])
173+
validation_datasets = build_dataset(**cutoff_options, **jdata["data_options"]["validation"], **jdata["common_options"])
166174
else:
167175
validation_datasets = None
168176
if jdata["data_options"].get("reference"):
169-
reference_datasets = build_dataset(**jdata["data_options"]["reference"], **jdata["common_options"])
177+
reference_datasets = build_dataset(**cutoff_options, **jdata["data_options"]["reference"], **jdata["common_options"])
170178
else:
171179
reference_datasets = None
172180

@@ -227,3 +235,26 @@ def train(
227235
log.info(f"wall time: {(end_time - start_time):.3f} s")
228236

229237

238+
def deep_dict_difference(base_key, expected_value, model_options):
239+
"""
240+
递归地记录嵌套字典中的选项差异。
241+
242+
:param base_key: 基础键名,用于构建警告消息的前缀。
243+
:param expected_value: 期望的值,可能是字典或非字典类型。
244+
:param model_options: 用于比较的模型选项字典。
245+
"""
246+
target_dict= copy.deepcopy(model_options) # 防止修改原始字典
247+
if isinstance(expected_value, dict):
248+
for subk, subv in expected_value.items():
249+
250+
if not isinstance(target_dict.get(base_key, {}),dict):
251+
log.warning(f"The model option {subk} in {base_key} is not defined in checkpoint, set to {subv}.")
252+
253+
elif subk not in target_dict.get(base_key, {}):
254+
log.warning(f"The model option {subk} in {base_key} is not defined in checkpoint, set to {subv}.")
255+
else:
256+
target2 = copy.deepcopy(target_dict[base_key])
257+
deep_dict_difference(f"{subk}", subv, target2)
258+
else:
259+
if expected_value != target_dict[base_key]:
260+
log.warning(f"The model option {base_key} is set to {expected_value}, but in checkpoint it is {target_dict[base_key]}, make sure it it correct!")

dptb/postprocess/elec_struc_cal.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ def __init__ (
4242
self.model.eval()
4343
self.overlap = hasattr(model, 'overlap')
4444

45+
if not self.model.transform:
46+
log.error('The model.transform is not True, please check the model.')
47+
raise RuntimeError('The model.transform is not True, please check the model.')
48+
4549
if self.overlap:
4650
self.eigv = Eigenvalues(
4751
idp=model.idp,

dptb/tests/data/Sn/soc/dataset/set.0/info.json

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
"nframes": 1,
33
"natoms": -1,
44
"pos_type": "ase",
5-
"AtomicData_options": {
6-
"r_max": 6.0,
7-
"er_max": 5.0,
8-
"oer_max":3.0,
9-
"pbc": true
10-
},
5+
"pbc": true,
116
"bandinfo": {
127
"band_min": 0,
138
"band_max":16,

dptb/tests/data/Sn/soc/input/input_soc.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
},
4444
"model_options": {
4545
"nnsk": {
46-
"onsite": {"method": "strain","rs":6.0, "w": 0.1},
46+
"onsite": {"method": "strain","rs":3.0, "w": 0.1},
4747
"hopping": {"method": "powerlaw", "rs":6.0, "w": 0.1},
4848
"soc":{"method":"uniform"},
4949
"push": false,

dptb/tests/data/hBN/dataset/kpath.0/info.json

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,7 @@
22
"nframes": 1,
33
"natoms": 2,
44
"pos_type": "ase",
5-
"AtomicData_options": {
6-
"r_max": 2.6,
7-
"er_max": 2.6,
8-
"oer_max":1.6,
9-
"pbc": true
10-
},
5+
"pbc": true,
116
"bandinfo": {
127
"band_min": 0,
138
"band_max": 6,

0 commit comments

Comments
 (0)