Skip to content

Commit f2966ce

Browse files
authored
Merge pull request #186 from devmotion/dde_updates
Add support for AbstractContinuousCallback and arbitrary discontinuities
2 parents 28f74d4 + 834aad7 commit f2966ce

File tree

5 files changed

+40
-31
lines changed

5 files changed

+40
-31
lines changed

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
julia 0.6.0
2-
DiffEqBase 1.20.0
2+
DiffEqBase 1.23.0
33
Parameters 0.5.0
44
ForwardDiff 0.5.0
55
GenericSVD 0.0.2

src/callbacks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Use Recursion to find the first callback for type-stability
22

33
# Base Case: Only one callback
4-
function find_first_continuous_callback(integrator,callback::ContinuousCallback)
4+
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback)
55
(find_callback_time(integrator,callback)...,1,1)
66
end
77

88
# Starting Case: Compute on the first callback
9-
function find_first_continuous_callback(integrator,callback::ContinuousCallback,args...)
9+
function find_first_continuous_callback(integrator, callback::AbstractContinuousCallback, args...)
1010
find_first_continuous_callback(integrator,find_callback_time(integrator,callback)...,1,1,args...)
1111
end
1212

src/integrators/integrator_utils.jl

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ function modify_dt_for_tstops!(integrator)
7272
end
7373
end
7474

75-
function savevalues!(integrator::ODEIntegrator,force_save=false)
75+
function savevalues!(integrator::ODEIntegrator,force_save=false,reduce_size=true)
7676
while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) <= integrator.tdir*integrator.t # Perform saveat
7777
integrator.saveiter += 1
7878
curt = pop!(integrator.opts.saveat)
@@ -132,7 +132,7 @@ function savevalues!(integrator::ODEIntegrator,force_save=false)
132132
copyat_or_push!(integrator.sol.alg_choice,integrator.saveiter,integrator.cache.current)
133133
end
134134
end
135-
resize!(integrator.k,integrator.kshortsize)
135+
reduce_size && resize!(integrator.k,integrator.kshortsize)
136136
end
137137

138138
function postamble!(integrator)
@@ -302,7 +302,7 @@ function handle_callbacks!(integrator)
302302
continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
303303
end
304304
end
305-
if !(typeof(discrete_callbacks)<:Tuple{})
305+
if !integrator.force_stepfail && !(typeof(discrete_callbacks)<:Tuple{})
306306
discrete_modified,saved_in_cb = apply_discrete_callback!(integrator,discrete_callbacks...)
307307
end
308308
if !saved_in_cb
@@ -346,12 +346,14 @@ function apply_step!(integrator)
346346
end
347347

348348
# Update fsal if needed
349-
if isfsal(integrator.alg)
350-
if !isempty(integrator.opts.d_discontinuities) && top(integrator.opts.d_discontinuities) == integrator.t
351-
pop!(integrator.opts.d_discontinuities)
352-
reset_fsal!(integrator)
353-
elseif integrator.reeval_fsal || integrator.u_modified || (typeof(integrator.alg)<:DP8 && !integrator.opts.calck) || (typeof(integrator.alg)<:Union{Rosenbrock23,Rosenbrock32} && !integrator.opts.adaptive)
354-
reset_fsal!(integrator)
349+
if !isempty(integrator.opts.d_discontinuities) &&
350+
top(integrator.opts.d_discontinuities) == integrator.t
351+
352+
handle_discontinuities!(integrator)
353+
isfsal(integrator.alg) && reset_fsal!(integrator)
354+
elseif isfsal(integrator.alg)
355+
if integrator.reeval_fsal || integrator.u_modified || (typeof(integrator.alg)<:DP8 && !integrator.opts.calck) || (typeof(integrator.alg)<:Union{Rosenbrock23,Rosenbrock32} && !integrator.opts.adaptive)
356+
reset_fsal!(integrator)
355357
else # Do not reeval_fsal, instead copy! over
356358
if isinplace(integrator.sol.prob)
357359
recursivecopy!(integrator.fsalfirst,integrator.fsallast)
@@ -362,6 +364,10 @@ function apply_step!(integrator)
362364
end
363365
end
364366

367+
function handle_discontinuities!(integrator)
368+
pop!(integrator.opts.d_discontinuities)
369+
end
370+
365371
function calc_dt_propose!(integrator,dtnew)
366372
dtpropose = integrator.tdir*min(abs(integrator.opts.dtmax),abs(dtnew))
367373
dtpropose = integrator.tdir*max(abs(dtpropose),abs(integrator.opts.dtmin))

