Skip to content

Commit 9b12282

Browse files
committed
update band post process.
1 parent 41a67a6 commit 9b12282

File tree

8 files changed

+137
-42
lines changed

8 files changed

+137
-42
lines changed

dptb/postprocess/bandstructure/band.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def __init__(self, model:torch.nn.Module, results_path: str=None, use_gui: bool=
170170
self.results_path = results_path
171171
self.use_gui = use_gui
172172

173-
def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict, AtomicData_options: dict={}):
173+
def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict, pbc:Union[bool,list]=None, Atomic_options:dict=None):
174174
kline_type = kpath_kwargs['kline_type']
175175

176176
# get the ase structure
@@ -208,7 +208,7 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
208208
log.error('Error, now, kline_type only support ase_kpath, abacus, or vasp.')
209209
raise ValueError
210210

211-
data, eigenvalues = self.get_eigs(data, klist, AtomicData_options)
211+
data, eigenvalues = self.get_eigs(data=data, klist=klist, pbc=pbc, Atomic_options=Atomic_options)
212212

213213

214214
# get the E_fermi from data
@@ -229,7 +229,7 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
229229
# estimated_E_fermi = None
230230
if nel_atom is not None:
231231
data,estimated_E_fermi = self.get_fermi_level(data=data, nel_atom=nel_atom, \
232-
klist = klist, AtomicData_options=AtomicData_options)
232+
klist = klist, pbc=pbc, Atomic_options=Atomic_options)
233233
else:
234234
estimated_E_fermi = None
235235

dptb/postprocess/elec_struc_cal.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
log = logging.getLogger(__name__)
1010
from dptb.data import AtomicData, AtomicDataDict
1111
from dptb.nn.energy import Eigenvalues
12+
from dptb.utils.argcheck import get_cutoffs_from_model_options
13+
from copy import deepcopy
1214

1315
# This class `ElecStruCal` is designed to calculate electronic structure properties such as
1416
# eigenvalues and Fermi energy based on provided input data and model.
@@ -61,8 +63,9 @@ def __init__ (
6163
device=self.device,
6264
dtype=model.dtype,
6365
)
64-
65-
def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: dict={},device: Union[str, torch.device]=None):
66+
r_max, er_max, oer_max = get_cutoffs_from_model_options(model.model_options)
67+
self.cutoffs = {'r_max': r_max, 'er_max': er_max, 'oer_max': oer_max}
68+
def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=None, device: Union[str, torch.device]=None, Atomic_options:dict=None):
6669
'''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
6770
object, processes it accordingly, and returns the AtomicData class.
6871
@@ -83,15 +86,36 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
8386
the loaded AtomicData object.
8487
8588
'''
89+
atomic_options = deepcopy(self.cutoffs)
90+
if pbc is not None:
91+
atomic_options.update({'pbc': pbc})
92+
93+
if Atomic_options is not None:
94+
if Atomic_options.get('r_max', None) is not None:
95+
if atomic_options['r_max'] != Atomic_options.get('r_max'):
96+
atomic_options['r_max'] = Atomic_options.get('r_max')
97+
log.warning(f'Overwrite the r_max setting in the model with the r_max setting in the Atomic_options: {Atomic_options.get("r_max")}')
98+
log.warning(f'This is very dangerous, please make sure you know what you are doing.')
99+
if Atomic_options.get('er_max', None) is not None:
100+
if atomic_options['er_max'] != Atomic_options.get('er_max'):
101+
atomic_options['er_max'] = Atomic_options.get('er_max')
102+
log.warning(f'Overwrite the er_max setting in the model with the er_max setting in the Atomic_options: {Atomic_options.get("er_max")}')
103+
log.warning(f'This is very dangerous, please make sure you know what you are doing.')
104+
if Atomic_options.get('oer_max', None) is not None:
105+
if atomic_options['oer_max'] != Atomic_options.get('oer_max'):
106+
atomic_options['oer_max'] = Atomic_options.get('oer_max')
107+
log.warning(f'Overwrite the oer_max setting in the model with the oer_max setting in the Atomic_options: {Atomic_options.get("oer_max")}')
108+
log.warning(f'This is very dangerous, please make sure you know what you are doing.')
86109

