1
+ import os
2
+ import h5py
1
3
import numpy as np
2
4
from ase .io import read
3
5
import ase
7
9
from typing import Optional
8
10
import logging
9
11
log = logging .getLogger (__name__ )
10
- from dptb .data import AtomicData , AtomicDataDict
12
+ from dptb .data import AtomicData , AtomicDataDict , block_to_feature
11
13
from dptb .nn .energy import Eigenvalues
12
14
from dptb .utils .argcheck import get_cutoffs_from_model_options
13
15
from copy import deepcopy
@@ -66,7 +68,12 @@ def __init__ (
66
68
)
67
69
r_max , er_max , oer_max = get_cutoffs_from_model_options (model .model_options )
68
70
self .cutoffs = {'r_max' : r_max , 'er_max' : er_max , 'oer_max' : oer_max }
69
- def get_data (self ,data : Union [AtomicData , ase .Atoms , str ],pbc :Union [bool ,list ]= None , device : Union [str , torch .device ]= None , AtomicData_options :dict = None ):
71
+ def get_data (self ,
72
+ data : Union [AtomicData , ase .Atoms , str ],
73
+ pbc :Union [bool ,list ]= None ,
74
+ device : Union [str , torch .device ]= None ,
75
+ AtomicData_options :dict = None ,
76
+ override_overlap :Optional [str ]= None ):
70
77
'''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
71
78
object, processes it accordingly, and returns the AtomicData class.
72
79
@@ -81,6 +88,7 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
81
88
device : Union[str, torch.device]
82
89
The `device` parameter in the `get_data` function is used to specify the device on which the data
83
90
should be processed. If no device is provided, it defaults to `self.device`.
91
+ override_overlap : the path for overlap.h5 to use and override overlap matrix from model.
84
92
85
93
Returns
86
94
-------
@@ -130,7 +138,30 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
130
138
data = data
131
139
else :
132
140
raise ValueError ('data should be either a string, ase.Atoms, or AtomicData' )
133
-
141
+
142
+ if isinstance (override_overlap , str ):
143
+ assert os .path .exists (override_overlap ), "Overlap file not found."
144
+ overlap_blocks = h5py .File (override_overlap , "r" )
145
+ if len (overlap_blocks ) != 1 :
146
+ log .info ('Overlap file contains more than one overlap matrix, only first will be used.' )
147
+ if self .overlap :
148
+ log .warning ('override_overlap is enabled while model contains overlap, override_overlap will be used.' )
149
+ if "0" in overlap_blocks :
150
+ overlaps = overlap_blocks ["0" ]
151
+ else :
152
+ overlaps = overlap_blocks ["1" ]
153
+ block_to_feature (data , self .model .idp , blocks = False , overlap_blocks = overlaps )
154
+ if not self .overlap :
155
+ self .eigv = Eigenvalues (
156
+ idp = self .model .idp ,
157
+ device = self .device ,
158
+ s_edge_field = AtomicDataDict .EDGE_OVERLAP_KEY ,
159
+ s_node_field = AtomicDataDict .NODE_OVERLAP_KEY ,
160
+ s_out_field = AtomicDataDict .OVERLAP_KEY ,
161
+ dtype = self .model .dtype ,
162
+ )
163
+ overlap_blocks .close ()
164
+
134
165
if device is None :
135
166
device = self .device
136
167
data = AtomicData .to_AtomicDataDict (data .to (device ))
@@ -139,7 +170,12 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
139
170
return data
140
171
141
172
142
- def get_eigs (self , data : Union [AtomicData , ase .Atoms , str ], klist : np .ndarray , pbc :Union [bool ,list ]= None , AtomicData_options :dict = None ):
173
+ def get_eigs (self ,
174
+ data : Union [AtomicData , ase .Atoms , str ],
175
+ klist : np .ndarray ,
176
+ pbc :Union [bool ,list ]= None ,
177
+ AtomicData_options :dict = None ,
178
+ override_overlap :Optional [str ]= None ):
143
179
'''This function calculates eigenvalues for Hk at specified k-points.
144
180
145
181
Parameters
@@ -152,20 +188,27 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, p
152
188
AtomicData_options : dict
153
189
The `AtomicData_options` parameter is a dictionary that contains options for configuring the
154
190
`AtomicData` object.
191
+ override_overlap : the path for overlap.h5 to use and override overlap matrix from model.
155
192
156
193
Returns
157
194
-------
158
195
The function `get_eigs` returns the loaded data and the energy eigenvalues as a numpy array.
159
196
160
197
'''
161
198
162
- data = self .get_data (data = data , pbc = pbc , device = self .device ,AtomicData_options = AtomicData_options )
199
+ data = self .get_data (data = data , pbc = pbc , device = self .device ,AtomicData_options = AtomicData_options , override_overlap = override_overlap )
163
200
# set the kpoint of the AtomicData
164
201
data [AtomicDataDict .KPOINT_KEY ] = \
165
202
torch .nested .as_nested_tensor ([torch .as_tensor (klist , dtype = self .model .dtype , device = self .device )])
203
+ if isinstance (override_overlap , str ):
204
+ override_overlap_edge = data [AtomicDataDict .EDGE_OVERLAP_KEY ]
205
+ override_overlap_node = data [AtomicDataDict .NODE_OVERLAP_KEY ]
166
206
# get the eigenvalues
167
207
data = self .model (data )
168
- if self .overlap == True :
208
+ if isinstance (override_overlap , str ):
209
+ data [AtomicDataDict .EDGE_OVERLAP_KEY ] = override_overlap_edge
210
+ data [AtomicDataDict .NODE_OVERLAP_KEY ] = override_overlap_node
211
+ if self .overlap or isinstance (override_overlap , str ):
169
212
assert data .get (AtomicDataDict .EDGE_OVERLAP_KEY ) is not None
170
213
data = self .eigv (data )
171
214
0 commit comments