Skip to content

Commit 745b3ae

Browse files
full scalar noise
1 parent 34b7f37 commit 745b3ae

File tree

4 files changed

+47
-9
lines changed

4 files changed

+47
-9
lines changed

src/caches/rossler_caches.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ du_cache(c::SRA1Cache) = (c.chi2,c.E₁,c.E₂,c.gt,c.k₁,c.k₂,c.gpdt)
2020
user_cache(c::SRA1Cache) = (c.u,c.uprev,c.tmp,c.tmp1)
2121

2222
function alg_cache(alg::SRA1,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prototype,uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{true}})
23-
chi2 = similar(ΔW)
23+
if typeof(ΔW) <: Union{SArray,Number}
24+
chi2 = copy(ΔW)
25+
else
26+
chi2 = similar(ΔW)
27+
end
2428
tmp1 = zeros(u)
2529
E₁ = zeros(rate_prototype); gt = zeros(noise_rate_prototype); gpdt = zeros(noise_rate_prototype)
2630
E₂ = zeros(rate_prototype); k₁ = zeros(rate_prototype); k₂ = zeros(rate_prototype)
@@ -81,7 +85,12 @@ function alg_cache(alg::SRA,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prototype,u
8185
push!(H0,zeros(u))
8286
end
8387
A0temp = zeros(rate_prototype); B0temp = zeros(rate_prototype)
84-
ftmp = zeros(rate_prototype); gtmp = zeros(noise_rate_prototype); chi2 = similar(ΔW)
88+
ftmp = zeros(rate_prototype); gtmp = zeros(noise_rate_prototype);
89+
if typeof(ΔW) <: Union{SArray,Number}
90+
chi2 = copy(ΔW)
91+
else
92+
chi2 = similar(ΔW)
93+
end
8594
atemp = zeros(rate_prototype); btemp = zeros(rate_prototype); E₂ = zeros(rate_prototype); E₁temp = zeros(rate_prototype)
8695
E₁ = zeros(rate_prototype)
8796
tmp = zeros(u)

src/integrators/sra.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ end
6868
integrator.g(t,uprev,gt)
6969
integrator.g(t+dt,uprev,gpdt)
7070
integrator.f(t,uprev,k₁); k₁*=dt
71-
@. chi2 = (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
71+
if typeof(W.dW) <: Union{SArray,Number}
72+
chi2 = @. (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
73+
else
74+
@. chi2 = (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
75+
end
7276

7377
if is_diagonal_noise(integrator.sol.prob)
7478
@. E₁ = chi2*gpdt
@@ -155,7 +159,13 @@ end
155159
@unpack t,dt,uprev,u,W = integrator
156160
@unpack H0,A0temp,B0temp,ftmp,gtmp,chi2,atemp,btemp,E₁,E₁temp,E₂,tmp = cache
157161
@unpack c₀,c₁,A₀,B₀,α,β₁,β₂,stages = cache.tab
158-
@. chi2 = (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
162+
163+
if typeof(W.dW) <: Union{SArray,Number}
164+
chi2 = @. (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
165+
else
166+
@. chi2 = (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
167+
end
168+
159169
for i in 1:stages
160170
fill!(H0[i],zero(eltype(integrator.u)))
161171
end

src/integrators/sri.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,17 @@ end
5656
@unpack c₀,c₁,A₀,A₁,B₀,B₁,α,β₁,β₂,β₃,β₄,stages,error_terms = cache.tab
5757
@unpack H0,H1,A0temp,A1temp,B0temp,B1temp,A0temp2,A1temp2,B0temp2,B1temp2,atemp,btemp,E₁,E₂,E₁temp,ftemp,gtemp,chi1,chi2,chi3,tmp = cache
5858
@unpack t,dt,uprev,u,W = integrator
59-
@tight_loop_macros for i in eachindex(u)
60-
@inbounds chi1[i] = .5*(W.dW[i].^2 - dt)/integrator.sqdt #I_(1,1)/sqrt(h)
61-
@inbounds chi2[i] = .5*(W.dW[i] + W.dZ[i]/sqrt(3)) #I_(1,0)/h
62-
@inbounds chi3[i] = 1/6 * (W.dW[i].^3 - 3*W.dW[i]*dt)/dt #I_(1,1,1)/h
59+
60+
if typeof(W.dW) <: Union{SArray,Number}
61+
chi1 = @. (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
62+
chi2 = @. (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
63+
chi3 = @. (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
64+
else
65+
@. chi1 = (W.dW.^2 - dt)/2integrator.sqdt #I_(1,1)/sqrt(h)
66+
@. chi2 = (W.dW + W.dZ/sqrt(3))/2 #I_(1,0)/h
67+
@. chi3 = (W.dW.^3 - 3W.dW*dt)/6dt #I_(1,1,1)/h
6368
end
69+
6470
for i=1:stages
6571
fill!(H0[i],zero(eltype(integrator.u)))
6672
fill!(H1[i],zero(eltype(integrator.u)))

test/scalar_noise.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,21 @@ using StochasticDiffEq, DiffEqNoiseProcess
33
f(t,u,du) = (du .= u)
44
g(t,u,du) = (du .= u)
55
u0 = rand(4,2)
6+
67
W = WienerProcess(0.0,0.0,0.0)
78
prob = SDEProblem(f,g,u0,(0.0,1.0),noise=W)
89
sol = solve(prob,SRIW1())
910

10-
similar(0.0)
11+
W = WienerProcess(0.0,0.0,0.0)
12+
prob = SDEProblem(f,g,u0,(0.0,1.0),noise=W)
13+
sol = solve(prob,SRI())
14+
15+
g(t,u,du) = (du .= 1)
16+
17+
W = WienerProcess(0.0,0.0,0.0)
18+
prob = SDEProblem(f,g,u0,(0.0,1.0),noise=W)
19+
sol = solve(prob,SRA1())
20+
21+
W = WienerProcess(0.0,0.0,0.0)
22+
prob = SDEProblem(f,g,u0,(0.0,1.0),noise=W)
23+
sol = solve(prob,SRA())

0 commit comments

Comments
 (0)