Skip to content

Commit bc7267a

Browse files
authored
refactor!: split out preparation (#179)
* refactor!: move to in-place linear solve, remove IterativeLinearSolvers * Import * Nostrict * Docs * Fix coverage * Split out preparation * Plop * Dep * Dep * Fixes * Fixes again * Cov
1 parent 42fa4a1 commit bc7267a

16 files changed

+337
-295
lines changed

Project.toml

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
name = "ImplicitDifferentiation"
22
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
33
authors = ["Guillaume Dalle", "Mohamed Tarek"]
4-
version = "0.8.1"
4+
version = "0.9.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
9-
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
109
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
@@ -33,11 +32,10 @@ Documenter = "1.12.0"
3332
ExplicitImports = "1"
3433
FiniteDiff = "2.27.0"
3534
ForwardDiff = "0.10.36, 1"
36-
IterativeSolvers = "0.9.4"
3735
JET = "0.9, 0.10"
3836
JuliaFormatter = "2.1.2"
3937
Krylov = "0.9.6, 0.10"
40-
LinearAlgebra = "1.10"
38+
LinearAlgebra = "1"
4139
LinearMaps = "3.11.4"
4240
LinearOperators = "2.8.0"
4341
NLsolve = "4.5.1"
@@ -65,6 +63,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6563
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
6664
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
6765
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
66+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6867
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
6968
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
7069
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -76,4 +75,27 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
7675
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
7776

7877
[targets]
79-
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"]
78+
test = [
79+
"ADTypes",
80+
"Aqua",
81+
"ChainRulesCore",
82+
"ChainRulesTestUtils",
83+
"ComponentArrays",
84+
"DifferentiationInterface",
85+
"Documenter",
86+
"ExplicitImports",
87+
"FiniteDiff",
88+
"ForwardDiff",
89+
"JET",
90+
"JuliaFormatter",
91+
"LinearAlgebra",
92+
"NLsolve",
93+
"Optim",
94+
"Random",
95+
"SparseArrays",
96+
"StaticArrays",
97+
"Test",
98+
"TestItems",
99+
"TestItemRunner",
100+
"Zygote",
101+
]

docs/src/api.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ ImplicitFunction
2222
MatrixRepresentation
2323
OperatorRepresentation
2424
IterativeLinearSolver
25+
DirectLinearSolver
26+
prepare_implicit
2527
```
2628

2729
## Internals

docs/src/faq.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ However, this can be switched to any other "inner" backend compatible with [Diff
1818
### Arrays
1919

2020
Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.
21-
22-
If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
21+
The array types involved should be mutable.
2322

2423
### Scalars
2524

examples/3_tricks.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ end
3838

3939
function conditions_components(x::ComponentVector, y::ComponentVector, _z)
4040
c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e)
41-
c = ComponentVector(; c_d=c_d, c_e=c_e)
41+
c = ComponentVector(; d=c_d, e=c_e)
4242
return c
4343
end;
4444

@@ -50,7 +50,7 @@ implicit_components = ImplicitFunction(forward_components, conditions_components
5050

5151
a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0
5252
x = ComponentVector(; a=a, b=b, m=m)
53-
implicit_components(x)
53+
y, z = implicit_components(x)
5454

5555
# And it works with both ForwardDiff.jl and Zygote.jl
5656

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using ChainRulesCore: unthunk, @not_implemented
66
using ImplicitDifferentiation:
77
ImplicitDifferentiation,
88
ImplicitFunction,
9+
ImplicitFunctionPreparation,
910
build_Aᵀ,
1011
build_Bᵀ,
1112
chainrules_suggested_backend
@@ -14,28 +15,33 @@ using ImplicitDifferentiation:
1415
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1516

1617
function ChainRulesCore.rrule(
17-
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
18+
rc::RuleConfig,
19+
implicit::ImplicitFunction,
20+
prep::ImplicitFunctionPreparation,
21+
x::AbstractArray,
22+
args::Vararg{Any,N};
1823
) where {N}
1924
y, z = implicit(x, args...)
2025
c = implicit.conditions(x, y, z, args...)
2126

2227
suggested_backend = chainrules_suggested_backend(rc)
23-
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
24-
Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend)
28+
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
29+
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
2530
project_x = ProjectTo(x)
2631

27-
function implicit_pullback((dy, dz))
32+
function implicit_pullback_prepared((dy, dz))
2833
dy = unthunk(dy)
2934
dy_vec = vec(dy)
3035
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
3136
dx_vec = Bᵀ(dc_vec)
3237
dx = reshape(dx_vec, size(x))
3338
df = NoTangent()
39+
dprep = @not_implemented("Tangents for mutable arguments are not defined")
3440
dargs = ntuple(unimplemented_tangent, N)
35-
return (df, project_x(dx), dargs...)
41+
return (df, dprep, project_x(dx), dargs...)
3642
end
3743

38-
return (y, z), implicit_pullback
44+
return (y, z), implicit_pullback_prepared
3945
end
4046

4147
function unimplemented_tangent(_)

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,36 @@ module ImplicitDifferentiationForwardDiffExt
22

33
using ADTypes: AutoForwardDiff
44
using ForwardDiff: Dual, Partials, partials, value
5-
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B
5+
using ImplicitDifferentiation:
6+
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B
67

78
function (implicit::ImplicitFunction)(
8-
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
9+
prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
910
) where {T,R,N}
1011
x = value.(x_and_dx)
1112
y, z = implicit(x, args...)
1213
c = implicit.conditions(x, y, z, args...)
1314

1415
suggested_backend = AutoForwardDiff()
15-
A = build_A(implicit, x, y, z, c, args...; suggested_backend)
16-
B = build_B(implicit, x, y, z, c, args...; suggested_backend)
16+
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
17+
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend)
1718

1819
dX = ntuple(Val(N)) do k
1920
partials.(x_and_dx, k)
2021
end
21-
dC_mat = mapreduce(hcat, dX) do dₖx
22+
dC_vec = map(dX) do dₖx
2223
dₖx_vec = vec(dₖx)
2324
dₖc_vec = B(dₖx_vec)
2425
return dₖc_vec
2526
end
26-
dY_mat = implicit.linear_solver(A, -dC_mat)
27+
dY = map(dC_vec) do dₖc_vec
28+
dₖy_vec = implicit.linear_solver(A, -dₖc_vec)
29+
dₖy = reshape(dₖy_vec, size(y))
30+
return dₖy
31+
end
2732

2833
y_and_dy = map(y, LinearIndices(y)) do yi, i
29-
Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N))))
34+
Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N))))
3035
end
3136

3237
return y_and_dy, z

src/ImplicitDifferentiation.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,21 @@ using DifferentiationInterface:
2121
pullback,
2222
pushforward!,
2323
pushforward
24-
using Krylov: Krylov
25-
using IterativeSolvers: IterativeSolvers
24+
using Krylov: Krylov, krylov_workspace, krylov_solve!, solution
2625
using LinearOperators: LinearOperator
2726
using LinearMaps: FunctionMap
2827
using LinearAlgebra: factorize
2928

3029
include("utils.jl")
3130
include("settings.jl")
32-
include("preparation.jl")
3331
include("implicit_function.jl")
32+
include("preparation.jl")
3433
include("execution.jl")
34+
include("callable.jl")
3535

3636
export MatrixRepresentation, OperatorRepresentation
37-
export IterativeLinearSolver
37+
export IterativeLinearSolver, DirectLinearSolver
3838
export ImplicitFunction
39+
export prepare_implicit
3940

4041
end

src/callable.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
2+
return implicit(ImplicitFunctionPreparation(), x, args...)
3+
end
4+
5+
function (implicit::ImplicitFunction)(
6+
::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N}
7+
) where {N}
8+
return implicit.solver(x, args...)
9+
end

0 commit comments

Comments
 (0)