Skip to content

Commit 89a3f66

Browse files
no broadcast with out of place
1 parent 46dfce6 commit 89a3f66

File tree

7 files changed

+93
-244
lines changed

7 files changed

+93
-244
lines changed

src/integrators/iif.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mutable struct RHS_IIF1M_Scalar{F,CType,tType} <: Function
66
end
77

88
function (p::RHS_IIF1M_Scalar)(u,resid)
9-
resid[1] .= u[1] .- p.tmp .- p.dt.*p.f[2](p.t+p.dt,u[1])[1]
9+
resid[1] .= u[1] - p.tmp - p.dt*p.f[2](p.t+p.dt,u[1])[1]
1010
end
1111

1212
mutable struct RHS_IIF2M_Scalar{F,CType,tType} <: Function
@@ -17,7 +17,7 @@ mutable struct RHS_IIF2M_Scalar{F,CType,tType} <: Function
1717
end
1818

1919
function (p::RHS_IIF2M_Scalar)(u,resid)
20-
resid[1] .= u[1] .- p.tmp .- 0.5p.dt.*p.f[2](p.t+p.dt,u[1])[1]
20+
resid[1] = u[1] - p.tmp - 0.5p.dt*p.f[2](p.t+p.dt,u[1])[1]
2121
end
2222

2323
@inline function initialize!(integrator,cache::Union{IIF1MConstantCache,IIF2MConstantCache,IIF1MilConstantCache},f=integrator.f)
@@ -31,9 +31,9 @@ end
3131
if typeof(cache) <: IIF1MilConstantCache
3232
error("Milstein correction does not work.")
3333
elseif typeof(cache) <: IIF1MConstantCache
34-
tmp = expm(A*dt)*(uprev .+ integrator.g(t,uprev).*W.dW)
34+
tmp = expm(A*dt)*(uprev + integrator.g(t,uprev)*W.dW)
3535
elseif typeof(cache) <: IIF2MConstantCache
36-
tmp = expm(A*dt)*(uprev .+ 0.5dt.*integrator.f[2](t,uprev) .+ integrator.g(t,uprev).*W.dW)
36+
tmp = expm(A*dt)*(uprev + 0.5dt*integrator.f[2](t,uprev) + integrator.g(t,uprev)*W.dW)
3737
end
3838

3939
if integrator.iter > 1 && !integrator.u_modified

src/integrators/kencarp.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
W = 1 - γdt*J
3333
end
3434

35-
z₁ = dt.*f(t, uprev)
35+
z₁ = dt*f(t, uprev)
3636

3737
##### Step 2
3838

@@ -44,28 +44,28 @@
4444

4545
g1 = g(t,uprev)
4646

47-
tmp = @. uprev + γ*z₁ + chi2*nb021*g1
47+
tmp = uprev + γ*z₁ + chi2*nb021*g1
4848

4949
if typeof(integrator.f) <: SplitFunction
5050
# This assumes the implicit part is cheaper than the explicit part
5151
k1 = dt*f2(t,uprev)
5252
tmp += ea21*k1
5353
end
5454

55-
u = @. tmp + γ*z₂
56-
b = dt.*f(tstep,u) .- z₂
55+
u = tmp + γ*z₂
56+
b = dt*f(tstep,u) - z₂
5757
dz = W\b
5858
ndz = integrator.opts.internalnorm(dz)
59-
z₂ = z₂ .+ dz
59+
z₂ = z₂ + dz
6060

6161
η = max(cache.ηold,eps(eltype(integrator.opts.reltol)))^(0.8)
6262
do_newton = integrator.success_iter == 0 || η*ndz > κtol
6363

