Skip to content

Commit f1a265b

Browse files
Merge remote-tracking branch 'origin/master'
2 parents d72c814 + 7a7caa6 commit f1a265b

File tree

8 files changed

+170
-6
lines changed

8 files changed

+170
-6
lines changed

src/OrdinaryDiffEq.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,5 @@ module OrdinaryDiffEq
152152
export SplitEuler
153153

154154
export Nystrom4, Nystrom4VelocityIndependent, Nystrom5VelocityIndependent,
155-
IRKN3, IRKN4, DPRKN6, DPRKN8, DPRKN12, ERKN4
155+
IRKN3, IRKN4, DPRKN6, DPRKN8, DPRKN12, ERKN4, ERKN5
156156
end # module

src/alg_utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ isfsal(alg::DPRKN6) = true
102102
isfsal(alg::DPRKN8) = true
103103
isfsal(alg::DPRKN12) = true
104104
isfsal(alg::ERKN4) = true
105+
isfsal(alg::ERKN5) = true
105106

106107
fsal_typeof(alg::OrdinaryDiffEqAlgorithm,rate_prototype) = typeof(rate_prototype)
107108
#fsal_typeof(alg::LawsonEuler,rate_prototype) = Vector{typeof(rate_prototype)}
@@ -269,6 +270,7 @@ alg_order(alg::DPRKN6) = 6
269270
alg_order(alg::DPRKN8) = 8
270271
alg_order(alg::DPRKN12) = 12
271272
alg_order(alg::ERKN4) = 4
273+
alg_order(alg::ERKN5) = 5
272274

273275
alg_order(alg::Midpoint) = 2
274276
alg_order(alg::GenericIIF1) = 1

src/algorithms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ struct DPRKN6 <: OrdinaryDiffEqAdaptiveAlgorithm end
133133
struct DPRKN8 <: OrdinaryDiffEqAdaptiveAlgorithm end
134134
struct DPRKN12 <: OrdinaryDiffEqAdaptiveAlgorithm end
135135
struct ERKN4 <: OrdinaryDiffEqAdaptiveAlgorithm end
136+
struct ERKN5 <: OrdinaryDiffEqAdaptiveAlgorithm end
136137

137138
################################################################################
138139

src/caches/rkn_caches.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,3 +294,34 @@ function alg_cache(alg::ERKN4,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev
294294
tmp = similar(u)
295295
ERKN4Cache(u,uprev,k1,k2,k3,k4,k,utilde,tmp,atmp,tab)
296296
end
297+
298+
struct ERKN5Cache{uType,uArrayType,rateType,reducedRateType,uEltypeNoUnits,TabType} <: OrdinaryDiffEqMutableCache
299+
u::uType
300+
uprev::uType
301+
fsalfirst::rateType
302+
k2::reducedRateType
303+
k3::reducedRateType
304+
k4::reducedRateType
305+
k::rateType
306+
utilde::uArrayType
307+
tmp::uType
308+
atmp::uEltypeNoUnits
309+
tab::TabType
310+
end
311+
312+
u_cache(c::ERKN5Cache) = (c.atmp,c.utilde)
313+
du_cache(c::ERKN5Cache) = (c.fsalfirst,c.k2,c.k3,c.k4,c.k)
314+
315+
function alg_cache(alg::ERKN5,u,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,::Type{Val{true}})
316+
reduced_rate_prototype = rate_prototype.x[2]
317+
tab = ERKN5ConstantCache(real(uEltypeNoUnits),real(tTypeNoUnits))
318+
k1 = zeros(rate_prototype)
319+
k2 = zeros(reduced_rate_prototype)
320+
k3 = zeros(reduced_rate_prototype)
321+
k4 = zeros(reduced_rate_prototype)
322+
k = zeros(rate_prototype)
323+
utilde = similar(u,indices(u))
324+
atmp = similar(u,uEltypeNoUnits)
325+
tmp = similar(u)
326+
ERKN5Cache(u,uprev,k1,k2,k3,k4,k,utilde,tmp,atmp,tab)
327+
end

src/dense/interpolants.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,13 @@ end
576576
b12Θ = @evalpoly(Θ, 0, 0, r122, r123, r124, r125, r126)
577577

578578
if out == nothing
579-
return y₀[idxs] + dt*(k[1][idxs]*b1Θ + k[4][idxs]*b4Θ + k[5][idxs]*b5Θ + k[6][idxs]*b6Θ + k[7][idxs]*b7Θ + k[8][idxs]*b8Θ + k[9][idxs]*b9Θ + k[10][idxs]*b10Θ + k[11][idxs]*b11Θ + k[12][idxs]*b12Θ)
579+
if idxs == nothing
580+
# return @. y₀ + dt*(k[1]*b1Θ + k[4]*b4Θ + k[5]*b5Θ + k[6]*b6Θ + k[7]*b7Θ + k[8]*b8Θ + k[9]*b9Θ + k[10]*b10Θ + k[11]*b11Θ + k[12]*b12Θ)
581+
return y₀ + dt*(k[1]*b1Θ + k[4]*b4Θ + k[5]*b5Θ + k[6]*b6Θ + k[7]*b7Θ + k[8]*b8Θ + k[9]*b9Θ + k[10]*b10Θ + k[11]*b11Θ + k[12]*b12Θ)
582+
else
583+
# return @. y₀[idxs] + dt*(k[1][idxs]*b1Θ + k[4][idxs]*b4Θ + k[5][idxs]*b5Θ + k[6][idxs]*b6Θ + k[7][idxs]*b7Θ + k[8][idxs]*b8Θ + k[9][idxs]*b9Θ + k[10][idxs]*b10Θ + k[11][idxs]*b11Θ + k[12][idxs]*b12Θ)
584+
return y₀[idxs] + dt*(k[1][idxs]*b1Θ + k[4][idxs]*b4Θ + k[5][idxs]*b5Θ + k[6][idxs]*b6Θ + k[7][idxs]*b7Θ + k[8][idxs]*b8Θ + k[9][idxs]*b9Θ + k[10][idxs]*b10Θ + k[11][idxs]*b11Θ + k[12][idxs]*b12Θ)
585+
end
580586
elseif idxs == nothing
581587
#@. out = y₀ + dt*(k[1]*b1Θ + k[4]*b4Θ + k[5]*b5Θ + k[6]*b6Θ + k[7]*b7Θ + k[8]*b8Θ + k[9]*b9Θ + k[10]*b10Θ + k[11]*b11Θ + k[12]*b12Θ)
582588
@inbounds for i in eachindex(out)

src/integrators/rkn_integrators.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
const NystromDefaultInitialization = Union{Nystrom4Cache,
88
Nystrom4VelocityIndependentCache,
99
Nystrom5VelocityIndependentCache,
10-
IRKN3Cache, IRKN4Cache,
10+
IRKN3Cache, IRKN4Cache,
1111
DPRKN8Cache, DPRKN12Cache,
12-
ERKN4Cache}
12+
ERKN4Cache, ERKN5Cache}
1313

