@@ -364,7 +364,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
364
364
elif batch .max () >= 1 :
365
365
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
366
366
slices = [0 ] + slices
367
- ndiag_batch = torch .stack ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
367
+ ndiag_batch = torch .IntTensor ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )], device = self . device )
368
368
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
369
369
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
370
370
ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
@@ -456,7 +456,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
456
456
elif batch .max () >= 1 :
457
457
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
458
458
slices = [0 ] + slices
459
- ndiag_batch = torch .stack ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
459
+ ndiag_batch = torch .IntTensor ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )], device = self . device )
460
460
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
461
461
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
462
462
ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
@@ -536,7 +536,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
536
536
elif batch .max () >= 1 :
537
537
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
538
538
slices = [0 ] + slices
539
- ndiag_batch = torch .stack ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
539
+ ndiag_batch = torch .IntTensor ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )], device = self . device )
540
540
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
541
541
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
542
542
ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
@@ -682,7 +682,7 @@ def forward(self, data: AtomicDataDict, ref_data: AtomicDataDict):
682
682
elif batch .max () >= 1 :
683
683
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
684
684
slices = [0 ] + slices
685
- ndiag_batch = torch .stack ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
685
+ ndiag_batch = torch .IntTensor ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )], device = self . device )
686
686
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
687
687
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
688
688
ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
@@ -811,7 +811,7 @@ def __call__(self, data: AtomicDataDict, ref_data: AtomicDataDict, running_avg:
811
811
elif batch .max () >= 1 :
812
812
slices = [data ["__slices__" ]["pos" ][i ]- data ["__slices__" ]["pos" ][i - 1 ] for i in range (1 ,len (data ["__slices__" ]["pos" ]))]
813
813
slices = [0 ] + slices
814
- ndiag_batch = torch .stack ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )])
814
+ ndiag_batch = torch .IntTensor ([i .shape [0 ] for i in self .idp .mask_to_ndiag [data [AtomicDataDict .ATOM_TYPE_KEY ].flatten ()].split (slices )], device = self . device )
815
815
ndiag_batch = torch .cumsum (ndiag_batch , dim = 0 )
816
816
mu = torch .stack ([mu [ndiag_batch [i ]:ndiag_batch [i + 1 ]].mean () for i in range (len (ndiag_batch )- 1 )])
817
817
ss = (ref_data [AtomicDataDict .NODE_OVERLAP_KEY ].sum (dim = - 1 ) * ref_data [AtomicDataDict .NODE_OVERLAP_KEY ]).sum (dim = - 1 )
0 commit comments