Skip to content

Commit b284df8

Browse files
Merge pull request #509 from rmsrosa/use-copy-for-noise
Use copy for noise
2 parents 119f70d + ac8a903 commit b284df8

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

src/solve.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem,
66
integrator = DiffEqBase.__init(prob,alg,timeseries,ts,recompile;kwargs...)
77
solve!(integrator)
88
if typeof(prob) <: DiffEqBase.AbstractRODEProblem && typeof(prob.noise) == typeof(integrator.sol.W) && (!haskey(kwargs, :alias_noise) || kwargs[:alias_noise] === true)
9-
# would be better to make the following a function `noise_deepcopy!(W::T, Z::T) where {T <: AbstractNoiseProcess}` in `DiffEqNoiseProcess.jl` or a proper `copy` overload, but this should do it for the moment
10-
for x in fieldnames(typeof(prob.noise))
11-
setfield!(prob.noise, x, deepcopy(getfield(integrator.sol.W, x)))
12-
end
9+
copy!(prob.noise, integrator.sol.W)
1310
end
1411
integrator.sol
1512
end
@@ -416,7 +413,7 @@ function DiffEqBase.__init(
416413
=#
417414
end
418415
elseif typeof(prob) <: DiffEqBase.AbstractRODEProblem
419-
W = (!haskey(kwargs, :alias_noise) || kwargs[:alias_noise] === true) ? deepcopy(prob.noise) : prob.noise
416+
W = (!haskey(kwargs, :alias_noise) || kwargs[:alias_noise] === true) ? copy(prob.noise) : prob.noise
420417
if W.reset
421418
# Reseed
422419
if typeof(W) <: Union{NoiseProcess, NoiseTransport} && W.reseed

test/noise_type_test.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,19 @@ sol = solve(prob,EM(),dt=1/100)
4646

4747
@test sol.W == prob.noise
4848
@test objectid(prob.noise) != objectid(sol.W)
49+
@test objectid(prob.noise.u) == objectid(prob.noise.W) != objectid(sol.W.W) == objectid(sol.W.u)
4950

5051
sol = solve(prob,EM(),dt=1/1000,alias_noise=false)
5152

5253
@test sol.W == prob.noise
5354
@test objectid(prob.noise) == objectid(sol.W)
55+
@test objectid(prob.noise.u) == objectid(prob.noise.W) == objectid(sol.W.W) == objectid(sol.W.u)
5456

5557
sol = solve(prob,EM(),dt=1/1000, alias_noise=true)
5658

5759
@test sol.W == prob.noise
5860
@test objectid(prob.noise) != objectid(sol.W)
61+
@test objectid(prob.noise.u) == objectid(prob.noise.W) != objectid(sol.W.W) == objectid(sol.W.u)
5962

6063
function g(du,u,p,t)
6164
@test typeof(du) <: SparseMatrixCSC
@@ -96,6 +99,9 @@ sol = solve(prob,EM(),dt=0.01)
9699
@test typeof(sol.W) == typeof(prob.noise) <: NoiseFunction
97100
@test objectid(prob.noise) != objectid(sol.W)
98101

102+
sol = solve(prob,EM(),dt=1/1000,alias_noise=false)
103+
@test objectid(prob.noise) == objectid(sol.W)
104+
99105
sol = solve(prob,EM(),dt=0.01,alias_noise=true)
100106
@test sol.W.curt last(tspan)
101107

0 commit comments

Comments
 (0)