Skip to content

Commit 2de5159

Browse files
committed
Added gradient ect computation to fect.py
1 parent 6b54483 commit 2de5159

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,32 @@ plt.imshow(ect.detach().squeeze().numpy().T)
6666
plt.show()
6767
```
6868

69+
70+
## Fast Euler Characteristic Transform
71+
72+
73+
```
74+
from dect.directions import generate_2d_directions
75+
from dect.fect import compute_fast_ect
76+
77+
v = generate_2d_directions(num_thetas=2048).cuda()
78+
x_true = 0.5 * torch.rand(size=(10000, 2)).cuda()
79+
x = torch.nn.Parameter(0.2 * (torch.rand(size=(10000, 2), device="cuda") - 0.5))
80+
81+
optimizer = torch.optim.Adam([x], lr=0.01)
82+
83+
for epoch in range(200):
84+
optimizer.zero_grad()
85+
ect_true = fastect(x_true, v)
86+
ect_pred = fastect(x, v)
87+
loss = torch.nn.functional.mse_loss(ect_pred, ect_true)
88+
loss.backward()
89+
optimizer.step()
90+
```
91+
92+
93+
94+
6995
## License
7096

7197
Our code is released under a BSD-3-Clause license. This license essentially

dect/fect.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Both compute at decent speed, inference times were approx 1.5 seconds. The
88
current code is not optimized and relies heavily on torch.scatter perform the 2d
99
bincount. It is therefore expected that a custom triton/cuda kernel will
10-
significantly reduce the compute time.
10+
significantly reduce the compute time.
1111
1212
Type casting to the right types in torch is non-ideal, leading to a non-memory
1313
optimized algorithm with much higher memory needs than needed. Case in point,
@@ -17,7 +17,7 @@
1717
unnecessary OOM errors.
1818
1919
Important, no guards for overflows (it happens silently) and no
20-
differentiability.
20+
differentiability.
2121
"""
2222

2323
import torch
@@ -29,7 +29,7 @@
2929

3030
def bincount(idx, resolution):
3131
"""Calculates the histogram in resolution bins."""
32-
x = torch.zeros(size=(resolution, resolution), dtype=torch.int16)
32+
x = torch.zeros(size=(resolution, resolution), dtype=torch.float32, device="cuda")
3333
return x.scatter_(0, idx.to(torch.int64), 1, reduce="add")
3434

3535

@@ -46,3 +46,30 @@ def fast_ect_edges(x, ei, v):
4646
nh = ((torch.matmul(x, v) + 1) * (resolution // 2)).to(torch.int32)
4747
eh = nh[ei].max(axis=0)[0]
4848
return bincount(nh, resolution), bincount(eh, resolution)
49+
50+
51+
class FastECT(torch.autograd.Function):
52+
@staticmethod
53+
def forward(x, v):
54+
ect, idx = fast_ect_fn(x, v)
55+
return ect.cumsum(dim=0), ect, idx
56+
57+
@staticmethod
58+
def setup_context(ctx, inputs, outputs):
59+
(ect, ect_grad, idx) = outputs
60+
(_, v) = inputs
61+
ctx.save_for_backward(ect, ect_grad, idx, v)
62+
63+
@staticmethod
64+
def backward(ctx, grad_output, _, __):
65+
(ect, ect_grad, idx, v) = ctx.saved_tensors
66+
grad = ect_grad * grad_output / v.shape[1]
67+
# Do not know if this will be correct.
68+
ect_final_grad = torch.gather(grad, dim=0, index=idx.to(torch.int64))
69+
out = ect_final_grad @ v.T
70+
return -1 * out, None
71+
72+
73+
def compute_fast_ect(x, v):
74+
ect, _, _ = FastECT.apply(x, v)
75+
return ect

0 commit comments

Comments
 (0)