7
7
Both compute at decent speed, inference times were approx 1.5 seconds. The
8
8
current code is not optimized and relies heavily on torch.scatter perform the 2d
9
9
bincount. It is therefore expected that a custom triton/cuda kernel will
10
- significantly reduce the compute time.
10
+ significantly reduce the compute time.
11
11
12
12
Type casting to the right types in torch is non-ideal, leading to a non-memory
13
13
optimized algorithm with much higher memory needs than needed. Case in point,
17
17
unnecessary OOM errors.
18
18
19
19
Important, no guards for overflows (it happens silently) and no
20
- differentiability.
20
+ differentiability.
21
21
"""
22
22
23
23
import torch
29
29
30
30
def bincount (idx , resolution ):
31
31
"""Calculates the histogram in resolution bins."""
32
- x = torch .zeros (size = (resolution , resolution ), dtype = torch .int16 )
32
+ x = torch .zeros (size = (resolution , resolution ), dtype = torch .float32 , device = "cuda" )
33
33
return x .scatter_ (0 , idx .to (torch .int64 ), 1 , reduce = "add" )
34
34
35
35
@@ -46,3 +46,30 @@ def fast_ect_edges(x, ei, v):
46
46
nh = ((torch .matmul (x , v ) + 1 ) * (resolution // 2 )).to (torch .int32 )
47
47
eh = nh [ei ].max (axis = 0 )[0 ]
48
48
return bincount (nh , resolution ), bincount (eh , resolution )
49
+
50
+
51
+ class FastECT (torch .autograd .Function ):
52
+ @staticmethod
53
+ def forward (x , v ):
54
+ ect , idx = fast_ect_fn (x , v )
55
+ return ect .cumsum (dim = 0 ), ect , idx
56
+
57
+ @staticmethod
58
+ def setup_context (ctx , inputs , outputs ):
59
+ (ect , ect_grad , idx ) = outputs
60
+ (_ , v ) = inputs
61
+ ctx .save_for_backward (ect , ect_grad , idx , v )
62
+
63
+ @staticmethod
64
+ def backward (ctx , grad_output , _ , __ ):
65
+ (ect , ect_grad , idx , v ) = ctx .saved_tensors
66
+ grad = ect_grad * grad_output / v .shape [1 ]
67
+ # Do not know if this will be correct.
68
+ ect_final_grad = torch .gather (grad , dim = 0 , index = idx .to (torch .int64 ))
69
+ out = ect_final_grad @ v .T
70
+ return - 1 * out , None
71
+
72
+
73
+ def compute_fast_ect (x , v ):
74
+ ect , _ , _ = FastECT .apply (x , v )
75
+ return ect
0 commit comments