1414
function initialize!(integrator,cache::NystromDefaultInitialization)
1515
@unpack fsalfirst,k = cache
@@ -492,3 +492,40 @@ end
492492
integrator.EEst = integrator.opts.internalnorm(atmp)
493493
end
494494
end
495+
496+
@muladd function perform_step!(integrator,cache::ERKN5Cache,repeat_step=false)
497+
@unpack t,dt,f = integrator
498+
u,du = integrator.u.x
499+
uprev,duprev = integrator.uprev.x
500+
@unpack tmp,atmp,fsalfirst,k2,k3,k4,k,utilde = cache
501+
@unpack c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4 = cache.tab
502+
ku, kdu = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2]
503+
uidx = eachindex(integrator.uprev.x[2])
504+
k1 = fsalfirst.x[2]
505+
506+
@. ku = uprev + dt*(c1*duprev + dt*a21*k1)
507+
508+
f.f2(t+dt*c1,ku,du,k2)
509+
@. ku = uprev + dt*(c2*duprev + dt*(a31*k1 + a32*k2))
510+
511+
f.f2(t+dt*c2,ku,du,k3)
512+
@. ku = uprev + dt*(c3*duprev + dt*(a41*k1 + a42*k2 + a43*k3))
513+
514+
f.f2(t+dt*c3,ku,du,k4)
515+
@tight_loop_macros for i in uidx
516+
@inbounds u[i] = uprev[i] + dt*(duprev[i] + dt*(b1 *k1[i] + b2 *k2[i] + b3 *k3[i] + b4 *k4[i]))
517+
@inbounds du[i] = duprev[i] + dt*(bp1*k1[i] + bp2*k2[i] + bp3*k3[i] + bp4*k4[i])
518+
end
519+
520+
f.f1(t+dt,u,du,k.x[1])
521+
f.f2(t+dt,u,du,k.x[2])
522+
if integrator.opts.adaptive
523+
uhat, duhat = utilde.x
524+
dtsq = dt^2
525+
@tight_loop_macros for i in uidx
526+
@inbounds uhat[i] = dtsq*(btilde1*k1[i] + btilde2*k2[i] + btilde3*k3[i] + btilde4*k4[i])
527+
end
528+
calculate_residuals!(atmp, uhat, integrator.uprev, integrator.u, integrator.opts.abstol, integrator.opts.reltol)
529+
integrator.EEst = integrator.opts.internalnorm(atmp)
530+
end
531+
end

