Skip to content

Commit 78a51cd

Browse files
Merge pull request #47 from JuliaDiffEq/difftools
update for difftools v0.2.0
2 parents 4951b96 + 3345d4d commit 78a51cd

File tree

3 files changed

+29
-58
lines changed

3 files changed

+29
-58
lines changed

REQUIRE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,4 @@ StaticArrays
1313
Reexport
1414
RandomNumbers
1515
MuladdMacro
16-
DiffEqDiffTools 0.1.0
16+
DiffEqDiffTools 0.2.0

src/caches/sdirk_caches.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,8 @@ function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prot
3232
fsalfirst = zeros(rate_prototype)
3333
k = zeros(rate_prototype)
3434

35-
uf = UJacobianWrapper(f,t,tmp,dz)
36-
if alg_autodiff(alg)
37-
jac_config = ForwardDiff.JacobianConfig(uf,du1,uprev,
38-
ForwardDiff.Chunk{determine_chunksize(u,alg)}())
39-
else
40-
jac_config = nothing
41-
end
35+
uf = DiffEqDiffTools.UJacobianWrapper(f,t)
36+
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
4237
ηold = one(uEltypeNoUnits)
4338

4439
if alg.κ != nothing
@@ -73,7 +68,7 @@ end
7368

7469
function alg_cache(alg::ImplicitEM,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prototype,
7570
uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}})
76-
uf = UDerivativeWrapper(f,t)
71+
uf = DiffEqDiffTools.UDerivativeWrapper(f,t)
7772
ηold = one(uEltypeNoUnits)
7873

7974
if alg.κ != nothing
@@ -125,13 +120,8 @@ function alg_cache(alg::ImplicitEulerHeun,prob,u,ΔW,ΔZ,rate_prototype,noise_ra
125120
fsalfirst = zeros(rate_prototype)
126121
k = zeros(rate_prototype)
127122

128-
uf = UJacobianWrapper(f,t,tmp,dz)
129-
if alg_autodiff(alg)
130-
jac_config = ForwardDiff.JacobianConfig(uf,du1,uprev,
131-
ForwardDiff.Chunk{determine_chunksize(u,alg)}())
132-
else
133-
jac_config = nothing
134-
end
123+
uf = DiffEqDiffTools.UJacobianWrapper(f,t)
124+
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
135125
ηold = one(uEltypeNoUnits)
136126

137127
if alg.κ != nothing
@@ -162,7 +152,7 @@ end
162152

163153
function alg_cache(alg::ImplicitEulerHeun,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prototype,
164154
uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}})
165-
uf = UDerivativeWrapper(f,t)
155+
uf = DiffEqDiffTools.UDerivativeWrapper(f,t)
166156
ηold = one(uEltypeNoUnits)
167157

168158
if alg.κ != nothing
@@ -215,13 +205,8 @@ function alg_cache(alg::ImplicitRKMil,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_p
215205
fsalfirst = zeros(rate_prototype)
216206
k = zeros(rate_prototype)
217207

218-
uf = UJacobianWrapper(f,t,tmp,dz)
219-
if alg_autodiff(alg)
220-
jac_config = ForwardDiff.JacobianConfig(uf,du1,uprev,
221-
ForwardDiff.Chunk{determine_chunksize(u,alg)}())
222-
else
223-
jac_config = nothing
224-
end
208+
uf = DiffEqDiffTools.UJacobianWrapper(f,t)
209+
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
225210
ηold = one(uEltypeNoUnits)
226211

227212
if alg.κ != nothing
@@ -253,7 +238,7 @@ end
253238

254239
function alg_cache(alg::ImplicitRKMil,prob,u,ΔW,ΔZ,rate_prototype,noise_rate_prototype,
255240
uEltypeNoUnits,tTypeNoUnits,uprev,f,t,::Type{Val{false}})
256-
uf = UDerivativeWrapper(f,t)
241+
uf = DiffEqDiffTools.UDerivativeWrapper(f,t)
257242
ηold = one(uEltypeNoUnits)
258243

259244
if alg.κ != nothing

src/derivative_wrappers.jl

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,9 @@
1-
mutable struct TimeGradientWrapper{fType,uType} <: Function
2-
f::fType
3-
uprev::uType
4-
fx1::uType
5-
end
6-
(p::TimeGradientWrapper)(t) = (du2 = similar(p.uprev); p.f(t,p.uprev,du2); du2)
7-
(p::TimeGradientWrapper)(du2,t) = p.f(t,p.uprev,du2)
8-
9-
mutable struct UJacobianWrapper{fType,tType,CacheType} <: Function
10-
f::fType
11-
t::tType
12-
x1::CacheType
13-
fx1::CacheType
14-
end
15-
16-
(p::UJacobianWrapper)(du1,uprev) = p.f(p.t,uprev,du1)
17-
(p::UJacobianWrapper)(uprev) = (du1 = similar(uprev); p.f(p.t,uprev,du1); du1)
18-
19-
mutable struct TimeDerivativeWrapper{F,uType} <: Function
20-
f::F
21-
u::uType
22-
end
23-
(p::TimeDerivativeWrapper)(t) = p.f(t,p.u)
24-
25-
mutable struct UDerivativeWrapper{F,tType} <: Function
26-
f::F
27-
t::tType
28-
end
29-
(p::UDerivativeWrapper)(u) = p.f(p.t,u)
30-
311
function derivative!(df::AbstractArray{<:Number}, f, x::Union{Number,AbstractArray{<:Number}}, fx::AbstractArray{<:Number}, integrator::DEIntegrator)
322
if alg_autodiff(integrator.alg)
333
ForwardDiff.derivative!(df, f, fx, x)
344
else
355
RealOrComplex = eltype(integrator.u) <: Complex ? Val{:Complex} : Val{:Real}
36-
DiffEqDiffTools.finite_difference!(df, f, x, integrator.alg.diff_type, RealOrComplex, Val{:DiffEqDerivativeWrapper}, fx)
6+
DiffEqDiffTools.finite_difference!(df, f, x, integrator.alg.diff_type, RealOrComplex, fx)
377
end
388
nothing
399
end
@@ -42,8 +12,24 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, f
4212
if alg_autodiff(integrator.alg)
4313
ForwardDiff.jacobian!(J, f, fx, x, jac_config)
4414
else
45-
RealOrComplex = eltype(integrator.u) <: Complex ? Val{:Complex} : Val{:Real}
46-
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, integrator.alg.diff_type, RealOrComplex, Val{:JacobianWrapper}, fx)
15+
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, jac_config)
4716
end
4817
nothing
4918
end
19+
20+
function build_jac_config(alg,f,uf,du1,uprev,u,tmp,du2)
21+
if !has_jac(f)
22+
if alg_autodiff(alg)
23+
jac_config = ForwardDiff.JacobianConfig(uf,du1,uprev,ForwardDiff.Chunk{determine_chunksize(u,alg)}())
24+
else
25+
if alg.diff_type != Val{:complex}
26+
jac_config = DiffEqDiffTools.JacobianCache(tmp,du1,du2,alg.diff_type)
27+
else
28+
jac_config = DiffEqDiffTools.JacobianCache(Complex{eltype(tmp)}.(tmp),Complex{eltype(du1)}.(du1),nothing,alg.diff_type)
29+
end
30+
end
31+
else
32+
jac_config = nothing
33+
end
34+
jac_config
35+
end

0 commit comments

Comments
 (0)