Skip to content

Commit 63412a2

Browse files
committed
Split out preparation
1 parent 0a68e03 commit 63412a2

13 files changed

+241
-200
lines changed

docs/src/faq.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ However, this can be switched to any other "inner" backend compatible with [Diff
1818
### Arrays
1919

2020
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.
21-
22-
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
21+
The array types involved should be mutable.
2322

2423
### Scalars
2524

examples/3_tricks.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,19 @@ end
3838

3939
function conditions_components(x::ComponentVector, y::ComponentVector, _z)
4040
c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e)
41-
c = ComponentVector(; c_d=c_d, c_e=c_e)
41+
c = ComponentVector(; d=c_d, e=c_e)
4242
return c
4343
end;
4444

4545
# And build your implicit function like so:
4646

47-
implicit_components = ImplicitFunction(
48-
forward_components, conditions_components; strict=Val(false)
49-
);
47+
implicit_components = ImplicitFunction(forward_components, conditions_components);
5048

5149
# Now we're good to go.
5250

5351
a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0
5452
x = ComponentVector(; a=a, b=b, m=m)
55-
implicit_components(x)
53+
y, z = implicit_components(x)
5654

5755
# And it works with both ForwardDiff.jl and Zygote.jl
5856

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ChainRulesCore: unthunk, @not_implemented
66
using ImplicitDifferentiation:
77
ImplicitDifferentiation,
88
ImplicitFunction,
9+
ImplicitFunctionPreparation,
910
build_Aᵀ,
1011
build_Bᵀ,
1112
chainrules_suggested_backend
@@ -14,29 +15,33 @@ using ImplicitDifferentiation:
1415
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1516

1617
function ChainRulesCore.rrule(
17-
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
18+
rc::RuleConfig,
19+
implicit::ImplicitFunction,
20+
prep::ImplicitFunctionPreparation,
21+
x::AbstractArray,
22+
args::Vararg{Any,N};
1823
) where {N}
1924
y, z = implicit(x, args...)
2025
c = implicit.conditions(x, y, z, args...)
2126

2227
suggested_backend = chainrules_suggested_backend(rc)
23-
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
24-
Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend)
28+
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
29+
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
2530
project_x = ProjectTo(x)
2631

27-
function implicit_pullback((dy, dz))
32+
function implicit_pullback_prepared((dy, dz))
2833
dy = unthunk(dy)
2934
dy_vec = vec(dy)
30-
dc_vec = similar(vec(c))
31-
implicit.linear_solver(dc_vec, Aᵀ, -dy_vec)
35+
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
3236
dx_vec = Bᵀ(dc_vec)
3337
dx = reshape(dx_vec, size(x))
3438
df = NoTangent()
39+
dprep = @not_implemented("Tangents for mutable arguments are not defined")
3540
dargs = ntuple(unimplemented_tangent, N)
36-
return (df, project_x(dx), dargs...)
41+
return (df, dprep, project_x(dx), dargs...)
3742
end
3843

39-
return (y, z), implicit_pullback
44+
return (y, z), implicit_pullback_prepared
4045
end
4146

4247
function unimplemented_tangent(_)

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@ module ImplicitDifferentiationForwardDiffExt
22

33
using ADTypes: AutoForwardDiff
44
using ForwardDiff: Dual, Partials, partials, value
5-
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B
5+
using ImplicitDifferentiation:
6+
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B
67

78
function (implicit::ImplicitFunction)(
8-
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
9+
prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
910
) where {T,R,N}
1011
x = value.(x_and_dx)
1112
y, z = implicit(x, args...)
1213
c = implicit.conditions(x, y, z, args...)
1314

1415
suggested_backend = AutoForwardDiff()
15-
A = build_A(implicit, x, y, z, c, args...; suggested_backend)
16-
B = build_B(implicit, x, y, z, c, args...; suggested_backend)
16+
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
17+
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend)
1718

1819
dX = ntuple(Val(N)) do k
1920
partials.(x_and_dx, k)
@@ -24,8 +25,7 @@ function (implicit::ImplicitFunction)(
2425
return dₖc_vec
2526
end
2627
dY = map(dC_vec) do dₖc_vec
27-
dₖy_vec = similar(vec(y))
28-
implicit.linear_solver(dₖy_vec, A, -dₖc_vec)
28+
dₖy_vec = implicit.linear_solver(A, -dₖc_vec)
2929
dₖy = reshape(dₖy_vec, size(y))
3030
return dₖy
3131
end

src/ImplicitDifferentiation.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ using DifferentiationInterface:
2121
pullback,
2222
pushforward!,
2323
pushforward
24-
using Krylov: KrylovConstructor, krylov_workspace, krylov_solve!, solution
25-
using LinearAlgebra: ldiv!
24+
using Krylov: Krylov, krylov_workspace, krylov_solve!, solution
2625
using LinearOperators: LinearOperator
2726
using LinearMaps: FunctionMap
2827
using LinearAlgebra: factorize
2928

3029
include("utils.jl")
3130
include("settings.jl")
32-
include("preparation.jl")
3331
include("implicit_function.jl")
32+
include("preparation.jl")
3433
include("execution.jl")
34+
include("callable.jl")
3535

3636
export MatrixRepresentation, OperatorRepresentation
3737
export IterativeLinearSolver, DirectLinearSolver
3838
export ImplicitFunction
39+
export prepare_implicit
3940

4041
end

src/callable.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
2+
return implicit(ImplicitFunctionPreparation(), x, args...)
3+
end
4+
5+
function (implicit::ImplicitFunction)(
6+
::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N}
7+
) where {N}
8+
return implicit.solver(x, args...)
9+
end

