Skip to content

Commit 1f26508

Browse files
add a new fermi level calculation method in band.py and its example (#176)
* add kmesh mode in band.py * add kmesh and run * add example for get_fermi * add abstract_process.py and update get_fermi.ipynb * add abstract_process in band.py * remove kmesh mode in band.py * update example for get_fermi * add get_eigs and get_fermi_level in abstract_process.py * update band.py with get_fermi * remove unnecessary packages * rename abstracprocess as elec_struc_cal.py * add docstring in elec_struc_cal.py * remove usegui and results_path in elec_struc_cal.py * use klist to calculate efermi in band.py * update get_fermi example * add unitest for get_fermi * add nnsk.best.pth * remove test * Refactor Band class to use torch.nn.Module for model parameter in __init__ * Refactor test_get_fermi to use meshgrid instead of kmesh --------- Co-authored-by: qqgu <guqq_phy@qq.com>
1 parent a2e503d commit 1f26508

File tree

10 files changed

+815
-64
lines changed

10 files changed

+815
-64
lines changed

dptb/postprocess/bandstructure/band.py

Lines changed: 31 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
1414
from dptb.data import AtomicData, AtomicDataDict
1515
from dptb.nn.energy import Eigenvalues
16-
16+
from dptb.postprocess.elec_struc_cal import ElecStruCal
1717
# class bandcalc(object):
1818
# def __init__ (self, apiHrk, run_opt, jdata):
1919
# self.apiH = apiHrk
@@ -163,61 +163,25 @@
163163
# plt.show()
164164

165165

166-
class Band(object):
167-
def __init__ (
168-
self,
169-
model: torch.nn.Module,
170-
results_path: Optional[str]=None,
171-
use_gui=False,
172-
device: Union[str, torch.device]=None
173-
):
174-
175-
if device is None:
176-
device = model.device
177-
if isinstance(device, str):
178-
device = torch.device(device)
179-
self.device = device
180-
self.model = model
181-
self.model.eval()
182-
self.use_gui = use_gui
166+
class Band(ElecStruCal):
167+
168+
def __init__(self, model:torch.nn.Module, results_path: str=None, use_gui: bool=False, device: str='cpu'):
169+
super().__init__(model=model, device=device)
183170
self.results_path = results_path
184-
self.overlap = hasattr(model, 'overlap')
185-
186-
if self.overlap:
187-
self.eigv = Eigenvalues(
188-
idp=model.idp,
189-
device=self.device,
190-
s_edge_field=AtomicDataDict.EDGE_OVERLAP_KEY,
191-
s_node_field=AtomicDataDict.NODE_OVERLAP_KEY,
192-
s_out_field=AtomicDataDict.OVERLAP_KEY,
193-
dtype=model.dtype,
194-
)
195-
else:
196-
self.eigv = Eigenvalues(
197-
idp=model.idp,
198-
device=self.device,
199-
dtype=model.dtype,
200-
)
171+
self.use_gui = use_gui
201172

202173
def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict, AtomicData_options: dict={}):
203174
kline_type = kpath_kwargs['kline_type']
204175

205-
# get the AtomicData structure and the ase structure
176+
# get the ase structure
206177
if isinstance(data, str):
207178
structase = read(data)
208-
data = AtomicData.from_ase(structase, **AtomicData_options)
209179
elif isinstance(data, ase.Atoms):
210180
structase = data
211-
data = AtomicData.from_ase(structase, **AtomicData_options)
212181
elif isinstance(data, AtomicData):
213182
structase = data.to("cpu").to_ase()
214-
data = data
215-
216-
217-
data = AtomicData.to_AtomicDataDict(data.to(self.device))
218-
data = self.model.idp(data)
219-
220183

184+
221185
if kline_type == 'ase':
222186
kpath = kpath_kwargs['kpath']
223187
nkpoints = kpath_kwargs['nkpoints']
@@ -239,48 +203,49 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
239203
high_sym_kpoints = kpath_kwargs.get('high_sym_kpoints', None)
240204
xlist = kpath_kwargs.get('xlist', None)
241205
labels = kpath_kwargs.get('labels', None)
206+
242207
else:
243208
log.error('Error, now, kline_type only support ase_kpath, abacus, or vasp.')
244209
raise ValueError
245210

