Skip to content

Commit 3a55b05

Browse files
full static array support by isinplace instead of array checks
1 parent 8098423 commit 3a55b05

File tree

4 files changed

+30
-3
lines changed

4 files changed

+30
-3
lines changed

src/integrators/integrator_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,13 +233,13 @@ end
233233

234234
#Update uprev
235235
if alg_extrapolates(integrator.alg)
236-
if typeof(integrator.u) <: AbstractArray
236+
if isinplace(integrator.sol.prob)
237237
recursivecopy!(integrator.uprev2,integrator.uprev)
238238
else
239239
integrator.uprev2 = integrator.uprev
240240
end
241241
end
242-
if typeof(integrator.u) <: AbstractArray
242+
if isinplace(integrator.sol.prob)
243243
recursivecopy!(integrator.uprev,integrator.u)
244244
else
245245
integrator.uprev = integrator.u
@@ -260,7 +260,7 @@ end
260260
elseif integrator.reeval_fsal || (typeof(integrator.alg)<:DP8 && !integrator.opts.calck) || (typeof(integrator.alg)<:Union{Rosenbrock23,Rosenbrock32} && !integrator.opts.adaptive)
261261
reset_fsal!(integrator)
262262
else # Do not reeval_fsal, instead copy! over
263-
if typeof(integrator.fsalfirst) <: Union{AbstractArray,ArrayPartition}
263+
if isinplace(integrator.sol.prob)
264264
recursivecopy!(integrator.fsalfirst,integrator.fsallast)
265265
else
266266
integrator.fsalfirst = integrator.fsallast

test/REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ ParameterizedFunctions 0.5.0
66
ODEInterfaceDiffEq
77
DiffEqBase 0.10.0
88
SpecialMatrices
9+
StaticArrays

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ tic()
3434
@time @testset "saveat Tests" begin include("ode/ode_saveat_tests.jl") end
3535
(LONGER_TESTS) && @time @testset "Feagin Tests" begin include("ode/ode_feagin_tests.jl") end
3636
@time @testset "Number Type Tests" begin include("ode/ode_numbertype_tests.jl") end
37+
@time @testset "Static Array Tests" begin include("static_array_tests.jl") end
3738
@time @testset "Data Array Tests" begin include("data_array_test.jl") end
3839
@time @testset "Ndim Complex Tests" begin include("ode/ode_ndim_complex_tests.jl") end
3940
@time @testset "Iterator Tests" begin include("iterator_tests.jl") end

test/static_array_tests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using StaticArrays
2+
using DiffEqBase, OrdinaryDiffEq
3+
4+
u0 = zeros(MVector{2,Float64}, 2) + 1
5+
u0[1] = ones(MVector{2,Float64}) + 1
6+
f = (t,u,du) -> du .= u
7+
ode = ODEProblem(f, u0, (0.,1.))
8+
sol = solve(ode, Euler(), dt=1.e-2)
9+
10+
u0 = zeros(SVector{2,Float64}, 2) + 1
11+
u0[1] = ones(SVector{2,Float64}) + 1
12+
ode = ODEProblem(f, u0, (0.,1.))
13+
sol = solve(ode, Euler(), dt=1.e-2)
14+
15+
sol = solve(ode, SSPRK22(), dt=1.e-2)
16+
17+
18+
u0 = zero(MVector{2,Float64}) + 1
19+
ode = ODEProblem(f, u0, (0.,1.))
20+
sol = solve(ode, Euler(), dt=1.e-2)
21+
22+
u0 = zero(SVector{2,Float64}) + 1
23+
f = (t,u) -> u
24+
ode = ODEProblem(f, u0, (0.,1.))
25+
sol = solve(ode, Euler(), dt=1.e-2)

0 commit comments

Comments
 (0)