Skip to content

Commit 9b16b65

Browse files
committed
Instantiate linspace on correct device
1 parent 71fb4a5 commit 9b16b65

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

dect/ect.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def compute_ect(
9191

9292
# Node heights have shape [num_points, num_directions]
9393
nh = x @ v
94-
lin = torch.linspace(-radius, radius, resolution).view(-1, 1, 1)
94+
lin = torch.linspace(-radius, radius, resolution,device=x.device).view(-1, 1, 1)
9595
ecc = ect_fn(scale * torch.sub(lin, nh))
9696

9797
output = torch.zeros(
@@ -217,7 +217,7 @@ def compute_ect_points(
217217

218218
# Node heights have shape [num_points, num_directions]
219219
nh = x @ v
220-
lin = torch.linspace(-radius, radius, resolution).view(-1, 1, 1)
220+
lin = torch.linspace(-radius, radius, resolution, device=x.device).view(-1, 1, 1)
221221
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
222222
output = torch.zeros(
223223
size=out_shape,
@@ -353,7 +353,7 @@ def compute_ect_mesh(
353353

354354
# Node heights have shape [num_points, num_directions]
355355
nh = x @ v
356-
lin = torch.linspace(-radius, radius, resolution).view(-1, 1, 1)
356+
lin = torch.linspace(-radius, radius, resolution,device=x.device).view(-1, 1, 1)
357357
ecc = torch.nn.functional.sigmoid(scale * torch.sub(lin, nh))
358358
output = torch.zeros(
359359
size=out_shape,

0 commit comments

Comments
 (0)