Skip to content

Commit d081104

Browse files
tests pass on Rosenbrock first derivative, ForwardDiff interp fix
1 parent a702304 commit d081104

File tree

3 files changed

+45
-53
lines changed

3 files changed

+45
-53
lines changed

src/dense/generic_dense.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,15 @@ times ts (sorted), with values timeseries and derivatives ks
121121
@inbounds for j in idx
122122
t = tvals[j]
123123
i = searchsortedfirst(@view(ts[@view(notsaveat_idxs[i:end])]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
124-
if deriv == Val{0} && ts[notsaveat_idxs[i]] == t
124+
avoid_constant_ends = deriv != Val{0} || typeof(t) <: ForwardDiff.Dual
125+
avoid_constant_ends && i==1 && (i+=1)
126+
if !avoid_constant_ends && ts[notsaveat_idxs[i]] == t
125127
if idxs == nothing
126128
vals[j] = timeseries[notsaveat_idxs[i]]
127129
else
128130
vals[j] = timeseries[notsaveat_idxs[i]][idxs]
129131
end
130-
elseif deriv == Val{0} && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
132+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
131133
if idxs == nothing
132134
vals[j] = timeseries[notsaveat_idxs[i-1]]
133135
else
@@ -178,13 +180,15 @@ times ts (sorted), with values timeseries and derivatives ks
178180
@inbounds for j in idx
179181
t = tvals[j]
180182
i = searchsortedfirst(@view(ts[@view(notsaveat_idxs[i:end])]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
181-
if deriv == Val{0} && ts[notsaveat_idxs[i]] == t
183+
avoid_constant_ends = deriv != Val{0} || typeof(t) <: ForwardDiff.Dual
184+
avoid_constant_ends && i==1 && (i+=1)
185+
if !avoid_constant_ends && ts[notsaveat_idxs[i]] == t
182186
if idxs == nothing
183187
vals[j] = timeseries[notsaveat_idxs[i]]
184188
else
185189
vals[j] = timeseries[notsaveat_idxs[i]][idxs]
186190
end
187-
elseif deriv == Val{0} && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
191+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == t # Can happen if it's the first value!
188192
if idxs == nothing
189193
vals[j] = timeseries[notsaveat_idxs[i-1]]
190194
else
@@ -242,13 +246,15 @@ times ts (sorted), with values timeseries and derivatives ks
242246
tval < ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
243247
tdir = sign(ts[end]-ts[1])
244248
@inbounds i = searchsortedfirst(@view(ts[notsaveat_idxs]),tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
245-
@inbounds if deriv == Val{0} && ts[notsaveat_idxs[i]] == tval
249+
avoid_constant_ends = deriv != Val{0} || typeof(tval) <: ForwardDiff.Dual
250+
avoid_constant_ends && i==1 && (i+=1)
251+
@inbounds if !avoid_constant_ends && ts[notsaveat_idxs[i]] == tval
246252
if idxs == nothing
247253
val = timeseries[notsaveat_idxs[i]]
248254
else
249255
val = timeseries[notsaveat_idxs[i]][idxs]
250256
end
251-
elseif deriv == Val{0} && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
257+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
252258
if idxs == nothing
253259
val = timeseries[notsaveat_idxs[i-1]]
254260
else
@@ -273,8 +279,6 @@ times ts (sorted), with values timeseries and derivatives ks
273279
val = ode_interpolant(Θ,dt,timeseries[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i]],ks[i],cache.caches[id.alg_choice[notsaveat_idxs[i-1]]],idxs_internal,deriv)
274280
else
275281
ode_addsteps!(ks[i],ts[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i]],dt,f,cache) # update the kcurrent
276-
@show deriv
277-
println(@which(ode_interpolant(Θ,dt,timeseries[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i]],ks[i],cache,idxs_internal,deriv)))
278282
val = ode_interpolant(Θ,dt,timeseries[notsaveat_idxs[i-1]],timeseries[notsaveat_idxs[i]],ks[i],cache,idxs_internal,deriv)
279283
end
280284
end
@@ -294,13 +298,15 @@ times ts (sorted), with values timeseries and derivatives ks
294298
tval < ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
295299
tdir = sign(ts[end]-ts[1])
296300
@inbounds i = searchsortedfirst(@view(ts[notsaveat_idxs]),tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
297-
@inbounds if deriv == Val{0} && ts[notsaveat_idxs[i]] == tval
301+
avoid_constant_ends = deriv != Val{0} || typeof(tval) <: ForwardDiff.Dual
302+
avoid_constant_ends && i==1 && (i+=1)
303+
@inbounds if !avoid_constant_ends && ts[notsaveat_idxs[i]] == tval
298304
if idxs == nothing
299305
copy!(out,timeseries[notsaveat_idxs[i]])
300306
else
301307
copy!(out,timeseries[notsaveat_idxs[i]][idxs])
302308
end
303-
elseif deriv == Val{0} && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
309+
elseif !avoid_constant_ends && ts[notsaveat_idxs[i-1]] == tval # Can happen if it's the first value!
304310
if idxs == nothing
305311
copy!(out,timeseries[notsaveat_idxs[i-1]])
306312
else
@@ -365,12 +371,21 @@ end
365371

366372
##################### Hermite Interpolants
367373

374+
# If no dispatch found, assume Hermite
375+
function ode_interpolant{TI}(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}})
376+
hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T)
377+
end
378+
379+
function ode_interpolant!{TI}(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{TI}})
380+
hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T)
381+
end
382+
368383
"""
369384
Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Problems Page 190
370385
371386
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
372387
"""
373-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
388+
@inline function hermite_interpolant(Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
374389
if typeof(y₀) <: AbstractArray
375390
if typeof(idxs) <: Tuple
376391
out = similar(y₀,idxs)
@@ -393,7 +408,7 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
393408
394409
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
395410
"""
396-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
411+
@inline function hermite_interpolant!(out,Θ,dt,y₀,y₁,k,cache,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
397412
if out == nothing
398413
return (1-Θ)*y₀[idxs]+Θ*y₁[idxs]+Θ*-1)*((1-2Θ)*(y₁[idxs]-y₀[idxs])+-1)*dt*k[1][idxs] + Θ*dt*k[2][idxs])
399414
else
@@ -408,7 +423,7 @@ Hairer Norsett Wanner Solving Ordinary Differential Euations I - Nonstiff Proble
408423
409424
Herimte Interpolation, chosen if no other dispatch for ode_interpolant
410425
"""
411-
@inline function ode_interpolant!(all_out::ArrayPartition,Θ,dt,all_y₀,all_y₁,all_k,cache,all_idxs,T::Type{Val{0}}) # Default interpolant is Hermite
426+
@inline function hermite_interpolant!(all_out::ArrayPartition,Θ,dt,all_y₀,all_y₁,all_k,cache,all_idxs,T::Type{Val{0}}) # Default interpolant is Hermite
412427
for (out,y₀,y₁,idxs,k1,k2) in zip(all_out.x,all_y₀.x,all_y₁.x,all_idxs,all_k[1].x,all_k[2].x)
413428
@inbounds for (j,i) in enumerate(idxs...)
414429
out[j] = (1-Θ)*y₀[i]+Θ*y₁[i]+Θ*-1)*((1-2Θ)*(y₁[i]-y₀[i])+-1)*dt*k1[i] + Θ*dt*k2[i])
@@ -420,7 +435,7 @@ end
420435

