@@ -352,18 +352,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
352
352
if self .onsite_shift :
353
353
batch = data .get ("batch" , torch .zeros (data [AtomicDataDict .POSITIONS_KEY ].shape [0 ]))
354
354
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
355
- mu = data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]] - \
356
- ref_data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [ref_data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]]
355
+ # mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
356
+ # ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
357
+ mu = (data [AtomicDataDict .NODE_FEATURES_KEY ] - ref_data [AtomicDataDict .NODE_FEATURES_KEY ]) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
358
+ mu = mu .sum (dim = - 1 ) # [natoms]
357
359
if batch .max () == 0 : # when batchsize is zero
358
- mu = mu .mean ().detach ()
360
+ mu = mu / (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 ).mean ()
361
+ mu = mu .mean ().detach () # still taking mean across atom dimension to avoid overflow
359
362
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
360
363
ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] = ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .EDGE_OVERLAP_KEY ]
361
364
elif batch .max () >= 1 :
362
365
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
363
366
slices = [0 ] + slices
364
- ndiag_batch = torch .stack ([i .sum () for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
367
+ ndiag_batch = torch .stack ([i .shape [ 0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
365
368
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
366
369
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
370
+ ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
371
+ ss = torch .stack ([ss [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
372
+ mu = mu / ss
367
373
mu = mu .detach ()
368
374
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu [batch , None ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
369
375
edge_mu_index = torch .zeros (data [AtomicDataDict .EDGE_INDEX_KEY ].shape [1 ], dtype = torch .long , device = self .device )
@@ -438,18 +444,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
438
444
if self .onsite_shift :
439
445
batch = data .get ("batch" , torch .zeros (data [AtomicDataDict .POSITIONS_KEY ].shape [0 ]))
440
446
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
441
- mu = data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]] - \
442
- ref_data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [ref_data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]]
447
+ # mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
448
+ # ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
449
+ mu = (data [AtomicDataDict .NODE_FEATURES_KEY ] - ref_data [AtomicDataDict .NODE_FEATURES_KEY ]) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
450
+ mu = mu .sum (dim = - 1 ) # [natoms]
443
451
if batch .max () == 0 : # when batchsize is zero
444
- mu = mu .mean ().detach ()
452
+ mu = mu / (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 ).mean ()
453
+ mu = mu .mean ().detach () # still taking mean across atom dimension to avoid overflow
445
454
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
446
455
ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] = ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .EDGE_OVERLAP_KEY ]
447
456
elif batch .max () >= 1 :
448
457
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
449
458
slices = [0 ] + slices
450
- ndiag_batch = torch .stack ([i .sum () for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
459
+ ndiag_batch = torch .stack ([i .shape [ 0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
451
460
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
452
461
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
462
+ ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
463
+ ss = torch .stack ([ss [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
464
+ mu = mu / ss
453
465
mu = mu .detach ()
454
466
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu [batch , None ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
455
467
edge_mu_index = torch .zeros (data [AtomicDataDict .EDGE_INDEX_KEY ].shape [1 ], dtype = torch .long , device = self .device )
@@ -512,18 +524,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
512
524
if self .onsite_shift :
513
525
batch = data .get ("batch" , torch .zeros (data [AtomicDataDict .POSITIONS_KEY ].shape [0 ]))
514
526
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
515
- mu = data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]] - \
516
- ref_data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [ref_data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]]
527
+ # mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
528
+ # ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
529
+ mu = (data [AtomicDataDict .NODE_FEATURES_KEY ] - ref_data [AtomicDataDict .NODE_FEATURES_KEY ]) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
530
+ mu = mu .sum (dim = - 1 ) # [natoms]
517
531
if batch .max () == 0 : # when batchsize is zero
518
- mu = mu .mean ().detach ()
532
+ mu = mu / (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 ).mean ()
533
+ mu = mu .mean ().detach () # still taking mean across atom dimension to avoid overflow
519
534
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
520
535
ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] = ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .EDGE_OVERLAP_KEY ]
521
536
elif batch .max () >= 1 :
522
537
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
523
538
slices = [0 ] + slices
524
- ndiag_batch = torch .stack ([i .sum () for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
539
+ ndiag_batch = torch .stack ([i .shape [ 0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
525
540
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
526
541
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
542
+ ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
543
+ ss = torch .stack ([ss [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
544
+ mu = mu / ss
527
545
mu = mu .detach ()
528
546
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu [batch , None ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
529
547
edge_mu_index = torch .zeros (data [AtomicDataDict .EDGE_INDEX_KEY ].shape [1 ], dtype = torch .long , device = self .device )
@@ -652,18 +670,24 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
652
670
if self .onsite_shift :
653
671
batch = data .get ("batch" , torch .zeros (data [AtomicDataDict .POSITIONS_KEY ].shape [0 ]))
654
672
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
655
- mu = data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]] - \
656
- ref_data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [ref_data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]]
673
+ # mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
674
+ # ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
675
+ mu = (data [AtomicDataDict .NODE_FEATURES_KEY ] - ref_data [AtomicDataDict .NODE_FEATURES_KEY ]) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
676
+ mu = mu .sum (dim = - 1 ) # [natoms]
657
677
if batch .max () == 0 : # when batchsize is zero
658
- mu = mu .mean ().detach ()
678
+ mu = mu / (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 ).mean ()
679
+ mu = mu .mean ().detach () # still taking mean across atom dimension to avoid overflow
659
680
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
660
681
ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] = ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .EDGE_OVERLAP_KEY ]
661
682
elif batch .max () >= 1 :
662
683
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
663
684
slices = [0 ] + slices
664
- ndiag_batch = torch .stack ([i .sum () for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
685
+ ndiag_batch = torch .stack ([i .shape [ 0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
665
686
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
666
687
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
688
+ ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
689
+ ss = torch .stack ([ss [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
690
+ mu = mu / ss
667
691
mu = mu .detach ()
668
692
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu [batch , None ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
669
693
edge_mu_index = torch .zeros (data [AtomicDataDict .EDGE_INDEX_KEY ].shape [1 ], dtype = torch .long , device = self .device )
@@ -775,18 +799,24 @@ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict, running_avg:
775
799
if self .onsite_shift :
776
800
batch = data .get ("batch" , torch .zeros (data [AtomicDataDict .POSITIONS_KEY ].shape [0 ]))
777
801
# assert batch.max() == 0, "The onsite shift is only supported for batchsize=1."
778
- mu = data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]] - \
779
- ref_data [AtomicDataDict .NODE_FEATURES_KEY ][self .idp .mask_to_ndiag [ref_data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()]]
802
+ # mu = data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]] - \
803
+ # ref_data[AtomicDataDict.NODE_FEATURES_KEY][self.idp.mask_to_ndiag[ref_data[AtomicDataDict.ATOM_TYPE_KEY].flatten()]]
804
+ mu = (data [AtomicDataDict .NODE_FEATURES_KEY ] - ref_data [AtomicDataDict .NODE_FEATURES_KEY ]) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
805
+ mu = mu .sum (dim = - 1 ) # [natoms]
780
806
if batch .max () == 0 : # when batchsize is zero
781
- mu = mu .mean ().detach ()
807
+ mu = mu / (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 ).mean ()
808
+ mu = mu .mean ().detach () # still taking mean across atom dimension to avoid overflow
782
809
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
783
810
ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] = ref_data [AtomicDataDict .EDGE_FEATURES_KEY ] + mu * ref_data [AtomicDataDict .EDGE_OVERLAP_KEY ]
784
811
elif batch .max () >= 1 :
785
812
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
786
813
slices = [0 ] + slices
787
- ndiag_batch = torch .stack ([i .sum () for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
814
+ ndiag_batch = torch .stack ([i .shape [ 0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
788
815
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
789
816
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
817
+ ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
818
+ ss = torch .stack ([ss [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
819
+ mu = mu / ss
790
820
mu = mu .detach ()
791
821
ref_data [AtomicDataDict .NODE_FEATURES_KEY ] = ref_data [AtomicDataDict .NODE_FEATURES_KEY ] + mu [batch , None ] * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]
792
822
edge_mu_index = torch .zeros (data [AtomicDataDict .EDGE_INDEX_KEY ].shape [1 ], dtype = torch .long , device = self .device )
0 commit comments