Skip to content

Commit f0ee099

Browse files
authored
Merge pull request #169 from FourierFlows/GridReturnsRanges
Grid returns ranges for x, y, z
2 parents 76dffa4 + 1c4ea09 commit f0ee099

File tree

7 files changed

+91
-87
lines changed

7 files changed

+91
-87
lines changed

src/diffusion.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ updatevars!(prob) = updatevars!(prob.vars, prob.grid, prob.sol)
115115
Set the solution as the transform of `c`.
116116
"""
117117
function set_c!(prob, c)
118-
mul!(prob.sol, prob.grid.rfftplan, c)
118+
T = typeof(prob.vars.c)
119+
prob.vars.c .= T(c) # this makes sure that c is converted to the ArrayType used in prob.vars.c (e.g., convert to CuArray if user gives c as Arrray)
120+
mul!(prob.sol, prob.grid.rfftplan, prob.vars.c)
119121
updatevars!(prob)
120122
end
121123

src/domains.jl

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,20 @@ Constructs a OneDGrid object with size `Lx`, resolution `nx`, and leftmost
3131
position `x0`. FFT plans are generated for `nthreads` CPUs using
3232
FFTW flag `effort`.
3333
"""
34-
struct OneDGrid{T<:AbstractFloat, Ta<:AbstractArray, Tfft, Trfft} <: AbstractGrid{T, Ta}
34+
struct OneDGrid{T<:AbstractFloat, Tk, Tx, Tfft, Trfft} <: AbstractGrid{T, Tk}
3535
nx :: Int
3636
nk :: Int
3737
nkr :: Int
3838

3939
dx :: T
4040
Lx :: T
4141

42-
x :: Ta
43-
k :: Ta
44-
kr :: Ta
45-
invksq :: Ta
46-
invkrsq :: Ta
42+
x :: Tx
43+
44+
k :: Tk
45+
kr :: Tk
46+
invksq :: Tk
47+
invkrsq :: Tk
4748