src/execution.jl

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,40 @@
1-
struct JVP!{F,P,B,I,C}
1+
struct JVP!{F,P,B,I,V,C}
22
f::F
33
prep::P
44
backend::B
55
input::I
6+
v_buffer::V
67
contexts::C
78
end
89

9-
struct VJP!{F,P,B,I,C}
10+
struct VJP!{F,P,B,I,V,C}
1011
f::F
1112
prep::P
1213
backend::B
1314
input::I
15+
v_buffer::V
1416
contexts::C
1517
end
1618

17-
function (po::JVP!)(res::AbstractVector, v::AbstractVector)
18-
(; f, backend, input, contexts, prep) = po
19-
pushforward!(f, (res,), prep, backend, input, (v,), contexts...)
19+
function (po::JVP!)(res::AbstractVector, v_wrongtype::AbstractVector)
20+
(; f, backend, input, v_buffer, contexts, prep) = po
21+
copyto!(v_buffer, v_wrongtype)
22+
pushforward!(f, (res,), prep, backend, input, (v_buffer,), contexts...)
2023
return res
2124
end
2225

23-
function (po::VJP!)(res::AbstractVector, v::AbstractVector)
24-
(; f, backend, input, contexts, prep) = po
25-
pullback!(f, (res,), prep, backend, input, (v,), contexts...)
26+
function (po::VJP!)(res::AbstractVector, v_wrongtype::AbstractVector)
27+
(; f, backend, input, v_buffer, contexts, prep) = po
28+
copyto!(v_buffer, v_wrongtype)
29+
pullback!(f, (res,), prep, backend, input, (v_buffer,), contexts...)
2630
return res
2731
end
2832

2933
## A
3034

