Skip to content

Commit 1450a01

Browse files
committed
Minor updates
1 parent 64cdbfb commit 1450a01

File tree

5 files changed

+101
-6
lines changed

5 files changed

+101
-6
lines changed

.github/workflows/docs.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@ name: website
33
# build the documentation whenever there are new commits on main
44
on:
55
push:
6-
# Alternative: only build for tags.
7-
tags:
8-
- '*'
6+
branches:
7+
- main
98

109
# security: restrict permissions for CI jobs.
1110
permissions:
@@ -57,4 +56,4 @@ jobs:
5756
url: ${{ steps.deployment.outputs.page_url }}
5857
steps:
5958
- id: deployment
60-
uses: actions/deploy-pages@v4
59+
uses: actions/deploy-pages@v4

Makefile

Lines changed: 0 additions & 2 deletions
This file was deleted.

dect/ect.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Callable, TypeAlias
1919

2020
import torch
21+
2122
from dect.ect_fn import indicator
2223

2324
Tensor: TypeAlias = torch.Tensor

dect/wect.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from typing import Literal
2+
3+
import geotorch
4+
import torch
5+
from torch import nn
6+
7+
def compute_wecc(
8+
nh: torch.FloatTensor,
9+
index: torch.LongTensor,
10+
lin: torch.FloatTensor,
11+
weight: torch.FloatTensor,
12+
scale: float = 500,
13+
):
14+
"""Computes the Weighted Euler Characteristic Curve.
15+
16+
Parameters
17+
----------
18+
nh : torch.FloatTensor
19+
The node heights, computed as the inner product of the node coordinates
20+
x and the direction vector v.
21+
index: torch.LongTensor
22+
The index that indicates to which pointcloud a node height belongs. For
23+
the node heights it is the same as the batch index, for the higher order
24+
simplices it will have to be recomputed.
25+
lin: torch.FloatTensor
26+
The discretization of the interval [-1,1] each node height falls in this
27+
range due to rescaling in normalizing the data.
28+
weight: torch.FloatTensor
29+
The weight of the node, edge or face. It is the maximum of the node
30+
weights for the edges and faces.
31+
scale: torch.FloatTensor
32+
A single number that scales the sigmoid function by multiplying the
33+
sigmoid with the scale. With high (100>) values, the ect will resemble a
34+
discrete ECT and with lower values it will smooth the ECT.
35+
"""
36+
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh)) * weight.view(
37+
1, -1, 1
38+
)
39+
ecc = ecc.movedim(0, 2).movedim(0, 1)
40+
return segment_add_coo(ecc, index)
41+
42+
43+
def compute_wect(
44+
batch: Batch,
45+
v: torch.FloatTensor,
46+
lin: torch.FloatTensor,
47+
wect_type: Literal["points"] | Literal["edges"] | Literal["faces"],
48+
):
49+
"""
50+
Computes the Weighted Euler Characteristic Transform of a batch of point
51+
clouds.
52+
53+
Parameters
54+
----------
55+
batch : Batch
56+
A batch of data containing the node coordinates, batch index,
57+
edge_index, face, and node weights.
58+
v: torch.FloatTensor
59+
The direction vector that contains the directions.
60+
lin: torch.FloatTensor
61+
The discretization of the interval [-1,1] each node height falls in this
62+
range due to rescaling in normalizing the data.
63+
wect_type: str
64+
The type of WECT to compute. Can be "points", "edges", or "faces".
65+
"""
66+
67+
nh = batch.x @ v
68+
if wect_type in ["edges", "faces"]:
69+
edge_weights, _ = batch.node_weights[batch.edge_index].max(axis=0)
70+
eh, _ = nh[batch.edge_index].min(dim=0)
71+
if wect_type == "faces":
72+
face_weights, _ = batch.node_weights[batch.face].max(axis=0)
73+
fh, _ = nh[batch.face].min(dim=0)
74+
75+
if wect_type == "points":
76+
return compute_wecc(nh, batch.batch, lin, batch.node_weights)
77+
if wect_type == "edges":
78+
# noinspection PyUnboundLocalVariable
79+
return compute_wecc(
80+
nh, batch.batch, lin, batch.node_weights
81+
) - compute_wecc(
82+
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
83+
)
84+
if wect_type == "faces":
85+
# noinspection PyUnboundLocalVariable
86+
return (
87+
compute_wecc(nh, batch.batch, lin, batch.node_weights)
88+
- compute_wecc(
89+
eh, batch.batch[batch.edge_index[0]], lin, edge_weights
90+
)
91+
+ compute_wecc(fh, batch.batch[batch.face[0]], lin, face_weights)
92+
)
93+
raise ValueError(f"Invalid wect_type: {wect_type}")
94+
95+
96+

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ ipykernel = "^6.29.5"
3030
pytest-cov = "^6.0.0"
3131
pytest = "^8.3.4"
3232

33+
3334
[build-system]
3435
requires = ["poetry-core"]
3536
build-backend = "poetry.core.masonry.api"

0 commit comments

Comments
 (0)