Skip to content

Commit b8149d6

Browse files
committed
Added template base nn and compiled files.
1 parent 0b941c4 commit b8149d6

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

dect/compiled.py

Whitespace-only changes.

dect/nn.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
"""
2+
Implementation of the ECT with learnable parameters.
3+
TODO: Needs implementation and refactoring.
4+
5+
"""
6+
7+
from dataclasses import dataclass
8+
import torch
9+
from torch import nn
10+
11+
12+
@dataclass
13+
class EctBatch(Batch):
14+
x: Tensor | None = None
15+
ect: Tensor | None = None
16+
17+
18+
class EctConfig(BaseModel):
19+
"""
20+
Config for initializing an ect layer.
21+
"""
22+
23+
num_thetas: int
24+
resolution: int
25+
r: float
26+
scale: float
27+
ect_type: Literal["points"]
28+
ambient_dimension: int
29+
normalized: bool
30+
seed: int
31+
32+
33+
# ---------------------------------------------------------------------------- #
34+
# To be depreciated #
35+
# ---------------------------------------------------------------------------- #
36+
37+
38+
@dataclass(frozen=True)
39+
class ECTConfig:
40+
"""
41+
Configuration of the ECT Layer.
42+
43+
Parameters
44+
----------
45+
bump_steps : int
46+
The number of steps to discretize the ECT into.
47+
radius : float
48+
The radius of the circle the directions lie on. Usually this is a bit
49+
larger than the objects we wish to compute the ECT for, which in most
50+
cases have radius 1. For now it defaults to 1 as well.
51+
ect_type : str
52+
The type of ECT we wish to compute. Can be "points" for point clouds,
53+
"edges" for graphs or "faces" for meshes.
54+
normalized: bool
55+
Whether or not to normalize the ECT. Only work with ect_type set to
56+
points and normalized the ECT to the interval [0,1].
57+
fixed: bool
58+
Option to keep the directions fixed or not. In case the directions are
59+
learnable, we can use backpropagation to optimize over a set of
60+
directions. See notebooks for examples.
61+
"""
62+
63+
bump_steps: int = 32
64+
radius: float = 1.1
65+
ect_type: str = "points"
66+
normalized: bool = False
67+
fixed: bool = True
68+
69+
70+
@dataclass()
71+
class Batch:
72+
"""Template of the required attributes for a data batch.
73+
74+
Parameters
75+
----------
76+
x : torch.FloatTensor
77+
The coordinates of the nodes in the simplical complex provided in the
78+
format [num_nodes,feature_size].
79+
batch: torch.LongTensor
80+
An index that indicates to which pointcloud a point belongs to, in
81+
principle automatically created by torch_geometric when initializing the
82+
batch.
83+
edge_index: torch.LongTensor
84+
The indices of the points that span an edge in the graph. Conforms to
85+
pytorch_geometric standards. Shape has to be of the form [2,num_edges].
86+
face:
87+
The indices of the points that span a face in the simplicial complex.
88+
Conforms to pytorch_geometric standards. Shape has to be of the form
89+
[3,num_faces] or [4, num_faces], depending on the type of complex
90+
(simplicial or cubical).
91+
node_weights: torch.FloatTensor
92+
Optional weights for the nodes in the complex. The shape has to be
93+
[num_nodes,].
94+
"""
95+
96+
x: torch.FloatTensor
97+
batch: torch.LongTensor
98+
edge_index: torch.LongTensor | None = None
99+
face: torch.LongTensor | None = None
100+
node_weights: torch.FloatTensor | None = None
101+
102+
103+
def compute_ecc(
104+
nh: torch.FloatTensor,
105+
index: torch.LongTensor,
106+
lin: torch.FloatTensor,
107+
scale: float = 100,
108+
) -> torch.FloatTensor:
109+
"""Computes the Euler Characteristic Curve.
110+
111+
Parameters
112+
----------
113+
nh : torch.FloatTensor
114+
The node heights, computed as the inner product of the node coordinates
115+
x and the direction vector v.
116+
index: torch.LongTensor
117+
The index that indicates to which pointcloud a node height belongs. For
118+
the node heights it is the same as the batch index, for the higher order
119+
simplices it will have to be recomputed.
120+
lin: torch.FloatTensor
121+
The discretization of the interval [-1,1] each node height falls in this
122+
range due to rescaling in normalizing the data.
123+
scale: torch.FloatTensor
124+
A single number that scales the sigmoid function by multiplying the
125+
sigmoid with the scale. With high (100>) values, the ect will resemble a
126+
discrete ECT and with lower values it will smooth the ECT.
127+
"""
128+
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
129+
130+
# Due to (I believe) a bug in segment_add_coo, we have to first transpose
131+
# and then apply segment add. In the original code movedim was applied after
132+
# and that yields an bug in the backwards pass. Will have to be reported to
133+
# pytorch eventually.
134+
ecc = ecc.movedim(0, 2).movedim(0, 1)
135+
return segment_add_coo(ecc, index)
136+
137+
138+
def compute_ect_points(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
139+
"""Computes the Euler Characteristic Transform of a batch of point clouds.
140+
141+
Parameters
142+
----------
143+
batch : Batch
144+
A batch of data containing the node coordinates and batch index.
145+
v: torch.FloatTensor
146+
The direction vector that contains the directions.
147+
lin: torch.FloatTensor
148+
The discretization of the interval [-1,1] each node height falls in this
149+
range due to rescaling in normalizing the data.
150+
"""
151+
nh = batch.x @ v
152+
return compute_ecc(nh, batch.batch, lin)
153+
154+
155+
def compute_ect_edges(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
156+
"""Computes the Euler Characteristic Transform of a batch of graphs.
157+
158+
Parameters
159+
----------
160+
batch : Batch
161+
A batch of data containing the node coordinates, the edges and batch
162+
index.
163+
v: torch.FloatTensor
164+
The direction vector that contains the directions.
165+
lin: torch.FloatTensor
166+
The discretization of the interval [-1,1] each node height falls in this
167+
range due to rescaling in normalizing the data.
168+
"""
169+
# Compute the node heigths
170+
nh = batch.x @ v
171+
172+
# Perform a lookup with the edge indices on node heights, this replaces the
173+
# node index with its node height and then compute the maximum over the
174+
# columns to compute the edge height.
175+
eh, _ = nh[batch.edge_index].max(dim=0)
176+
177+
# Compute which batch an edge belongs to. We take the first index of the
178+
# edge (or faces) and do a lookup on the batch index of that node in the
179+
# batch indices of the nodes.
180+
batch_index_nodes = batch.batch
181+
batch_index_edges = batch.batch[batch.edge_index[0]]
182+
183+
return compute_ecc(nh, batch_index_nodes, lin) - compute_ecc(
184+
eh, batch_index_edges, lin
185+
)
186+
187+
188+
def compute_ect_faces(batch: Batch, v: torch.FloatTensor, lin: torch.FloatTensor):
189+
"""Computes the Euler Characteristic Transform of a batch of meshes.
190+
191+
Parameters
192+
----------
193+
batch : Batch
194+
A batch of data containing the node coordinates, edges, faces and batch
195+
index.
196+
v: torch.FloatTensor
197+
The direction vector that contains the directions.
198+
lin: torch.FloatTensor
199+
The discretization of the interval [-1,1] each node height falls in this
200+
range due to rescaling in normalizing the data.
201+
"""
202+
# Compute the node heigths
203+
nh = batch.x @ v
204+
205+
# Perform a lookup with the edge indices on node heights, this replaces the
206+
# node index with its node height and then compute the maximum over the
207+
# columns to compute the edge height.
208+
eh, _ = nh[batch.edge_index].max(dim=0)
209+
210+
# Do the same thing for the faces.
211+
fh, _ = nh[batch.face].max(dim=0)
212+
213+
# Compute which batch an edge belongs to. We take the first index of the
214+
# edge (or faces) and do a lookup on the batch index of that node in the
215+
# batch indices of the nodes.
216+
batch_index_nodes = batch.batch
217+
batch_index_edges = batch.batch[batch.edge_index[0]]
218+
batch_index_faces = batch.batch[batch.face[0]]
219+
220+
return (
221+
compute_ecc(nh, batch_index_nodes, lin)
222+
- compute_ecc(eh, batch_index_edges, lin)
223+
+ compute_ecc(fh, batch_index_faces, lin)
224+
)
225+
226+
227+
def normalize(ect):
228+
"""Returns the normalized ect, scaled to lie in the interval 0,1"""
229+
return ect / torch.amax(ect, dim=(2, 3)).unsqueeze(2).unsqueeze(2)
230+
231+
232+
class ECTLayer(nn.Module):
233+
"""Machine learning layer for computing the ECT.
234+
235+
Parameters
236+
----------
237+
v: torch.FloatTensor
238+
The direction vector that contains the directions. The shape of the
239+
tensor v is either [ndims, num_thetas] or [n_channels, ndims,
240+
num_thetas].
241+
config: ECTConfig
242+
The configuration config of the ECT layer.
243+
244+
"""
245+
246+
def __init__(self, config: ECTConfig, v=None):
247+
super().__init__()
248+
self.config = config
249+
self.lin = nn.Parameter(
250+
torch.linspace(-config.radius, config.radius, config.bump_steps).view(
251+
-1, 1, 1, 1
252+
),
253+
requires_grad=False,
254+
)
255+
256+
# If provided with one set of directions.
257+
# For backwards compatibility.
258+
if v.ndim == 2:
259+
v.unsqueeze(0)
260+
261+
# The set of directions is added
262+
if config.fixed:
263+
self.v = nn.Parameter(v.movedim(-1, -2), requires_grad=False)
264+
else:
265+
# Movedim to make geotorch happy, me not happy.
266+
self.v = nn.Parameter(torch.zeros_like(v.movedim(-1, -2)))
267+
geotorch.constraints.sphere(self, "v", radius=config.radius)
268+
# Since geotorch randomizes the vector during initialization, we
269+
# assign the values after registering it with spherical constraints.
270+
# See Geotorch documentation for examples.
271+
self.v = v.movedim(-1, -2)
272+
273+
if config.ect_type == "points":
274+
self.compute_ect = compute_ect_points
275+
elif config.ect_type == "edges":
276+
self.compute_ect = compute_ect_edges
277+
elif config.ect_type == "faces":
278+
self.compute_ect = compute_ect_faces
279+
280+
def forward(self, batch: Batch):
281+
"""Forward method for the ECT Layer.
282+
283+
284+
Parameters
285+
----------
286+
batch : Batch
287+
A batch of data containing the node coordinates, edges, faces and
288+
batch index. It should follow the pytorch geometric conventions.
289+
290+
Returns
291+
----------
292+
ect: torch.FloatTensor
293+
Returns the ECT of each data object in the batch. If the layer is
294+
initialized with v of the shape [ndims,num_thetas], the returned ECT
295+
has shape [batch,num_thetas,bump_steps]. In case the layer is
296+
initialized with v of the form [n_channels, ndims, num_thetas] the
297+
returned ECT has the shape [batch,n_channels,num_thetas,bump_steps]
298+
"""
299+
# Movedim for geotorch.
300+
ect = self.compute_ect(batch, self.v.movedim(-1, -2), self.lin)
301+
if self.config.normalized:
302+
return normalize(ect)
303+
return ect.squeeze()

0 commit comments

Comments
 (0)