13
13
from matplotlib .ticker import MultipleLocator , FormatStrFormatter
14
14
from dptb .data import AtomicData , AtomicDataDict
15
15
from dptb .nn .energy import Eigenvalues
16
-
16
+ from dptb . postprocess . elec_struc_cal import ElecStruCal
17
17
# class bandcalc(object):
18
18
# def __init__ (self, apiHrk, run_opt, jdata):
19
19
# self.apiH = apiHrk
163
163
# plt.show()
164
164
165
165
166
- class Band (object ):
167
- def __init__ (
168
- self ,
169
- model : torch .nn .Module ,
170
- results_path : Optional [str ]= None ,
171
- use_gui = False ,
172
- device : Union [str , torch .device ]= None
173
- ):
174
-
175
- if device is None :
176
- device = model .device
177
- if isinstance (device , str ):
178
- device = torch .device (device )
179
- self .device = device
180
- self .model = model
181
- self .model .eval ()
182
- self .use_gui = use_gui
166
+ class Band (ElecStruCal ):
167
+
168
+ def __init__ (self , model :torch .nn .Module , results_path : str = None , use_gui : bool = False , device : str = 'cpu' ):
169
+ super ().__init__ (model = model , device = device )
183
170
self .results_path = results_path
184
- self .overlap = hasattr (model , 'overlap' )
185
-
186
- if self .overlap :
187
- self .eigv = Eigenvalues (
188
- idp = model .idp ,
189
- device = self .device ,
190
- s_edge_field = AtomicDataDict .EDGE_OVERLAP_KEY ,
191
- s_node_field = AtomicDataDict .NODE_OVERLAP_KEY ,
192
- s_out_field = AtomicDataDict .OVERLAP_KEY ,
193
- dtype = model .dtype ,
194
- )
195
- else :
196
- self .eigv = Eigenvalues (
197
- idp = model .idp ,
198
- device = self .device ,
199
- dtype = model .dtype ,
200
- )
171
+ self .use_gui = use_gui
201
172
202
173
def get_bands (self , data : Union [AtomicData , ase .Atoms , str ], kpath_kwargs : dict , AtomicData_options : dict = {}):
203
174
kline_type = kpath_kwargs ['kline_type' ]
204
175
205
- # get the AtomicData structure and the ase structure
176
+ # get the ase structure
206
177
if isinstance (data , str ):
207
178
structase = read (data )
208
- data = AtomicData .from_ase (structase , ** AtomicData_options )
209
179
elif isinstance (data , ase .Atoms ):
210
180
structase = data
211
- data = AtomicData .from_ase (structase , ** AtomicData_options )
212
181
elif isinstance (data , AtomicData ):
213
182
structase = data .to ("cpu" ).to_ase ()
214
- data = data
215
-
216
-
217
- data = AtomicData .to_AtomicDataDict (data .to (self .device ))
218
- data = self .model .idp (data )
219
-
220
183
184
+
221
185
if kline_type == 'ase' :
222
186
kpath = kpath_kwargs ['kpath' ]
223
187
nkpoints = kpath_kwargs ['nkpoints' ]
@@ -239,48 +203,49 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
239
203
high_sym_kpoints = kpath_kwargs .get ('high_sym_kpoints' , None )
240
204
xlist = kpath_kwargs .get ('xlist' , None )
241
205
labels = kpath_kwargs .get ('labels' , None )
206
+
242
207
else :
243
208
log .error ('Error, now, kline_type only support ase_kpath, abacus, or vasp.' )
244
209
raise ValueError
245
210
246
- # set the kpoint of the AtomicData
247
- data [AtomicDataDict .KPOINT_KEY ] = torch .nested .as_nested_tensor ([torch .as_tensor (klist , dtype = self .model .dtype , device = self .device )])
248
-
249
- # get the eigenvalues
250
- data = self .model (data )
251
- if self .overlap == True :
252
- assert data .get (AtomicDataDict .EDGE_OVERLAP_KEY ) is not None
253
- data = self .eigv (data )
211
+ data , eigenvalues = self .get_eigs (data , klist , AtomicData_options )
212
+
254
213
255
214
# get the E_fermi from data
256
215
nel_atom = kpath_kwargs .get ('nel_atom' , None )
257
- assert isinstance (nel_atom , dict ) or nel_atom is None
216
+ # assert isinstance(nel_atom, dict) or nel_atom is None
258
217
218
+ # if nel_atom is not None:
219
+ # atomtype_list = self.data[AtomicDataDict.ATOM_TYPE_KEY].flatten().tolist()
220
+ # atomtype_symbols = np.asarray(self.model.idp.type_names)[atomtype_list].tolist()
221
+ # total_nel = np.array([nel_atom[s] for s in atomtype_symbols]).sum()
222
+ # if hasattr(self.model,'soc_param'):
223
+ # spindeg = 1
224
+ # else:
225
+ # spindeg = 2
226
+ # estimated_E_fermi = self.estimate_E_fermi(self.data[AtomicDataDict.ENERGY_EIGENVALUE_KEY][0].detach().cpu().numpy(), total_nel, spindeg)
227
+ # log.info(f'Estimated E_fermi: {estimated_E_fermi} based on the valence electrons setting nel_atom : {nel_atom} .')
228
+ # else:
229
+ # estimated_E_fermi = None
259
230
if nel_atom is not None :
260
- atomtype_list = data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ().tolist ()
261
- atomtype_symbols = np .asarray (self .model .idp .type_names )[atomtype_list ].tolist ()
262
- total_nel = np .array ([nel_atom [s ] for s in atomtype_symbols ]).sum ()
263
- if hasattr (self .model ,'soc_param' ):
264
- spindeg = 1
265
- else :
266
- spindeg = 2
267
- estimated_E_fermi = self .estimate_E_fermi (data [AtomicDataDict .ENERGY_EIGENVALUE_KEY ][0 ].detach ().cpu ().numpy (), total_nel , spindeg )
268
- log .info (f'Estimated E_fermi: { estimated_E_fermi } based on the valence electrons setting nel_atom : { nel_atom } .' )
231
+ data ,estimated_E_fermi = self .get_fermi_level (data = data , nel_atom = nel_atom , \
232
+ klist = klist , AtomicData_options = AtomicData_options )
269
233
else :
270
234
estimated_E_fermi = None
271
235
272
236
self .eigenstatus = {'klist' : klist ,
273
237
'xlist' : xlist ,
274
238
'high_sym_kpoints' : high_sym_kpoints ,
275
239
'labels' : labels ,
276
- 'eigenvalues' : data [ AtomicDataDict . ENERGY_EIGENVALUE_KEY ][ 0 ]. detach (). cpu (). numpy () ,
240
+ 'eigenvalues' : eigenvalues ,
277
241
'E_fermi' : estimated_E_fermi }
278
242
279
243
if self .results_path is not None :
280
244
np .save (f'{ self .results_path } /bandstructure' ,self .eigenstatus )
281
245
282
246
return self .eigenstatus
283
247
248
+
284
249
@classmethod
285
250
def estimate_E_fermi (cls , eigenvalues : np .array , total_electrons : int , spindeg : int = 2 ):
286
251
assert len (eigenvalues .shape ) == 2
@@ -290,7 +255,7 @@ def estimate_E_fermi(cls, eigenvalues: np.array, total_electrons: int, spindeg:
290
255
EF = (sorteigs [numek ] + sorteigs [numek - 1 ])/ 2
291
256
292
257
return EF
293
-
258
+
294
259
295
260
def band_plot (
296
261
self ,
@@ -329,6 +294,8 @@ def band_plot(
329
294
raise ValueError
330
295
331
296
if ref_band .shape [0 ] != self .eigenstatus ["eigenvalues" ].shape [0 ]:
297
+ print ('ref_band.shape[0]' ,ref_band .shape [0 ])
298
+ print ('self.eigenstatus["eigenvalues"].shape[0]' ,self .eigenstatus ["eigenvalues" ].shape [0 ])
332
299
log .error ("Reference Eigenvalues' should have sampled from the sample kpath as model's prediction." )
333
300
raise ValueError
334
301
ref_band = ref_band - (np .min (ref_band ) - np .min (self .eigenstatus ["eigenvalues" ]))
0 commit comments