@@ -78,6 +78,10 @@ def compute_ect(
78
78
"""
79
79
80
80
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],
81
+
82
+ # ensure that the scale is in the right device
83
+ scale = torch .tensor ([scale ], device = x .device )
84
+
81
85
if index is not None :
82
86
batch_len = int (index .max () + 1 )
83
87
else :
@@ -165,6 +169,10 @@ def compute_ect_point_cloud(
165
169
point clouds (thus ECT's), N is the number of direction and R is the
166
170
resolution.
167
171
"""
172
+
173
+ # ensure that the scale is in the right device
174
+ scale = torch .tensor ([scale ], device = x .device )
175
+
168
176
lin = torch .linspace (
169
177
start = - radius , end = radius , steps = resolution , device = x .device
170
178
).view (- 1 , 1 , 1 )
@@ -208,6 +216,9 @@ def compute_ect_points(
208
216
The index tensor is assumed to start at 0.
209
217
"""
210
218
219
+ # ensure that the scale is in the right device
220
+ scale = torch .tensor ([scale ], device = x .device )
221
+
211
222
if index is not None :
212
223
batch_len = int (index .max () + 1 )
213
224
else :
@@ -273,6 +284,9 @@ def compute_ect_edges(
273
284
The index tensor is assumed to start at 0.
274
285
"""
275
286
287
+ # ensure that the scale is in the right device
288
+ scale = torch .tensor ([scale ], device = x .device )
289
+
276
290
if index is not None :
277
291
batch_len = int (index .max () + 1 )
278
292
else :
@@ -357,6 +371,9 @@ def compute_ect_mesh(
357
371
The index tensor is assumed to start at 0.
358
372
"""
359
373
374
+ # ensure that the scale is in the right device
375
+ scale = torch .tensor ([scale ], device = x .device )
376
+
360
377
if index is not None :
361
378
batch_len = int (index .max () + 1 )
362
379
else :
0 commit comments