@@ -91,7 +91,7 @@ def compute_ect(
91
91
92
92
# Node heights have shape [num_points, num_directions]
93
93
nh = x @ v
94
- lin = torch .linspace (- radius , radius , resolution ).view (- 1 , 1 , 1 )
94
+ lin = torch .linspace (- radius , radius , resolution , device = x . device ).view (- 1 , 1 , 1 )
95
95
ecc = ect_fn (scale * torch .sub (lin , nh ))
96
96
97
97
output = torch .zeros (
@@ -217,7 +217,7 @@ def compute_ect_points(
217
217
218
218
# Node heights have shape [num_points, num_directions]
219
219
nh = x @ v
220
- lin = torch .linspace (- radius , radius , resolution ).view (- 1 , 1 , 1 )
220
+ lin = torch .linspace (- radius , radius , resolution , device = x . device ).view (- 1 , 1 , 1 )
221
221
ecc = torch .nn .functional .sigmoid (scale * torch .sub (lin , nh ))
222
222
output = torch .zeros (
223
223
size = out_shape ,
@@ -353,7 +353,7 @@ def compute_ect_mesh(
353
353
354
354
# Node heights have shape [num_points, num_directions]
355
355
nh = x @ v
356
- lin = torch .linspace (- radius , radius , resolution ).view (- 1 , 1 , 1 )
356
+ lin = torch .linspace (- radius , radius , resolution , device = x . device ).view (- 1 , 1 , 1 )
357
357
ecc = torch .nn .functional .sigmoid (scale * torch .sub (lin , nh ))
358
358
output = torch .zeros (
359
359
size = out_shape ,
0 commit comments