Skip to content

Commit a2e503d

Browse files
Update: support read multiple lmdb data file with LMDBDataset (#194)
* Update: support read multiple lmdb data file with LMDBDataset * fix: run write_block
1 parent 16bf20a commit a2e503d

File tree

4 files changed

+66
-46
lines changed

4 files changed

+66
-46
lines changed

dptb/data/build.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dptb.data import AtomicDataset, register_fields
1414
from dptb.utils import instantiate, get_w_prefix
1515
from dptb.utils.tools import j_loader
16-
from dptb.utils.argcheck import normalize_setinfo
16+
from dptb.utils.argcheck import normalize_setinfo, normalize_lmdbsetinfo
1717

1818

1919
def dataset_from_config(config, prefix: str = "dataset") -> AtomicDataset:
@@ -153,7 +153,7 @@ def build_dataset(
153153
else:
154154
idp = None
155155

156-
if dataset_type in ["DefaultDataset", "DeePHDataset", "HDF5Dataset"]:
156+
if dataset_type in ["DefaultDataset", "DeePHDataset", "HDF5Dataset", "LMDBDataset"]:
157157

158158
# Explore the dataset's folder structure.
159159
#include_folders = []
@@ -176,7 +176,10 @@ def build_dataset(
176176
include_folders=[]
177177
for idir in prefix_folders:
178178
if os.path.isdir(idir):
179-
if not glob.glob(os.path.join(idir, '*.dat')) and not glob.glob(os.path.join(idir, '*.traj')) and not glob.glob(os.path.join(idir, '*.h5')):
179+
if not glob.glob(os.path.join(idir, '*.dat')) \
180+
and not glob.glob(os.path.join(idir, '*.traj')) \
181+
and not glob.glob(os.path.join(idir, '*.h5')) \
182+
and not glob.glob(os.path.join(idir, '*.mdb')):
180183
raise Exception(f"{idir} does not have the proper traj data files. Please check the data files.")
181184
include_folders.append(idir.split('/')[-1])
182185

@@ -191,7 +194,10 @@ def build_dataset(
191194
#if "info.json" in os.listdir(root):
192195
if os.path.exists(f"{root}/info.json"):
193196
public_info = j_loader(os.path.join(root, "info.json"))
194-
public_info = normalize_setinfo(public_info)
197+
if dataset_type == "LMDBDataset":
198+
public_info = normalize_lmdbsetinfo(public_info)
199+
else:
200+
public_info = normalize_setinfo(public_info)
195201
print("A public `info.json` file is provided, and will be used by the subfolders who do not have their own `info.json` file.")
196202
else:
197203
public_info = None
@@ -202,7 +208,10 @@ def build_dataset(
202208
if os.path.exists(f"{root}/{file}/info.json"):
203209
# use info provided in this trajectory.
204210
info = j_loader(f"{root}/{file}/info.json")
205-
info = normalize_setinfo(info)
211+
if dataset_type == "LMDBDataset":
212+
info = normalize_lmdbsetinfo(info)
213+
else:
214+
info = normalize_setinfo(info)
206215
info_files[file] = info
207216
elif public_info is not None:
208217
# use public info instead
@@ -234,7 +243,7 @@ def build_dataset(
234243
get_eigenvalues=get_eigenvalues,
235244
info_files = info_files
236245
)
237-
else:
246+
elif dataset_type == "HDF5Dataset":
238247
dataset = HDF5Dataset(
239248
root=root,
240249
type_mapper=idp,
@@ -244,40 +253,16 @@ def build_dataset(
244253
get_eigenvalues=get_eigenvalues,
245254
info_files = info_files
246255
)
247-
248-
elif dataset_type == "LMDBDataset":
249-
assert prefix is not None, "The prefix is not provided. Please provide the prefix to select the trajectory folders."
250-
prefix_folders = glob.glob(f"{root}/{prefix}*.lmdb")
251-
include_folders=[]
252-
for idir in prefix_folders:
253-
if os.path.isdir(idir):
254-
if not glob.glob(os.path.join(idir, '*.mdb')):
255-
raise Exception(f"{idir} does not have the proper traj data files. Please check the data files.")
256-
include_folders.append(idir.split('/')[-1])
257-
258-
assert isinstance(include_folders, list) and len(include_folders) == 1, "No trajectory folders are found. Please check the prefix."
259-
260-
# See if a public info is provided.
261-
#if "info.json" in os.listdir(root):
262-
263-
if os.path.exists(f"{root}/info.json"):
264-
info = j_loader(f"{root}/info.json")
265-
else:
266-
print("Please provide a info.json file.")
267-
raise Exception("info.json is not properly provided for this dataset.")
268-
269-
# We will sort the info_files here.
270-
# The order itself is not important, but must be consistant for the same list.
271-
272-
dataset = LMDBDataset(
273-
root=os.path.join(root, include_folders[0]),
256+
elif dataset_type == "LMDBDataset":
257+
dataset = LMDBDataset(
258+
root=root,
274259
type_mapper=idp,
275-
info=info,
276260
orthogonal=orthogonal,
277261
get_Hamiltonian=get_Hamiltonian,
278262
get_overlap=get_overlap,
279263
get_DM=get_DM,
280264
get_eigenvalues=get_eigenvalues,
265+
info_files = info_files
281266
)
282267

283268
else:

dptb/data/dataset/lmdb_dataset.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class LMDBDataset(AtomicDataset):
2525
def __init__(
2626
self,
2727
root: str,
28-
info: dict,
28+
info_files: dict,
2929
url: Optional[str] = None,
3030
include_frames: Optional[List[int]] = None,
3131
type_mapper: TypeMapper = None,
@@ -39,9 +39,7 @@ def __init__(
3939
# See if a subclass defines some inputs
4040
self.url = getattr(type(self), "URL", url)
4141
self.include_frames = include_frames
42-
self.info = info # there should be one info file for one LMDB Dataset
43-
44-
assert "r_max" in info
42+
self.info_files = info_files # there should be one info file for one LMDB Dataset
4543

4644

4745
self.data = None
@@ -66,10 +64,16 @@ def __init__(
6664
assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData."
6765

6866

69-
db_env = lmdb.open(os.path.join(self.root), readonly=True, lock=False)
70-
with db_env.begin() as txn:
71-
self.num_graphs = txn.stat()['entries']
72-
db_env.close()
67+
self.num_graphs = 0
68+
self.file_map = []
69+
self.index_map = []
70+
for file in self.info_files.keys():
71+
db_env = lmdb.open(os.path.join(self.root, file), readonly=True, lock=False)
72+
with db_env.begin() as txn:
73+
self.num_graphs += txn.stat()['entries']
74+
self.file_map += [file] * txn.stat()['entries']
75+
self.index_map += list(range(txn.stat()['entries']))
76+
db_env.close()
7377

7478
def len(self):
7579
return self.num_graphs
@@ -94,9 +98,9 @@ def download(self):
9498
extract_zip(download_path, self.raw_dir)
9599

96100
def get(self, idx):
97-
db_env = lmdb.open(os.path.join(self.root), readonly=True, lock=False)
101+
db_env = lmdb.open(os.path.join(self.root, self.file_map[idx]), readonly=True, lock=False)
98102
with db_env.begin() as txn:
99-
data_dict = txn.get(int(idx).to_bytes(length=4, byteorder='big'))
103+
data_dict = txn.get(self.index_map[int(idx)].to_bytes(length=4, byteorder='big'))
100104
data_dict = pickle.loads(data_dict)
101105
cell, pos, atomic_numbers = \
102106
data_dict[AtomicDataDict.CELL_KEY], \
@@ -141,7 +145,7 @@ def get(self, idx):
141145
cell=cell.reshape(3,3),
142146
atomic_numbers=atomic_numbers,
143147
pbc=pbc,
144-
**self.info
148+
**self.info_files[self.file_map[idx]]
145149
)
146150

147151
# transform blocks to atomicdata features

dptb/entrypoints/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def run(
9191
block = write_ham(data=struct_file, AtomicData_options=jdata['AtomicData_options'], model=model, device=jdata["device"])
9292
# write to h5 file, block is a dict, write to a h5 file
9393
with h5py.File(os.path.join(results_path, task+".h5"), 'w') as fid:
94-
default_group = fid.create_group("1")
94+
default_group = fid.create_group("0")
9595
for key_str, value in block.items():
9696
default_group[key_str] = value.detach().cpu().numpy()
9797
log.info(msg='write block successfully completed.')

dptb/utils/argcheck.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,10 +1423,41 @@ def set_info_options():
14231423

14241424
return Argument("setinfo", dict, sub_fields=args)
14251425

1426+
def set_info_options():
1427+
doc_nframes = "Number of frames in this trajectory."
1428+
doc_natoms = "Number of atoms in each frame."
1429+
doc_pos_type = "Type of atomic position input. Can be frac / cart / ase."
1430+
1431+
args = [
1432+
Argument("nframes", int, optional=False, doc=doc_nframes),
1433+
Argument("natoms", int, optional=True, default=-1, doc=doc_natoms),
1434+
Argument("pos_type", str, optional=False, doc=doc_pos_type),
1435+
bandinfo_sub(),
1436+
AtomicData_options_sub()
1437+
]
1438+
1439+
return Argument("setinfo", dict, sub_fields=args)
1440+
1441+
def lmdbset_info_options():
1442+
doc_r_max = "the cutoff value for bond considering in TB model."
1443+
1444+
args = [
1445+
Argument("r_max", [float, int, dict], optional=False, doc=doc_r_max, default=4.0)
1446+
]
1447+
return Argument("setinfo", dict, sub_fields=args)
1448+
14261449
def normalize_setinfo(data):
14271450

14281451
setinfo = set_info_options()
14291452
data = setinfo.normalize_value(data)
14301453
setinfo.check_value(data, strict=True)
14311454

1455+
return data
1456+
1457+
def normalize_lmdbsetinfo(data):
1458+
1459+
setinfo = lmdbset_info_options()
1460+
data = setinfo.normalize_value(data)
1461+
setinfo.check_value(data, strict=True)
1462+
14321463
return data

0 commit comments

Comments
 (0)