Skip to content

Commit e30a535

Browse files
fix callback saving
1 parent 77f77a5 commit e30a535

File tree

2 files changed

+26
-21
lines changed

2 files changed

+26
-21
lines changed

src/callbacks.jl

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -131,13 +131,15 @@ function find_callback_time(integrator,callback)
131131
end
132132

133133
function apply_callback!(integrator,callback::ContinuousCallback,cb_time,prev_sign)
134+
saved_in_cb = false
134135
if cb_time != zero(typeof(integrator.t))
135136
change_t_via_interpolation!(integrator,integrator.tprev+cb_time)
136137
end
137138

138139
if callback.save_positions[1]
139140
update_running_noise!(integrator)
140-
savevalues!(integrator)
141+
savevalues!(integrator,true)
142+
saved_in_cb = true
141143
end
142144

143145
integrator.u_modified = true
@@ -162,40 +164,44 @@ function apply_callback!(integrator,callback::ContinuousCallback,cb_time,prev_si
162164
if !callback.save_positions[1]
163165
update_running_noise!(integrator)
164166
end
165-
savevalues!(integrator)
167+
savevalues!(integrator,true)
168+
saved_in_cb = true
166169
end
167-
return true
170+
return true,saved_in_cb
168171
end
169-
false
172+
false,saved_in_cb
170173
end
171174

172175
#Base Case: Just one
173176
function apply_discrete_callback!(integrator::SDEIntegrator,callback::DiscreteCallback)
177+
saved_in_cb = false
174178
if callback.save_positions[1]
175-
savevalues!(integrator)
179+
savevalues!(integrator,true)
180+
saved_in_cb = true
176181
end
177182

178183
integrator.u_modified = true
179184
if callback.condition(integrator.t,integrator.u,integrator)
180185
callback.affect!(integrator)
181186
if callback.save_positions[2]
182-
savevalues!(integrator)
187+
savevalues!(integrator,true)
188+
saved_in_cb = true
183189
end
184190
end
185-
integrator.u_modified
191+
integrator.u_modified,saved_in_cb
186192
end
187193

188194
#Starting: Get bool from first and do next
189195
function apply_discrete_callback!(integrator::SDEIntegrator,callback::DiscreteCallback,args...)
190-
apply_discrete_callback!(integrator,apply_discrete_callback!(integrator,callback),args...)
196+
apply_discrete_callback!(integrator,apply_discrete_callback!(integrator,callback)...,args...)
191197
end
192198

193-
function apply_discrete_callback!(integrator::SDEIntegrator,discrete_modified::Bool,callback::DiscreteCallback,args...)
194-
bool = apply_discrete_callback!(integrator,apply_discrete_callback!(integrator,callback),args...)
195-
discrete_modified || bool
199+
function apply_discrete_callback!(integrator::SDEIntegrator,discrete_modified::Bool,saved_in_cb::Bool,callback::DiscreteCallback,args...)
200+
bool,saved_in_cb2 = apply_discrete_callback!(integrator,apply_discrete_callback!(integrator,callback)...,args...)
201+
discrete_modified || bool, saved_in_cb || saved_in_cb2
196202
end
197203

198-
function apply_discrete_callback!(integrator::SDEIntegrator,discrete_modified::Bool,callback::DiscreteCallback)
199-
bool = apply_discrete_callback!(integrator,callback)
200-
discrete_modified || bool
204+
function apply_discrete_callback!(integrator::SDEIntegrator,discrete_modified::Bool,saved_in_cb::Bool,callback::DiscreteCallback)
205+
bool,saved_in_cb2 = apply_discrete_callback!(integrator,callback)
206+
discrete_modified || bool, saved_in_cb || saved_in_cb2
201207
end

src/integrators/integrator_utils.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
end
7373

7474

75-
@inline function savevalues!(integrator::SDEIntegrator)
75+
@inline function savevalues!(integrator::SDEIntegrator,force_save=false)
7676
while !isempty(integrator.opts.saveat) && integrator.tdir*top(integrator.opts.saveat) <= integrator.tdir*integrator.t # Perform saveat
7777
integrator.saveiter += 1
7878
curt = pop!(integrator.opts.saveat)
@@ -97,7 +97,7 @@ end
9797
end
9898
end
9999
end
100-
if integrator.opts.save_everystep && integrator.iter%integrator.opts.timeseries_steps==0
100+
if force_save || (integrator.opts.save_everystep && integrator.iter%integrator.opts.timeseries_steps==0)
101101
integrator.saveiter += 1
102102
if integrator.opts.save_idxs == nothing
103103
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
@@ -189,18 +189,17 @@ end
189189

190190
continuous_modified = false
191191
discrete_modified = false
192+
saved_in_cb = false
192193
if !(typeof(continuous_callbacks)<:Tuple{})
193194
time,upcrossing,idx,counter = find_first_continuous_callback(integrator,continuous_callbacks...)
194195
if time != zero(typeof(integrator.t)) && upcrossing != 0 # if not, then no events
195-
atleast_one_callback = true
196-
continuous_modified = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
196+
continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
197197
end
198198
end
199199
if !(typeof(discrete_callbacks)<:Tuple{})
200-
atleast_one_callback = true
201-
discrete_modified = apply_discrete_callback!(integrator,discrete_callbacks...)
200+
discrete_modified,saved_in_cb = apply_discrete_callback!(integrator,discrete_callbacks...)
202201
end
203-
if !atleast_one_callback
202+
if !saved_in_cb
204203
update_running_noise!(integrator)
205204
savevalues!(integrator)
206205
end

0 commit comments

Comments
 (0)