6464
fail_convergence = false
6565
while (do_newton || iter < integrator.alg.min_newton_iter) && iter < integrator.alg.max_newton_iter
6666
iter += 1
67-
u = @. tmp + γ*z₂
68-
b = dt.*f(tstep,u) .- z₂
67+
u = tmp + γ*z₂
68+
b = dt*f(tstep,u) - z₂
6969
dz = W\b
7070
ndzprev = ndz
7171
ndz = integrator.opts.internalnorm(dz)
@@ -76,7 +76,7 @@
7676
end
7777
η = θ/(1-θ)
7878
do_newton =*ndz > κtol)
79-
z₂ = z₂ .+ dz
79+
z₂ = z₂ + dz
8080
end
8181

8282
if (iter >= integrator.alg.max_newton_iter && do_newton) || fail_convergence
@@ -91,30 +91,30 @@
9191

9292
if typeof(integrator.f) <: SplitFunction
9393
z₃ = z₂
94-
u = @. tmp + γ*z₂
94+
u = tmp + γ*z₂
9595
k2 = dt*f2(t + 2γ*dt, u)
96-
tmp = @. uprev + a31*z₁ + a32*z₂ + ea31*k1 + ea32*k2
96+
tmp = uprev + a31*z₁ + a32*z₂ + ea31*k1 + ea32*k2
9797
else
9898
# Guess is from Hermite derivative on z₁ and z₂
99-
#z₃ = @. α31*z₁ + α32*z₂
99+
#z₃ = α31*z₁ + α32*z₂
100100
z₃ = z₂
101-
tmp = @. uprev + a31*z₁ + a32*z₂
101+
tmp = uprev + a31*z₁ + a32*z₂
102102
end
103103

104-
u = @. tmp + γ*z₃
105-
b = dt.*f(tstep,u) .- z₃
104+
u = tmp + γ*z₃
105+
b = dt*f(tstep,u) - z₃
106106
dz = W\b
107107
ndz = integrator.opts.internalnorm(dz)
108-
z₃ = z₃ .+ dz
108+
z₃ = z₃ + dz
109109

110110
η = max(η,eps(eltype(integrator.opts.reltol)))^(0.8)
111111
do_newton =*ndz > κtol)
112112

113113
fail_convergence = false
114114
while (do_newton || iter < integrator.alg.min_newton_iter) && iter < integrator.alg.max_newton_iter
115115
iter += 1
116-
u = @. tmp + γ*z₃
117-
b = dt.*f(tstep,u) .- z₃
116+
u = tmp + γ*z₃
117+
b = dt*f(tstep,u) - z₃
118118
dz = W\b
119119
ndzprev = ndz
120120
ndz = integrator.opts.internalnorm(dz)
@@ -125,7 +125,7 @@
125125
end
126126
η = θ/(1-θ)
127127
do_newton =*ndz > κtol)
128-
z₃ = z₃ .+ dz
128+
z₃ = z₃ + dz
129129
end
130130

131131
if (iter >= integrator.alg.max_newton_iter && do_newton) || fail_convergence
@@ -142,30 +142,30 @@
142142

143143
if typeof(integrator.f) <: SplitFunction
144144
z₄ = z₂
145-
u = @. tmp + γ*z₃
145+
u = tmp + γ*z₃
146146
k3 = dt*f2(t + c3*dt, u)
147-
tmp = @. uprev + a41*z₁ + a42*z₂ + a43*z₃ + ea41*k1 + ea42*k2 + ea43*k3 + chi2*nb043*g1
147+
tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ + ea41*k1 + ea42*k2 + ea43*k3 + chi2*nb043*g1
148148
else
149149
@unpack α41,α42 = cache.tab
150-
#z₄ = @. α41*z₁ + α42*z₂
150+
#z₄ = α41*z₁ + α42*z₂
151151
z₄ = z₂
152-
tmp = @. uprev + a41*z₁ + a42*z₂ + a43*z₃ + chi2*nb043*g1
152+
tmp = uprev + a41*z₁ + a42*z₂ + a43*z₃ + chi2*nb043*g1
153153
end
154154