src/integrators/type.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
mutable struct DEOptions{uEltype,uEltypeNoUnits,tTypeNoUnits,tType,F2,F3,F4,F5,F6,tstopsType,ECType,SType,MI}
1+
mutable struct DEOptions{absType,relType,tTypeNoUnits,tType,F2,F3,F4,F5,F6,tstopsType,discType,ECType,SType,MI}
22
maxiters::MI
33
timeseries_steps::Int
44
save_everystep::Bool
55
adaptive::Bool
6-
abstol::uEltype
7-
reltol::uEltypeNoUnits
6+
abstol::absType
7+
reltol::relType
88
gamma::tTypeNoUnits
99
qmax::tTypeNoUnits
1010
qmin::tTypeNoUnits
@@ -17,7 +17,7 @@ mutable struct DEOptions{uEltype,uEltypeNoUnits,tTypeNoUnits,tType,F2,F3,F4,F5,F
1717
save_idxs::SType
1818
tstops::tstopsType
1919
saveat::tstopsType
20-
d_discontinuities::tstopsType
20+
d_discontinuities::discType
2121
userdata::ECType
2222
progress::Bool
2323
progress_steps::Int

src/solve.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function init{algType<:OrdinaryDiffEqAlgorithm,recompile_flag}(
7171
error("Timespan is trivial")
7272
end
7373

74-
tstops_vec = collect(tType,Iterators.filter(x->tdir*tspan[1]<tdir*xtdir*tspan[end],Iterators.flatten((tstops,d_discontinuities,tspan[end]))))
74+
tstops_vec = vec(collect(tType,Iterators.filter(x->tdir*tspan[1]<tdir*xtdir*tspan[end],Iterators.flatten((tstops,d_discontinuities,tspan[end])))))
7575

7676
if tdir>0
7777
tstops_internal = binary_minheap(tstops_vec)
@@ -154,7 +154,7 @@ function init{algType<:OrdinaryDiffEqAlgorithm,recompile_flag}(
154154
saveat_vec = collect(tType,tspan[1]+saveat:saveat:(tspan[end]-saveat))
155155
# Exclude the endpoint because of floating point issues
156156
else
157-
saveat_vec = collect(tType,Iterators.filter(x->tdir*tspan[1]<tdir*x<tdir*tspan[end],saveat))
157+
saveat_vec = vec(collect(tType,Iterators.filter(x->tdir*tspan[1]<tdir*x<tdir*tspan[end],saveat)))
158158
end
159159

160160
if tdir>0
@@ -163,7 +163,7 @@ function init{algType<:OrdinaryDiffEqAlgorithm,recompile_flag}(
163163
saveat_internal = binary_maxheap(saveat_vec)
164164
end
165165

166-
d_discontinuities_vec = collect(tType,d_discontinuities)
166+
d_discontinuities_vec = vec(collect(d_discontinuities))
167167

168168
if tdir>0
169169
d_discontinuities_internal = binary_minheap(d_discontinuities_vec)
@@ -209,18 +209,21 @@ function init{algType<:OrdinaryDiffEqAlgorithm,recompile_flag}(
209209
saveiter_dense = 0
210210
end
211211

212-
opts = DEOptions(maxiters,timeseries_steps,save_everystep,adaptive,abstol_internal,
213-
reltol_internal,tTypeNoUnits(gamma),tTypeNoUnits(qmax),tTypeNoUnits(qmin),
214-
tTypeNoUnits(qsteady_max),tTypeNoUnits(qsteady_min),
215-
tTypeNoUnits(failfactor),tType(dtmax),tType(dtmin),internalnorm,save_idxs,
216-
tstops_internal,saveat_internal,d_discontinuities_internal,
217-
userdata,
218-
progress,progress_steps,
219-
progress_name,progress_message,
220-
timeseries_errors,dense_errors,
221-
tTypeNoUnits(beta1),tTypeNoUnits(beta2),tTypeNoUnits(qoldinit),dense,save_start,
222-
callbacks_internal,isoutofdomain,unstable_check,verbose,calck,force_dtmin,
223-
advance_to_tstop,stop_at_next_tstop)
212+
opts = DEOptions{typeof(abstol_internal),typeof(reltol_internal),tTypeNoUnits,tType,
213+
typeof(internalnorm),typeof(callbacks_internal),typeof(isoutofdomain),
214+
typeof(progress_message),typeof(unstable_check),typeof(tstops_internal),
215+
typeof(d_discontinuities_internal),typeof(userdata),typeof(save_idxs),
216+
typeof(maxiters)}(
217+
maxiters,timeseries_steps,save_everystep,adaptive,abstol_internal,
218+
reltol_internal,tTypeNoUnits(gamma),tTypeNoUnits(qmax),
219+
tTypeNoUnits(qmin),tTypeNoUnits(qsteady_max),
220+
tTypeNoUnits(qsteady_min),tTypeNoUnits(failfactor),tType(dtmax),
221+
tType(dtmin),internalnorm,save_idxs,tstops_internal,saveat_internal,
222+
d_discontinuities_internal,userdata,progress,progress_steps,
223+
progress_name,progress_message,timeseries_errors,dense_errors,
224+
tTypeNoUnits(beta1),tTypeNoUnits(beta2),tTypeNoUnits(qoldinit),dense,
225+
save_start,callbacks_internal,isoutofdomain,unstable_check,verbose,
226+
calck,force_dtmin,advance_to_tstop,stop_at_next_tstop)
224227

225228
progress ? (prog = Juno.ProgressBar(name=progress_name)) : prog = nothing
226229

0 commit comments

Comments
 (0)