Skip to content

Commit 13e8f07

Browse files
fix non-diagonal adaptivity
1 parent 384ceea commit 13e8f07

File tree

6 files changed

+56
-28
lines changed

6 files changed

+56
-28
lines changed

src/caches/implicit_split_step_caches.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
mutable struct ISSEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
1+
mutable struct ISSEMCache{uType,rateType,J,JC,UF,
2+
uEltypeNoUnits,noiseRateType,F,dWType} <:
3+
StochasticDiffEqMutableCache
24
u::uType
35
uprev::uType
46
du1::rateType
@@ -18,6 +20,7 @@ mutable struct ISSEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F}
1820
κ::uEltypeNoUnits
1921
tol::uEltypeNoUnits
2022
newton_iters::Int
23+
dW_cache::dWType
2124
end
2225

2326
u_cache(c::ISSEMCache) = (c.uprev2,c.z,c.dz)
@@ -52,12 +55,14 @@ function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
5255

5356
if is_diagonal_noise(prob)
5457
gtmp2 = gtmp
58+
dW_cache = nothing
5559
else
5660
gtmp2 = similar(rate_prototype)
61+
dW_cache = similar(ΔW)
5762
end
5863

5964
ISSEMCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,J,W,jac_config,linsolve,uf,
60-
ηold,κ,tol,10000)
65+
ηold,κ,tol,10000,dW_cache)
6166
end
6267

6368
mutable struct ISSEMConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
@@ -88,7 +93,9 @@ function alg_cache(alg::ISSEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototy
8893
ISSEMConstantCache(uf,ηold,κ,tol,100000)
8994
end
9095

91-
mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
96+
mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,
97+
noiseRateType,F,dWType} <:
98+
StochasticDiffEqMutableCache
9299
u::uType
93100
uprev::uType
94101
du1::rateType
@@ -109,6 +116,7 @@ mutable struct ISSEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRate
109116
κ::uEltypeNoUnits
110117
tol::uEltypeNoUnits
111118
newton_iters::Int
119+
dW_cache::dWType
112120
end
113121

114122
u_cache(c::ISSEulerHeunCache) = (c.uprev2,c.z,c.dz)
@@ -145,12 +153,14 @@ function alg_cache(alg::ISSEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_
145153

146154
if is_diagonal_noise(prob)
147155
gtmp3 = gtmp2
156+
dW_cache = nothing
148157
else
149158
gtmp3 = similar(noise_rate_prototype)
159+
dW_cache = similar(ΔW)
150160
end
151161

152162
ISSEulerHeunCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,gtmp3,
153-
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000)
163+
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000,dW_cache)
154164
end
155165

156166
mutable struct ISSEulerHeunConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache

src/caches/lamba_caches.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct LambaEMConstantCache <: StochasticDiffEqConstantCache end
2-
struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
2+
struct LambaEMCache{uType,rateType,rateNoiseType,dWType} <: StochasticDiffEqMutableCache
33
u::uType
44
uprev::uType
55
du1::rateType
@@ -8,6 +8,7 @@ struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCach
88
tmp::uType
99
L::rateType
1010
gtmp::rateNoiseType
11+
dW_cache::dWType
1112
end
1213

1314
u_cache(c::LambaEMCache) = ()
@@ -20,11 +21,16 @@ function alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_proto
2021
K = zeros(rate_prototype); tmp = similar(u);
2122
L = zeros(noise_rate_prototype)
2223
gtmp = zeros(noise_rate_prototype)
23-
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp)
24+
if is_diagonal_noise(prob)
25+
dW_cache = nothing
26+
else
27+
dW_cache = similar(ΔW)
28+
end
29+
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp,dW_cache)
2430
end
2531

2632
struct LambaEulerHeunConstantCache <: StochasticDiffEqConstantCache end
27-
struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
33+
struct LambaEulerHeunCache{uType,rateType,rateNoiseType,dWType} <: StochasticDiffEqMutableCache
2834
u::uType
2935
uprev::uType
3036
du1::rateType
@@ -33,6 +39,7 @@ struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMuta
3339
tmp::uType
3440
L::rateType
3541
gtmp::rateNoiseType
42+
dW_cache::dWType
3643
end
3744

