Skip to content

Commit 07c9871

Browse files
authored
fix: add checks for basis existence in NNSK methods to prevent key errors (deepmodeling#257)
1 parent b666d17 commit 07c9871

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

dptb/nn/nnsk.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,8 @@ def to_json(self, version=2, basisref=None):
10211021
rev_line = self.idp_sk.transform_bond(jan, ian)
10221022
for orbpair, slices in self.idp_sk.orbpair_maps.items():
10231023
fiorb, fjorb = orbpair.split("-")
1024+
if fiorb not in self.idp_sk.full_basis_to_basis[iasym] or fjorb not in self.idp_sk.full_basis_to_basis[jasym]:
1025+
continue
10241026
iorb = self.idp_sk.full_basis_to_basis[iasym].get(fiorb)
10251027
jorb = self.idp_sk.full_basis_to_basis[jasym].get(fjorb)
10261028
if to_uniform:
@@ -1060,6 +1062,8 @@ def to_json(self, version=2, basisref=None):
10601062
rev_line = self.idp_sk.transform_bond(jan, ian)
10611063
for orbpair, slices in self.idp_sk.orbpair_maps.items():
10621064
fiorb, fjorb = orbpair.split("-")
1065+
if fiorb not in self.idp_sk.full_basis_to_basis[iasym] or fjorb not in self.idp_sk.full_basis_to_basis[jasym]:
1066+
continue
10631067
iorb = self.idp_sk.full_basis_to_basis[iasym].get(fiorb)
10641068
jorb = self.idp_sk.full_basis_to_basis[jasym].get(fjorb)
10651069
if to_uniform:
@@ -1093,6 +1097,8 @@ def to_json(self, version=2, basisref=None):
10931097
for asym in self.idp_sk.type_names:
10941098
for orbpair, slices in self.idp_sk.skonsite_maps.items():
10951099
fiorb, fjorb = orbpair.split("-")
1100+
if fiorb not in self.idp_sk.full_basis_to_basis[asym] or fjorb not in self.idp_sk.full_basis_to_basis[asym]:
1101+
continue
10961102
if fiorb != fjorb:
10971103
iorb = self.idp_sk.full_basis_to_basis[asym][fiorb]
10981104
jorb = self.idp_sk.full_basis_to_basis[asym][fjorb]
@@ -1112,6 +1118,8 @@ def to_json(self, version=2, basisref=None):
11121118
ian, jan = torch.tensor(atomic_num_dict[iasym]), torch.tensor(atomic_num_dict[jasym])
11131119
for orbpair, slices in self.idp_sk.orbpair_maps.items():
11141120
fiorb, fjorb = orbpair.split("-")
1121+
if fiorb not in self.idp_sk.full_basis_to_basis[iasym] or fjorb not in self.idp_sk.full_basis_to_basis[jasym]:
1122+
continue
11151123
iorb = self.idp_sk.full_basis_to_basis[iasym].get(fiorb)
11161124
jorb = self.idp_sk.full_basis_to_basis[jasym].get(fjorb)
11171125
if to_uniform:
@@ -1128,6 +1136,8 @@ def to_json(self, version=2, basisref=None):
11281136
for asym in self.idp_sk.type_names:
11291137
for orbpair, slices in self.idp_sk.skonsite_maps.items():
11301138
fiorb, fjorb = orbpair.split("-")
1139+
if fiorb not in self.idp_sk.full_basis_to_basis[asym] or fjorb not in self.idp_sk.full_basis_to_basis[asym]:
1140+
continue
11311141
iorb = self.idp_sk.full_basis_to_basis[asym][fiorb]
11321142
jorb = self.idp_sk.full_basis_to_basis[asym][fjorb]
11331143
if to_uniform:
@@ -1153,6 +1163,8 @@ def to_json(self, version=2, basisref=None):
11531163
soc_param = {}
11541164
for asym in self.idp_sk.type_names:
11551165
for fiorb, slices in self.idp_sk.sksoc_maps.items():
1166+
if fiorb not in self.idp_sk.full_basis_to_basis[asym]:
1167+
continue
11561168
iorb = self.idp_sk.full_basis_to_basis[asym][fiorb]
11571169
if to_uniform:
11581170
iorb = basisref[asym][iorb]

0 commit comments

Comments
 (0)