155-
u = @. tmp + γ*z₄
156-
b = dt.*f(tstep,u) .- z₄
155+
u = tmp + γ*z₄
156+
b = dt*f(tstep,u) - z₄
157157
dz = W\b
158158
ndz = integrator.opts.internalnorm(dz)
159-
z₄ = z₄ .+ dz
159+
z₄ = z₄ + dz
160160

161161
η = max(η,eps(eltype(integrator.opts.reltol)))^(0.8)
162162
do_newton =*ndz > κtol)
163163

164164
fail_convergence = false
165165
while (do_newton || iter < integrator.alg.min_newton_iter) && iter < integrator.alg.max_newton_iter
166166
iter += 1
167-
u = @. tmp + γ*z₄
168-
b = dt.*f(tstep,u) .- z₄
167+
u = tmp + γ*z₄
168+
b = dt*f(tstep,u) - z₄
169169
dz = W\b
170170
ndzprev = ndz
171171
ndz = integrator.opts.internalnorm(dz)
@@ -176,24 +176,24 @@
176176
end
177177
η = θ/(1-θ)
178178
do_newton =*ndz > κtol)
179-
z₄ = z₄ .+ dz
179+
z₄ = z₄ + dz
180180
end
181181

182182
if (iter >= integrator.alg.max_newton_iter && do_newton) || fail_convergence
183183
integrator.force_stepfail = true
184184
return
185185
end
186186

187-
u = @. tmp + γ*z₄
187+
u = tmp + γ*z₄
188188
g4 = g(t+dt,uprev)
189189

190190
E₂ = chi2*(g1-g4)
191191

192192
if typeof(integrator.f) <: SplitFunction
193193
k4 = dt*f2(t+dt, u)
194-
u = @. uprev + a41*z₁ + a42*z₂ + a43*z₃ + γ*z₄ + eb1*k1 + eb2*k2 + eb3*k3 + eb4*k4 + integrator.W.dW*g4 + E₂
194+
u = uprev + a41*z₁ + a42*z₂ + a43*z₃ + γ*z₄ + eb1*k1 + eb2*k2 + eb3*k3 + eb4*k4 + integrator.W.dW*g4 + E₂
195195
else
196-
u = @. uprev + a41*z₁ + a42*z₂ + a43*z₃ + γ*z₄ + integrator.W.dW*g4 + E₂
196+
u = uprev + a41*z₁ + a42*z₂ + a43*z₃ + γ*z₄ + integrator.W.dW*g4 + E₂
197197
end
198198

199199
################################### Finalize
@@ -205,9 +205,9 @@
205205

206206
#=
207207
if typeof(integrator.f) <: SplitFunction
208-
tmp = @. btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + ebtilde1*k1 + ebtilde2*k2 + ebtilde3*k3 + ebtilde4*k4 + chi2*(g1-g4)
208+
tmp = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + ebtilde1*k1 + ebtilde2*k2 + ebtilde3*k3 + ebtilde4*k4 + chi2*(g1-g4)
209209
else
210-
tmp = @. btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + chi2*(g1-g4)
210+
tmp = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + chi2*(g1-g4)
211211
end
212212
if integrator.alg.smooth_est # From Shampine
213213
est = W\tmp
@@ -248,7 +248,7 @@ end
248248
end
249249

250250
if typeof(integrator.W.dW) <: Union{SArray,Number}
251-
chi2 = @. (integrator.W.dW + integrator.W.dZ/sqrt(3))/2 #I_(1,0)/h
251+
chi2 = (integrator.W.dW + integrator.W.dZ/sqrt(3))/2 #I_(1,0)/h
252252
else
253253
@. chi2 = (integrator.W.dW + integrator.W.dZ/sqrt(3))/2 #I_(1,0)/h
254254
end

src/integrators/low_order.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
else
66
noise = integrator.g(t,uprev)*W.dW
77
end
8-
u = @muladd uprev .+ dt.*integrator.f(t,uprev) .+ noise
8+
u = @muladd uprev + dt*integrator.f(t,uprev) + noise
99
@pack integrator = t,dt,u
1010
end
1111

