Skip to content

Commit 13f3472

Browse files
committed
CUDA.@allowscalar is not just for tests
1 parent 32c0370 commit 13f3472

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,16 @@ end
248248
Returns an array, of the ArrayType of the device `grid` lives on, that contains the values of
249249
function `func` evaluated on the `grid`.
250250
"""
251-
on_grid(func, grid::OneDGrid) = @. func(grid.x)
251+
on_grid(func, grid::OneDGrid) = CUDA.@allowscalar @. func(grid.x)
252252

253253
function on_grid(func, grid::TwoDGrid)
254254
x, y = gridpoints(grid)
255-
return func.(x, y)
255+
return CUDA.@allowscalar func.(x, y)
256256
end
257257

258258
function on_grid(func, grid::ThreeDGrid)
259259
x, y, z = gridpoints(grid)
260-
return func.(x, y, z)
260+
return CUDA.@allowscalar func.(x, y, z)
261261
end
262262

263263
"""

test/test_utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,5 +157,5 @@ function test_ongrid(dev::Device)
157157
X₃, Y₃, Z₃ = gridpoints(g₃)
158158
f₃(x, y, z) = x^2 - y^3 + sin(z)
159159

160-
return CUDA.@allowscalar (FourierFlows.on_grid(f₁, g₁) == f₁.(X₁) && FourierFlows.on_grid(f₂, g₂) == f₂.(X₂, Y₂) && FourierFlows.on_grid(f₃, g₃) == f₃.(X₃, Y₃, Z₃))
160+
return (FourierFlows.on_grid(f₁, g₁) == f₁.(X₁) && FourierFlows.on_grid(f₂, g₂) == f₂.(X₂, Y₂) && FourierFlows.on_grid(f₃, g₃) == f₃.(X₃, Y₃, Z₃))
161161
end

0 commit comments

Comments
 (0)