3845
u_cache(c::LambaEulerHeunCache) = ()
@@ -45,5 +52,10 @@ function alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rat
4552
K = zeros(rate_prototype); tmp = similar(u);
4653
L = zeros(noise_rate_prototype)
4754
gtmp = zeros(noise_rate_prototype)
48-
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp)
55+
if is_diagonal_noise(prob)
56+
dW_cache = nothing
57+
else
58+
dW_cache = similar(ΔW)
59+
end
60+
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp,dW_cache)
4961
end

src/caches/sdirk_caches.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
1+
mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F,dWType} <: StochasticDiffEqMutableCache
22
u::uType
33
uprev::uType
44
du1::rateType
@@ -18,6 +18,7 @@ mutable struct ImplicitEMCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateTy
1818
κ::uEltypeNoUnits
1919
tol::uEltypeNoUnits
2020
newton_iters::Int
21+
dW_cache::dWType
2122
end
2223

2324
u_cache(c::ImplicitEMCache) = (c.uprev2,c.z,c.dz)
@@ -52,12 +53,14 @@ function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_pr
5253

5354
if is_diagonal_noise(prob)
5455
gtmp2 = gtmp
56+
dW_cache = nothing
5557
else
5658
gtmp2 = similar(rate_prototype)
59+
dW_cache = similar(ΔW)
5760
end
5861

5962
ImplicitEMCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,J,W,jac_config,linsolve,uf,
60-
ηold,κ,tol,10000)
63+
ηold,κ,tol,10000,dW_cache)
6164
end
6265

6366
mutable struct ImplicitEMConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache
@@ -88,7 +91,7 @@ function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_pr
8891
ImplicitEMConstantCache(uf,ηold,κ,tol,100000)
8992
end
9093

91-
mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F} <: StochasticDiffEqMutableCache
94+
mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,noiseRateType,F,dWType} <: StochasticDiffEqMutableCache
9295
u::uType
9396
uprev::uType
9497
du1::rateType
@@ -109,6 +112,7 @@ mutable struct ImplicitEulerHeunCache{uType,rateType,J,JC,UF,uEltypeNoUnits,nois
109112
κ::uEltypeNoUnits
110113
tol::uEltypeNoUnits
111114
newton_iters::Int
115+
dW_cache::dWType
112116
end
113117

114118
u_cache(c::ImplicitEulerHeunCache) = (c.uprev2,c.z,c.dz)
@@ -145,12 +149,14 @@ function alg_cache(alg::ImplicitEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_
145149

146150
if is_diagonal_noise(prob)
147151
gtmp3 = gtmp2
152+
dW_cache = nothing
148153
else
149154
gtmp3 = similar(noise_rate_prototype)
155+
dW_cache = similar(ΔW)
150156
end
151157

152158
ImplicitEulerHeunCache(u,uprev,du1,fsalfirst,k,z,dz,tmp,gtmp,gtmp2,gtmp3,
153-
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000)
159+
J,W,jac_config,linsolve,uf,ηold,κ,tol,10000,dW_cache)
154160
end
155161

156162
mutable struct ImplicitEulerHeunConstantCache{F,uEltypeNoUnits} <: StochasticDiffEqConstantCache

src/perform_step/implicit_split_step.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ end
124124
ISSEulerHeunCache},
125125
f=integrator.f)
126126
@unpack t,dt,uprev,u,p = integrator
127-
@unpack uf,du1,dz,z,k,J,W,jac_config,gtmp,gtmp2,tmp = cache
127+
@unpack uf,du1,dz,z,k,J,W,jac_config,gtmp,gtmp2,tmp,dW_cache = cache
128128
integrator.alg.symplectic ? a = dt/2 : a = dt
129129
dW = integrator.W.dW
130130
mass_matrix = integrator.sol.prob.mass_matrix
@@ -308,8 +308,8 @@ end
308308
if !is_diagonal_noise(integrator.sol.prob)
309309
integrator.g(gtmp,z,p,t)
310310
g_sized2 = norm(gtmp,2)
311-
@. dz = dW.^2 - dt
312-
diff_tmp = integrator.opts.internalnorm(dz)
311+
@. dW_cache = dW.^2 - dt
312+
diff_tmp = integrator.opts.internalnorm(dW_cache)
313313
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
314314
@. dz = En
315315
else
@@ -324,8 +324,8 @@ end
324324
if !is_diagonal_noise(integrator.sol.prob)
325325
integrator.g(gtmp,z,p,t)
326326
g_sized2 = norm(gtmp,2)
327-
@. dz = dW.^2
328-
diff_tmp = integrator.opts.internalnorm(dz)
327+
@. dW_cache = dW.^2
328+
diff_tmp = integrator.opts.internalnorm(dW_cache)
329329
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
330330
@. dz = En
331331
else

