Skip to content

Commit 95739a2

Browse files
committed
Merge branch 'main' of https://github.com/QG-phy/DeePTB into hotfix_tojson
2 parents 9292b49 + f9c4d66 commit 95739a2

22 files changed

+580
-14
lines changed

dptb/nn/nnsk.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,8 @@ def to_json(self, version=2, basisref=None):
10211021
rev_line = self.idp_sk.transform_bond(jan, ian)
10221022
for orbpair, slices in self.idp_sk.orbpair_maps.items():
10231023
fiorb, fjorb = orbpair.split("-")
1024+
if fiorb not in self.idp_sk.full_basis_to_basis[iasym] or fjorb not in self.idp_sk.full_basis_to_basis[jasym]:
1025+
continue
10241026
iorb = self.idp_sk.full_basis_to_basis[iasym].get(fiorb)
10251027
jorb = self.idp_sk.full_basis_to_basis[jasym].get(fjorb)
10261028

@@ -1064,6 +1066,8 @@ def to_json(self, version=2, basisref=None):
10641066
rev_line = self.idp_sk.transform_bond(jan, ian)
10651067
for orbpair, slices in self.idp_sk.orbpair_maps.items():
10661068
fiorb, fjorb = orbpair.split("-")
1069+
if fiorb not in self.idp_sk.full_basis_to_basis[iasym] or fjorb not in self.idp_sk.full_basis_to_basis[jasym]:
1070+
continue
10671071
iorb = self.idp_sk.full_basis_to_basis[iasym].get(fiorb)
10681072
jorb = self.idp_sk.full_basis_to_basis[jasym].get(fjorb)
10691073

@@ -1167,6 +1171,8 @@ def to_json(self, version=2, basisref=None):
11671171
soc_param = {}
11681172
for asym in self.idp_sk.type_names:
11691173
for fiorb, slices in self.idp_sk.sksoc_maps.items():
1174+
if fiorb not in self.idp_sk.full_basis_to_basis[asym]:
1175+
continue
11701176
iorb = self.idp_sk.full_basis_to_basis[asym][fiorb]
11711177
if fiorb not in self.idp_sk.full_basis_to_basis[asym]:
11721178
continue

