Skip to content

Commit 39d4d4a

Browse files
committed
compute onsite shift mu from whole onsite block
1 parent 36b04d7 commit 39d4d4a

File tree

4 files changed

+58
-24
lines changed

4 files changed

+58
-24
lines changed

dptb/nn/base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
from dptb.utils.constants import dtype_dict
77
from dptb.utils.tools import _get_activation_fn
88
from e3nn.util.codegen import CodeGenMixin
9+
from e3nn.math import normalize2mom
910
import torch.nn.functional as F
11+
import math
12+
from torch import fx
1013
import torch.nn as nn
1114

1215
class AtomicLinear(torch.nn.Module):
@@ -574,4 +577,8 @@ def Proxy(n):
574577
self._codegen_register({"_forward": fx.GraphModule(base, graph)})
575578

576579
def forward(self, x):
577-
return self._forward(x)
580+
return self._forward(x)
581+
582+
@torch.jit.script
583+
def ShiftedSoftPlus(x: torch.Tensor):
584+
return torch.nn.functional.softplus(x) - math.log(2.0)

dptb/nn/embedding/lem.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
from e3nn.nn import Gate
99
from torch_scatter import scatter_mean
1010
from e3nn.o3 import Linear, SphericalHarmonics
11-
from e3nn.math import normalize2mom
12-
from e3nn.util.jit import compile_mode
1311
from dptb.data import AtomicDataDict
1412
from dptb.nn.embedding.emb import Embedding
1513
from ..radial_basis import BesselBasis

dptb/nn/embedding/slem.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from e3nn.nn import Gate
99
from torch_scatter import scatter_mean
1010
from e3nn.o3 import Linear, SphericalHarmonics
11-
from e3nn.math import normalize2mom
1211
from e3nn.util.jit import compile_mode
1312
from dptb.data import AtomicDataDict
1413
from dptb.nn.embedding.emb import Embedding

dptb/nnops/loss.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -352,18 +352,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
352352
if self.onsite_shift:
353353
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
354354
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
355-
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
356-
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
355+
# mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
356+
# ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
357+
mu = (data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
358+
mu = mu.sum(dim=-1) # [natoms]
357359
if batch.max() == 0: # when batchsize is zero
358-
mu = mu.mean().detach()
360+
mu = mu / (ref_data[AtomicDataDict.NODE_OVERLAP_KEY] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1).mean()
361+
mu = mu.mean().detach() # still taking mean across atom dimension to avoid overflow
359362
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
360363
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
361364
elif batch.max() >= 1:
362365
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
363366
slices = [0] + slices
364-
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
367+
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
365368
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
366369
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
370+
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
371+
ss = torch.stack([ss[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
372+
mu = mu / ss
367373
mu = mu.detach()
368374
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
369375
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
@@ -438,18 +444,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
438444
if self.onsite_shift:
439445
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
440446
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
441-
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
442-
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
447+
# mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
448+
# ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
449+
mu = (data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
450+
mu = mu.sum(dim=-1) # [natoms]
443451
if batch.max() == 0: # when batchsize is zero
444-
mu = mu.mean().detach()
452+
mu = mu / (ref_data[AtomicDataDict.NODE_OVERLAP_KEY] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1).mean()
453+
mu = mu.mean().detach() # still taking mean across atom dimension to avoid overflow
445454
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
446455
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
447456
elif batch.max() >= 1:
448457
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
449458
slices = [0] + slices
450-
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
459+
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
451460
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
452461
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
462+
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
463+
ss = torch.stack([ss[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
464+
mu = mu / ss
453465
mu = mu.detach()
454466
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
455467
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
@@ -512,18 +524,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
512524
if self.onsite_shift:
513525
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
514526
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
515-
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
516-
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
527+
# mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
528+
# ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
529+
mu = (data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
530+
mu = mu.sum(dim=-1) # [natoms]
517531
if batch.max() == 0: # when batchsize is zero
518-
mu = mu.mean().detach()
532+
mu = mu / (ref_data[AtomicDataDict.NODE_OVERLAP_KEY] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1).mean()
533+
mu = mu.mean().detach() # still taking mean across atom dimension to avoid overflow
519534
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
520535
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
521536
elif batch.max() >= 1:
522537
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
523538
slices = [0] + slices
524-
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
539+
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
525540
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
526541
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
542+
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
543+
ss = torch.stack([ss[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
544+
mu = mu / ss
527545
mu = mu.detach()
528546
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
529547
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
@@ -652,18 +670,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
652670
if self.onsite_shift:
653671
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
654672
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
655-
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
656-
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
673+
# mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
674+
# ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
675+
mu = (data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
676+
mu = mu.sum(dim=-1) # [natoms]
657677
if batch.max() == 0: # when batchsize is zero
658-
mu = mu.mean().detach()
678+
mu = mu / (ref_data[AtomicDataDict.NODE_OVERLAP_KEY] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1).mean()
679+
mu = mu.mean().detach() # still taking mean across atom dimension to avoid overflow
659680
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
660681
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
661682
elif batch.max() >= 1:
662683
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
663684
slices = [0] + slices
664-
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
685+
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
665686
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
666687
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
688+
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
689+
ss = torch.stack([ss[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
690+
mu = mu / ss
667691
mu = mu.detach()
668692
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
669693
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)
@@ -775,18 +799,24 @@ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict, running_avg:
775799
if self.onsite_shift:
776800
batch = data.get("batch", torch.zeros(data[AtomicDataDict.POSITIONS_KEY].shape[0]))
777801
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
778-
mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
779-
ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
802+
# mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
803+
# ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
804+
mu = (data[AtomicDataDict.NODE_FEATURES_KEY] - ref_data[AtomicDataDict.NODE_FEATURES_KEY]) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
805+
mu = mu.sum(dim=-1) # [natoms]
780806
if batch.max() == 0: # when batchsize is zero
781-
mu = mu.mean().detach()
807+
mu = mu / (ref_data[AtomicDataDict.NODE_OVERLAP_KEY] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1).mean()
808+
mu = mu.mean().detach() # still taking mean across atom dimension to avoid overflow
782809
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
783810
ref_data[AtomicDataDict.EDGE_FEATURES_KEY] = ref_data[AtomicDataDict.EDGE_FEATURES_KEY] + mu * ref_data[AtomicDataDict.EDGE_OVERLAP_KEY]
784811
elif batch.max() >= 1:
785812
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
786813
slices = [0] + slices
787-
ndiag_batch = torch.stack([i.sum() for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
814+
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
788815
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
789816
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
817+
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
818+
ss = torch.stack([ss[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
819+
mu = mu / ss
790820
mu = mu.detach()
791821
ref_data[AtomicDataDict.NODE_FEATURES_KEY] = ref_data[AtomicDataDict.NODE_FEATURES_KEY] + mu[batch, None] * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]
792822
edge_mu_index = torch.zeros(data[AtomicDataDict.EDGE_INDEX_KEY].shape[1], dtype=torch.long, device=self.device)

0 commit comments

Comments
 (0)