3135
function build_A(
3236
implicit::ImplicitFunction,
37+
prep::ImplicitFunctionPreparation,
3338
x::AbstractArray,
3439
y::AbstractArray,
3540
z,
@@ -38,14 +43,15 @@ function build_A(
3843
suggested_backend::AbstractADType,
3944
)
4045
return build_A_aux(
41-
implicit.representation, implicit, x, y, z, c, args...; suggested_backend
46+
implicit.representation, implicit, prep, x, y, z, c, args...; suggested_backend
4247
)
4348
end
4449

4550
function build_A_aux(
46-
::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend
51+
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
4752
)
48-
(; conditions, backends, prep_A) = implicit
53+
(; conditions, backends) = implicit
54+
(; prep_A) = prep
4955
actual_backend = isnothing(backends) ? suggested_backend : backends.y
5056
contexts = (Constant(x), Constant(z), map(Constant, args)...)
5157
if isnothing(prep_A)
@@ -59,6 +65,7 @@ end
5965
function build_A_aux(
6066
::OperatorRepresentation{package,symmetric,hermitian,posdef},
6167
implicit,
68+
prep,
6269
x,
6370
y,
6471
z,
@@ -67,7 +74,8 @@ function build_A_aux(
6774
suggested_backend,
6875
) where {package,symmetric,hermitian,posdef}
6976
T = Base.promote_eltype(x, y, c)
70-
(; conditions, backends, prep_A) = implicit
77+
(; conditions, backends) = implicit
78+
(; prep_A) = prep
7179
actual_backend = isnothing(backends) ? suggested_backend : backends.y
7280
contexts = (Constant(x), Constant(z), map(Constant, args)...)
7381
f_vec = VecToVec(Switch12(conditions), y)
@@ -82,9 +90,9 @@ function build_A_aux(
8290
f_vec, prep_A, actual_backend, y_vec, (dy_vec,), contexts...
8391
)
8492
end
85-
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, contexts)
93+
prod! = JVP!(f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts)
8694
if package == :LinearOperators
87-
return LinearOperator(T, length(c), length(y), symmetric, hermitian, prod!;)
95+
return LinearOperator(T, length(c), length(y), symmetric, hermitian, prod!)
8896
elseif package == :LinearMaps
8997
return FunctionMap{T}(
9098
prod!,
@@ -102,6 +110,7 @@ end
102110

103111
function build_Aᵀ(
104112
implicit::ImplicitFunction,
113+
prep::ImplicitFunctionPreparation,
105114
x::AbstractArray,
106115
y::AbstractArray,
107116
z,
@@ -110,14 +119,15 @@ function build_Aᵀ(
110119
suggested_backend::AbstractADType,
111120
)
112121
return build_Aᵀ_aux(
113-
implicit.representation, implicit, x, y, z, c, args...; suggested_backend
122+
implicit.representation, implicit, prep, x, y, z, c, args...; suggested_backend
114123
)
115124
end
116125

117126
function build_Aᵀ_aux(
118-
::MatrixRepresentation, implicit, x, y, z, c, args...; suggested_backend
127+
::MatrixRepresentation, implicit, prep, x, y, z, c, args...; suggested_backend
119128
)
120-
(; conditions, backends, prep_Aᵀ) = implicit
129+
(; conditions, backends) = implicit
130+
(; prep_Aᵀ) = prep
121131
actual_backend = isnothing(backends) ? suggested_backend : backends.y
122132
contexts = (Constant(x), Constant(z), map(Constant, args)...)
123133
if isnothing(prep_Aᵀ)
@@ -133,6 +143,7 @@ end
133143
function build_Aᵀ_aux(
134144
::OperatorRepresentation{package,symmetric,hermitian,posdef},
135145
implicit,
146+
prep,
136147
x,
137148
y,
138149
z,
@@ -141,7 +152,8 @@ function build_Aᵀ_aux(
141152
suggested_backend,
142153
) where {package,symmetric,hermitian,posdef}
143154
T = Base.promote_eltype(x, y, c)
144-
(; conditions, backends, prep_Aᵀ) = implicit
155+
(; conditions, backends) = implicit
156+
(; prep_Aᵀ) = prep
145157
actual_backend = isnothing(backends) ? suggested_backend : backends.y
146158
contexts = (Constant(x), Constant(z), map(Constant, args)...)
147159
f_vec = VecToVec(Switch12(conditions), y)
@@ -156,9 +168,9 @@ function build_Aᵀ_aux(
156168
f_vec, prep_Aᵀ, actual_backend, y_vec, (dc_vec,), contexts...
157169
)
158170
end
159-
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts)
171+
prod! = VJP!(f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts)
160172
if package == :LinearOperators
161-
return LinearOperator(T, length(y), length(c), symmetric, hermitian, prod!;)
173+
return LinearOperator(T, length(y), length(c), symmetric, hermitian, prod!)
162174
elseif package == :LinearMaps
163175
return FunctionMap{T}(
164176
prod!,
@@ -176,14 +188,16 @@ end
176188

177189
function build_B(
178190
implicit::ImplicitFunction,
191+
prep::ImplicitFunctionPreparation,
179192
x::AbstractArray,
180193
y::AbstractArray,
181194
z,
182195
c,
183196
args...;
184197
suggested_backend::AbstractADType,
185198
)
186-
(; conditions, backends, prep_B) = implicit
199+
(; conditions, backends) = implicit
200+
(; prep_B) = prep
187201
actual_backend = isnothing(backends) ? suggested_backend : backends.x
188202
contexts = (Constant(y), Constant(z), map(Constant, args)...)
189203
f_vec = VecToVec(conditions, x)
@@ -198,7 +212,8 @@ function build_B(
198212
f_vec, prep_B, actual_backend, x_vec, (dx_vec,), contexts...
199213
)
200214
end
201-
function B_fun(dx_vec)
215+
function B_fun(dx_vec_wrongtype)
216+
copyto!(dx_vec, dx_vec_wrongtype)
202217
return pushforward(
203218
f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts...
204219
)[1]
@@ -210,14 +225,16 @@ end
210225

211226
function build_Bᵀ(
212227
implicit::ImplicitFunction,
228+
prep::ImplicitFunctionPreparation,
213229
x::AbstractArray,
214230
y::AbstractArray,
215231
z,
216232
c,
217233
args...;
218234
suggested_backend::AbstractADType,
219235
)
220-
(; conditions, backends, prep_Bᵀ) = implicit
236+
(; conditions, backends) = implicit
237+
(; prep_Bᵀ) = prep
221238
actual_backend = isnothing(backends) ? suggested_backend : backends.x
222239
contexts = (Constant(y), Constant(z), map(Constant, args)...)
223240
f_vec = VecToVec(conditions, x)
@@ -232,7 +249,8 @@ function build_Bᵀ(
232249
f_vec, prep_Bᵀ, actual_backend, x_vec, (dc_vec,), contexts...
233250
)
234251
end
235-
function Bᵀ_fun(dc_vec)
252+
function Bᵀ_fun(dc_vec_wrongtype)
253+
copyto!(dc_vec, dc_vec_wrongtype)
236254
return pullback(f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts...)[1]
237255
end
238256
return Bᵀ_fun

0 commit comments

Comments
 (0)