Skip to content

Commit 7eb57e4

Browse files
committed
refactor(default_dataset): refactor the _TrajData for ase data.
Previous the ase data will be transferred into text file and then loaded by the _TrajData. now i refactor the function. both text and ase data are treated equally. will works as a class funtion to initial the _TrajData class.
1 parent 3a1e1ef commit 7eb57e4

File tree

1 file changed

+119
-70
lines changed

1 file changed

+119
-70
lines changed

dptb/data/dataset/_default_dataset.py

Lines changed: 119 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from dptb.data.AtomicDataDict import with_edge_vectors
2222
from dptb.nn.hamiltonian import E3Hamiltonian
2323
from tqdm import tqdm
24+
import logging
25+
26+
log = logging.getLogger(__name__)
2427

2528
class _TrajData(object):
2629
'''
@@ -40,67 +43,18 @@ class _TrajData(object):
4043

4144
def __init__(self,
4245
root: str,
46+
data ={},
4347
get_Hamiltonian = False,
4448
get_overlap = False,
4549
get_DM = False,
4650
get_eigenvalues = False,
47-
info = None,
48-
_clear = False):
51+
info = None):
4952

5053
assert not get_Hamiltonian * get_DM, "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData."
5154
self.root = root
5255
self.info = info
53-
self.data = {}
54-
pbc = info["pbc"]
55-
# load cell
56-
if isinstance(pbc, bool):
57-
has_cell = pbc
58-
elif isinstance(pbc, list):
59-
has_cell = any(pbc)
60-
else:
61-
raise ValueError("pbc must be bool or list.")
62-
63-
if has_cell:
64-
cell = np.loadtxt(os.path.join(root, "cell.dat"))
65-
if cell.shape[0] == 3:
66-
# same cell size, then copy it to all frames.
67-
cell = np.expand_dims(cell, axis=0)
68-
self.data["cell"] = np.broadcast_to(cell, (self.info["nframes"], 3, 3))
69-
elif cell.shape[0] == self.info["nframes"] * 3:
70-
self.data["cell"] = cell.reshape(self.info["nframes"], 3, 3)
71-
else:
72-
raise ValueError("Wrong cell dimensions.")
73-
74-
# load positions, stored as cartesion no matter what provided.
75-
pos = np.loadtxt(os.path.join(root, "positions.dat"))
76-
if len(pos.shape) == 1:
77-
pos = pos.reshape(1,3)
78-
natoms = self.info["natoms"]
79-
if natoms < 0:
80-
natoms = int(pos.shape[0] / self.info["nframes"])
81-
assert pos.shape[0] == self.info["nframes"] * natoms
82-
pos = pos.reshape(self.info["nframes"], natoms, 3)
83-
# ase use cartesian by default.
84-
if self.info["pos_type"] == "cart" or self.info["pos_type"] == "ase":
85-
self.data["pos"] = pos
86-
elif self.info["pos_type"] == "frac":
87-
self.data["pos"] = pos @ self.data["cell"]
88-
else:
89-
raise NameError("Position type must be cart / frac.")
90-
91-
# load atomic numbers
92-
atomic_numbers = np.loadtxt(os.path.join(root, "atomic_numbers.dat"))
93-
if atomic_numbers.shape == ():
94-
atomic_numbers = atomic_numbers.reshape(1)
95-
if atomic_numbers.shape[0] == natoms:
96-
# same atomic_numbers, copy it to all frames.
97-
atomic_numbers = np.expand_dims(atomic_numbers, axis=0)
98-
self.data["atomic_numbers"] = np.broadcast_to(atomic_numbers, (self.info["nframes"], natoms))
99-
elif atomic_numbers.shape[0] == natoms * self.info["nframes"]:
100-
self.data["atomic_numbers"] = atomic_numbers.reshape(self.info["nframes"],natoms)
101-
else:
102-
raise ValueError("Wrong atomic_number dimensions.")
103-
56+
self.data = data
57+
10458
# load optional data files
10559
if get_eigenvalues == True:
10660
if os.path.exists(os.path.join(self.root, "eigenvalues.npy")):
@@ -142,12 +96,74 @@ def __init__(self,
14296
else:
14397
self.data["DM_blocks"] = h5py.File(os.path.join(self.root, "DM.h5"), "r")
14498

145-
# this is used to clear the tmp files to load ase trajectory only.
146-
if _clear:
147-
os.remove(os.path.join(root, "positions.dat"))
148-
os.remove(os.path.join(root, "cell.dat"))
149-
os.remove(os.path.join(root, "atomic_numbers.dat"))
150-
99+
@classmethod
100+
def from_text_data(cls,
101+
root: str,
102+
get_Hamiltonian = False,
103+
get_overlap = False,
104+
get_DM = False,
105+
get_eigenvalues = False,
106+
info = None):
107+
108+
data = {}
109+
pbc = info["pbc"]
110+
# load cell
111+
if isinstance(pbc, bool):
112+
has_cell = pbc
113+
elif isinstance(pbc, list):
114+
has_cell = any(pbc)
115+
else:
116+
raise ValueError("pbc must be bool or list.")
117+
118+
if has_cell:
119+
cell = np.loadtxt(os.path.join(root, "cell.dat"))
120+
if cell.shape[0] == 3:
121+
# same cell size, then copy it to all frames.
122+
cell = np.expand_dims(cell, axis=0)
123+
data["cell"] = np.broadcast_to(cell, (info["nframes"], 3, 3))
124+
elif cell.shape[0] == info["nframes"] * 3:
125+
data["cell"] = cell.reshape(info["nframes"], 3, 3)
126+
else:
127+
raise ValueError("Wrong cell dimensions.")
128+
129+
# load positions, stored as cartesion no matter what provided.
130+
pos = np.loadtxt(os.path.join(root, "positions.dat"))
131+
if len(pos.shape) == 1:
132+
pos = pos.reshape(1,3)
133+
natoms = info["natoms"]
134+
if natoms < 0:
135+
natoms = int(pos.shape[0] / info["nframes"])
136+
assert pos.shape[0] == info["nframes"] * natoms
137+
pos = pos.reshape(info["nframes"], natoms, 3)
138+
# ase use cartesian by default.
139+
if info["pos_type"] == "cart" or info["pos_type"] == "ase":
140+
data["pos"] = pos
141+
elif info["pos_type"] == "frac":
142+
data["pos"] = pos @ data["cell"]
143+
else:
144+
raise NameError("Position type must be cart / frac.")
145+
146+
# load atomic numbers
147+
atomic_numbers = np.loadtxt(os.path.join(root, "atomic_numbers.dat"))
148+
if atomic_numbers.shape == ():
149+
atomic_numbers = atomic_numbers.reshape(1)
150+
if atomic_numbers.shape[0] == natoms:
151+
# same atomic_numbers, copy it to all frames.
152+
atomic_numbers = np.expand_dims(atomic_numbers, axis=0)
153+
data["atomic_numbers"] = np.broadcast_to(atomic_numbers, (info["nframes"], natoms))
154+
elif atomic_numbers.shape[0] == natoms * info["nframes"]:
155+
data["atomic_numbers"] = atomic_numbers.reshape(info["nframes"],natoms)
156+
else:
157+
raise ValueError("Wrong atomic_number dimensions.")
158+
159+
return cls(root=root,
160+
data=data,
161+
get_Hamiltonian=get_Hamiltonian,
162+
get_overlap=get_overlap,
163+
get_DM=get_DM,
164+
get_eigenvalues=get_eigenvalues,
165+
info=info)
166+
151167
@classmethod
152168
def from_ase_traj(cls,
153169
root: str,
@@ -162,30 +178,63 @@ def from_ase_traj(cls,
162178
traj_file = glob.glob(f"{root}/*.traj")
163179
assert len(traj_file) == 1, print("only one ase trajectory file can be provided.")
164180
traj = Trajectory(traj_file[0], 'r')
181+
nframes = len(traj)
182+
assert nframes > 0, print("trajectory file is empty.")
183+
if nframes != info.get("nframes", None):
184+
info['nframes'] = nframes
185+
log.info(f"Number of frames ({nframes}) in trajectory file does not match the number of frames in info file.")
186+
187+
natoms = traj[0].positions.shape[0]
188+
if natoms != info["natoms"]:
189+
info["natoms"] = natoms
190+
191+
pbc = info.get("pbc",None)
192+
if pbc is None:
193+
pbc = traj[0].pbc.tolist()
194+
info["pbc"] = pbc
195+
196+
if isinstance(pbc, bool):
197+
pbc = [pbc] * 3
198+
199+
if pbc != traj[0].pbc.tolist():
200+
log.warning("!! PBC setting in info file does not match the PBC setting in trajectory file, we use the one in info json. BE CAREFUL!")
201+
165202
positions = []
166203
cell = []
167204
atomic_numbers = []
205+
168206
for atoms in traj:
169207
positions.append(atoms.get_positions())
170-
cell.append(atoms.get_cell())
208+
171209
atomic_numbers.append(atoms.get_atomic_numbers())
210+
if (np.abs(atoms.get_cell()-np.zeros([3,3]))< 1e-6).all():
211+
cell = None
212+
else:
213+
cell.append(atoms.get_cell())
214+
172215
positions = np.array(positions)
173-
positions = positions.reshape(-1, 3)
174-
cell = np.array(cell)
175-
cell = cell.reshape(-1, 3)
216+
positions = positions.reshape(nframes,natoms, 3)
217+
218+
if cell is not None:
219+
cell = np.array(cell)
220+
cell = cell.reshape(nframes,3, 3)
221+
176222
atomic_numbers = np.array(atomic_numbers)
177-
atomic_numbers = atomic_numbers.reshape(-1, 1)
178-
np.savetxt(os.path.join(root, "positions.dat"), positions)
179-
np.savetxt(os.path.join(root, "cell.dat"), cell)
180-
np.savetxt(os.path.join(root, "atomic_numbers.dat"), atomic_numbers, fmt='%d')
223+
atomic_numbers = atomic_numbers.reshape(nframes, natoms)
224+
225+
data = {}
226+
if cell is not None:
227+
data["cell"] = cell
228+
data["pos"] = positions
229+
data["atomic_numbers"] = atomic_numbers
181230

182231
return cls(root=root,
232+
data=data,
183233
get_Hamiltonian=get_Hamiltonian,
184234
get_overlap=get_overlap,
185235
get_DM=get_DM,
186236
get_eigenvalues=get_eigenvalues,
187-
info=info,
188-
_clear=True)
237+
info=info)
189238

190239
def toAtomicDataList(self, idp: TypeMapper = None):
191240
data_list = []
@@ -307,7 +356,7 @@ def __init__(
307356
get_eigenvalues,
308357
info=info)
309358
else:
310-
subdata = _TrajData(os.path.join(self.root, file),
359+
subdata = _TrajData.from_text_data(os.path.join(self.root, file),
311360
get_Hamiltonian,
312361
get_overlap,
313362
get_DM,

0 commit comments

Comments
 (0)