Skip to content

Commit 0ab0449

Browse files
performance boosts
1 parent 80ff0e0 commit 0ab0449

File tree

6 files changed

+37
-21
lines changed

6 files changed

+37
-21
lines changed

src/integrators/integrator_utils.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
@inline function modify_dt_for_tstops!(integrator)
3434
tstops = integrator.opts.tstops
35-
if !isempty(tstops)
35+
@fastmath if !isempty(tstops)
3636
if integrator.opts.adaptive
3737
if integrator.tdir > 0
3838
integrator.dt = min(abs(integrator.dt),abs(top(tstops)-integrator.t)) # step! to the end
@@ -43,7 +43,7 @@ end
4343
integrator.dt = integrator.tdir*abs(top(tstops)-integrator.t)
4444
elseif integrator.dtchangeable && !integrator.force_stepfail
4545
# always try to step! with dtcache, but lower if a tstops
46-
integrator.dt = integrator.tdir*min(abs(integrator.dtcache),abs(top(tstops)-integrator.t)) # step! to the end
46+
integrator.dt = @fastmath integrator.tdir*min(abs(integrator.dtcache),abs(top(tstops)-integrator.t)) # step! to the end
4747
end
4848
end
4949
end
@@ -119,7 +119,8 @@ end
119119
end
120120
end
121121
end
122-
if force_save || (integrator.opts.save_everystep && integrator.iter%integrator.opts.timeseries_steps==0)
122+
if force_save || (integrator.opts.save_everystep &&
123+
integrator.iter%integrator.opts.timeseries_steps==0)
123124
integrator.saveiter += 1
124125
if integrator.opts.save_idxs == nothing
125126
copyat_or_push!(integrator.sol.u,integrator.saveiter,integrator.u)
@@ -155,7 +156,7 @@ end
155156
integrator.tprev = integrator.t
156157
if typeof(integrator.t)<:AbstractFloat && !isempty(integrator.opts.tstops)
157158
tstop = top(integrator.opts.tstops)
158-
abs(ttmp - tstop) < 10eps(typeof(integrator.EEst)) ? (integrator.t = tstop) : (integrator.t = ttmp)
159+
@fastmath abs(ttmp - tstop) < 10eps(typeof(integrator.EEst)) ? (integrator.t = tstop) : (integrator.t = ttmp)
159160
else
160161
integrator.t = ttmp
161162
end
@@ -166,7 +167,9 @@ end
166167
integrator.tprev = integrator.t
167168
if typeof(integrator.t)<:AbstractFloat && !isempty(integrator.opts.tstops)
168169
tstop = top(integrator.opts.tstops)
169-
abs(ttmp - tstop) < 10eps(integrator.t) ? (integrator.t = tstop) : (integrator.t = ttmp)
170+
# For some reason 10eps(integrator.t) is slow here
171+
# TODO: Allow higher precision but profile
172+
@fastmath abs(ttmp - tstop) < 1e-15 ? (integrator.t = tstop) : (integrator.t = ttmp)
170173
else
171174
integrator.t = ttmp
172175
end
@@ -260,7 +263,7 @@ end
260263
modify_dt_for_tstops!(integrator)
261264
accept_step!(integrator.W,integrator.dt)
262265
integrator.dt = integrator.W.dt
263-
integrator.sqdt = sqrt(abs(integrator.dt)) # It can change dt, like in RSwM1
266+
integrator.sqdt = @fastmath sqrt(abs(integrator.dt)) # It can change dt, like in RSwM1
264267
end
265268

266269
@inline function handle_tstop!(integrator)

src/perform_step/low_order.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
@muladd function perform_step!(integrator,cache::EMConstantCache,f=integrator.f)
22
@unpack t,dt,uprev,u,W = integrator
3-
if is_diagonal_noise(integrator.sol.prob)
4-
noise = integrator.g(t,uprev).*W.dW
5-
else
3+
if !is_diagonal_noise(integrator.sol.prob) || typeof(W.dW) <: Number
64
noise = integrator.g(t,uprev)*W.dW
5+
else
6+
noise = integrator.g(t,uprev).*W.dW
77
end
88
u = @muladd uprev + dt*integrator.f(t,uprev) + noise
99
integrator.u = u

src/perform_step/sra.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ end
191191

192192
if integrator.opts.adaptive
193193
E₁ = dt*(k1 + k2 + k3)
194-
integrator.EEst = integrator.opts.internalnorm(@muladd(integrator.opts.delta*E₁+E₂)./@muladd(integrator.opts.abstol + max.(integrator.opts.internalnorm.(uprev),integrator.opts.internalnorm.(u))*integrator.opts.reltol))
194+
integrator.EEst = integrator.opts.internalnorm(@muladd(integrator.opts.delta*E₁+E₂)/@muladd(integrator.opts.abstol + max(integrator.opts.internalnorm(uprev),integrator.opts.internalnorm(u))*integrator.opts.reltol))
195195
end
196196
integrator.u = u
197197
end

