Skip to content

Commit b7ed4fa

Browse files
Merge branch 'interp_api'
2 parents c460900 + d081104 commit b7ed4fa

File tree

3 files changed

+64
-37
lines changed

3 files changed

+64
-37
lines changed

src/dense/generic_dense.jl

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,15 @@ times ts (sorted), with values timeseries and derivatives ks
121121
@inbounds for j in idx
122122
t = tvals[j]
123123
i = searchsortedfirst(@view(ts[@view(notsaveat_idxs[i:end])]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
124-
if ts[notsaveat_idxs[i]] == t
124+
avoid_constant_ends = deriv != Val{0} || typeof(t) <: ForwardDiff.Dual
125+
avoid_constant_ends && i==1 && (i+=1)
126+
if !avoid_constant_ends && ts[notsaveat_idxs[i]] == t
125127
if idxs == nothing
126128
vals[j] = timeseries[notsaveat_idxs[i]]
127129
else
128130
vals[j] = timeseries[notsaveat_idxs[i]][idxs]
129131
end
130-
elseif ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
132+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
131133
if idxs == nothing
132134
vals[j] = timeseries[notsaveat_idxs[i-1]]
133135
else
@@ -144,7 +146,7 @@ times ts (sorted), with values timeseries and derivatives ks
144146
else
145147
idxs_internal = idxs
146148
end
147-
149+
148150
if typeof(cache) <: (DiscreteCache) || typeof(cache) <: DiscreteConstantCache
149151
vals[j] = ode_interpolant(Θ,dt,timeseries[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i]],0,cache,idxs_internal,deriv)
150152
elseif !id.dense
@@ -178,13 +180,15 @@ times ts (sorted), with values timeseries and derivatives ks
178180
@inbounds for j in idx
179181
t = tvals[j]
180182
i = searchsortedfirst(@view(ts[@view(notsaveat_idxs[i:end])]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
181-
if ts[notsaveat_idxs[i]] == t
183+
avoid_constant_ends = deriv != Val{0} || typeof(t) <: ForwardDiff.Dual
184+
avoid_constant_ends && i==1 && (i+=1)
185+
if !avoid_constant_ends && ts[notsaveat_idxs[i]] == t
182186
if idxs == nothing
183187
vals[j] = timeseries[notsaveat_idxs[i]]
184188
else
185189
vals[j] = timeseries[notsaveat_idxs[i]][idxs]
186190
end
187-
elseif ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
191+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
188192
if idxs == nothing
189193
vals[j] = timeseries[notsaveat_idxs[i-1]]
190194
else
@@ -242,13 +246,15 @@ times ts (sorted), with values timeseries and derivatives ks
242246
tval < ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
243247
tdir = sign(ts[end]-ts[1])
244248
@inbounds i = searchsortedfirst(@view(ts[notsaveat_idxs]),tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
245-
@inbounds if ts[notsaveat_idxs[i]] == tval
249+
avoid_constant_ends = deriv != Val{0} || typeof(tval) <: ForwardDiff.Dual
250+
avoid_constant_ends && i==1 && (i+=1)
251+
@inbounds if !avoid_constant_ends && ts[notsaveat_idxs[i]] == tval
246252
if idxs == nothing
247253
val = timeseries[notsaveat_idxs[i]]
248254
else
249255
val = timeseries[notsaveat_idxs[i]][idxs]
250256
end
251-
elseif ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
257+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
252258
if idxs == nothing
253259
val = timeseries[notsaveat_idxs[i-1]]
254260
else
@@ -292,13 +298,15 @@ times ts (sorted), with values timeseries and derivatives ks
292298
tval < ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
293299
tdir = sign(ts[end]-ts[1])
294300
@inbounds i = searchsortedfirst(@view(ts[notsaveat_idxs]),tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
295-
@inbounds if ts[notsaveat_idxs[i]] == tval
301+
avoid_constant_ends = deriv != Val{0} || typeof(tval) <: ForwardDiff.Dual
302+
avoid_constant_ends && i==1 && (i+=1)
303+
@inbounds if !avoid_constant_ends && ts[notsaveat_idxs[i]] == tval
296304
if idxs == nothing
297305
copy!(out,timeseries[notsaveat_idxs[i]])
298306
else
299307
copy!(out,timeseries[notsaveat_idxs[i]][idxs])
300308
end
301-
elseif ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
309+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
302310
if idxs == nothing
303311
copy!(out,timeseries[notsaveat_idxs[i-1]])
304312
else
@@ -345,7 +353,7 @@ function ode_addsteps!{calcVal,calcVal2,calcVal3}(k,t,uprev,u,dt,f,cache,always_
345353
nothing
346354
end
347355

348-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idxs,T::Type{Val{0}})
356+
@inline function ode_interpolant{TI}(Θ,dt,y₀,y₁,k,cache::OrdinaryDiffEqMutableCache,idxs,T::Type{Val{TI}})
349357
if typeof(idxs) <: Tuple
350358
out = similar(y₀,idxs)
351359
idxs_internal=eachindex(y₀)
@@ -363,12 +371,21 @@ end
363371

364372
##################### Hermite Interpolants
365373

374+
# If no dispatch found, assume Hermite
375+
function ode_interpolant{TI}(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}})
376+
hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
377+
end
378+
379+
function ode_interpolant!{TI}(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}})
380+
hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
381+
end
382+
366383
"""
367384
Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190
368385
369386
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
370387
"""
371-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
388+
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
372389
if typeof(y₀) <: AbstractArray
373390
if typeof(idxs) <: Tuple
374391
out = similar(y₀,idxs)
@@ -391,7 +408,7 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
391408
392409
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
393410
"""
394-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
411+
@inline function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
395412
if out == nothing
396413
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
397414
else
@@ -406,7 +423,7 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
406423
407424
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
408425
"""
409-
@inline function ode_interpolant!(all_out::ArrayPartition,Θ,dt,all_y₀,all_y₁,all_k,cache,all_idxs,T::Type{Val{0}}) # Default interpolant is Hermite
426+
@inline function hermite_interpolant!(all_out::ArrayPartition,Θ,dt,all_y₀,all_y₁,all_k,cache,all_idxs,T::Type{Val{0}}) # Default interpolant is Hermite
410427
for (out,y₀,y₁,idxs,k1,k2) in zip(all_out.x,all_y₀.x,all_y₁.x,all_idxs,all_k[1].x,all_k[2].x)
411428
@inbounds for (j,i) in enumerate(idxs...)
412429
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k1[i] + Θ*dt*k2[i])
@@ -418,7 +435,7 @@ end
418435

