9
9
log = logging .getLogger (__name__ )
10
10
from dptb .data import AtomicData , AtomicDataDict
11
11
from dptb .nn .energy import Eigenvalues
12
+ from dptb .utils .argcheck import get_cutoffs_from_model_options
13
+ from copy import deepcopy
12
14
13
15
# This class `ElecStruCal` is designed to calculate electronic structure properties such as
14
16
# eigenvalues and Fermi energy based on provided input data and model.
@@ -61,8 +63,9 @@ def __init__ (
61
63
device = self .device ,
62
64
dtype = model .dtype ,
63
65
)
64
-
65
- def get_data (self ,data : Union [AtomicData , ase .Atoms , str ],AtomicData_options : dict = {},device : Union [str , torch .device ]= None ):
66
+ r_max , er_max , oer_max = get_cutoffs_from_model_options (model .model_options )
67
+ self .cutoffs = {'r_max' : r_max , 'er_max' : er_max , 'oer_max' : oer_max }
68
+ def get_data (self ,data : Union [AtomicData , ase .Atoms , str ],pbc :Union [bool ,list ]= None , device : Union [str , torch .device ]= None , Atomic_options :dict = None ):
66
69
'''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
67
70
object, processes it accordingly, and returns the AtomicData class.
68
71
@@ -83,15 +86,36 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
83
86
the loaded AtomicData object.
84
87
85
88
'''
89
+ atomic_options = deepcopy (self .cutoffs )
90
+ if pbc is not None :
91
+ atomic_options .update ({'pbc' : pbc })
92
+
93
+ if Atomic_options is not None :
94
+ if Atomic_options .get ('r_max' , None ) is not None :
95
+ if atomic_options ['r_max' ] != Atomic_options .get ('r_max' ):
96
+ atomic_options ['r_max' ] = Atomic_options .get ('r_max' )
97
+ log .warning (f'Overwrite the r_max setting in the model with the r_max setting in the Atomic_options: { Atomic_options .get ("r_max" )} ' )
98
+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
99
+ if Atomic_options .get ('er_max' , None ) is not None :
100
+ if atomic_options ['er_max' ] != Atomic_options .get ('er_max' ):
101
+ atomic_options ['er_max' ] = Atomic_options .get ('er_max' )
102
+ log .warning (f'Overwrite the er_max setting in the model with the er_max setting in the Atomic_options: { Atomic_options .get ("er_max" )} ' )
103
+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
104
+ if Atomic_options .get ('oer_max' , None ) is not None :
105
+ if atomic_options ['oer_max' ] != Atomic_options .get ('oer_max' ):
106
+ atomic_options ['oer_max' ] = Atomic_options .get ('oer_max' )
107
+ log .warning (f'Overwrite the oer_max setting in the model with the oer_max setting in the Atomic_options: { Atomic_options .get ("oer_max" )} ' )
108
+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
86
109
87
110
if isinstance (data , str ):
88
111
structase = read (data )
89
- data = AtomicData .from_ase (structase , ** AtomicData_options )
112
+ data = AtomicData .from_ase (structase , ** atomic_options )
90
113
elif isinstance (data , ase .Atoms ):
91
114
structase = data
92
- data = AtomicData .from_ase (structase , ** AtomicData_options )
115
+ data = AtomicData .from_ase (structase , ** atomic_options )
93
116
elif isinstance (data , AtomicData ):
94
117
# structase = data.to("cpu").to_ase()
118
+ log .info ('The data is already an instance of AtomicData. Then the data is used directly.' )
95
119
data = data
96
120
else :
97
121
raise ValueError ('data should be either a string, ase.Atoms, or AtomicData' )
@@ -104,7 +128,7 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
104
128
return data
105
129
106
130
107
- def get_eigs (self , data : Union [AtomicData , ase .Atoms , str ], klist : np .ndarray , AtomicData_options : dict = {} ):
131
+ def get_eigs (self , data : Union [AtomicData , ase .Atoms , str ], klist : np .ndarray , pbc : Union [ bool , list ] = None , Atomic_options : dict = None ):
108
132
'''This function calculates eigenvalues for Hk at specified k-points.
109
133
110
134
Parameters
@@ -124,7 +148,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
124
148
125
149
'''
126
150
127
- data = self .get_data (data = data , AtomicData_options = AtomicData_options , device = self .device )
151
+ data = self .get_data (data = data , pbc = pbc , device = self .device , Atomic_options = Atomic_options )
128
152
# set the kpoint of the AtomicData
129
153
data [AtomicDataDict .KPOINT_KEY ] = \
130
154
torch .nested .as_nested_tensor ([torch .as_tensor (klist , dtype = self .model .dtype , device = self .device )])
@@ -137,7 +161,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
137
161
return data , data [AtomicDataDict .ENERGY_EIGENVALUE_KEY ][0 ].detach ().cpu ().numpy ()
138
162
139
163
def get_fermi_level (self , data : Union [AtomicData , ase .Atoms , str ], nel_atom : dict , \
140
- meshgrid : list = None , klist : np .ndarray = None , AtomicData_options : dict = {} ):
164
+ meshgrid : list = None , klist : np .ndarray = None , pbc : Union [ bool , list ] = None , Atomic_options : dict = None ):
141
165
'''This function calculates the Fermi level based on provided data with iteration method, electron counts per atom, and
142
166
optional parameters like specific k-points and eigenvalues.
143
167
@@ -188,7 +212,7 @@ def get_fermi_level(self, data: Union[AtomicData, ase.Atoms, str], nel_atom: dic
188
212
189
213
# eigenvalues would be used if provided, otherwise the eigenvalues would be calculated from the model on the specified k-points
190
214
if not AtomicDataDict .ENERGY_EIGENVALUE_KEY in data :
191
- data , eigs = self .get_eigs (data = data , klist = klist , AtomicData_options = AtomicData_options )
215
+ data , eigs = self .get_eigs (data = data , klist = klist , pbc = pbc , Atomic_options = Atomic_options )
192
216
log .info ('Getting eigenvalues from the model.' )
193
217
else :
194
218
log .info ('The eigenvalues are already in data. will use them.' )
0 commit comments