21
21
from dptb .data .AtomicDataDict import with_edge_vectors
22
22
from dptb .nn .hamiltonian import E3Hamiltonian
23
23
from tqdm import tqdm
24
+ import logging
25
+
26
+ log = logging .getLogger (__name__ )
24
27
25
28
class _TrajData (object ):
26
29
'''
@@ -40,67 +43,18 @@ class _TrajData(object):
40
43
41
44
def __init__ (self ,
42
45
root : str ,
46
+ data = {},
43
47
get_Hamiltonian = False ,
44
48
get_overlap = False ,
45
49
get_DM = False ,
46
50
get_eigenvalues = False ,
47
- info = None ,
48
- _clear = False ):
51
+ info = None ):
49
52
50
53
assert not get_Hamiltonian * get_DM , "Hamiltonian and Density Matrix can only loaded one at a time, for which will occupy the same attribute in the AtomicData."
51
54
self .root = root
52
55
self .info = info
53
- self .data = {}
54
- pbc = info ["pbc" ]
55
- # load cell
56
- if isinstance (pbc , bool ):
57
- has_cell = pbc
58
- elif isinstance (pbc , list ):
59
- has_cell = any (pbc )
60
- else :
61
- raise ValueError ("pbc must be bool or list." )
62
-
63
- if has_cell :
64
- cell = np .loadtxt (os .path .join (root , "cell.dat" ))
65
- if cell .shape [0 ] == 3 :
66
- # same cell size, then copy it to all frames.
67
- cell = np .expand_dims (cell , axis = 0 )
68
- self .data ["cell" ] = np .broadcast_to (cell , (self .info ["nframes" ], 3 , 3 ))
69
- elif cell .shape [0 ] == self .info ["nframes" ] * 3 :
70
- self .data ["cell" ] = cell .reshape (self .info ["nframes" ], 3 , 3 )
71
- else :
72
- raise ValueError ("Wrong cell dimensions." )
73
-
74
- # load positions, stored as cartesion no matter what provided.
75
- pos = np .loadtxt (os .path .join (root , "positions.dat" ))
76
- if len (pos .shape ) == 1 :
77
- pos = pos .reshape (1 ,3 )
78
- natoms = self .info ["natoms" ]
79
- if natoms < 0 :
80
- natoms = int (pos .shape [0 ] / self .info ["nframes" ])
81
- assert pos .shape [0 ] == self .info ["nframes" ] * natoms
82
- pos = pos .reshape (self .info ["nframes" ], natoms , 3 )
83
- # ase use cartesian by default.
84
- if self .info ["pos_type" ] == "cart" or self .info ["pos_type" ] == "ase" :
85
- self .data ["pos" ] = pos
86
- elif self .info ["pos_type" ] == "frac" :
87
- self .data ["pos" ] = pos @ self .data ["cell" ]
88
- else :
89
- raise NameError ("Position type must be cart / frac." )
90
-
91
- # load atomic numbers
92
- atomic_numbers = np .loadtxt (os .path .join (root , "atomic_numbers.dat" ))
93
- if atomic_numbers .shape == ():
94
- atomic_numbers = atomic_numbers .reshape (1 )
95
- if atomic_numbers .shape [0 ] == natoms :
96
- # same atomic_numbers, copy it to all frames.
97
- atomic_numbers = np .expand_dims (atomic_numbers , axis = 0 )
98
- self .data ["atomic_numbers" ] = np .broadcast_to (atomic_numbers , (self .info ["nframes" ], natoms ))
99
- elif atomic_numbers .shape [0 ] == natoms * self .info ["nframes" ]:
100
- self .data ["atomic_numbers" ] = atomic_numbers .reshape (self .info ["nframes" ],natoms )
101
- else :
102
- raise ValueError ("Wrong atomic_number dimensions." )
103
-
56
+ self .data = data
57
+
104
58
# load optional data files
105
59
if get_eigenvalues == True :
106
60
if os .path .exists (os .path .join (self .root , "eigenvalues.npy" )):
@@ -142,12 +96,74 @@ def __init__(self,
142
96
else :
143
97
self .data ["DM_blocks" ] = h5py .File (os .path .join (self .root , "DM.h5" ), "r" )
144
98
145
- # this is used to clear the tmp files to load ase trajectory only.
146
- if _clear :
147
- os .remove (os .path .join (root , "positions.dat" ))
148
- os .remove (os .path .join (root , "cell.dat" ))
149
- os .remove (os .path .join (root , "atomic_numbers.dat" ))
150
-
99
+ @classmethod
100
+ def from_text_data (cls ,
101
+ root : str ,
102
+ get_Hamiltonian = False ,
103
+ get_overlap = False ,
104
+ get_DM = False ,
105
+ get_eigenvalues = False ,
106
+ info = None ):
107
+
108
+ data = {}
109
+ pbc = info ["pbc" ]
110
+ # load cell
111
+ if isinstance (pbc , bool ):
112
+ has_cell = pbc
113
+ elif isinstance (pbc , list ):
114
+ has_cell = any (pbc )
115
+ else :
116
+ raise ValueError ("pbc must be bool or list." )
117
+
118
+ if has_cell :
119
+ cell = np .loadtxt (os .path .join (root , "cell.dat" ))
120
+ if cell .shape [0 ] == 3 :
121
+ # same cell size, then copy it to all frames.
122
+ cell = np .expand_dims (cell , axis = 0 )
123
+ data ["cell" ] = np .broadcast_to (cell , (info ["nframes" ], 3 , 3 ))
124
+ elif cell .shape [0 ] == info ["nframes" ] * 3 :
125
+ data ["cell" ] = cell .reshape (info ["nframes" ], 3 , 3 )
126
+ else :
127
+ raise ValueError ("Wrong cell dimensions." )
128
+
129
+ # load positions, stored as cartesion no matter what provided.
130
+ pos = np .loadtxt (os .path .join (root , "positions.dat" ))
131
+ if len (pos .shape ) == 1 :
132
+ pos = pos .reshape (1 ,3 )
133
+ natoms = info ["natoms" ]
134
+ if natoms < 0 :
135
+ natoms = int (pos .shape [0 ] / info ["nframes" ])
136
+ assert pos .shape [0 ] == info ["nframes" ] * natoms
137
+ pos = pos .reshape (info ["nframes" ], natoms , 3 )
138
+ # ase use cartesian by default.
139
+ if info ["pos_type" ] == "cart" or info ["pos_type" ] == "ase" :
140
+ data ["pos" ] = pos
141
+ elif info ["pos_type" ] == "frac" :
142
+ data ["pos" ] = pos @ data ["cell" ]
143
+ else :
144
+ raise NameError ("Position type must be cart / frac." )
145
+
146
+ # load atomic numbers
147
+ atomic_numbers = np .loadtxt (os .path .join (root , "atomic_numbers.dat" ))
148
+ if atomic_numbers .shape == ():
149
+ atomic_numbers = atomic_numbers .reshape (1 )
150
+ if atomic_numbers .shape [0 ] == natoms :
151
+ # same atomic_numbers, copy it to all frames.
152
+ atomic_numbers = np .expand_dims (atomic_numbers , axis = 0 )
153
+ data ["atomic_numbers" ] = np .broadcast_to (atomic_numbers , (info ["nframes" ], natoms ))
154
+ elif atomic_numbers .shape [0 ] == natoms * info ["nframes" ]:
155
+ data ["atomic_numbers" ] = atomic_numbers .reshape (info ["nframes" ],natoms )
156
+ else :
157
+ raise ValueError ("Wrong atomic_number dimensions." )
158
+
159
+ return cls (root = root ,
160
+ data = data ,
161
+ get_Hamiltonian = get_Hamiltonian ,
162
+ get_overlap = get_overlap ,
163
+ get_DM = get_DM ,
164
+ get_eigenvalues = get_eigenvalues ,
165
+ info = info )
166
+
151
167
@classmethod
152
168
def from_ase_traj (cls ,
153
169
root : str ,
@@ -162,30 +178,63 @@ def from_ase_traj(cls,
162
178
traj_file = glob .glob (f"{ root } /*.traj" )
163
179
assert len (traj_file ) == 1 , print ("only one ase trajectory file can be provided." )
164
180
traj = Trajectory (traj_file [0 ], 'r' )
181
+ nframes = len (traj )
182
+ assert nframes > 0 , print ("trajectory file is empty." )
183
+ if nframes != info .get ("nframes" , None ):
184
+ info ['nframes' ] = nframes
185
+ log .info (f"Number of frames ({ nframes } ) in trajectory file does not match the number of frames in info file." )
186
+
187
+ natoms = traj [0 ].positions .shape [0 ]
188
+ if natoms != info ["natoms" ]:
189
+ info ["natoms" ] = natoms
190
+
191
+ pbc = info .get ("pbc" ,None )
192
+ if pbc is None :
193
+ pbc = traj [0 ].pbc .tolist ()
194
+ info ["pbc" ] = pbc
195
+
196
+ if isinstance (pbc , bool ):
197
+ pbc = [pbc ] * 3
198
+
199
+ if pbc != traj [0 ].pbc .tolist ():
200
+ log .warning ("!! PBC setting in info file does not match the PBC setting in trajectory file, we use the one in info json. BE CAREFUL!" )
201
+
165
202
positions = []
166
203
cell = []
167
204
atomic_numbers = []
205
+
168
206
for atoms in traj :
169
207
positions .append (atoms .get_positions ())
170
- cell . append ( atoms . get_cell ())
208
+
171
209
atomic_numbers .append (atoms .get_atomic_numbers ())
210
+ if (np .abs (atoms .get_cell ()- np .zeros ([3 ,3 ]))< 1e-6 ).all ():
211
+ cell = None
212
+ else :
213
+ cell .append (atoms .get_cell ())
214
+
172
215
positions = np .array (positions )
173
- positions = positions .reshape (- 1 , 3 )
174
- cell = np .array (cell )
175
- cell = cell .reshape (- 1 , 3 )
216
+ positions = positions .reshape (nframes ,natoms , 3 )
217
+
218
+ if cell is not None :
219
+ cell = np .array (cell )
220
+ cell = cell .reshape (nframes ,3 , 3 )
221
+
176
222
atomic_numbers = np .array (atomic_numbers )
177
- atomic_numbers = atomic_numbers .reshape (- 1 , 1 )
178
- np .savetxt (os .path .join (root , "positions.dat" ), positions )
179
- np .savetxt (os .path .join (root , "cell.dat" ), cell )
180
- np .savetxt (os .path .join (root , "atomic_numbers.dat" ), atomic_numbers , fmt = '%d' )
223
+ atomic_numbers = atomic_numbers .reshape (nframes , natoms )
224
+
225
+ data = {}
226
+ if cell is not None :
227
+ data ["cell" ] = cell
228
+ data ["pos" ] = positions
229
+ data ["atomic_numbers" ] = atomic_numbers
181
230
182
231
return cls (root = root ,
232
+ data = data ,
183
233
get_Hamiltonian = get_Hamiltonian ,
184
234
get_overlap = get_overlap ,
185
235
get_DM = get_DM ,
186
236
get_eigenvalues = get_eigenvalues ,
187
- info = info ,
188
- _clear = True )
237
+ info = info )
189
238
190
239
def toAtomicDataList (self , idp : TypeMapper = None ):
191
240
data_list = []
@@ -307,7 +356,7 @@ def __init__(
307
356
get_eigenvalues ,
308
357
info = info )
309
358
else :
310
- subdata = _TrajData (os .path .join (self .root , file ),
359
+ subdata = _TrajData . from_text_data (os .path .join (self .root , file ),
311
360
get_Hamiltonian ,
312
361
get_overlap ,
313
362
get_DM ,
0 commit comments