Skip to content

Commit 1ef7813

Browse files
LambaEulerHeun passes
1 parent 4c9570d commit 1ef7813

File tree

6 files changed

+235
-146
lines changed

6 files changed

+235
-146
lines changed

src/StochasticDiffEq.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ module StochasticDiffEq
4646
include("interp_func.jl")
4747
include("caches/cache_types.jl")
4848
include("caches/basic_method_caches.jl")
49+
include("caches/lamba_caches.jl")
4950
include("caches/iif_caches.jl")
5051
include("caches/sdirk_caches.jl")
5152
include("caches/sra_caches.jl")
@@ -62,6 +63,7 @@ module StochasticDiffEq
6263
include("solve.jl")
6364
include("initdt.jl")
6465
include("perform_step/low_order.jl")
66+
include("perform_step/lamba.jl")
6567
include("perform_step/iif.jl")
6668
include("perform_step/sri.jl")
6769
include("perform_step/sra.jl")

src/caches/basic_method_caches.jl

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -92,56 +92,6 @@ function alg_cache(alg::RandomEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prot
9292
RandomEMCache(u,uprev,tmp,rtmp)
9393
end
9494

95-
struct LambaEMConstantCache <: StochasticDiffEqConstantCache end
96-
struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
97-
u::uType
98-
uprev::uType
99-
du1::rateType
100-
du2::rateType
101-
K::rateType
102-
tmp::uType
103-
L::rateType
104-
gtmp::rateNoiseType
105-
end
106-
107-
u_cache(c::LambaEMCache) = ()
108-
du_cache(c::LambaEMCache) = (c.du1,c.du2,c.K,c.L)
109-
110-
alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEMConstantCache()
111-
112-
function alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
113-
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
114-
K = zeros(rate_prototype); tmp = similar(u);
115-
L = zeros(noise_rate_prototype)
116-
gtmp = zeros(noise_rate_prototype)
117-
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp)
118-
end
119-
120-
struct LambaEulerHeunConstantCache <: StochasticDiffEqConstantCache end
121-
struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
122-
u::uType
123-
uprev::uType
124-
du1::rateType
125-
du2::rateType
126-
K::rateType
127-
tmp::uType
128-
L::rateType
129-
gtmp::rateNoiseType
130-
end
131-
132-
u_cache(c::LambaEulerHeunCache) = ()
133-
du_cache(c::LambaEulerHeunCache) = (c.du1,c.du2,c.K,c.L)
134-
135-
alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEulerHeunConstantCache()
136-
137-
function alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
138-
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
139-
K = zeros(rate_prototype); tmp = similar(u);
140-
L = zeros(noise_rate_prototype)
141-
gtmp = zeros(noise_rate_prototype)
142-
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp)
143-
end
144-
14595
struct RKMilConstantCache <: StochasticDiffEqConstantCache end
14696
struct RKMilCache{uType,rateType} <: StochasticDiffEqMutableCache
14797
u::uType