87110
if isinstance(data, str):
88111
structase = read(data)
89-
data = AtomicData.from_ase(structase, **AtomicData_options)
112+
data = AtomicData.from_ase(structase, **atomic_options)
90113
elif isinstance(data, ase.Atoms):
91114
structase = data
92-
data = AtomicData.from_ase(structase, **AtomicData_options)
115+
data = AtomicData.from_ase(structase, **atomic_options)
93116
elif isinstance(data, AtomicData):
94117
# structase = data.to("cpu").to_ase()
118+
log.info('The data is already an instance of AtomicData. Then the data is used directly.')
95119
data = data
96120
else:
97121
raise ValueError('data should be either a string, ase.Atoms, or AtomicData')
@@ -104,7 +128,7 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
104128
return data
105129

106130

107-
def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, AtomicData_options: dict={}):
131+
def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, pbc:Union[bool,list]=None, Atomic_options:dict=None):
108132
'''This function calculates eigenvalues for Hk at specified k-points.
109133
110134
Parameters
@@ -124,7 +148,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
124148
125149
'''
126150

127-
data = self.get_data(data=data, AtomicData_options=AtomicData_options, device=self.device)
151+
data = self.get_data(data=data, pbc=pbc, device=self.device,Atomic_options=Atomic_options)
128152
# set the kpoint of the AtomicData
129153
data[AtomicDataDict.KPOINT_KEY] = \
130154
torch.nested.as_nested_tensor([torch.as_tensor(klist, dtype=self.model.dtype, device=self.device)])
@@ -137,7 +161,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
137161
return data, data[AtomicDataDict.ENERGY_EIGENVALUE_KEY][0].detach().cpu().numpy()
138162