src/tableaus/rkn_tableaus.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,88 @@ ERKN4ConstantCache(
227227
T(0.0016835016835016834))
228228
end
229229

230+
struct ERKN5ConstantCache{T,T2} <: OrdinaryDiffEqConstantCache
231+
c1::T2
232+
c2::T2
233+
c3::T2
234+
a21::T
235+
a31::T
236+
a32::T
237+
a41::T
238+
a42::T
239+
a43::T
240+
b1::T
241+
b2::T
242+
b3::T
243+
b4::T
244+
bp1::T # bp denotes bprime
245+
bp2::T
246+
bp3::T
247+
bp4::T
248+
btilde1::T
249+
btilde2::T
250+
btilde3::T
251+
btilde4::T
252+
# bptilde1::T
253+
# bptilde2::T
254+
# bptilde3::T
255+
# bptilde4::T
256+
end
257+
258+
function ERKN5ConstantCache(T::Type,T2::Type)
259+
c1 = T2(1//2)
260+
c2 = T2(19//70)
261+
c3 = T2(44//51)
262+
a21 = T(1//8)
263+
a31 = T(2907//343000)
264+
a32 = T(1216//42875)
265+
a41 = T(6624772//Int64(128538819))
266+
a42 = T(6273905//Int64(54121608))
267+
a43 = T(Int64(210498365)//Int64(1028310552))
268+
b1 = T(479//5016)
269+
b2 = T(235//1776)
270+
b3 = T(145775//641744)
271+
b4 = T(309519//6873416)
272+
btilde1 = T(479//5016 - 184883//2021250)
273+
btilde2 = T(235//1776 - 411163//3399375)
274+
btilde3 = T(145775//641744 - 6//25)
275+
btilde4 = T(309519//6873416 - 593028//Int64(12464375))
276+
bp1 = b1
277+
bp2 = T(235//888)
278+
bp3 = T(300125//962616)
279+
bp4 = T(2255067//6873416)
280+
# bptilde1 = T(0)
281+
# bptilde2 = T(0)
282+
# bptilde3 = T(0)
283+
# bptilde4 = T(0)
284+
ERKN5ConstantCache(c1, c2, c3, a21, a31, a32, a41, a42, a43, b1, b2, b3, b4, bp1, bp2, bp3, bp4, btilde1, btilde2, btilde3, btilde4)
285+
end
286+
287+
Base.@pure function ERKN5ConstantCache{T<:CompiledFloats,T2<:CompiledFloats}(::Type{T},::Type{T2})
288+
ERKN5ConstantCache(
289+
T2(0.5),
290+
T2(0.2714285714285714),
291+
T2(0.8627450980392157),
292+
T(0.125),
293+
T(0.008475218658892128),
294+
T(0.028361516034985424),
295+
T(0.051539076300366506),
296+
T(0.11592236875149756),
297+
T(0.20470310704348388),
298+
T(0.09549441786283891),
299+
T(0.13231981981981983),
300+
T(0.22715444164651324),
301+
T(0.04503132067082801),
302+
T(0.09549441786283891),
303+
T(0.26463963963963966),
304+
T(0.3117806061814888),
305+
T(0.32808533631603265),
306+
T(0.004024782736060931),
307+
T(0.011367291781577495),
308+
T(-0.012845558353486749),
309+
T(-0.0025465161641516788))
310+
end
311+
230312
struct DPRKN6ConstantCache{T,T2} <: OrdinaryDiffEqConstantCache
231313
c1::T2
232314
c2::T2

test/partitioned_methods_tests.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ sim = test_convergence(dts,prob_big,DPRKN12(),dense_errors=true)
144144
sim = test_convergence(dts,prob_big,ERKN4(),dense_errors=true)
145145
@test sim.𝒪est[:l2] 5 rtol = 1e-1
146146
@test sim.𝒪est[:L2] 4 rtol = 1e-1
147+
sim = test_convergence(dts,prob_big,ERKN5(),dense_errors=true)
148+
@test sim.𝒪est[:l2] 5 rtol = 1e-1
149+
@test sim.𝒪est[:L2] 4 rtol = 1e-1
147150

148151
# Adaptive methods regression test
149152
sol = solve(prob, DPRKN6())
@@ -152,8 +155,10 @@ sol = solve(prob, DPRKN8())
152155
@test length(sol.u) < 13
153156
sol = solve(prob, DPRKN12())
154157
@test length(sol.u) < 9
155-
sol = solve(prob, ERKN4())
156-
@test length(sol.u) < 15
158+
sol = solve(prob, ERKN4(),reltol=1e-8)
159+
@test length(sol.u) < 38
160+
sol = solve(prob, ERKN5(),reltol=1e-8)
161+
@test length(sol.u) < 29
157162

158163
f = function (t,u,du)
159164
du.x[1] .= u.x[2]

0 commit comments

Comments
 (0)