Skip to content

Commit f6b3cc9

Browse files
caches work with multiscale resize
1 parent 25fe723 commit f6b3cc9

File tree

2 files changed

+38
-33
lines changed

2 files changed

+38
-33
lines changed

src/caches.jl

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -230,59 +230,58 @@ end
230230
immutable SRIW1ConstantCache <: StochasticDiffEqConstantCache end
231231
alg_cache(alg::SRIW1,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = SRIW1ConstantCache()
232232

233-
immutable SRIW1Cache{randType,uType} <: StochasticDiffEqMutableCache
233+
immutable SRIW1Cache{randType,uType,rateType} <: StochasticDiffEqMutableCache
234234
u::uType
235235
uprev::uType
236236
chi1::randType
237237
chi2::randType
238238
chi3::randType
239-
fH01o4::uType
240-
g₁o2::uType
239+
fH01o4::rateType
240+
g₁o2::rateType
241241
H0::uType
242242
H11::uType
243243
H12::uType
244244
H13::uType
245-
g₂o3::uType
246-
Fg₂o3::uType
247-
g₃o3::uType
248-
Tg₃o3::uType
249-
mg₁::uType
250-
E₁::uType
251-
E₂::uType
252-
fH01::uType
253-
fH02::uType
254-
g₁::uType
255-
g₂::uType
256-
g₃::uType
257-
g₄::uType
245+
g₂o3::rateType
246+
Fg₂o3::rateType
247+
g₃o3::rateType
248+
Tg₃o3::rateType
249+
mg₁::rateType
250+
E₁::rateType
251+
E₂::rateType
252+
fH01::rateType
253+
fH02::rateType
254+
g₁::rateType
255+
g₂::rateType
256+
g₃::rateType
257+
g₄::rateType
258258
tmp::uType
259259
end
260260

261261
u_cache(c::SRIW1Cache) = ()
262-
du_cache(c::SRIW1Cache) = (c.chi1,c.chi2,c.chi3,c.fH01o4,c.g₁o2,c.H0,c.H11,
263-
c.H12,c.H13,c.g₂o3,c.Fg₂o3,c.g₃o3,c.Tg₃o3,c.mg₁,
262+
du_cache(c::SRIW1Cache) = (c.chi1,c.chi2,c.chi3,c.fH01o4,c.g₁o2,c.g₂o3,c.Fg₂o3,c.g₃o3,c.Tg₃o3,c.mg₁,
264263
c.E₁,c.E₂,c.fH01,c.fH02,c.g₁,c.g₂,c.g₃,c.g₄)
265-
264+
user_cache(c::SRIW1Cache) = (c.u,c.uprev,c.tmp,c.H0,c.H11,c.H12,c.H13)
266265

267266
function alg_cache(alg::SRIW1,u,ΔW,ΔZ,rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
268267
chi1 = similar(ΔW)
269268
chi2 = similar(ΔW)
270269
chi3 = similar(ΔW)
271-
fH01o4 = zeros(uprev)
272-
g₁o2 = zeros(u)
270+
fH01o4 = zeros(rate_prototype)
271+
g₁o2 = zeros(rate_prototype)
273272
H0 = zeros(u)
274273
H11 = zeros(u)
275274
H12 = zeros(u)
276275
H13 = zeros(u)
277-
g₂o3 = zeros(u)
278-
Fg₂o3 = zeros(u)
279-
g₃o3 = zeros(u)
280-
Tg₃o3 = zeros(u)
281-
mg₁ = zeros(u)
282-
E₁ = zeros(u)
283-
E₂ = zeros(u)
284-
fH01 = zeros(u); fH02 = zeros(u)
285-
g₁ = zeros(u); g₂ = zeros(u); g₃ = zeros(u); g₄ = zeros(u)
276+
g₂o3 = zeros(rate_prototype)
277+
Fg₂o3 = zeros(rate_prototype)
278+
g₃o3 = zeros(rate_prototype)
279+
Tg₃o3 = zeros(rate_prototype)
280+
mg₁ = zeros(rate_prototype)
281+
E₁ = zeros(rate_prototype)
282+
E₂ = zeros(rate_prototype)
283+
fH01 = zeros(rate_prototype); fH02 = zeros(rate_prototype)
284+
g₁ = zeros(rate_prototype); g₂ = zeros(rate_prototype); g₃ = zeros(rate_prototype); g₄ = zeros(rate_prototype)
286285
tmp = zeros(u)
287286
SRIW1Cache(u,uprev,chi1,chi2,chi3,fH01o4,g₁o2,H0,H11,H12,H13,g₂o3,Fg₂o3,g₃o3,Tg₃o3,mg₁,E₁,E₂,fH01,fH02,g₁,g₂,g₃,g₄,tmp)
288287
end

src/integrators/integrator_interface.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ end
2626

2727
(integrator::SDEIntegrator)(val::AbstractArray,t::Union{Number,AbstractArray},deriv::Type=Val{0};idxs=eachindex(integrator.uprev)) = current_interpolant!(val,t,integrator,idxs,deriv)
2828

29-
user_cache(integrator::SDEIntegrator) = (integrator.cache.u,integrator.cache.uprev,integrator.cache.tmp)
29+
30+
user_cache(integrator::SDEIntegrator) = user_cache(integrator.cache)
3031
u_cache(integrator::SDEIntegrator) = u_cache(integrator.cache)
3132
du_cache(integrator::SDEIntegrator)= du_cache(integrator.cache)
33+
user_cache(c::StochasticDiffEqCache) = (c.u,c.uprev,c.tmp)
3234
full_cache(integrator::SDEIntegrator) = chain(user_cache(integrator),u_cache(integrator),du_cache(integrator.cache))
3335
default_non_user_cache(integrator::SDEIntegrator) = chain(u_cache(integrator),du_cache(integrator.cache))
3436
@inline add_tstop!(integrator::SDEIntegrator,t) = push!(integrator.opts.tstops,t)
@@ -37,10 +39,13 @@ resize_non_user_cache!(integrator::SDEIntegrator,i::Int) = resize_non_user_cache
3739
resize!(integrator::SDEIntegrator,i::Int) = resize!(integrator,integrator.cache,i)
3840

3941
function resize!(integrator::SDEIntegrator,cache,i)
40-
prev_len = length(integrator.u)
42+
resize_non_user_cache!(integrator,cache,i)
4143
for c in user_cache(integrator)
4244
resize!(c,i)
4345
end
46+
end
47+
48+
function resize_noise!(integrator,cache,prev_len,i)
4449
for c in integrator.S₁
4550
resize!(c[2],i)
4651
resize!(c[3],i)
@@ -67,10 +72,11 @@ function resize!(integrator::SDEIntegrator,cache,i)
6772
fill!(@view(integrator.W[prev_len:i]),zero(eltype(integrator.u)))
6873
fill!(@view(integrator.Z[prev_len:i]),zero(eltype(integrator.u)))
6974
end
70-
resize_non_user_cache!(integrator,cache,i)
7175
end
7276

7377
function resize_non_user_cache!(integrator::SDEIntegrator,cache,i)
78+
prev_len = length(integrator.u)
79+
resize_noise!(integrator,cache,prev_len,i)
7480
for c in default_non_user_cache(integrator)
7581
resize!(c,i)
7682
end

0 commit comments

Comments
 (0)