src/perform_step/lamba.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
end
2929

3030
@muladd function perform_step!(integrator,cache::LambaEMCache,f=integrator.f)
31-
@unpack du1,du2,K,tmp,L,gtmp = cache
31+
@unpack du1,du2,K,tmp,L,gtmp,dW_cache = cache
3232
@unpack t,dt,uprev,u,W,p = integrator
3333

3434
integrator.f(du1,uprev,p,t)
@@ -61,8 +61,8 @@ end
6161
if !is_diagonal_noise(integrator.sol.prob)
6262
integrator.g(gtmp,tmp,p,t)
6363
g_sized2 = norm(gtmp,2)
64-
@. tmp = dW.^2 - dt
65-
diff_tmp = integrator.opts.internalnorm(tmp)
64+
@. dW_cache = dW.^2 - dt
65+
diff_tmp = integrator.opts.internalnorm(dW_cache)
6666
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
6767
@. tmp = En
6868
else
@@ -118,7 +118,7 @@ end
118118
end
119119

120120
@muladd function perform_step!(integrator,cache::LambaEulerHeunCache,f=integrator.f)
121-
@unpack du1,du2,K,tmp,L,gtmp = cache
121+
@unpack du1,du2,K,tmp,L,gtmp,dW_cache = cache
122122
@unpack t,dt,uprev,u,W,p = integrator
123123
integrator.f(du1,uprev,p,t)
124124
integrator.g(L,uprev,p,t)
@@ -168,8 +168,8 @@ end
168168
if !is_diagonal_noise(integrator.sol.prob)
169169
integrator.g(gtmp,tmp,p,t)
170170
g_sized2 = norm(gtmp,2)
171-
@. tmp = dW.^2
172-
diff_tmp = integrator.opts.internalnorm(tmp)
171+
@. dW_cache = dW.^2
172+
diff_tmp = integrator.opts.internalnorm(dW_cache)
173173
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
174174
@. tmp = En
175175
else

src/perform_step/sdirk.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ end
318318
# k is Ed
319319
# dz is En
320320
if typeof(cache) <: Union{ImplicitEMCache,ImplicitEulerHeunCache}
321-
321+
dW_cache = cache.dW_cache
322322
if !is_diagonal_noise(integrator.sol.prob)
323323
g_sized = norm(gtmp,2)
324324
else
@@ -331,8 +331,8 @@ end
331331
if !is_diagonal_noise(integrator.sol.prob)
332332
integrator.g(gtmp,z,p,t)
333333
g_sized2 = norm(gtmp,2)
334-
@. dz = dW.^2 - dt
335-
diff_tmp = integrator.opts.internalnorm(dz)
334+
@. dW_cache = dW.^2 - dt
335+
diff_tmp = integrator.opts.internalnorm(dW_cache)
336336
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
337337
@. dz = En
338338
else
@@ -347,8 +347,8 @@ end
347347
if !is_diagonal_noise(integrator.sol.prob)
348348
integrator.g(gtmp,z,p,t)
349349
g_sized2 = norm(gtmp,2)
350-
@. dz = dW.^2
351-
diff_tmp = integrator.opts.internalnorm(dz)
350+
@. dW_cache = dW.^2
351+
diff_tmp = integrator.opts.internalnorm(dW_cache)
352352
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
353353
@. dz = En
354354
else

0 commit comments

Comments
 (0)