Skip to content

Commit b1b5a51

Browse files
committed
Started working on new features, unit testing.
1 parent ba22b14 commit b1b5a51

File tree

9 files changed

+122
-11
lines changed

9 files changed

+122
-11
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Vscode setting.
2+
.vscode
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
2+
# Updates
3+
This the development branch of DECT. Currently implementing/changing the following
4+
features:
5+
- Simplifying the usage, decoupling the torch_geometric style interface.
6+
- Significant increase in unit testing.
7+
- Adding proper poetry for development.
8+
- Providing jit compiled versions of all functions -> 20x speedups.
9+
- Rendering of Point Clouds using Backprop.
10+
11+
112
# DECT - Differentiable Euler Characteristic Transform
213
[![arXiv](https://img.shields.io/badge/arXiv-2310.07630-b31b1b.svg)](https://arxiv.org/abs/2310.07630) ![GitHub contributors](https://img.shields.io/github/contributors/aidos-lab/dect) ![GitHub](https://img.shields.io/github/license/aidos-lab/dect) [![Maintainability](https://api.codeclimate.com/v1/badges/82f86d7e2f0aae342055/maintainability)](https://codeclimate.com/github/aidos-lab/dect/maintainability)
314

dect/ect.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""All ECT calculation functions."""
2+
3+
from typing import TypeAlias
4+
5+
import torch
6+
7+
Tensor: TypeAlias = torch.Tensor
8+
9+
10+
def compute_ect_point_cloud(
11+
x: Tensor,
12+
v: Tensor,
13+
radius: float,
14+
resolution: int,
15+
scale: float,
16+
) -> Tensor:
17+
"""
18+
Computes the ECT of a point cloud.
19+
20+
Parameters
21+
----------
22+
x : Tensor
23+
The point cloud of shape [B,N,D] where B is the number of point clouds,
24+
N is the number of points and D is the ambient dimension.
25+
v : Tensor
26+
The tensor of directions of shape [D,N], where D is the ambient
27+
dimension and N is the number of directions.
28+
radius : float
29+
Radius of the interval to discretize the ECT into.
30+
resolution : int
31+
Number of steps to divide the lin interval into.
32+
scale : Tensor
33+
The multipicative factor for the sigmoid function.
34+
35+
Returns
36+
-------
37+
Tensor
38+
The ECT of the point cloud of shape [B,N,R] where B is the number of
39+
point clouds (thus ECT's), N is the number of direction and R is the
40+
resolution.
41+
"""
42+
lin = torch.linspace(
43+
start=-radius, end=radius, steps=resolution, device=x.device
44+
).view(-1, 1, 1)
45+
nh = (x @ v).unsqueeze(1)
46+
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
47+
ect = torch.sum(ecc, dim=2)
48+
return ect

dect/ect_compiled.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""
2+
Contains all compiled versions, only use them if you are familiar with torch
3+
compile. Yields significant speed ups.
4+
"""
5+
6+
from typing import Callable, TypeAlias
7+
import torch
8+
9+
Tensor: TypeAlias = torch.Tensor
10+
11+
from dect.ect import compute_ect_point_cloud as compute_ect_point_cloud_vanilla
12+
13+
14+
@torch.compile
15+
def compute_ect_point_cloud(
16+
x: Tensor,
17+
v: Tensor,
18+
radius: float,
19+
resolution: int,
20+
scale: float,
21+
) -> Tensor:
22+
"""
23+
Compiled function wrapper around the compute_ect_pointcloud.
24+
25+
Parameters
26+
----------
27+
x : Tensor
28+
The point cloud of shape [B,N,D] where B is the number of point clouds,
29+
N is the number of points and D is the ambient dimension.
30+
v : Tensor
31+
The tensor of directions of shape [D,N], where D is the ambient
32+
dimension and N is the number of directions.
33+
radius : float
34+
Radius of the interval to discretize the ECT into.
35+
resolution : int
36+
Number of steps to divide the lin interval into.
37+
scale : Tensor
38+
The multipicative factor for the sigmoid function.
39+
40+
Returns
41+
-------
42+
Tensor
43+
The ECT of the point cloud of shape [B,N,R] where B is the number of
44+
point clouds (thus ECT's), N is the number of direction and R is the
45+
resolution.
46+
"""
47+
return compute_ect_point_cloud_vanilla(
48+
x=x, v=v, radius=radius, resolution=resolution, scale=scale
49+
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

scratch.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
import torch
22

3+
from dect.ect_compiled import compute_ect_point_cloud as ect_compiled
4+
from dect.ect import compute_ect_point_cloud
5+
from dect.directions import generate_uniform_directions
36

4-
def ect(x, *args, index):
5-
print("x", x)
6-
for i, arg in enumerate(args):
7-
print("arg", (-1) ** (i + 1) * arg)
7+
v = generate_uniform_directions(num_thetas=10, d=3, seed=10)
88

9+
x = torch.rand(size=(1, 10, 3))
910

10-
ect(
11-
torch.tensor(0),
12-
torch.tensor(4),
13-
torch.tensor(1),
14-
torch.tensor(2),
15-
index=torch.tensor(2),
16-
)
11+
12+
ect1 = compute_ect_point_cloud(x, v, radius=1, resolution=64, scale=1)
13+
print(ect1)
14+
15+
ect = ect_compiled(x, v, radius=1, resolution=64, scale=1)
16+
print(ect)

0 commit comments

Comments
 (0)