src/caches/lamba_caches.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
struct LambaEMConstantCache <: StochasticDiffEqConstantCache end
2+
struct LambaEMCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
3+
u::uType
4+
uprev::uType
5+
du1::rateType
6+
du2::rateType
7+
K::rateType
8+
tmp::uType
9+
L::rateType
10+
gtmp::rateNoiseType
11+
end
12+
13+
u_cache(c::LambaEMCache) = ()
14+
du_cache(c::LambaEMCache) = (c.du1,c.du2,c.K,c.L)
15+
16+
alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEMConstantCache()
17+
18+
function alg_cache(alg::LambaEM,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
19+
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
20+
K = zeros(rate_prototype); tmp = similar(u);
21+
L = zeros(noise_rate_prototype)
22+
gtmp = zeros(noise_rate_prototype)
23+
LambaEMCache(u,uprev,du1,du2,K,tmp,L,gtmp)
24+
end
25+
26+
struct LambaEulerHeunConstantCache <: StochasticDiffEqConstantCache end
27+
struct LambaEulerHeunCache{uType,rateType,rateNoiseType} <: StochasticDiffEqMutableCache
28+
u::uType
29+
uprev::uType
30+
du1::rateType
31+
du2::rateType
32+
K::rateType
33+
tmp::uType
34+
L::rateType
35+
gtmp::rateNoiseType
36+
end
37+
38+
u_cache(c::LambaEulerHeunCache) = ()
39+
du_cache(c::LambaEulerHeunCache) = (c.du1,c.du2,c.K,c.L)
40+
41+
alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{false}}) = LambaEulerHeunConstantCache()
42+
43+
function alg_cache(alg::LambaEulerHeun,prob,u,ΔW,ΔZ,p,rate_prototype,noise_rate_prototype,uEltypeNoUnits,uBottomEltype,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
44+
du1 = zeros(rate_prototype); du2 = zeros(rate_prototype)
45+
K = zeros(rate_prototype); tmp = similar(u);
46+
L = zeros(noise_rate_prototype)
47+
gtmp = zeros(noise_rate_prototype)
48+
LambaEulerHeunCache(u,uprev,du1,du2,K,tmp,L,gtmp)
49+
end

src/perform_step/lamba.jl

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
@muladd function perform_step!(integrator,cache::LambaEMConstantCache,f=integrator.f)
2+
@unpack t,dt,uprev,u,W,p = integrator
3+
du1 = integrator.f(uprev,p,t)
4+
K = @muladd uprev + dt*du1
5+
L = integrator.g(uprev,p,t)
6+
mil_correction = zero(u)
7+
8+
u = K+L*W.dW
9+
10+
if integrator.opts.adaptive
11+
du2 = integrator.f(K,p,t+dt)
12+
Ed = dt*(du2 - du1)/2
13+
14+
utilde = K + L*integrator.sqdt
15+
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
16+
En = ggprime.*(W.dW.^2 .- dt)./2
17+
18+
integrator.EEst = integrator.opts.internalnorm((Ed + En)/((integrator.opts.abstol + max.(abs(uprev),abs(u))*integrator.opts.reltol)))
19+
end
20+
21+
integrator.u = u
22+
end
23+
24+
@muladd function perform_step!(integrator,cache::LambaEMCache,f=integrator.f)
25+
@unpack du1,du2,K,tmp,L,gtmp = cache
26+
@unpack t,dt,uprev,u,W,p = integrator
27+
integrator.f(du1,uprev,p,t)
28+
integrator.g(L,uprev,p,t)
29+
@. K = @muladd uprev + dt*du1
30+
31+
if is_diagonal_noise(integrator.sol.prob)
32+
@. tmp=L*W.dW
33+
else
34+
A_mul_B!(tmp,L,W.dW)
35+
end
36+
37+
@. u = K+tmp
38+
39+
if integrator.opts.adaptive
40+
41+
if !is_diagonal_noise(integrator.sol.prob)
42+
g_sized = norm(L,2)
43+
else
44+
g_sized = L
45+
end
46+
47+
@. tmp = @muladd K + L*integrator.sqdt
48+
49+
if !is_diagonal_noise(integrator.sol.prob)
50+
integrator.g(gtmp,tmp,p,t)
51+
g_sized2 = norm(gtmp,2)
52+
@. tmp = dW.^2 - dt
53+
diff_tmp = integrator.opts.internalnorm(tmp)
54+
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
55+
@. tmp = En
56+
else
57+
integrator.g(gtmp,tmp,p,t)
58+
@. tmp = (gtmp-L)/(2integrator.sqdt)*(W.dW.^2 - dt)
59+
end
60+
61+
# Ed
62+
integrator.f(du2,K,p,t+dt)
63+
@. tmp += integrator.opts.internalnorm(dt*(du2 - du1)/2)
64+
65+
66+
@tight_loop_macros for (i,atol,rtol) in zip(eachindex(u),Iterators.cycle(integrator.opts.abstol),Iterators.cycle(integrator.opts.reltol))
67+
@inbounds tmp[i] = (tmp[i])/(atol + max(abs(uprev[i]),abs(u[i]))*rtol)
68+
end
69+
integrator.EEst = integrator.opts.internalnorm(tmp)
70+
end
71+
end
72+
73+
@muladd function perform_step!(integrator,cache::LambaEulerHeunConstantCache,f=integrator.f)
74+
@unpack t,dt,uprev,u,W,p = integrator
75+
du1 = integrator.f(uprev,p,t)
76+
K = @muladd uprev + dt*du1
77+
L = integrator.g(uprev,p,t)
78+
79+
if is_diagonal_noise(integrator.sol.prob)
80+
noise = L.*W.dW
81+
else
82+
noise = L*W.dW
83+
end
84+
tmp = @muladd K+L*W.dW
85+
gtmp2 = (1/2).*(L.+integrator.g(tmp,p,t+dt))
86+
if is_diagonal_noise(integrator.sol.prob)
87+
noise2 = gtmp2.*W.dW
88+
else
89+
noise2 = gtmp2*W.dW
90+
end
91+
92+
u = @muladd uprev + (1/2)*dt*(du1+integrator.f(tmp,p,t+dt)) + noise2
93+
94+
if integrator.opts.adaptive
95+
du2 = integrator.f(K,p,t+dt)
96+
Ed = dt*(du2 - du1)/2
97+
98+
utilde = uprev + L*integrator.sqdt
99+
ggprime = (integrator.g(utilde,p,t).-L)./(integrator.sqdt)
100+
En = ggprime.*(W.dW.^2)./2
101+
102+
integrator.EEst = integrator.opts.internalnorm((Ed + En)/((integrator.opts.abstol + max.(abs(uprev),abs(u))*integrator.opts.reltol)))
103+
end
104+
105+
integrator.u = u
106+
end
107+
108+
@muladd function perform_step!(integrator,cache::LambaEulerHeunCache,f=integrator.f)
109+
@unpack du1,du2,K,tmp,L,gtmp = cache
110+
@unpack t,dt,uprev,u,W,p = integrator
111+
integrator.f(du1,uprev,p,t)
112+
integrator.g(L,uprev,p,t)
113+
@. K = @muladd uprev + dt*du1
114+
115+
if is_diagonal_noise(integrator.sol.prob)
116+
@. tmp=L*W.dW
117+
else
118+
A_mul_B!(tmp,L,W.dW)
119+
end
120+
121+
@. tmp = K+tmp
122+
123+
integrator.f(du2,tmp,p,t+dt)
124+
integrator.g(gtmp,tmp,p,t+dt)
125+
126+
if is_diagonal_noise(integrator.sol.prob)
127+
#@. nrtmp=(1/2)*W.dW*(gtmp1+gtmp2)
128+
@tight_loop_macros for i in eachindex(u)
129+
@inbounds dWo2 = (1/2)*W.dW[i]
130+
@inbounds tmp[i]=dWo2*(L[i]+gtmp[i])
131+
end
132+
else
133+
#@. gtmp1 = (1/2)*(gtmp1+gtmp2)
134+
@tight_loop_macros for i in eachindex(gtmp)
135+
@inbounds gtmp[i] = (1/2)*(L[i]+gtmp[i])
136+
end
137+
A_mul_B!(tmp,gtmp,W.dW)
138+
end
139+
140+
dto2 = dt*(1/2)
141+
#@. u = @muladd uprev + dto2*(ftmp1+ftmp2) + nrtmp
142+
@tight_loop_macros for i in eachindex(u)
143+
@inbounds u[i] = @muladd uprev[i] + dto2*(du1[i]+du2[i]) + tmp[i]
144+
end
145+
146+
if integrator.opts.adaptive
147+
148+
if !is_diagonal_noise(integrator.sol.prob)
149+
g_sized = norm(L,2)
150+
else
151+
g_sized = L
152+
end
153+
154+
@. tmp = @muladd uprev + L*integrator.sqdt
155+
156+
if !is_diagonal_noise(integrator.sol.prob)
157+
integrator.g(gtmp,tmp,p,t)
158+
g_sized2 = norm(gtmp,2)
159+
@. tmp = dW.^2
160+
diff_tmp = integrator.opts.internalnorm(tmp)
161+
En = (g_sized2-g_sized)/(2integrator.sqdt)*diff_tmp
162+
@. tmp = En
163+
else
164+
integrator.g(gtmp,tmp,p,t)
165+
@. tmp = (gtmp-L)/(2integrator.sqdt)*(W.dW.^2)
166+
end
167+
168+
# Ed
169+
integrator.f(du2,K,p,t+dt)
170+
@. tmp += integrator.opts.internalnorm(dt*(du2 - du1)/2)
171+
172+
173+
@tight_loop_macros for (i,atol,rtol) in zip(eachindex(u),Iterators.cycle(integrator.opts.abstol),Iterators.cycle(integrator.opts.reltol))
174+
@inbounds tmp[i] = (tmp[i])/(atol + max(abs(uprev[i]),abs(u[i]))*rtol)
175+
end
176+
integrator.EEst = integrator.opts.internalnorm(tmp)
177+
end
178+
end

0 commit comments

Comments
 (0)