421436

422437

423-
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}}) # Default interpolant is Hermite
438+
@inline function linear_interpolant(Θ,dt,y₀,y₁,idxs,T::Type{Val{0}})
424439
if typeof(y₀) <: AbstractArray
425440
if typeof(idxs) <: Tuple
426441
out = similar(y₀,idxs)

src/dense/interpolants.jl

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,25 @@ end
4545
"""
4646
From MATLAB ODE Suite by Shampine
4747
"""
48-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rosenbrock23ConstantCache,idxs,T::Type{Val{0}})
48+
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{0}})
4949
d = cache.d
5050
c1 = Θ*(1-Θ)/(1-2d)
5151
c2 = Θ*-2d)/(1-2d)
5252
y₀ + dt*(c1*k[1] + c2*k[2])
5353
end
5454

5555
# First Derivative of the dense output
56-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rosenbrock23ConstantCache,idxs,T::Type{Val{1}})
56+
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{1}})
5757
d = cache.d
5858
c1diff = (1-2*Θ)/(1-2*d)
5959
c2diff = (2*Θ-2*d)/(1-2*d)
60-
@show "here!"
6160
c1diff*k[1] + c2diff*k[2]
6261
end
6362

6463
"""
6564
From MATLAB ODE Suite by Shampine
6665
"""
67-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Rosenbrock23Cache,idxs,T::Type{Val{0}})
66+
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{0}})
6867
d = cache.tab.d
6968
c1 = Θ*(1-Θ)/(1-2d)
7069
c2 = Θ*-2d)/(1-2d)
@@ -77,7 +76,7 @@ From MATLAB ODE Suite by Shampine
7776
end
7877
end
7978