4849
fftplan :: Tfft
4950
rfftplan :: Trfft
@@ -58,15 +59,15 @@ function OneDGrid(nx, Lx; x0=-Lx/2, nthreads=Sys.CPU_THREADS, effort=FFTW.MEASUR
5859

5960
dx = Lx/nx
6061

61-
nk = nx
62+
nk = nx
6263
nkr = Int(nx/2+1)
6364

6465
# Physical grid
65-
x = ArrayType{T}(range(x0, step=dx, length=nx))
66+
x = range(T(x0), step=T(dx), length=nx)
6667

6768
# Wavenubmer grid
68-
k = ArrayType{T}(fftfreq(nx, 2π/Lx*nx))
69-
kr = ArrayType{T}(rfftfreq(nx, 2π/Lx*nx))
69+
k = ArrayType{T}( fftfreq(nx, 2π/Lx*nx))
70+
kr = ArrayType{T}(rfftfreq(nx, 2π/Lx*nx))
7071

7172
invksq = @. 1/k^2
7273
invkrsq = @. 1/kr^2
@@ -79,11 +80,12 @@ function OneDGrid(nx, Lx; x0=-Lx/2, nthreads=Sys.CPU_THREADS, effort=FFTW.MEASUR
7980

8081
kalias, kralias = getaliasedwavenumbers(nk, nkr, dealias)
8182

82-
Ta = typeof(x)
83+
Tx = typeof(x)
84+
Tk = typeof(k)
8385
Tfft = typeof(fftplan)
8486
Trfft = typeof(rfftplan)
8587

86-
return OneDGrid{T, Ta, Tfft, Trfft}(nx, nk, nkr, dx, Lx, x, k, kr,
88+
return OneDGrid{T, Tk, Tx, Tfft, Trfft}(nx, nk, nkr, dx, Lx, x, k, kr,
8789
invksq, invkrsq, fftplan, rfftplan, kalias, kralias)
8890
end
8991

@@ -92,7 +94,7 @@ end
9294
9395
Constructs a TwoDGrid object.
9496
"""
95-
struct TwoDGrid{T<:AbstractFloat, Ta<:AbstractArray, Tfft, Trfft} <: AbstractGrid{T, Ta}
97+
struct TwoDGrid{T<:AbstractFloat, Tk, Tx, Tfft, Trfft} <: AbstractGrid{T, Tk}
9698
nx :: Int
9799
ny :: Int
98100
nk :: Int
@@ -104,15 +106,16 @@ struct TwoDGrid{T<:AbstractFloat, Ta<:AbstractArray, Tfft, Trfft} <: AbstractGri
104106
Lx :: T
105107
Ly :: T
106108

107-
x :: Ta
108-
y :: Ta
109-
k :: Ta
110-
l :: Ta
111-
kr :: Ta
112-
Ksq :: Ta
113-
invKsq :: Ta
114-
Krsq :: Ta
115-
invKrsq :: Ta
109+
x :: Tx
110+
y :: Tx
111+
112+
k :: Tk
113+
l :: Tk
114+
kr :: Tk
115+
Ksq :: Tk
116+
invKsq :: Tk
117+
Krsq :: Tk
118+
invKrsq :: Tk
116119

117120
fftplan :: Tfft
118121
rfftplan :: Trfft
@@ -129,18 +132,18 @@ function TwoDGrid(nx, Lx, ny=nx, Ly=Lx; x0=-Lx/2, y0=-Ly/2, nthreads=Sys.CPU_THR
129132
dx = Lx/nx
130133
dy = Ly/ny
131134

132-
nk = nx
133-
nl = ny
135+
nk = nx
136+
nl = ny
134137
nkr = Int(nx/2+1)
135138

136139
# Physical grid
137-
x = ArrayType{T}(reshape(range(x0, step=dx, length=nx), (nx, 1)))
138-
y = ArrayType{T}(reshape(range(y0, step=dy, length=ny), (1, ny)))
140+
x = range(T(x0), step=T(dx), length=nx)
141+
y = range(T(y0), step=T(dy), length=ny)
139142

140143
# Wavenubmer grid
141-
k = ArrayType{T}(reshape(fftfreq(nx, 2π/Lx*nx), (nk, 1)))
142-
l = ArrayType{T}(reshape(fftfreq(ny, 2π/Ly*ny), (1, nl)))
143-
kr = ArrayType{T}(reshape(rfftfreq(nx, 2π/Lx*nx), (nkr, 1)))
144+
k = ArrayType{T}(reshape( fftfreq(nx, 2π/Lx*nx), (nk, 1)))
145+
l = ArrayType{T}(reshape( fftfreq(ny, 2π/Ly*ny), (1, nl)))
146+
kr = ArrayType{T}(reshape(rfftfreq(nx, 2π/Lx*nx), (nkr, 1)))
144147

145148
Ksq = @. k^2 + l^2
146149
invKsq = @. 1/Ksq
@@ -159,11 +162,12 @@ function TwoDGrid(nx, Lx, ny=nx, Ly=Lx; x0=-Lx/2, y0=-Ly/2, nthreads=Sys.CPU_THR
159162
kalias, kralias = getaliasedwavenumbers(nk, nkr, dealias)
160163
lalias, _ = getaliasedwavenumbers(nl, nl, dealias)
161164

162-
Ta = typeof(x)
165+
Tx = typeof(x)
166+
Tk = typeof(k)
163167
Tfft = typeof(fftplan)
164168
Trfft = typeof(rfftplan)
165169

166-
return TwoDGrid{T, Ta, Tfft, Trfft}(nx, ny, nk, nl, nkr, dx, dy, Lx, Ly, x, y, k, l, kr, Ksq, invKsq, Krsq, invKrsq,
170+
return TwoDGrid{T, Tk, Tx, Tfft, Trfft}(nx, ny, nk, nl, nkr, dx, dy, Lx, Ly, x, y, k, l, kr, Ksq, invKsq, Krsq, invKrsq,
167171
fftplan, rfftplan, kalias, kralias, lalias)
168172
end
169173

@@ -172,7 +176,7 @@ end
172176
173177
Constructs a ThreeDGrid object.
174178
"""
175-
struct ThreeDGrid{T<:AbstractFloat, Ta<:AbstractArray, Tfft, Trfft} <: AbstractGrid{T, Ta}
179+
struct ThreeDGrid{T<:AbstractFloat, Tk, Tx, Tfft, Trfft} <: AbstractGrid{T, Tk}
176180
nx :: Int
177181
ny :: Int
178182
nz :: Int
@@ -188,17 +192,18 @@ struct ThreeDGrid{T<:AbstractFloat, Ta<:AbstractArray, Tfft, Trfft} <: AbstractG
188192
Ly :: T
189193
Lz :: T
190194

191-
x :: Ta
192-
y :: Ta
193-
z :: Ta
194-
k :: Ta
195-
l :: Ta
196-
m :: Ta
197-
kr :: Ta
198-
Ksq :: Ta
199-
invKsq :: Ta
200-
Krsq :: Ta
201-
invKrsq :: Ta
195+
x :: Tx
196+
y :: Tx
197+
z :: Tx
198+
199+
k :: Tk
200+
l :: Tk
201+
m :: Tk
202+
kr :: Tk
203+
Ksq :: Tk
204+
invKsq :: Tk
205+
Krsq :: Tk
206+
invKrsq :: Tk
202207

203208
fftplan :: Tfft
204209
rfftplan :: Trfft
@@ -223,14 +228,14 @@ function ThreeDGrid(nx, Lx, ny=nx, Ly=Lx, nz=nx, Lz=Lx; x0=-Lx/2, y0=-Ly/2, z0=-
223228
nkr = Int(nx/2+1)
224229

225230
# Physical grid
226-
x = ArrayType{T}(reshape(range(x0, step=dx, length=nx), (nx, 1, 1)))
227-
y = ArrayType{T}(reshape(range(y0, step=dy, length=ny), (1, ny, 1)))
228-
z = ArrayType{T}(reshape(range(z0, step=dz, length=nz), (1, 1, nz)))
231+
x = range(T(x0), step=T(dx), length=nx)
232+
y = range(T(y0), step=T(dy), length=ny)
233+
z = range(T(z0), step=T(dz), length=nz)
229234

230235
# Wavenubmer grid
231-
k = ArrayType{T}(reshape(fftfreq(nx, 2π/Lx*nx), (nk, 1, 1)))
232-
l = ArrayType{T}(reshape(fftfreq(ny, 2π/Ly*ny), (1, nl, 1)))
233-
m = ArrayType{T}(reshape(fftfreq(nz, 2π/Lz*nz), (1, 1, nm)))
236+
k = ArrayType{T}(reshape( fftfreq(nx, 2π/Lx*nx), (nk, 1, 1)))
237+
l = ArrayType{T}(reshape( fftfreq(ny, 2π/Ly*ny), (1, nl, 1)))
238+
m = ArrayType{T}(reshape( fftfreq(nz, 2π/Lz*nz), (1, 1, nm)))
234239
kr = ArrayType{T}(reshape(rfftfreq(nx, 2π/Lx*nx), (nkr, 1, 1)))
235240

236241
Ksq = @. k^2 + l^2 + m^2
@@ -250,11 +255,12 @@ function ThreeDGrid(nx, Lx, ny=nx, Ly=Lx, nz=nx, Lz=Lx; x0=-Lx/2, y0=-Ly/2, z0=-
250255
kalias, kralias = getaliasedwavenumbers(nk, nkr, dealias)
251256
lalias, malias = getaliasedwavenumbers(nl, nm, dealias)
252257

253-
Ta = typeof(x)
258+
Tx = typeof(x)
259+
Tk = typeof(k)
254260
Tfft = typeof(fftplan)
255261
Trfft = typeof(rfftplan)
256262

257-
return ThreeDGrid{T, Ta, Tfft, Trfft}(nx, ny, nz, nk, nl, nm, nkr, dx, dy, dz, Lx, Ly, Lz,
263+
return ThreeDGrid{T, Tk, Tx, Tfft, Trfft}(nx, ny, nz, nk, nl, nm, nkr, dx, dy, dz, Lx, Ly, Lz,
258264
x, y, z, k, l, m, kr, Ksq, invKsq, Krsq, invKrsq, fftplan, rfftplan,
259265
kalias, kralias, lalias, malias)
260266
end
@@ -268,20 +274,20 @@ TwoDGrid(dev::CPU, args...; kwargs...) = TwoDGrid(args...; ArrayType=Array, kwar
268274
ThreeDGrid(dev::CPU, args...; kwargs...) = ThreeDGrid(args...; ArrayType=Array, kwargs...)
269275

270276
"""
271-
gridpoints(g)
277+
gridpoints(grid)
272278
273-
Returns the collocation points of the grid `g` in 2D or 3D arrays `X, Y (and Z)`.
279+
Returns the collocation points of the `grid` in 2D or 3D arrays `X, Y` (and `Z`).
274280
"""
275-
function gridpoints(g::TwoDGrid{T, A}) where {T, A}
276-
X = [ g.x[i] for i=1:g.nx, j=1:g.ny]
277-
Y = [ g.y[j] for i=1:g.nx, j=1:g.ny]
281+
function gridpoints(grid::TwoDGrid{T, A}) where {T, A}
282+
X = [ grid.x[i] for i=1:grid.nx, i₂=1:grid.ny ]
283+
Y = [ grid.y[i₂] for i=1:grid.nx, i₂=1:grid.ny ]
278284
return A(X), A(Y)
279285
end
280286

281-
function gridpoints(g::ThreeDGrid{T, A}) where {T, A}
282-
X = [ g.x[i] for i=1:g.nx, j=1:g.ny, k=1:g.nz]
283-
Y = [ g.y[j] for i=1:g.nx, j=1:g.ny, k=1:g.nz]
284-
Z = [ g.z[k] for i=1:g.nx, j=1:g.ny, k=1:g.nz]
287+
function gridpoints(grid::ThreeDGrid{T, A}) where {T, A}
288+
X = [ grid.x[i] for i=1:grid.nx, i₂=1:grid.ny, i₃=1:grid.nz ]
289+
Y = [ grid.y[i₂] for i=1:grid.nx, i₂=1:grid.ny, i₃=1:grid.nz ]
290+
Z = [ grid.z[i₃] for i=1:grid.nx, i₂=1:grid.ny, i₃=1:grid.nz ]
285291
return A(X), A(Y), A(Z)
286292
end
287293

src/utils.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -260,18 +260,16 @@ function radialspectrum(ah, g::TwoDGrid; n=nothing, m=nothing, refinement=2)
260260
# Interpolate ah onto fine grid in (ρ,θ).
261261
ahρθ = zeros(eltype(ahshift), (n, m))
262262

263-
for i=2:n, j=1:m # ignore zeroth mode
264-
kk = ρ[i]*cos(θ[j])
265-
ll = ρ[i]*sin(θ[j])
266-
ahρθ[i, j] = itp(kk, ll)
263+
for i₁=2:n, i₂=1:m # ignore zeroth mode; i₁≥2
264+
ahρθ[i₁, i₂] = itp(ρ[i₁]*cos(θ[i₂]), ρ[i₁]*sin(θ[i₂]))
267265
end
268266

269267
# ahρ = ρ ∫ ah(ρ,θ) dθ => Ah = ∫ ahρ dρ = ∫∫ ah dk dl
270268
= θ[2]-θ[1]
271269
if size(ah)[1] == g.nkr
272270
ahρ = 2ρ.*sum(ahρθ, dims=2)*# multiply by 2 for conjugate symmetry
273271
else
274-
ahρ = ρ.*sum(ahρθ, dims=2)*
272+
ahρ = ρ.*sum(ahρθ, dims=2)*
275273
end
276274

277275
ahρ[1] = ah[1, 1] # zeroth mode

test/createffttestfunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535

3636
function create_testfuncs(g::TwoDGrid{Tg,<:Array}) where Tg
3737
g.nx > 8 || error("nx must be > 8")
38-
x, y = g.x, g.y
38+
x, y = gridpoints(g)
3939
nx, ny = g.nx, g.ny
4040
m, n = 5, 2
4141
k₀, l₀ = g.k[2], g.l[2]
@@ -88,7 +88,7 @@ end
8888

8989
function create_testfuncs(g::ThreeDGrid{Tg,<:Array}) where Tg
9090
g.nx > 8 || error("nx must be > 8")
91-
x, y, z = g.x, g.y, g.z
91+
x, y, z = gridpoints(g)
9292
nx, ny, nz = g.nx, g.ny, g.nz
9393
mx, my, mz = 5, 2, 3
9494
k₀, l₀, m₀ = g.k[2], g.l[2], g.m[2]

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ for dev in devices
273273
@test test_diagnosticsteps(dev, freq=1)
274274
@test test_diagnosticsteps(dev, freq=2)
275275
@test test_diagnosticsteps(dev, nsteps=100, freq=9, ndata=20)
276-
# @test test_basicdiagnostics(dev)
276+
@test test_basicdiagnostics(dev)
277277
@test test_scalardiagnostics(dev, freq=1)
278278
@test test_scalardiagnostics(dev, freq=2)
279279
end

test/test_diagnostics.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ function test_basicdiagnostics(dev::Device=CPU(); nx=6, Lx=2π, kappa=1e-2)
3939

4040
prob = Problem(nx=nx, Lx=Lx, kappa=kappa, dt=dt, stepper="ETDRK4", dev=dev)
4141
g = prob.grid
42-
43-
c0(x) = sin(k1*x)
42+
43+
c0 = @. sin(k1*g.x)
4444
set_c!(prob, c0)
4545

4646
getsol(prob) = prob.sol

test/test_timesteppers.jl

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@ function constantdiffusionproblem(stepper; nx=128, Lx=2π, kappa=1e-2, nsteps=10
33
dt = 1e-9 * τ # dynamics are resolved
44

55
prob = Problem(nx=nx, Lx=Lx, kappa=kappa, dt=dt, stepper=stepper, dev=dev)
6-
g = prob.grid
7-
6+
87
# a gaussian initial condition c(x, t=0)
98
c0ampl, σ = 0.01, 0.2
10-
c0func(x) = @. c0ampl*exp(-x^2/(2σ^2))
11-
c0 = c0func.(g.x)
9+
c0func(x) = c0ampl * exp(-x^2/(2σ^2))
10+
c0 = c0func.(prob.grid.x)
1211

1312
# analytic solution for for 1D heat equation with constant κ
1413
tfinal = nsteps*dt
1514
σt = sqrt(2*kappa*tfinal + σ^2)
16-
cfinal = @. c0ampl*σ/σt * exp(-g.x^2/(2*σt^2))
15+
cfinal = @. c0ampl * σ/σt * exp(-prob.grid.x^2/(2*σt^2))
1716

1817
set_c!(prob, c0)
1918
tcomp = @elapsed stepforward!(prob, nsteps)
@@ -31,17 +30,16 @@ function varyingdiffusionproblem(stepper; nx=128, Lx=2π, kappa=1e-2, nsteps=100
3130
# instead of just the linear coefficients L*sol
3231

3332
prob = Problem(nx=nx, Lx=Lx, kappa=kappa, dt=dt, stepper=stepper, dev=dev)
34-
g = prob.grid
35-
33+
3634
# a gaussian initial condition c(x, t=0)
3735
c0ampl, σ = 0.01, 0.2
38-
c0func(x) = @. c0ampl*exp(-x^2/(2σ^2))
39-
c0 = c0func.(g.x)
36+
c0func(x) = c0ampl * exp(-x^2/(2σ^2))
37+
c0 = c0func.(prob.grid.x)
4038

4139
# analytic solution for for 1D heat equation with constant κ
4240
tfinal = nsteps*dt
4341
σt = sqrt(2*kappa[1]*tfinal + σ^2)
44-
cfinal = @. c0ampl*σ/σt * exp(-g.x^2/(2*σt^2))
42+
cfinal = @. c0ampl * σ/σt * exp(-prob.grid.x^2/(2*σt^2))
4543

4644
set_c!(prob, c0)
4745
tcomp = @elapsed stepforward!(prob, nsteps)
@@ -52,15 +50,15 @@ end
5250

5351

5452
function constantdiffusiontest(stepper, dev::Device=CPU(); kwargs...)
55-
prob, c0, c1, nsteps, tcomp = constantdiffusionproblem(stepper; kwargs...)
53+
prob, c0, cfinal, nsteps, tcomp = constantdiffusionproblem(stepper; kwargs...)
5654
normmsg = "$stepper: relative error ="
57-
@printf("% 40s %.2e (%.3f s)\n", normmsg, norm(c1-prob.vars.c)/norm(c1), tcomp)
58-
isapprox(c1, prob.vars.c, rtol=nsteps*rtol_timesteppers)
55+
@printf("% 40s %.2e (%.3f s)\n", normmsg, norm(cfinal-Array(prob.vars.c))/norm(cfinal), tcomp)
56+
isapprox(cfinal, Array(prob.vars.c), rtol=nsteps*rtol_timesteppers)
5957
end
6058

6159
function varyingdiffusiontest(stepper, dev::Device=CPU(); kwargs...)
62-
prob, c0, c1, nsteps, tcomp = varyingdiffusionproblem(stepper; kwargs...)
60+
prob, c0, cfinal, nsteps, tcomp = varyingdiffusionproblem(stepper; kwargs...)
6361
normmsg = "$stepper: relative error ="
64-
@printf("% 40s %.2e (%.3f s)\n", normmsg, norm(c1-prob.vars.c)/norm(c1), tcomp)
65-
isapprox(c1, prob.vars.c, rtol=nsteps*rtol_timesteppers)
62+
@printf("% 40s %.2e (%.3f s)\n", normmsg, norm(cfinal-Array(prob.vars.c))/norm(cfinal), tcomp)
63+
isapprox(cfinal, Array(prob.vars.c), rtol=nsteps*rtol_timesteppers)
6664
end

0 commit comments

Comments
 (0)