Skip to content

Commit a6c5fc3

Browse files
addat deleteat
1 parent f6b3cc9 commit a6c5fc3

File tree

3 files changed

+113
-35
lines changed

3 files changed

+113
-35
lines changed

src/StochasticDiffEq.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ module StochasticDiffEq
1111

1212
import DiffEqBase: solve, solve!, init, step!, build_solution
1313

14-
import DiffEqBase: resize!,deleteat!,full_cache,user_cache, u_cache,du_cache,
15-
resize_non_user_cache!,
14+
import DiffEqBase: resize!,deleteat!,addat!,full_cache,user_cache, u_cache,du_cache,
15+
resize_non_user_cache!,deleteat_non_user_cache!,addat_non_user_cache!,
1616
terminate!,get_du, get_dt,get_proposed_dt,modify_proposed_dt!,
1717
u_modified!,savevalues!,add_tstop!,add_saveat!,set_reltol!,
1818
set_abstol!

src/integrators/integrator_interface.jl

Lines changed: 111 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ default_non_user_cache(integrator::SDEIntegrator) = chain(u_cache(integrator),du
3636
@inline add_tstop!(integrator::SDEIntegrator,t) = push!(integrator.opts.tstops,t)
3737

3838
resize_non_user_cache!(integrator::SDEIntegrator,i::Int) = resize_non_user_cache!(integrator,integrator.cache,i)
39+
deleteat_non_user_cache!(integrator::SDEIntegrator,i) = deleteat_non_user_cache!(integrator,integrator.cache,i)
40+
addat_non_user_cache!(integrator::SDEIntegrator,i) = addat_non_user_cache!(integrator,integrator.cache,i)
3941
resize!(integrator::SDEIntegrator,i::Int) = resize!(integrator,integrator.cache,i)
4042

4143
function resize!(integrator::SDEIntegrator,cache,i)
@@ -45,19 +47,19 @@ function resize!(integrator::SDEIntegrator,cache,i)
4547
end
4648
end
4749

48-
function resize_noise!(integrator,cache,prev_len,i)
50+
function resize_noise!(integrator,cache,bot_idx,i)
4951
for c in integrator.S₁
5052
resize!(c[2],i)
5153
resize!(c[3],i)
52-
if i > prev_len # fill in rands
53-
resize_noise_caches!(integrator,c,c[1],prev_len:i)
54+
if i > bot_idx # fill in rands
55+
fill_new_noise_caches!(integrator,c,sqrt(c[1]),bot_idx:i)
5456
end
5557
end
5658
for c in integrator.S₂
5759
resize!(c[2],i)
5860
resize!(c[3],i)
59-
if i > prev_len # fill in rands
60-
resize_noise_caches!(integrator,c,c[1],prev_len:i)
61+
if i > bot_idx # fill in rands
62+
fill_new_noise_caches!(integrator,c,sqrt(c[1]),bot_idx:i)
6163
end
6264
end
6365
resize!(integrator.ΔW,i)
@@ -68,26 +70,122 @@ function resize_noise!(integrator,cache,prev_len,i)
6870
resize!(integrator.ΔZtmp,i)
6971
resize!(integrator.W,i)
7072
resize!(integrator.Z,i)
71-
if i > prev_len # fill in rands
72-
fill!(@view(integrator.W[prev_len:i]),zero(eltype(integrator.u)))
73-
fill!(@view(integrator.Z[prev_len:i]),zero(eltype(integrator.u)))
73+
if i > bot_idx # fill in rands
74+
fill!(@view(integrator.W[bot_idx:i]),zero(eltype(integrator.u)))
75+
fill!(@view(integrator.Z[bot_idx:i]),zero(eltype(integrator.u)))
76+
end
77+
end
78+
79+
@inline function fill_new_noise_caches!(integrator,c,scaling_factor,idxs)
80+
if isinplace(integrator.noise)
81+
integrator.noise(@view c[2][idxs])
82+
for i in idxs
83+
c[2][i] *= scaling_factor
84+
end
85+
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
86+
integrator.noise(@view c[3][idxs])
87+
for i in idxs
88+
c[3][i] .*= scaling_factor
89+
end
90+
end
91+
else
92+
c[2][idxs] .= scaling_factor.*integrator.noise(length(idxs))
93+
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
94+
c[3][idxs] .= scaling_factor.*integrator.noise(length(idxs))
95+
end
7496
end
7597
end
7698

7799
function resize_non_user_cache!(integrator::SDEIntegrator,cache,i)
78-
prev_len = length(integrator.u)
79-
resize_noise!(integrator,cache,prev_len,i)
100+
bot_idx = length(integrator.u)
101+
resize_noise!(integrator,cache,bot_idx,i)
80102
for c in default_non_user_cache(integrator)
81103
resize!(c,i)
82104
end
83105
end
84106

85-
function deleteat!(integrator::SDEIntegrator,i::Int)
86-
for c in full_cache(integrator)
87-
deleteat!(c,i)
107+
function deleteat!(integrator::SDEIntegrator,idxs)
108+
deleteat_non_user_cache!(integrator,cache,idxs)
109+
for c in user_cache(integrator)
110+
deleteat!(c,idxs)
111+
end
112+
end
113+
114+
function addat!(integrator::SDEIntegrator,idxs)
115+
addat_non_user_cache!(integrator,cache,idxs)
116+
for c in user_cache(integrator)
117+
addat!(c,idxs)
118+
end
119+
end
120+
121+
function deleteat_non_user_cache!(integrator::SDEIntegrator,cache,idxs)
122+
deleteat_noise!(integrator,cache,idxs)
123+
i = length(integrator.u)
124+
# Ordering doesn't matter in these caches
125+
# So just resize
126+
for c in default_non_user_cache(integrator)
127+
resize!(c,i)
128+
end
129+
end
130+
131+
function addat_non_user_cache!(integrator::SDEIntegrator,cache,idxs)
132+
addat_noise!(integrator,cache,idxs)
133+
i = length(integrator.u)
134+
# Ordering doesn't matter in these caches
135+
# So just resize
136+
for c in default_non_user_cache(integrator)
137+
resize!(c,i)
88138
end
89139
end
90140

141+
function deleteat_noise!(integrator,cache,idxs)
142+
for c in integrator.S₁
143+
deleteat!(c[2],idxs)
144+
deleteat!(c[3],idxs)
145+
end
146+
for c in integrator.S₂
147+
deleteat!(c[2],idxs)
148+
deleteat!(c[3],idxs)
149+
end
150+
deleteat!(integrator.ΔW,idxs)
151+
deleteat!(integrator.ΔZ,idxs)
152+
deleteat!(integrator.ΔWtilde,idxs)
153+
deleteat!(integrator.ΔZtilde,idxs)
154+
deleteat!(integrator.ΔWtmp,idxs)
155+
deleteat!(integrator.ΔZtmp,idxs)
156+
deleteat!(integrator.W,idxs)
157+
deleteat!(integrator.Z,idxs)
158+
end
159+
160+
function addat_noise!(integrator,cache,idxs)
161+
for c in integrator.S₁
162+
addat!(c[2],idxs)
163+
addat!(c[3],idxs)
164+
fill_new_noise_caches!(integrator,c,sqrt(c[1]),idxs)
165+
end
166+
for c in integrator.S₂
167+
addat!(c[2],idxs)
168+
addat!(c[3],idxs)
169+
fill_new_noise_caches!(integrator,c,sqrt(c[1]),idxs)
170+
end
171+
172+
addat!(integrator.ΔW,idxs)
173+
addat!(integrator.ΔZ,idxs)
174+
addat!(integrator.W,idxs)
175+
addat!(integrator.Z,idxs)
176+
177+
i = length(integrator.u)
178+
resize!(integrator.ΔWtilde,i)
179+
resize!(integrator.ΔZtilde,i)
180+
resize!(integrator.ΔWtmp,i)
181+
resize!(integrator.ΔZtmp,i)
182+
183+
# fill in rands
184+
fill!(@view(integrator.W[idxs]),zero(eltype(integrator.u)))
185+
fill!(@view(integrator.Z[idxs]),zero(eltype(integrator.u)))
186+
end
187+
188+
91189
function terminate!(integrator::SDEIntegrator)
92190
integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)()
93191
end

src/integrators/integrator_utils.jl

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -434,26 +434,6 @@ end
434434
end
435435
end
436436

437-
@inline function resize_noise_caches!(integrator,c,scaling_factor,idxs)
438-
if isinplace(integrator.noise)
439-
integrator.noise(@view c[2][idxs])
440-
for i in idxs
441-
c[2][i] *= scaling_factor
442-
end
443-
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
444-
integrator.noise(@view c[3][idxs])
445-
for i in idxs
446-
c[3][i] .*= scaling_factor
447-
end
448-
end
449-
else
450-
c[2][idxs] .= scaling_factor.*integrator.noise(length(idxs))
451-
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
452-
c[3][idxs] .= scaling_factor.*integrator.noise(length(idxs))
453-
end
454-
end
455-
end
456-
457437
@inline function generate_tildes(integrator,add1,add2,scaling)
458438
if isinplace(integrator.noise)
459439
integrator.noise(integrator.ΔWtilde)

0 commit comments

Comments
 (0)