80-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Rosenbrock23Cache,idxs,T::Type{Val{1}})
79+
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{1}})
8180
d = cache.tab.d
8281
c1diff = (1-2*Θ)/(1-2*d)
8382
c2diff = (2*Θ-2*d)/(1-2*d)
@@ -90,33 +89,6 @@ end
9089
end
9190
end
9291

93-
"""
94-
From MATLAB ODE Suite by Shampine
95-
"""
96-
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rosenbrock32ConstantCache,idxs,T::Type{Val{0}})
97-
d = cache.d
98-
c1 = Θ*(1-Θ)/(1-2d)
99-
c2 = Θ*-2d)/(1-2d)
100-
y₀ + dt*(c1*k[1] + c2*k[2])
101-
end
102-
103-
"""
104-
From MATLAB ODE Suite by Shampine
105-
"""
106-
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Rosenbrock32Cache,idxs,T::Type{Val{0}})
107-
d = cache.tab.d
108-
c1 = Θ*(1-Θ)/(1-2d)
109-
c2 = Θ*-2d)/(1-2d)
110-
y₀ + dt*(c1*k[1] + c2*k[2])
111-
if out == nothing
112-
return y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
113-
else
114-
@inbounds for (j,i) in enumerate(idxs)
115-
out[j] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
116-
end
117-
end
118-
end
119-
12092
"""
12193
Runge–Kutta pairs of order 5(4) satisfying only the first column
12294
simplifying assumption

test/ode/ode_dense_tests.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using OrdinaryDiffEq, DiffEqProblemLibrary,Base.Test, DiffEqBase
2+
using Calculus, ForwardDiff
23

34
bools = Vector{Bool}(0)
45
prob = prob_ode_linear
@@ -380,13 +381,17 @@ sol(0:1//2^(4):1)
380381

381382
sol(0:1//2^(4):1,Val{1})
382383

383-
@which sol(0.55,Val{1})
384-
@which ode_interpolation(tvals,interp,idxs,deriv)
385-
@which sol.interp(0.55,nothing,Val{1})
386-
using ForwardDiff
387-
ForwardDiff.derivative(sol,0.55)
388-
using Calculus
389-
derivative(sol,0.55)
384+
const deriv_test_points = linspace(0,1,10)
385+
386+
for t in deriv_test_points
387+
deriv = sol(t,Val{1})
388+
if t == 0
389+
#@test deriv ≈ derivative(sol,0.00,:forward)
390+
elseif t != 1
391+
#@test deriv ≈ derivative(sol,t)
392+
end
393+
@test deriv ForwardDiff.derivative(sol,t)
394+
end
390395

391396
sol2 =solve(prob,Rosenbrock23(),dt=1//2^(4),dense=true,adaptive=false)
392397

0 commit comments

Comments
 (0)