419436

420437

421-
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
438+
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
422439
if typeof(y₀) <: AbstractArray
423440
if typeof(idxs) <: Tuple
424441
out = similar(y₀,idxs)

src/dense/interpolants.jl

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,25 @@ end
4545
"""
4646
From MATLAB ODE Suite by Shampine
4747
"""
48-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rosenbrock23ConstantCache,idxs,T::Type{Val{0}})
48+
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{0}})
4949
d = cache.d
5050
c1 = Θ*(1-Θ)/(1-2d)
5151
c2 = Θ*-2d)/(1-2d)
5252
y₀ + dt*(c1*k[1] + c2*k[2])
5353
end
5454

55+
# First Derivative of the dense output
56+
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{1}})
57+
d = cache.d
58+
c1diff = (1-2*Θ)/(1-2*d)
59+
c2diff = (2*Θ-2*d)/(1-2*d)
60+
c1diff*k[1] + c2diff*k[2]
61+
end
62+
5563
"""
5664
From MATLAB ODE Suite by Shampine
5765
"""
58-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Rosenbrock23Cache,idxs,T::Type{Val{0}})
66+
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{0}})
5967
d = cache.tab.d
6068
c1 = Θ*(1-Θ)/(1-2d)
6169
c2 = Θ*-2d)/(1-2d)
@@ -68,29 +76,15 @@ From MATLAB ODE Suite by Shampine
6876
end
6977
end
7078

71-
"""
72-
From MATLAB ODE Suite by Shampine
73-
"""
74-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rosenbrock32ConstantCache,idxs,T::Type{Val{0}})
75-
d = cache.d
76-
c1 = Θ*(1-Θ)/(1-2d)
77-
c2 = Θ*-2d)/(1-2d)
78-
y₀ + dt*(c1*k[1] + c2*k[2])
79-
end
80-
81-
"""
82-
From MATLAB ODE Suite by Shampine
83-
"""
84-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Rosenbrock32Cache,idxs,T::Type{Val{0}})
79+
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{1}})
8580
d = cache.tab.d
86-
c1 = Θ*(1-Θ)/(1-2d)
87-
c2 = Θ*-2d)/(1-2d)
88-
y₀ + dt*(c1*k[1] + c2*k[2])
81+
c1diff = (1-2*Θ)/(1-2*d)
82+
c2diff = (2*Θ-2*d)/(1-2*d)
8983
if out == nothing
90-
return y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
84+
return c1diff*k[1][idxs] + c2diff*k[2][idxs]
9185
else
9286
@inbounds for (j,i) in enumerate(idxs)
93-
out[j] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
87+
out[j] = c1diff*k[1][i] + c2diff*k[2][i]
9488
end
9589
end
9690
end

test/ode/ode_dense_tests.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using OrdinaryDiffEq, DiffEqProblemLibrary,Base.Test, DiffEqBase
2+
using Calculus, ForwardDiff
23

34
bools = Vector{Bool}(0)
45
prob = prob_ode_linear
@@ -12,6 +13,7 @@ sol2 =solve(prob,Euler(),dt=1//2^(4),dense=true)
1213
sol3 =solve(prob,Euler(),dt=1//2^(5),dense=true)
1314

1415
prob = prob_ode_2Dlinear
16+
1517
sol =solve(prob,Euler(),dt=1//2^(2),dense=true)
1618

1719
interpd = sol(0:1//2^(4):1)
@@ -373,9 +375,23 @@ sol2 =solve(prob,Vern9(),dt=1//2^(4),dense=true,adaptive=false)
373375

374376
prob = prob_ode_linear
375377

376-
sol =solve(prob,Rosenbrock23(),dt=1//2^(2),dense=true)
378+
sol =solve(prob,Rosenbrock23(),dt=1//2^(12),dense=true)
377379

378-
sol(interpd_1d,0:1//2^(4):1)
380+
sol(0:1//2^(4):1)
381+
382+
sol(0:1//2^(4):1,Val{1})
383+
384+
const deriv_test_points = linspace(0,1,10)
385+
386+
for t in deriv_test_points
387+
deriv = sol(t,Val{1})
388+
if t == 0
389+
#@test deriv ≈ derivative(sol,0.00,:forward)
390+
elseif t != 1
391+
#@test deriv ≈ derivative(sol,t)
392+
end
393+
@test deriv ForwardDiff.derivative(sol,t)
394+
end
379395

380396
sol2 =solve(prob,Rosenbrock23(),dt=1//2^(4),dense=true,adaptive=false)
381397

0 commit comments

Comments
 (0)