src/solve.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ function init(
1414
recompile::Type{Val{recompile_flag}}=Val{true};
1515
dt = tType(0),
1616
timeseries_steps::Int = 1,
17-
save_noise = true,
1817
saveat = tType[],tstops = tType[],d_discontinuities= tType[],
1918
save_timeseries = nothing,
2019
save_everystep = isempty(saveat),
20+
save_noise = save_everystep && typeof(prob.f) <: Tuple ?
21+
has_analytic(prob.f[1]) : has_analytic(prob.f),
2122
save_idxs = nothing,
2223
save_start = true,save_end = true,
2324
dense = save_everystep,
@@ -151,6 +152,16 @@ function init(
151152
timeseries = convert(Vector{typeof(u_initial)},timeseries_init)
152153
end
153154
ts = convert(Vector{tType},ts_init)
155+
156+
if !adaptive
157+
dt == 0 ? steps = length(tstops) : steps = round(Int,(tspan[2]-tspan[1])/dt,RoundUp)
158+
sizehint!(timeseries,steps+1)
159+
sizehint!(ts,steps+1)
160+
else
161+
sizehint!(timeseries,50)
162+
sizehint!(ts,50)
163+
end
164+
154165
#ks = convert(Vector{ksEltype},ks_init)
155166
alg_choice = Int[]
156167

@@ -220,31 +231,33 @@ function init(
220231
seed == 0 ? (prob.seed == 0 ? _seed = rand(UInt64) : _seed = prob.seed) : _seed = seed
221232

222233
if typeof(prob.noise) <: Void
234+
isadaptive(alg) ? rswm = RSWM(adaptivealg=:RSwM3) : rswm = RSWM(adaptivealg=:RSwM1)
223235
if isinplace
224-
isadaptive(alg) ? rswm = RSWM(adaptivealg=:RSwM3) : rswm = RSWM(adaptivealg=:RSwM1)
225236
if alg_needs_extra_process(alg)
226237
W = WienerProcess!(t,rand_prototype,rand_prototype,
227-
save_everystep=save_everystep,
238+
save_everystep=save_noise,
228239
timeseries_steps=timeseries_steps,
229240
rswm=rswm,
230241
rng = Xorshifts.Xoroshiro128Plus(_seed))
231242
else
232243
W = WienerProcess!(t,rand_prototype,
233-
save_everystep=save_everystep,
244+
save_everystep=save_noise,
234245
timeseries_steps=timeseries_steps,
235246
rswm=rswm,
236247
rng = Xorshifts.Xoroshiro128Plus(_seed))
237248
end
238249
else
239250
if alg_needs_extra_process(alg)
240251
W = WienerProcess(t,rand_prototype,rand_prototype,
241-
save_everystep=save_everystep,
252+
save_everystep=save_noise,
242253
timeseries_steps=timeseries_steps,
254+
rswm=rswm,
243255
rng = Xorshifts.Xoroshiro128Plus(_seed))
244256
else
245257
W = WienerProcess(t,rand_prototype,
246-
save_everystep=save_everystep,
258+
save_everystep=save_noise,
247259
timeseries_steps=timeseries_steps,
260+
rswm=rswm,
248261
rng = Xorshifts.Xoroshiro128Plus(_seed))
249262
end
250263
end

test/first_rand_test.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@ f1(t,u) = 0.
44
g1(t,u) = 1.
55
dt = 1//2^(4)
66
prob1 = SDEProblem{false}(f1,g1,0.,(0.0,1.0))
7-
integrator = init(prob1,EM(),dt=dt)
7+
integrator = init(prob1,EM(),dt=dt,save_noise=true)
88

99
k = integrator.W.dW
1010
@test integrator.W.dW != 0
1111
solve!(integrator)
1212
@test integrator.sol.W[2] == k
1313

1414
prob1 = SDEProblem{false}(f1,g1,zeros(4),(0.0,1.0))
15-
sol = solve(prob1,EM(),dt=dt)
15+
sol = solve(prob1,EM(),dt=dt,save_noise=true)
1616
@test sol.W[2] != zeros(4)
1717

1818

1919
f1(t,u,du) = du.=0.
2020
g1(t,u,du) = du.=1.
2121
dt = 1//2^(4)
2222
prob1 = SDEProblem(f1,g1,zeros(4),(0.0,1.0))
23-
sol = solve(prob1,EM(),dt=dt)
23+
sol = solve(prob1,EM(),dt=dt,save_noise=true)
2424
@test sol.W[2] != zeros(4)

test/split_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ f2(t,u) = (1.01)/2 * u
88

99
prob = SDEProblem{false}((f1,f2),σ,1/2,(0.0,1.0))
1010

11-
sol = solve(prob,SplitEM(),dt=1/10)
11+
sol = solve(prob,SplitEM(),dt=1/10,save_noise=true)
1212

1313
prob = SDEProblem{false}(f,σ,1/2,(0.0,1.0),noise = NoiseWrapper(sol.W))
1414

@@ -19,7 +19,7 @@ sol2 = solve(prob,EM(),dt=1/10)
1919
u0 = rand(4)
2020
prob = SDEProblem{false}((f1,f2),σ,u0,(0.0,1.0))
2121

22-
sol = solve(prob,SplitEM(),dt=1/10)
22+
sol = solve(prob,SplitEM(),dt=1/10,save_noise=true)
2323

2424
prob = SDEProblem{false}(f,σ,u0,(0.0,1.0),noise = NoiseWrapper(sol.W))
2525

0 commit comments

Comments
 (0)