@@ -186,24 +186,32 @@ def compute_ect_edges(
186
186
187
187
Parameters
188
188
----------
189
- batch : Batch
190
- A batch of data containing the node coordinates, the edges and batch
191
- index.
192
- v: torch.FloatTensor
193
- The direction vector that contains the directions.
194
- lin: torch.FloatTensor
195
- The discretization of the interval [-1,1] each node height falls in this
196
- range due to rescaling in normalizing the data.
189
+ x : Tensor
190
+ The point cloud of shape [B,N,D] where B is the number of point clouds,
191
+ N is the number of points and D is the ambient dimension.
192
+ edge_index : Tensor
193
+ The edge index tensor in torch geometric format, has to have shape
194
+ [2,num_edges]. Be careful when using undirected graphs, since torch
195
+ geometric views undirected graphs as 2 directed edges, leading to
196
+ double counts.
197
+ v : Tensor
198
+ The tensor of directions of shape [D,N], where D is the ambient
199
+ dimension and N is the number of directions.
200
+ radius : float
201
+ Radius of the interval to discretize the ECT into.
202
+ resolution : int
203
+ Number of steps to divide the lin interval into.
204
+ scale : Tensor
205
+ The multipicative factor for the sigmoid function.
197
206
"""
198
207
199
- # ecc.shape[0], index.max().item() + 1, ecc.shape[2],
200
208
if index is not None :
201
209
batch_len = int (index .max () + 1 )
202
210
else :
203
211
batch_len = 1
204
212
index = torch .zeros (size = (len (x ),), dtype = torch .int32 )
205
213
206
- # v is of shape [d , num_thetas]
214
+ # v is of shape [ambient_dimension , num_thetas]
207
215
num_thetas = v .shape [1 ]
208
216
209
217
out_shape = (resolution , batch_len , num_thetas )
@@ -219,10 +227,6 @@ def compute_ect_edges(
219
227
220
228
output .index_add_ (1 , index , ecc )
221
229
222
- # For the calculation of the edges, loop over the simplex tensors.
223
- # Each index tensor is assumed to be of shape [d,num_simplices],
224
- # where d is the dimension of the simplex.
225
-
226
230
# Edges heights.
227
231
sh , _ = nh [edge_index ].max (dim = 0 )
228
232
@@ -231,7 +235,7 @@ def compute_ect_edges(
231
235
# batch indices of the nodes.
232
236
index_simplex = index [edge_index [0 ]]
233
237
234
- # Calculate the ECC of the simplices .
238
+ # Calculate the ECC of the edges .
235
239
secc = (- 1 ) * torch .nn .functional .sigmoid (scale * torch .sub (lin , sh ))
236
240
237
241
# Add the ECC of the simplices to the running total.
@@ -255,14 +259,27 @@ def compute_ect_mesh(
255
259
256
260
Parameters
257
261
----------
258
- batch : Batch
259
- A batch of data containing the node coordinates, the edges and batch
260
- index.
261
- v: torch.FloatTensor
262
- The direction vector that contains the directions.
263
- lin: torch.FloatTensor
264
- The discretization of the interval [-1,1] each node height falls in this
265
- range due to rescaling in normalizing the data.
262
+ x : Tensor
263
+ The point cloud of shape [B,N,D] where B is the number of point clouds,
264
+ N is the number of points and D is the ambient dimension.
265
+ edge_index : Tensor
266
+ The edge index tensor in torch geometric format, has to have shape
267
+ [2,num_edges]. Be careful when using undirected graphs, since torch
268
+ geometric views undirected graphs as 2 directed edges, leading to
269
+ double counts.
270
+ face_index : Tensor
271
+ The face index tensor of shape [3,num_faces]. Each column is a face
272
+ where a face is a triple of indices referencing to the rows of the
273
+ x tensor with coordinates.
274
+ v : Tensor
275
+ The tensor of directions of shape [D,N], where D is the ambient
276
+ dimension and N is the number of directions.
277
+ radius : float
278
+ Radius of the interval to discretize the ECT into.
279
+ resolution : int
280
+ Number of steps to divide the lin interval into.
281
+ scale : Tensor
282
+ The multipicative factor for the sigmoid function.
266
283
"""
267
284
268
285
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],
0 commit comments