Skip to content

Commit 9faa8be

Browse files
NoiseProcess API change
1 parent d63a805 commit 9faa8be

File tree

4 files changed

+48
-41
lines changed

4 files changed

+48
-41
lines changed

src/integrators/integrator_interface.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,20 +78,20 @@ end
7878

7979
@inline function fill_new_noise_caches!(integrator,c,scaling_factor,idxs)
8080
if isinplace(integrator.noise)
81-
integrator.noise(@view c[2][idxs])
81+
integrator.noise(@view(c[2][idxs]),integrator)
8282
for i in idxs
8383
c[2][i] *= scaling_factor
8484
end
8585
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
86-
integrator.noise(@view c[3][idxs])
86+
integrator.noise(@view(c[3][idxs]),integrator)
8787
for i in idxs
8888
c[3][i] .*= scaling_factor
8989
end
9090
end
9191
else
92-
c[2][idxs] .= scaling_factor.*integrator.noise(length(idxs))
92+
c[2][idxs] .= scaling_factor.*integrator.noise(length(idxs),integrator)
9393
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
94-
c[3][idxs] .= scaling_factor.*integrator.noise(length(idxs))
94+
c[3][idxs] .= scaling_factor.*integrator.noise(length(idxs),integrator)
9595
end
9696
end
9797
end

src/integrators/integrator_utils.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -445,34 +445,34 @@ end
445445

446446
@inline function update_noise!(integrator,scaling_factor=integrator.sqdt)
447447
if isinplace(integrator.noise)
448-
integrator.noise(integrator.ΔW)
448+
integrator.noise(integrator.ΔW,integrator)
449449
for i in eachindex(integrator.ΔW)
450450
integrator.ΔW[i] *= scaling_factor
451451
end
452452
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
453-
integrator.noise(integrator.ΔZ)
453+
integrator.noise(integrator.ΔZ,integrator)
454454
for i in eachindex(integrator.ΔW)
455455
integrator.ΔZ[i] .*= scaling_factor
456456
end
457457
end
458458
else
459459
if (typeof(integrator.u) <: AbstractArray)
460-
integrator.ΔW .= scaling_factor.*integrator.noise(size(integrator.u))
460+
integrator.ΔW .= scaling_factor.*integrator.noise(size(integrator.u),integrator)
461461
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
462-
integrator.ΔZ .= scaling_factor.*integrator.noise(size(integrator.u))
462+
integrator.ΔZ .= scaling_factor.*integrator.noise(size(integrator.u),integrator)
463463
end
464464
else
465-
integrator.ΔW = scaling_factor*integrator.noise()
465+
integrator.ΔW = scaling_factor*integrator.noise(integrator)
466466
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
467-
integrator.ΔZ = scaling_factor*integrator.noise()
467+
integrator.ΔZ = scaling_factor*integrator.noise(integrator)
468468
end
469469
end
470470
end
471471
end
472472

473473
@inline function generate_tildes(integrator,add1,add2,scaling)
474474
if isinplace(integrator.noise)
475-
integrator.noise(integrator.ΔWtilde)
475+
integrator.noise(integrator.ΔWtilde,integrator)
476476
if add1 != 0
477477
for i in eachindex(integrator.ΔW)
478478
integrator.ΔWtilde[i] = add1[i] + scaling*integrator.ΔWtilde[i]
@@ -483,7 +483,7 @@ end
483483
end
484484
end
485485
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
486-
integrator.noise(integrator.ΔZtilde)
486+
integrator.noise(integrator.ΔZtilde,integrator)
487487
if add2 != 0
488488
for i in eachindex(integrator.ΔW)
489489
integrator.ΔZtilde[i] = add2[i] + scaling*integrator.ΔZtilde[i]
@@ -497,21 +497,21 @@ end
497497
else
498498
if (typeof(integrator.u) <: AbstractArray)
499499
if add1 != 0
500-
integrator.ΔWtilde = add1 .+ scaling.*integrator.noise(size(integrator.u))
500+
integrator.ΔWtilde = add1 .+ scaling.*integrator.noise(size(integrator.u),integrator)
501501
else
502-
integrator.ΔWtilde = scaling.*integrator.noise(size(integrator.u))
502+
integrator.ΔWtilde = scaling.*integrator.noise(size(integrator.u),integrator)
503503
end
504504
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
505505
if add2 != 0
506-
integrator.ΔZtilde = add2 .+ scaling.*integrator.noise(size(integrator.u))
506+
integrator.ΔZtilde = add2 .+ scaling.*integrator.noise(size(integrator.u),integrator)
507507
else
508-
integrator.ΔZtilde = scaling.*integrator.noise(size(integrator.u))
508+
integrator.ΔZtilde = scaling.*integrator.noise(size(integrator.u),integrator)
509509
end
510510
end
511511
else
512-
integrator.ΔWtilde = add1 + scaling*integrator.noise()
512+
integrator.ΔWtilde = add1 + scaling*integrator.noise(integrator)
513513
if !(typeof(integrator.alg) <: EM) || !(typeof(integrator.alg) <: RKMil)
514-
integrator.ΔZtilde = add2 + scaling*integrator.noise()
514+
integrator.ΔZtilde = add2 + scaling*integrator.noise(integrator)
515515
end
516516
end
517517
end