139163
def get_fermi_level(self, data: Union[AtomicData, ase.Atoms, str], nel_atom: dict, \
140-
meshgrid: list = None, klist: np.ndarray=None, AtomicData_options: dict={}):
164+
meshgrid: list = None, klist: np.ndarray=None, pbc:Union[bool,list]=None,Atomic_options:dict=None):
141165
'''This function calculates the Fermi level based on provided data with iteration method, electron counts per atom, and
142166
optional parameters like specific k-points and eigenvalues.
143167
@@ -188,7 +212,7 @@ def get_fermi_level(self, data: Union[AtomicData, ase.Atoms, str], nel_atom: dic
188212

189213
# eigenvalues would be used if provided, otherwise the eigenvalues would be calculated from the model on the specified k-points
190214
if not AtomicDataDict.ENERGY_EIGENVALUE_KEY in data:
191-
data, eigs = self.get_eigs(data=data, klist=klist, AtomicData_options=AtomicData_options)
215+
data, eigs = self.get_eigs(data=data, klist=klist, pbc=pbc, Atomic_options=Atomic_options)
192216
log.info('Getting eigenvalues from the model.')
193217
else:
194218
log.info('The eigenvalues are already in data. will use them.')

dptb/tests/test_from_v1json.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,8 @@ def test_bands(self):
6666
device=model.device)
6767

6868
stru_data = f"{rootdir}/json_model/AlAs.vasp"
69-
AtomicData_options = {"r_max": 5.2, "pbc": True}
70-
7169
eigenstatus = bcal.get_bands(data=stru_data,
72-
kpath_kwargs=kpath_kwargs,
73-
AtomicData_options=AtomicData_options)
70+
kpath_kwargs=kpath_kwargs)
7471

7572
expected_bands =np.array([[-2.48727150e+01, -1.29382324e+01, -1.29382257e+01, -1.29382229e+01, -1.10868120e+01, -8.07862854e+00, -8.07862568e+00, -8.07861805e+00, 9.56408596e+00, 9.56408691e+00, 1.25271873e+01, 1.25271950e+01, 1.25271978e+01, 4.23655891e+01, 4.23656044e+01, 4.32170753e+01, 4.32170792e+01, 4.32170868e+01],
7673
[-2.41187267e+01, -1.61148472e+01, -1.42793083e+01, -1.42793045e+01, -1.03604565e+01, -8.68612957e+00, -5.90628624e+00, -5.90628576e+00, 2.25617599e+00, 5.51729870e+00, 5.51730347e+00, 5.61441135e+00, 5.90860081e+00, 2.50449829e+01, 2.82622643e+01, 2.82622776e+01, 2.84239502e+01, 3.07470131e+01],
@@ -149,11 +146,10 @@ def test_bands(self):
149146
device=model.device)
150147

151148
stru_data = f"{rootdir}/json_model/silicon.vasp"
152-
AtomicData_options = {"r_max": 2.6, "oer_max":2.5, "pbc": True}
149+
AtomicData_options = {"r_max": 2.6, "oer_max":2.5}
153150

154151
eigenstatus = bcal.get_bands(data=stru_data,
155-
kpath_kwargs=kpath_kwargs,
156-
AtomicData_options=AtomicData_options)
152+
kpath_kwargs=kpath_kwargs,Atomic_options=AtomicData_options)
157153

158154
expected_bands =np.array([[-20.259584 , -8.328452 , -8.328452 , -8.328451 , -5.782879 , -5.782879 , -5.7828774 , -4.800206 , -0.8470682 , -0.8470663 , 4.9619126 , 4.961913 , 4.9619136 , 6.4527135 , 6.452714 , 6.452715 , 10.1427765 , 10.142781 ],
159155
[-19.173727 , -11.876228 , -10.340221 , -10.34022 , -6.861969 , -4.9920564 , -2.1901789 , -2.1901765 , -0.9258757 , 0.76235735, 4.2745295 , 4.2745323 , 4.990632 , 5.55916 , 5.559161 , 8.533346 , 8.716906 , 11.661528 ],

dptb/tests/test_from_v2json.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ def test_bands(self):
4242
device=model.device)
4343

4444
stru_data = f"{rootdir}/json_model/AlAs.vasp"
45-
AtomicData_options = {"r_max": 5.2, "pbc": True}
45+
AtomicData_options = {"r_max": 5.2}
4646

4747
eigenstatus = bcal.get_bands(data=stru_data,
48-
kpath_kwargs=kpath_kwargs,
49-
AtomicData_options=AtomicData_options)
48+
kpath_kwargs=kpath_kwargs)
5049

5150
expected_bands =np.array([[-2.48727150e+01, -1.29382324e+01, -1.29382257e+01, -1.29382229e+01, -1.10868120e+01, -8.07862854e+00, -8.07862568e+00, -8.07861805e+00, 9.56408596e+00, 9.56408691e+00, 1.25271873e+01, 1.25271950e+01, 1.25271978e+01, 4.23655891e+01, 4.23656044e+01, 4.32170753e+01, 4.32170792e+01, 4.32170868e+01],
5251
[-2.41187267e+01, -1.61148472e+01, -1.42793083e+01, -1.42793045e+01, -1.03604565e+01, -8.68612957e+00, -5.90628624e+00, -5.90628576e+00, 2.25617599e+00, 5.51729870e+00, 5.51730347e+00, 5.61441135e+00, 5.90860081e+00, 2.50449829e+01, 2.82622643e+01, 2.82622776e+01, 2.84239502e+01, 3.07470131e+01],
@@ -99,11 +98,11 @@ def test_bands(self):
9998
device=model.device)
10099

101100
stru_data = f"{rootdir}/json_model/silicon.vasp"
102-
AtomicData_options = {"r_max": 2.6, "oer_max":2.5, "pbc": True}
101+
AtomicData_options = {"r_max": 2.6, "oer_max":2.5}
103102

104103
eigenstatus = bcal.get_bands(data=stru_data,
105104
kpath_kwargs=kpath_kwargs,
106-
AtomicData_options=AtomicData_options)
105+
Atomic_options=AtomicData_options)
107106

108107
expected_bands =np.array([[-20.259584 , -8.328452 , -8.328452 , -8.328451 , -5.782879 , -5.782879 , -5.7828774 , -4.800206 , -0.8470682 , -0.8470663 , 4.9619126 , 4.961913 , 4.9619136 , 6.4527135 , 6.452714 , 6.452715 , 10.1427765 , 10.142781 ],
109108
[-19.173727 , -11.876228 , -10.340221 , -10.34022 , -6.861969 , -4.9920564 , -2.1901789 , -2.1901765 , -0.9258757 , 0.76235735, 4.2745295 , 4.2745323 , 4.990632 , 5.55916 , 5.559161 , 8.533346 , 8.716906 , 11.661528 ],

dptb/tests/test_get_fermi.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,12 @@ def test_get_fermi():
1313
stru_data = f"{rootdir}/test_get_fermi/PRIMCELL.vasp"
1414

1515
model = build_model(checkpoint=ckpt)
16-
AtomicData_options={
17-
"r_max": 5.50,
18-
"pbc": True
19-
}
20-
21-
AtomicData_options = AtomicData_options
2216
nel_atom = {"Au":11}
2317

2418
elec_cal = ElecStruCal(model=model,device='cpu')
2519
_, efermi =elec_cal.get_fermi_level(data=stru_data,
2620
nel_atom = nel_atom,
27-
meshgrid=[30,30,30],
28-
AtomicData_options=AtomicData_options)
21+
meshgrid=[30,30,30])
2922

3023
assert abs(efermi + 3.25725233554) < 1e-5
3124

dptb/tests/test_nrl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,15 @@ def test_nrl_json_band():
4444
}
4545

4646
stru_data = f"{rootdir}/json_model/silicon.vasp"
47-
AtomicData_options = {"r_max": 5.0, "oer_max":6.6147151362875, "pbc": True}
47+
AtomicData_options = {"r_max": 5.0, "oer_max":6.6147151362875}
4848
kpath_kwargs = jdata["task_options"]
4949
bcal = Band(model=model,
5050
use_gui=True,
5151
results_path='./',
5252
device=model.device)
5353

5454
eigenstatus = bcal.get_bands(data=stru_data,
55-
kpath_kwargs=kpath_kwargs,
56-
AtomicData_options=AtomicData_options)
55+
kpath_kwargs=kpath_kwargs, Atomic_options = AtomicData_options)
5756

5857
expected_eigenvalues = np.array([[-6.1745434 , 5.282297 , 5.282303 , 5.2823052 , 8.658317 , 8.6583185 , 8.658324 , 9.862869 , 14.152446 , 14.152451 , 15.180438 , 15.180452 , 16.983887 , 16.983889 , 16.983896 , 23.09491 , 23.094921 , 23.094925 ],
5958
[-5.5601606 , 2.1920488 , 3.4229636 , 3.4229672 , 7.347074 , 9.382092 , 11.1772175 , 11.177221 , 14.349099 , 14.924912 , 15.062427 , 15.064081 , 16.540335 , 16.54034 , 20.871534 , 20.871536 , 21.472364 , 28.740482 ],

dptb/tests/test_soc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,11 @@ def test_soc_json_band():
4444
device=model.device)
4545

4646
stru_data = f"{rootdir}/Sn/soc/dataset/Sn.vasp"
47-
AtomicData_options = {"r_max": 6.0, "oer_max":3.0, "pbc": True}
47+
AtomicData_options = {"r_max": 6.0, "oer_max":3.0}
4848

4949
eigenstatus = bcal.get_bands(data=stru_data,
5050
kpath_kwargs=kpath_kwargs,
51-
AtomicData_options=AtomicData_options)
51+
Atomic_options=AtomicData_options)
5252

5353
expected_eigenvalues = np.array([[-18.796585 , -18.796577 , -8.796718 , -8.796717 ,
5454
-8.467822 , -8.46782 , -8.202273 , -8.202273 ,

examples/hBN/band_plot.ipynb

Lines changed: 90 additions & 6 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)