@@ -39,14 +39,14 @@ end
3939
else
4040
noise = gtmp*W.dW
4141
end
42-
tmp = @. @muladd uprev + ftmp.*dt + noise
42+
tmp = @muladd uprev + ftmp*dt + noise
4343
gtmp2 = (1/2).*(gtmp.+integrator.g(t+dt,tmp))
4444
if is_diagonal_noise(integrator.sol.prob)
4545
noise2 = gtmp2.*W.dW
4646
else
4747
noise2 = gtmp2*W.dW
4848
end
49-
u = @muladd uprev .+ (1/2).*dt.*(ftmp.+integrator.f(t+dt,tmp)) .+ noise2
49+
u = @muladd uprev + (1/2)*dt*(ftmp+integrator.f(t+dt,tmp)) + noise2
5050
@pack integrator = t,dt,u
5151
end
5252

@@ -97,7 +97,7 @@ end
9797

9898
@inline function perform_step!(integrator,cache::RandomEMConstantCache,f=integrator.f)
9999
@unpack t,dt,uprev,u,W = integrator
100-
u = @muladd uprev .+ dt.*integrator.f(t,uprev,W.dW)
100+
u = @muladd uprev + dt*integrator.f(t,uprev,W.dW)
101101
@pack integrator = t,dt,u
102102
end
103103

@@ -114,19 +114,19 @@ end
114114

115115
@inline function perform_step!(integrator,cache::RKMilConstantCache,f=integrator.f)
116116
@unpack t,dt,uprev,u,W = integrator
117-
K = @muladd uprev .+ dt.*integrator.f(t,uprev)
117+
K = @muladd uprev + dt*integrator.f(t,uprev)
118118
L = integrator.g(t,uprev)
119119
mil_correction = zero(u)
120120
if alg_interpretation(integrator.alg) == :Ito
121-
utilde = @. K + L*integrator.sqdt
121+
utilde = K + L*integrator.sqdt
122122
mil_correction = (integrator.g(t,utilde).-L)./(2 .* integrator.sqdt).*(W.dW.^2 .- dt)
123123
elseif alg_interpretation(integrator.alg) == :Stratonovich
124-
utilde = @. uprev + L*integrator.sqdt
124+
utilde = uprev + L*integrator.sqdt
125125
mil_correction = (integrator.g(t,utilde).-L)./(2 .* integrator.sqdt).*(W.dW.^2)
126126
end
127-
u = @. K+L*W.dW+mil_correction
127+
u = K+L*W.dW+mil_correction
128128
if integrator.opts.adaptive
129-
integrator.EEst = integrator.opts.internalnorm(@.(mil_correction/(@muladd(integrator.opts.abstol + max.(abs(uprev),abs(u))*integrator.opts.reltol))))
129+
integrator.EEst = integrator.opts.internalnorm(mil_correction/(@muladd(integrator.opts.abstol + max.(abs(uprev),abs(u))*integrator.opts.reltol)))
130130
end
131131
@pack integrator = t,dt,u
132132
end

src/integrators/sdirk.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,19 @@
2929
gtmp = L.*integrator.W.dW
3030

3131
if typeof(cache) <: ImplicitEulerHeunConstantCache
32-
utilde = @. uprev + gtmp
33-
gtmp = @. ((integrator.g(t,utilde) + L)/2)*integrator.W.dW
32+
utilde = uprev + gtmp
33+
gtmp = ((integrator.g(t,utilde) + L)/2)*integrator.W.dW
3434
end
3535

