Skip to content

Commit e47029d

Browse files
committed
Merge branch 'main' of github.com:aidos-lab/dect
2 parents ce60b9c + 93fbd0a commit e47029d

File tree

6 files changed

+123
-59
lines changed

6 files changed

+123
-59
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# DECT - Differentiable Euler Characteristic Transform
2-
[![arXiv](https://img.shields.io/badge/arXiv-2310.07630-b31b1b.svg)](https://arxiv.org/abs/2310.07630) ![GitHub contributors](https://img.shields.io/github/contributors/aidos-lab/dect) ![GitHub](https://img.shields.io/github/license/aidos-lab/dect) [![Maintainability](https://api.codeclimate.com/v1/badges/82f86d7e2f0aae342055/maintainability)](https://codeclimate.com/github/aidos-lab/dect/maintainability)
2+
[![arXiv](https://img.shields.io/badge/arXiv-2310.07630-b31b1b.svg)](https://arxiv.org/abs/2310.07630) ![GitHub contributors](https://img.shields.io/github/contributors/aidos-lab/dect) ![GitHub](https://img.shields.io/github/license/aidos-lab/dect) [![Maintainability](https://qlty.sh/badges/b7958e48-8382-4fb0-ac0b-63b95b7a5426/maintainability.svg)](https://qlty.sh/gh/aidos-lab/projects/dect)
33

44
This is the official implementation for the **Differentiable Euler Characteristic
55
Transform**, a geometrical-topological method for shape classification. Our

dect/ect.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def compute_ect(
7878
"""
7979

8080
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],
81+
82+
# ensure that the scale is in the right device
83+
scale = torch.tensor([scale], device=x.device)
84+
8185
if index is not None:
8286
batch_len = int(index.max() + 1)
8387
else:
@@ -165,6 +169,10 @@ def compute_ect_point_cloud(
165169
point clouds (thus ECT's), N is the number of direction and R is the
166170
resolution.
167171
"""
172+
173+
# ensure that the scale is in the right device
174+
scale = torch.tensor([scale], device=x.device)
175+
168176
lin = torch.linspace(
169177
start=-radius, end=radius, steps=resolution, device=x.device
170178
).view(-1, 1, 1)
@@ -208,6 +216,9 @@ def compute_ect_points(
208216
The index tensor is assumed to start at 0.
209217
"""
210218

219+
# ensure that the scale is in the right device
220+
scale = torch.tensor([scale], device=x.device)
221+
211222
if index is not None:
212223
batch_len = int(index.max() + 1)
213224
else:
@@ -273,6 +284,9 @@ def compute_ect_edges(
273284
The index tensor is assumed to start at 0.
274285
"""
275286

287+
# ensure that the scale is in the right device
288+
scale = torch.tensor([scale], device=x.device)
289+
276290
if index is not None:
277291
batch_len = int(index.max() + 1)
278292
else:
@@ -357,6 +371,9 @@ def compute_ect_mesh(
357371
The index tensor is assumed to start at 0.
358372
"""
359373

374+
# ensure that the scale is in the right device
375+
scale = torch.tensor([scale], device=x.device)
376+
360377
if index is not None:
361378
batch_len = int(index.max() + 1)
362379
else:

tests/test_ect_edges.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,40 @@
11
"""
22
Tests the core functions for computing the
3-
ECT over a point cloud.
3+
ECT over a edges.
44
"""
55

6+
import pytest
7+
import torch
8+
import warnings
69

7-
def test_true():
8-
pass
10+
from dect.directions import generate_uniform_directions
11+
from dect.ect import compute_ect_edges
12+
13+
14+
@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"])
15+
def test_compute_ect_edges_noindex(device):
16+
"""
17+
Test the `compute_ect` function for point clouds.
18+
"""
19+
20+
# Check if device is available, else skip the test.
21+
if not getattr(torch, device).is_available():
22+
warnings.warn(f"Device {device} not available, skipping the tests.")
23+
return
24+
25+
seed = 2024
26+
ambient_dimension = 5
27+
num_points = 10
28+
v = generate_uniform_directions(
29+
num_thetas=17, seed=seed, device=device, d=ambient_dimension
30+
).to(device)
31+
x = torch.rand(size=(num_points, ambient_dimension), device=device)
32+
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6]], device=device)
33+
ect = compute_ect_edges(
34+
x, edge_index=edge_index, v=v, radius=1, resolution=13, scale=10
35+
)
36+
37+
assert ect.device.type == device
38+
39+
# TODO: Implement proper tests that the ect has been computed correctly.
40+
# Most likely a parametrized set of fixtures are the way to go here.

tests/test_ect_faces.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,47 @@
11
"""
22
Tests the core functions for computing the
3-
ECT over a point cloud.
3+
ECT over a edges.
44
"""
55

6+
import pytest
7+
import torch
68

7-
def test_true():
8-
pass
9+
from dect.directions import generate_uniform_directions
10+
from dect.ect import compute_ect_mesh
11+
12+
13+
@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"])
14+
def test_compute_ect_mesh_noindex(device):
15+
"""
16+
Test the `compute_ect` function for point clouds.
17+
"""
18+
19+
# Check if device is available, else skip the test.
20+
if not getattr(torch, device).is_available():
21+
return
22+
23+
seed = 2024
24+
ambient_dimension = 5
25+
num_points = 10
26+
v = generate_uniform_directions(
27+
num_thetas=17, seed=seed, device=device, d=ambient_dimension
28+
).to(device)
29+
x = torch.rand(size=(num_points, ambient_dimension), device=device)
30+
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6]], device=device)
31+
face_index = torch.tensor(
32+
[[0, 1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6], [2, 3, 4, 5, 6, 7]], device=device
33+
)
34+
ect = compute_ect_mesh(
35+
x,
36+
edge_index=edge_index,
37+
face_index=face_index,
38+
v=v,
39+
radius=1,
40+
resolution=13,
41+
scale=10,
42+
)
43+
44+
assert ect.device.type == device
45+
46+
# TODO: Implement proper tests that the ect has been computed correctly.
47+
# Most likely a parametrized set of fixtures are the way to go here.

tests/test_ect_point_clouds.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -6,57 +6,48 @@
66
third is the ambient dimension.
77
Note that in this format each point in the point
88
cloud has to have the same cardinality.
9+
10+
11+
NOTE: This will run on the devices available on
12+
the machine when executed. Currently it is only
13+
tested on GPU and CPU and MPS is skipped.
14+
If you run the tests on MPS and tests fail, please
15+
contact me.
916
"""
1017

18+
import pytest
1119
import torch
1220

13-
from dect.directions import generate_2d_directions, generate_uniform_directions
21+
from dect.directions import generate_uniform_directions
1422
from dect.ect import compute_ect_point_cloud
1523

1624

17-
def test_compute_ect_point_cloud_case_cpu():
25+
@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"])
26+
def test_compute_ect_point_cloud(device):
1827
"""
1928
Tests the ECT computation of a point cloud.
2029
Differs in that it expects an input shape of
2130
size [B,N,D], where B is the batch size,
2231
N is the number of points and D is the ambient
2332
dimension.
2433
"""
25-
ambient_dimension = 4
26-
num_pts = 100
27-
batch_size = 8
28-
num_thetas = 100
29-
seed = 0
30-
x = torch.rand(size=(batch_size, num_pts, ambient_dimension))
31-
v = generate_uniform_directions(
32-
num_thetas=num_thetas, d=ambient_dimension, device="cpu", seed=seed
33-
)
34-
35-
ect = compute_ect_point_cloud(x, v, radius=10, resolution=30, scale=500)
36-
37-
assert ect[0].max() == num_pts
38-
assert ect[0].min() == 0
3934

35+
if not getattr(torch, device).is_available():
36+
return
4037

41-
def test_compute_ect_point_cloud_case_cuda():
42-
"""
43-
Tests the ECT computation of a point cloud.
44-
Differs in that it expects an input shape of
45-
size [B,N,D], where B is the batch size,
46-
N is the number of points and D is the ambient
47-
dimension.
48-
"""
4938
ambient_dimension = 4
5039
num_pts = 100
5140
batch_size = 8
5241
num_thetas = 100
5342
seed = 0
54-
x = torch.rand(size=(batch_size, num_pts, ambient_dimension), device="cuda")
43+
x = torch.rand(size=(batch_size, num_pts, ambient_dimension), device=device)
5544
v = generate_uniform_directions(
56-
num_thetas=num_thetas, d=ambient_dimension, device="cuda", seed=seed
45+
num_thetas=num_thetas, d=ambient_dimension, device=device, seed=seed
5746
)
5847

5948
ect = compute_ect_point_cloud(x, v, radius=10, resolution=30, scale=500)
6049

50+
assert ect.device.type == device
51+
6152
assert ect[0].max() == num_pts
6253
assert ect[0].min() == 0

tests/test_ect_points.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,53 +3,38 @@
33
ECT over a point cloud.
44
"""
55

6+
import pytest
67
import torch
78

8-
from dect.directions import generate_2d_directions, generate_uniform_directions
9+
from dect.directions import generate_uniform_directions
910
from dect.ect import compute_ect_points
1011

1112
"""
1213
Test the ECT for
1314
"""
1415

1516

16-
def test_compute_ect_case_points():
17+
@pytest.mark.parametrize("device", ["cpu", "cuda", "mps"])
18+
def test_compute_ect_case_points_noindex_cpu(device):
1719
"""
1820
Test the `compute_ect` function for point clouds.
1921
"""
20-
device = "cpu"
22+
23+
# Check if device is available, else skip the test.
24+
if not getattr(torch, device).is_available():
25+
return
26+
2127
# 2D Case
2228
seed = 2024
2329
ambient_dimension = 5
2430
num_points = 10
2531
v = generate_uniform_directions(
2632
num_thetas=17, seed=seed, device=device, d=ambient_dimension
2733
).to(device)
28-
x = torch.rand(size=(num_points, ambient_dimension))
29-
ect = compute_ect_points(x, v=v, radius=1, resolution=13, scale=10)
30-
31-
assert ect.get_device() == -1
32-
33-
# Check that min and max are 0 and num_pts
34-
torch.testing.assert_close(ect.max(), torch.tensor(num_points, dtype=torch.float32))
35-
torch.testing.assert_close(ect.min(), torch.tensor(0.0, dtype=torch.float32))
36-
37-
38-
def test_compute_ect_case_points_cuda():
39-
"""
40-
Test the `compute_ect` function for point clouds on the gpu.
41-
"""
42-
if not torch.cuda.is_available():
43-
return
44-
45-
device = "cuda:0"
46-
# 2D Case
47-
num_points = 10
48-
v = generate_2d_directions(num_thetas=17).to(device)
49-
x = torch.rand(size=(num_points, 2), device=device)
34+
x = torch.rand(size=(num_points, ambient_dimension), device=device)
5035
ect = compute_ect_points(x, v=v, radius=1, resolution=13, scale=10)
5136

52-
assert ect.get_device() == 0 # 0 indicates cuda.
37+
assert ect.device.type == device
5338

5439
# Check that min and max are 0 and num_pts
5540
torch.testing.assert_close(

0 commit comments

Comments
 (0)