Skip to content

Commit a78c308

Browse files
committed
Updated doc strings in ect.py
1 parent 21a02d2 commit a78c308

File tree

1 file changed

+40
-23
lines changed

1 file changed

+40
-23
lines changed

dect/ect.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -186,24 +186,32 @@ def compute_ect_edges(
186186
187187
Parameters
188188
----------
189-
batch : Batch
190-
A batch of data containing the node coordinates, the edges and batch
191-
index.
192-
v: torch.FloatTensor
193-
The direction vector that contains the directions.
194-
lin: torch.FloatTensor
195-
The discretization of the interval [-1,1] each node height falls in this
196-
range due to rescaling in normalizing the data.
189+
x : Tensor
190+
The point cloud of shape [B,N,D] where B is the number of point clouds,
191+
N is the number of points and D is the ambient dimension.
192+
edge_index : Tensor
193+
The edge index tensor in torch geometric format, has to have shape
194+
[2,num_edges]. Be careful when using undirected graphs, since torch
195+
geometric views undirected graphs as 2 directed edges, leading to
196+
double counts.
197+
v : Tensor
198+
The tensor of directions of shape [D,N], where D is the ambient
199+
dimension and N is the number of directions.
200+
radius : float
201+
Radius of the interval to discretize the ECT into.
202+
resolution : int
203+
Number of steps to divide the lin interval into.
204+
scale : Tensor
205+
The multipicative factor for the sigmoid function.
197206
"""
198207

199-
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],
200208
if index is not None:
201209
batch_len = int(index.max() + 1)
202210
else:
203211
batch_len = 1
204212
index = torch.zeros(size=(len(x),), dtype=torch.int32)
205213

206-
# v is of shape [d, num_thetas]
214+
# v is of shape [ambient_dimension, num_thetas]
207215
num_thetas = v.shape[1]
208216

209217
out_shape = (resolution, batch_len, num_thetas)
@@ -219,10 +227,6 @@ def compute_ect_edges(
219227

220228
output.index_add_(1, index, ecc)
221229

222-
# For the calculation of the edges, loop over the simplex tensors.
223-
# Each index tensor is assumed to be of shape [d,num_simplices],
224-
# where d is the dimension of the simplex.
225-
226230
# Edges heights.
227231
sh, _ = nh[edge_index].max(dim=0)
228232

@@ -231,7 +235,7 @@ def compute_ect_edges(
231235
# batch indices of the nodes.
232236
index_simplex = index[edge_index[0]]
233237

234-
# Calculate the ECC of the simplices.
238+
# Calculate the ECC of the edges.
235239
secc = (-1) * torch.nn.functional.sigmoid(scale * torch.sub(lin, sh))
236240

237241
# Add the ECC of the simplices to the running total.
@@ -255,14 +259,27 @@ def compute_ect_mesh(
255259
256260
Parameters
257261
----------
258-
batch : Batch
259-
A batch of data containing the node coordinates, the edges and batch
260-
index.
261-
v: torch.FloatTensor
262-
The direction vector that contains the directions.
263-
lin: torch.FloatTensor
264-
The discretization of the interval [-1,1] each node height falls in this
265-
range due to rescaling in normalizing the data.
262+
x : Tensor
263+
The point cloud of shape [B,N,D] where B is the number of point clouds,
264+
N is the number of points and D is the ambient dimension.
265+
edge_index : Tensor
266+
The edge index tensor in torch geometric format, has to have shape
267+
[2,num_edges]. Be careful when using undirected graphs, since torch
268+
geometric views undirected graphs as 2 directed edges, leading to
269+
double counts.
270+
face_index : Tensor
271+
The face index tensor of shape [3,num_faces]. Each column is a face
272+
where a face is a triple of indices referencing to the rows of the
273+
x tensor with coordinates.
274+
v : Tensor
275+
The tensor of directions of shape [D,N], where D is the ambient
276+
dimension and N is the number of directions.
277+
radius : float
278+
Radius of the interval to discretize the ECT into.
279+
resolution : int
280+
Number of steps to divide the lin interval into.
281+
scale : Tensor
282+
The multipicative factor for the sigmoid function.
266283
"""
267284

268285
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],

0 commit comments

Comments
 (0)