3636
if typeof(cache) <: ImplicitRKMilConstantCache
3737
if alg_interpretation(integrator.alg) == :Ito
3838
K = @muladd uprev .+ dt.*ftmp
39-
utilde = @. K + L*integrator.sqdt
39+
utilde = K + L*integrator.sqdt
4040
mil_correction = (integrator.g(t,utilde).-L)./(2 .* integrator.sqdt).*
4141
(integrator.W.dW.^2 .- dt)
4242
gtmp += mil_correction
4343
elseif alg_interpretation(integrator.alg) == :Stratonovich
44-
utilde = @. uprev + L*integrator.sqdt
44+
utilde = uprev + L*integrator.sqdt
4545
mil_correction = (integrator.g(t,utilde).-L)./(2 .* integrator.sqdt).*
4646
(integrator.W.dW.^2)
4747
gtmp += mil_correction
@@ -57,9 +57,9 @@
5757
iter += 1
5858
if integrator.alg.symplectic
5959
# u = uprev + z then u = (uprev+u)/2 = (uprev+uprev+z)/2 = uprev + z/2
60-
u = @. uprev + z/2 + gtmp/2
60+
u = uprev + z/2 + gtmp/2
6161
else
62-
u = @. uprev + dt*(1-theta)*ftmp + theta*z + gtmp
62+
u = uprev + dt*(1-theta)*ftmp + theta*z + gtmp
6363
end
6464
b = -z .+ dt.*f(t+a,u)
6565
dz = W\b
@@ -78,9 +78,9 @@
7878
iter += 1
7979
if integrator.alg.symplectic
8080
# u = uprev + z then u = (uprev+u)/2 = (uprev+uprev+z)/2 = uprev + z/2
81-
u = @. uprev + z/2 + gtmp/2
81+
u = uprev + z/2 + gtmp/2
8282
else
83-
u = @. uprev + dt*(1-theta)*ftmp + theta*z + gtmp
83+
u = uprev + dt*(1-theta)*ftmp + theta*z + gtmp
8484
end
8585
b = -z .+ dt.*f(t+a,u)
8686
dz = W\b
@@ -97,9 +97,9 @@
9797
end
9898

9999
if integrator.alg.symplectic
100-
u = @. uprev + z + gtmp
100+
u = uprev + z + gtmp
101101
else
102-
u = @. uprev + dt*(1-theta)*ftmp + theta*z + gtmp
102+
u = uprev + dt*(1-theta)*ftmp + theta*z + gtmp
103103
end
104104

105105
if (iter >= integrator.alg.max_newton_iter && do_newton) || fail_convergence
@@ -109,7 +109,7 @@
109109

110110
cache.ηold = η
111111
cache.newton_iters = iter
112-
u = @. uprev + dt*(1-theta)*ftmp + theta*z + gtmp
112+
u = uprev + dt*(1-theta)*ftmp + theta*z + gtmp
113113

114114
#=
115115
if integrator.opts.adaptive && integrator.success_iter > 0
@@ -118,7 +118,7 @@
118118
tprev = integrator.tprev
119119
DD3 = ((u - uprev)/((dt)*(t+dt-tprev)) + (uprev-uprev2)/((t-tprev)*(t+dt-tprev)))
120120
dEst = (dt^2)*abs(DD3/6)
121-
integrator.EEst = @. dEst/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol)
121+
integrator.EEst = dEst/(integrator.opts.abstol+max(abs(uprev),abs(u))*integrator.opts.reltol)
122122
else
123123
integrator.EEst = 1
124124
end

src/integrators/split.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
@inline function perform_step!(integrator,cache::SplitEMConstantCache,f=integrator.f)
22
@unpack t,dt,uprev,u,W = integrator
3-
u = dt.*(integrator.f[1](t,uprev) .+
4-
integrator.f[2](t,uprev)) .+
5-
integrator.g(t,uprev).*W.dW .+ uprev
3+
u = dt*(integrator.f[1](t,uprev) +
4+
integrator.f[2](t,uprev)) +
5+
integrator.g(t,uprev).*W.dW + uprev
66
@pack integrator = t,dt,u
77
end
88

0 commit comments

Comments
 (0)