src/solve.jl

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -254,30 +254,15 @@ function init{uType,tType,isinplace,NoiseClass,algType<:Union{AbstractRODEAlgori
254254

255255
Ws = Vector{randType}(0)
256256
if !(uType <: AbstractArray)
257-
W = 0.0
258-
Z = 0.0
259-
ΔW = sqdt*noise()
260-
ΔZ = sqdt*noise()
261-
if save_noise
262-
push!(Ws,W)
263-
end
257+
W = zero(randType)
258+
Z = zero(randType)
259+
ΔW= zero(randType)
260+
ΔZ= zero(randType)
264261
else
265262
W = zeros(rand_prototype)
266263
Z = zeros(rand_prototype)
267-
if DiffEqBase.isinplace(prob.noise)
268-
ΔW = similar(rand_prototype)
269-
ΔZ = similar(rand_prototype)
270-
noise(ΔW)
271-
noise(ΔZ)
272-
ΔW .*= sqdt
273-
ΔZ .*= sqdt
274-
else
275-
ΔW = sqdt.*noise(size(rand_prototype))
276-
ΔZ = sqdt.*noise(size(rand_prototype))
277-
end
278-
if save_noise
279-
push!(Ws,copy(W))
280-
end
264+
ΔW = similar(rand_prototype)
265+
ΔZ = similar(rand_prototype)
281266
end
282267

283268
S₁ = DataStructures.Stack{}(Tuple{typeof(t),typeof(W),typeof(Z)})
@@ -324,6 +309,28 @@ function init{uType,tType,isinplace,NoiseClass,algType<:Union{AbstractRODEAlgori
324309
cache,sqdt,W,Z,ΔW,ΔZ,copy(ΔW),copy(ΔZ),copy(ΔW),copy(ΔZ),
325310
opts,iter,prog,S₁,S₂,EEst,q,
326311
tTypeNoUnits(qoldinit),q11)
312+
313+
if !(uType <: AbstractArray)
314+
ΔW = sqdt*noise(integrator)
315+
ΔZ = sqdt*noise(integrator)
316+
if save_noise
317+
push!(Ws,W)
318+
end
319+
else
320+
if DiffEqBase.isinplace(prob.noise)
321+
noise(ΔW,integrator)
322+
noise(ΔZ,integrator)
323+
ΔW .*= sqdt
324+
ΔZ .*= sqdt
325+
else
326+
ΔW = sqdt.*noise(size(rand_prototype),integrator)
327+
ΔZ = sqdt.*noise(size(rand_prototype),integrator)
328+
end
329+
if save_noise
330+
push!(Ws,copy(W))
331+
end
332+
end
333+
327334
if initialize_integrator
328335
initialize!(integrator,integrator.cache)
329336
initialize!(callbacks_internal,t,u,integrator)

test/sde/sde_convergence_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ dts = 1./2.^(10:-1:2) #14->7 good plot
55
prob = prob_sde_wave
66
sim = test_convergence(dts,prob,EM(),numMonte=Int(1e1))
77
@test abs(sim.𝒪est[:l2]-.5) < 0.1
8-
sim2 = test_convergence(dts,prob,RKMil(),numMonte=Int(1e2))
9-
@test abs(sim2.𝒪est[:l∞]-1) < 0.1
8+
sim2 = test_convergence(dts,prob,RKMil(),numMonte=Int(2e2))
9+
@test abs(sim2.𝒪est[:l∞]-1) < 0.2
1010
sim3 = test_convergence(dts,prob,SRI(),numMonte=Int(1e1))
1111
@test abs(sim3.𝒪est[:final]-1.5) < 0.3
1212
sim4 = test_convergence(dts,prob,SRIW1(),numMonte=Int(1e1))

0 commit comments

Comments
 (0)