246-
# set the kpoint of the AtomicData
247-
data[AtomicDataDict.KPOINT_KEY] = torch.nested.as_nested_tensor([torch.as_tensor(klist, dtype=self.model.dtype, device=self.device)])
248-
249-
# get the eigenvalues
250-
data = self.model(data)
251-
if self.overlap == True:
252-
assert data.get(AtomicDataDict.EDGE_OVERLAP_KEY) is not None
253-
data = self.eigv(data)
211+
data, eigenvalues = self.get_eigs(data, klist, AtomicData_options)
212+
254213

255214
# get the E_fermi from data
256215
nel_atom = kpath_kwargs.get('nel_atom', None)
257-
assert isinstance(nel_atom, dict) or nel_atom is None
216+
# assert isinstance(nel_atom, dict) or nel_atom is None
258217

218+
# if nel_atom is not None:
219+
# atomtype_list = self.data[AtomicDataDict.ATOM_TYPE_KEY].flatten().tolist()
220+
# atomtype_symbols = np.asarray(self.model.idp.type_names)[atomtype_list].tolist()
221+
# total_nel = np.array([nel_atom[s] for s in atomtype_symbols]).sum()
222+
# if hasattr(self.model,'soc_param'):
223+
# spindeg = 1
224+
# else:
225+
# spindeg = 2
226+
# estimated_E_fermi = self.estimate_E_fermi(self.data[AtomicDataDict.ENERGY_EIGENVALUE_KEY][0].detach().cpu().numpy(), total_nel, spindeg)
227+
# log.info(f'Estimated E_fermi: {estimated_E_fermi} based on the valence electrons setting nel_atom : {nel_atom} .')
228+
# else:
229+
# estimated_E_fermi = None
259230
if nel_atom is not None:
260-
atomtype_list = data[AtomicDataDict.ATOM_TYPE_KEY].flatten().tolist()
261-
atomtype_symbols = np.asarray(self.model.idp.type_names)[atomtype_list].tolist()
262-
total_nel = np.array([nel_atom[s] for s in atomtype_symbols]).sum()
263-
if hasattr(self.model,'soc_param'):
264-
spindeg = 1
265-
else:
266-
spindeg = 2
267-
estimated_E_fermi = self.estimate_E_fermi(data[AtomicDataDict.ENERGY_EIGENVALUE_KEY][0].detach().cpu().numpy(), total_nel, spindeg)
268-
log.info(f'Estimated E_fermi: {estimated_E_fermi} based on the valence electrons setting nel_atom : {nel_atom} .')
231+
data,estimated_E_fermi = self.get_fermi_level(data=data, nel_atom=nel_atom, \
232+
klist = klist, AtomicData_options=AtomicData_options)
269233
else:
270234
estimated_E_fermi = None
271235

272236
self.eigenstatus = {'klist': klist,
273237
'xlist': xlist,
274238
'high_sym_kpoints': high_sym_kpoints,
275239
'labels': labels,
276-
'eigenvalues': data[AtomicDataDict.ENERGY_EIGENVALUE_KEY][0].detach().cpu().numpy(),
240+
'eigenvalues': eigenvalues,
277241
'E_fermi': estimated_E_fermi}
278242

279243
if self.results_path is not None:
280244
np.save(f'{self.results_path}/bandstructure',self.eigenstatus)
281245

282246
return self.eigenstatus
283247

248+
284249
@classmethod
285250
def estimate_E_fermi(cls, eigenvalues: np.array, total_electrons: int, spindeg: int=2):
286251
assert len(eigenvalues.shape) == 2
@@ -290,7 +255,7 @@ def estimate_E_fermi(cls, eigenvalues: np.array, total_electrons: int, spindeg:
290255
EF=(sorteigs[numek] + sorteigs[numek-1])/2
291256

292257
return EF
293-
258+
294259

295260
def band_plot(
296261
self,
@@ -329,6 +294,8 @@ def band_plot(
329294
raise ValueError
330295

331296
if ref_band.shape[0] != self.eigenstatus["eigenvalues"].shape[0]:
297+
print('ref_band.shape[0]',ref_band.shape[0])
298+
print('self.eigenstatus["eigenvalues"].shape[0]',self.eigenstatus["eigenvalues"].shape[0])
332299
log.error("Reference Eigenvalues' should have sampled from the sample kpath as model's prediction.")
333300
raise ValueError
334301
ref_band = ref_band - (np.min(ref_band) - np.min(self.eigenstatus["eigenvalues"]))

0 commit comments

Comments
 (0)