Skip to content

Commit a9b310b

Browse files
authored
Merge pull request #11 from aidos-lab/dev-jpga
Convert scale to tensor in appropriate device
2 parents 6e6b6ac + db81372 commit a9b310b

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

dect/ect.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ def compute_ect(
7878
"""
7979

8080
# ecc.shape[0], index.max().item() + 1, ecc.shape[2],
81+
82+
# ensure that the scale is in the right device
83+
scale = torch.tensor([scale], device=x.device)
84+
8185
if index is not None:
8286
batch_len = int(index.max() + 1)
8387
else:
@@ -165,6 +169,10 @@ def compute_ect_point_cloud(
165169
point clouds (thus ECT's), N is the number of direction and R is the
166170
resolution.
167171
"""
172+
173+
# ensure that the scale is in the right device
174+
scale = torch.tensor([scale], device=x.device)
175+
168176
lin = torch.linspace(
169177
start=-radius, end=radius, steps=resolution, device=x.device
170178
).view(-1, 1, 1)
@@ -208,6 +216,9 @@ def compute_ect_points(
208216
The index tensor is assumed to start at 0.
209217
"""
210218

219+
# ensure that the scale is in the right device
220+
scale = torch.tensor([scale], device=x.device)
221+
211222
if index is not None:
212223
batch_len = int(index.max() + 1)
213224
else:
@@ -273,6 +284,9 @@ def compute_ect_edges(
273284
The index tensor is assumed to start at 0.
274285
"""
275286

287+
# ensure that the scale is in the right device
288+
scale = torch.tensor([scale], device=x.device)
289+
276290
if index is not None:
277291
batch_len = int(index.max() + 1)
278292
else:
@@ -357,6 +371,9 @@ def compute_ect_mesh(
357371
The index tensor is assumed to start at 0.
358372
"""
359373

374+
# ensure that the scale is in the right device
375+
scale = torch.tensor([scale], device=x.device)
376+
360377
if index is not None:
361378
batch_len = int(index.max() + 1)
362379
else:

0 commit comments

Comments
 (0)