99log = logging .getLogger (__name__ )
1010from dptb .data import AtomicData , AtomicDataDict
1111from dptb .nn .energy import Eigenvalues
12+ from dptb .utils .argcheck import get_cutoffs_from_model_options
13+ from copy import deepcopy
1214
1315# This class `ElecStruCal` is designed to calculate electronic structure properties such as
1416# eigenvalues and Fermi energy based on provided input data and model.
@@ -61,8 +63,9 @@ def __init__ (
6163 device = self .device ,
6264 dtype = model .dtype ,
6365 )
64-
65- def get_data (self ,data : Union [AtomicData , ase .Atoms , str ],AtomicData_options : dict = {},device : Union [str , torch .device ]= None ):
66+ r_max , er_max , oer_max = get_cutoffs_from_model_options (model .model_options )
67+ self .cutoffs = {'r_max' : r_max , 'er_max' : er_max , 'oer_max' : oer_max }
68+ def get_data (self ,data : Union [AtomicData , ase .Atoms , str ],pbc :Union [bool ,list ]= None , device : Union [str , torch .device ]= None , Atomic_options :dict = None ):
6669 '''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
6770 object, processes it accordingly, and returns the AtomicData class.
6871
@@ -83,15 +86,36 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
8386 the loaded AtomicData object.
8487
8588 '''
89+ atomic_options = deepcopy (self .cutoffs )
90+ if pbc is not None :
91+ atomic_options .update ({'pbc' : pbc })
92+
93+ if Atomic_options is not None :
94+ if Atomic_options .get ('r_max' , None ) is not None :
95+ if atomic_options ['r_max' ] != Atomic_options .get ('r_max' ):
96+ atomic_options ['r_max' ] = Atomic_options .get ('r_max' )
97+ log .warning (f'Overwrite the r_max setting in the model with the r_max setting in the Atomic_options: { Atomic_options .get ("r_max" )} ' )
98+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
99+ if Atomic_options .get ('er_max' , None ) is not None :
100+ if atomic_options ['er_max' ] != Atomic_options .get ('er_max' ):
101+ atomic_options ['er_max' ] = Atomic_options .get ('er_max' )
102+ log .warning (f'Overwrite the er_max setting in the model with the er_max setting in the Atomic_options: { Atomic_options .get ("er_max" )} ' )
103+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
104+ if Atomic_options .get ('oer_max' , None ) is not None :
105+ if atomic_options ['oer_max' ] != Atomic_options .get ('oer_max' ):
106+ atomic_options ['oer_max' ] = Atomic_options .get ('oer_max' )
107+ log .warning (f'Overwrite the oer_max setting in the model with the oer_max setting in the Atomic_options: { Atomic_options .get ("oer_max" )} ' )
108+ log .warning (f'This is very dangerous, please make sure you know what you are doing.' )
86109
87110 if isinstance (data , str ):
88111 structase = read (data )
89- data = AtomicData .from_ase (structase , ** AtomicData_options )
112+ data = AtomicData .from_ase (structase , ** atomic_options )
90113 elif isinstance (data , ase .Atoms ):
91114 structase = data
92- data = AtomicData .from_ase (structase , ** AtomicData_options )
115+ data = AtomicData .from_ase (structase , ** atomic_options )
93116 elif isinstance (data , AtomicData ):
94117 # structase = data.to("cpu").to_ase()
118+ log .info ('The data is already an instance of AtomicData. Then the data is used directly.' )
95119 data = data
96120 else :
97121 raise ValueError ('data should be either a string, ase.Atoms, or AtomicData' )
@@ -104,7 +128,7 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],AtomicData_options: di
104128 return data
105129
106130
107- def get_eigs (self , data : Union [AtomicData , ase .Atoms , str ], klist : np .ndarray , AtomicData_options : dict = {} ):
131+ def get_eigs (self , data : Union [AtomicData , ase .Atoms , str ], klist : np .ndarray , pbc : Union [ bool , list ] = None , Atomic_options : dict = None ):
108132 '''This function calculates eigenvalues for Hk at specified k-points.
109133
110134 Parameters
@@ -124,7 +148,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
124148
125149 '''
126150
127- data = self .get_data (data = data , AtomicData_options = AtomicData_options , device = self .device )
151+ data = self .get_data (data = data , pbc = pbc , device = self .device , Atomic_options = Atomic_options )
128152 # set the kpoint of the AtomicData
129153 data [AtomicDataDict .KPOINT_KEY ] = \
130154 torch .nested .as_nested_tensor ([torch .as_tensor (klist , dtype = self .model .dtype , device = self .device )])
@@ -137,7 +161,7 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, A
137161 return data , data [AtomicDataDict .ENERGY_EIGENVALUE_KEY ][0 ].detach ().cpu ().numpy ()
138162
139163 def get_fermi_level (self , data : Union [AtomicData , ase .Atoms , str ], nel_atom : dict , \
140- meshgrid : list = None , klist : np .ndarray = None , AtomicData_options : dict = {} ):
164+ meshgrid : list = None , klist : np .ndarray = None , pbc : Union [ bool , list ] = None , Atomic_options : dict = None ):
141165 '''This function calculates the Fermi level based on provided data with iteration method, electron counts per atom, and
142166 optional parameters like specific k-points and eigenvalues.
143167
@@ -188,7 +212,7 @@ def get_fermi_level(self, data: Union[AtomicData, ase.Atoms, str], nel_atom: dic
188212
189213 # eigenvalues would be used if provided, otherwise the eigenvalues would be calculated from the model on the specified k-points
190214 if not AtomicDataDict .ENERGY_EIGENVALUE_KEY in data :
191- data , eigs = self .get_eigs (data = data , klist = klist , AtomicData_options = AtomicData_options )
215+ data , eigs = self .get_eigs (data = data , klist = klist , pbc = pbc , Atomic_options = Atomic_options )
192216 log .info ('Getting eigenvalues from the model.' )
193217 else :
194218 log .info ('The eigenvalues are already in data. will use them.' )
0 commit comments