dptb/postprocess/bandstructure/band.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,12 @@ def get_bands(self, data: Union[AtomicData, ase.Atoms, str], kpath_kwargs: dict,
207207
else:
208208
log.error('Error, now, kline_type only support ase_kpath, abacus, or vasp.')
209209
raise ValueError
210+
211+
override_overlap = None
212+
if kpath_kwargs.get("override_overlap", None):
213+
override_overlap = kpath_kwargs["override_overlap"]
210214

211-
data, eigenvalues = self.get_eigs(data=data, klist=klist, pbc=pbc, AtomicData_options=AtomicData_options)
215+
data, eigenvalues = self.get_eigs(data=data, klist=klist, pbc=pbc, AtomicData_options=AtomicData_options, override_overlap=override_overlap)
212216

213217

214218
# get the E_fermi from data

dptb/postprocess/elec_struc_cal.py

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import h5py
13
import numpy as np
24
from ase.io import read
35
import ase
@@ -7,7 +9,7 @@
79
from typing import Optional
810
import logging
911
log = logging.getLogger(__name__)
10-
from dptb.data import AtomicData, AtomicDataDict
12+
from dptb.data import AtomicData, AtomicDataDict, block_to_feature
1113
from dptb.nn.energy import Eigenvalues
1214
from dptb.utils.argcheck import get_cutoffs_from_model_options
1315
from copy import deepcopy
@@ -66,7 +68,12 @@ def __init__ (
6668
)
6769
r_max, er_max, oer_max = get_cutoffs_from_model_options(model.model_options)
6870
self.cutoffs = {'r_max': r_max, 'er_max': er_max, 'oer_max': oer_max}
69-
def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=None, device: Union[str, torch.device]=None, AtomicData_options:dict=None):
71+
def get_data(self,
72+
data: Union[AtomicData, ase.Atoms, str],
73+
pbc:Union[bool,list]=None,
74+
device: Union[str, torch.device]=None,
75+
AtomicData_options:dict=None,
76+
override_overlap:Optional[str]=None):
7077
'''The function `get_data` takes input data in the form of a string, ase.Atoms object, or AtomicData
7178
object, processes it accordingly, and returns the AtomicData class.
7279
@@ -81,6 +88,7 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
8188
device : Union[str, torch.device]
8289
The `device` parameter in the `get_data` function is used to specify the device on which the data
8390
should be processed. If no device is provided, it defaults to `self.device`.
91+
override_overlap : the path for overlap.h5 to use and override overlap matrix from model.
8492
8593
Returns
8694
-------
@@ -130,7 +138,30 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
130138
data = data
131139
else:
132140
raise ValueError('data should be either a string, ase.Atoms, or AtomicData')
133-
141+
142+
if isinstance(override_overlap, str):
143+
assert os.path.exists(override_overlap), "Overlap file not found."
144+
overlap_blocks = h5py.File(override_overlap, "r")
145+
if len(overlap_blocks) != 1:
146+
log.info('Overlap file contains more than one overlap matrix, only first will be used.')
147+
if self.overlap:
148+
log.warning('override_overlap is enabled while model contains overlap, override_overlap will be used.')
149+
if "0" in overlap_blocks:
150+
overlaps = overlap_blocks["0"]
151+
else:
152+
overlaps = overlap_blocks["1"]
153+
block_to_feature(data, self.model.idp, blocks=False, overlap_blocks=overlaps)
154+
if not self.overlap:
155+
self.eigv = Eigenvalues(
156+
idp=self.model.idp,
157+
device=self.device,
158+
s_edge_field=AtomicDataDict.EDGE_OVERLAP_KEY,
159+
s_node_field=AtomicDataDict.NODE_OVERLAP_KEY,
160+
s_out_field=AtomicDataDict.OVERLAP_KEY,
161+
dtype=self.model.dtype,
162+
)
163+
overlap_blocks.close()
164+
134165
if device is None:
135166
device = self.device
136167
data = AtomicData.to_AtomicDataDict(data.to(device))
@@ -139,7 +170,12 @@ def get_data(self,data: Union[AtomicData, ase.Atoms, str],pbc:Union[bool,list]=N
139170
return data
140171

141172

142-
def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, pbc:Union[bool,list]=None, AtomicData_options:dict=None):
173+
def get_eigs(self,
174+
data: Union[AtomicData, ase.Atoms, str],
175+
klist: np.ndarray,
176+
pbc:Union[bool,list]=None,
177+
AtomicData_options:dict=None,
178+
override_overlap:Optional[str]=None):
143179
'''This function calculates eigenvalues for Hk at specified k-points.
144180
145181
Parameters
@@ -152,20 +188,27 @@ def get_eigs(self, data: Union[AtomicData, ase.Atoms, str], klist: np.ndarray, p
152188
AtomicData_options : dict
153189
The `AtomicData_options` parameter is a dictionary that contains options for configuring the
154190
`AtomicData` object.
191+
override_overlap : the path for overlap.h5 to use and override overlap matrix from model.
155192
156193
Returns
157194
-------
158195
The function `get_eigs` returns the loaded data and the energy eigenvalues as a numpy array.
159196
160197
'''
161198

162-
data = self.get_data(data=data, pbc=pbc, device=self.device,AtomicData_options=AtomicData_options)
199+
data = self.get_data(data=data, pbc=pbc, device=self.device,AtomicData_options=AtomicData_options, override_overlap=override_overlap)
163200
# set the kpoint of the AtomicData
164201
data[AtomicDataDict.KPOINT_KEY] = \
165202
torch.nested.as_nested_tensor([torch.as_tensor(klist, dtype=self.model.dtype, device=self.device)])
203+
if isinstance(override_overlap, str):
204+
override_overlap_edge = data[AtomicDataDict.EDGE_OVERLAP_KEY]
205+
override_overlap_node = data[AtomicDataDict.NODE_OVERLAP_KEY]
166206
# get the eigenvalues
167207
data = self.model(data)
168-
if self.overlap == True:
208+
if isinstance(override_overlap, str):
209+
data[AtomicDataDict.EDGE_OVERLAP_KEY] = override_overlap_edge
210+
data[AtomicDataDict.NODE_OVERLAP_KEY] = override_overlap_node
211+
if self.overlap or isinstance(override_overlap, str):
169212
assert data.get(AtomicDataDict.EDGE_OVERLAP_KEY) is not None
170213
data = self.eigv(data)
171214

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
14
2+
14
3+
14
4+
14
5+
14
6+
14
7+
14
8+
14
9+
14
10+
14
11+
14
12+
14
13+
14
14+
14
15+
14
16+
14
17+
14
18+
14
19+
14
20+
14
21+
14
22+
14
23+
14
24+
14
25+
14
26+
14
27+
14
28+
14
29+
14
30+
14
31+
14
32+
14
33+
14
34+
14
35+
14
36+
14
37+
14
38+
14
39+
14
40+
14
41+
14
42+
14
43+
14
44+
14
45+
14
46+
14
47+
14
48+
14
49+
14
50+
14
51+
14
52+
14
53+
14
54+
14
55+
14
56+
14
57+
14
58+
14
59+
14
60+
14
61+
14
62+
14
63+
14
64+
14
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{'Si': '1s1p'}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
1.088740444183349609e+01 0.000000000000000000e+00 0.000000000000000000e+00
2+
0.000000000000000000e+00 1.088740444183349609e+01 0.000000000000000000e+00
3+
0.000000000000000000e+00 0.000000000000000000e+00 1.088740444183349609e+01
Binary file not shown.
992 Bytes
Binary file not shown.
2.21 MB
Binary file not shown.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
1.000000000000000000e+00
2+
1.000000000000000000e+00
3+
1.000000000000000000e+00

0 commit comments

Comments
 (0)