Skip to content

Commit 09f1a39

Browse files
committed
drop E3statistics when model does is valid and does not have a scale shift
1 parent 39d4d4a commit 09f1a39

File tree

3 files changed

+13
-5
lines changed

3 files changed

+13
-5
lines changed

dptb/data/dataset/_default_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,10 @@ def E3statistics(self, model: torch.nn.Module=None, decay=False):
401401

402402
if self.data[AtomicDataDict.EDGE_FEATURES_KEY].abs().sum() < 1e-7:
403403
return None
404+
405+
if model is not None:
406+
if not isinstance(model.node_prediction_h, torch.nn.Module):
407+
return None
404408

405409
typed_dataset = idp(self.data.clone().to_dict())
406410
e3h = E3Hamiltonian(basis=idp.basis, decompose=True)

dptb/data/dataset/lmdb_dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ def E3statistics(self, model: torch.nn.Module=None):
158158

159159
if not self.get_Hamiltonian and not self.get_DM:
160160
return None
161+
162+
if model is not None:
163+
if not isinstance(model.node_prediction_h, torch.nn.Module):
164+
return None
161165

162166
assert self.transform is not None
163167
idp = self.transform

dptb/nnops/loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
364364
elif batch.max() >= 1:
365365
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
366366
slices = [0] + slices
367-
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
367+
ndiag_batch = torch.IntTensor([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)], device=self.device)
368368
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
369369
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
370370
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
@@ -456,7 +456,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
456456
elif batch.max() >= 1:
457457
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
458458
slices = [0] + slices
459-
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
459+
ndiag_batch = torch.IntTensor([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)], device=self.device)
460460
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
461461
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
462462
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
@@ -536,7 +536,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
536536
elif batch.max() >= 1:
537537
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
538538
slices = [0] + slices
539-
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
539+
ndiag_batch = torch.IntTensor([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)], device=self.device)
540540
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
541541
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
542542
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
@@ -682,7 +682,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
682682
elif batch.max() >= 1:
683683
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
684684
slices = [0] + slices
685-
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
685+
ndiag_batch = torch.IntTensor([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)], device=self.device)
686686
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
687687
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
688688
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)
@@ -811,7 +811,7 @@ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict, running_avg:
811811
elif batch.max() >= 1:
812812
slices = [data["__slices__"]["pos"][i]-data["__slices__"]["pos"][i-1] for i in range(1,len(data["__slices__"]["pos"]))]
813813
slices = [0] + slices
814-
ndiag_batch = torch.stack([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)])
814+
ndiag_batch = torch.IntTensor([i.shape[0] for i in self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()].split(slices)], device=self.device)
815815
ndiag_batch = torch.cumsum(ndiag_batch, dim=0)
816816
mu = torch.stack([mu[ndiag_batch[i]:ndiag_batch[i+1]].mean() for i in range(len(ndiag_batch)-1)])
817817
ss = (ref_data[AtomicDataDict.NODE_OVERLAP_KEY].sum(dim=-1) * ref_data[AtomicDataDict.NODE_OVERLAP_KEY]).sum(dim=-1)

0 commit comments

Comments
 (0)