Skip to content

refactor!: split out preparation #179

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek"]
version = "0.8.1"
version = "0.9.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"
Expand All @@ -33,11 +32,10 @@ Documenter = "1.12.0"
ExplicitImports = "1"
FiniteDiff = "2.27.0"
ForwardDiff = "0.10.36, 1"
IterativeSolvers = "0.9.4"
JET = "0.9, 0.10"
JuliaFormatter = "2.1.2"
Krylov = "0.9.6, 0.10"
LinearAlgebra = "1.10"
LinearAlgebra = "1"
LinearMaps = "3.11.4"
LinearOperators = "2.8.0"
NLsolve = "4.5.1"
Expand Down Expand Up @@ -65,6 +63,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -76,4 +75,27 @@ TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "DifferentiationInterface", "Documenter", "ExplicitImports", "FiniteDiff", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Random", "SparseArrays", "StaticArrays", "Test", "TestItems", "TestItemRunner", "Zygote"]
test = [
"ADTypes",
"Aqua",
"ChainRulesCore",
"ChainRulesTestUtils",
"ComponentArrays",
"DifferentiationInterface",
"Documenter",
"ExplicitImports",
"FiniteDiff",
"ForwardDiff",
"JET",
"JuliaFormatter",
"LinearAlgebra",
"NLsolve",
"Optim",
"Random",
"SparseArrays",
"StaticArrays",
"Test",
"TestItems",
"TestItemRunner",
"Zygote",
]
2 changes: 2 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ ImplicitFunction
MatrixRepresentation
OperatorRepresentation
IterativeLinearSolver
DirectLinearSolver
prepare_implicit
```

## Internals
Expand Down
3 changes: 1 addition & 2 deletions docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ However, this can be switched to any other "inner" backend compatible with [Diff
### Arrays

Functions that eat or spit out arbitrary arrays are supported, as long as the forward mapping _and_ conditions return arrays of the same size.

If you deal with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) for increased performance.
The array types involved should be mutable.

### Scalars

Expand Down
4 changes: 2 additions & 2 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end

function conditions_components(x::ComponentVector, y::ComponentVector, _z)
c_d, c_e = conditions_components_aux(x.a, x.b, x.m, y.d, y.e)
c = ComponentVector(; c_d=c_d, c_e=c_e)
c = ComponentVector(; d=c_d, e=c_e)
return c
end;

Expand All @@ -50,7 +50,7 @@ implicit_components = ImplicitFunction(forward_components, conditions_components

a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0
x = ComponentVector(; a=a, b=b, m=m)
implicit_components(x)
y, z = implicit_components(x)

# And it works with both ForwardDiff.jl and Zygote.jl

Expand Down
18 changes: 12 additions & 6 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ChainRulesCore: unthunk, @not_implemented
using ImplicitDifferentiation:
ImplicitDifferentiation,
ImplicitFunction,
ImplicitFunctionPreparation,
build_Aᵀ,
build_Bᵀ,
chainrules_suggested_backend
Expand All @@ -14,28 +15,33 @@ using ImplicitDifferentiation:
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)

function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
rc::RuleConfig,
implicit::ImplicitFunction,
prep::ImplicitFunctionPreparation,
x::AbstractArray,
args::Vararg{Any,N};
) where {N}
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)

suggested_backend = chainrules_suggested_backend(rc)
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, x, y, z, c, args...; suggested_backend)
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
project_x = ProjectTo(x)

function implicit_pullback((dy, dz))
function implicit_pullback_prepared((dy, dz))
dy = unthunk(dy)
dy_vec = vec(dy)
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
dx_vec = Bᵀ(dc_vec)
dx = reshape(dx_vec, size(x))
df = NoTangent()
dprep = @not_implemented("Tangents for mutable arguments are not defined")
dargs = ntuple(unimplemented_tangent, N)
return (df, project_x(dx), dargs...)
return (df, dprep, project_x(dx), dargs...)
end

return (y, z), implicit_pullback
return (y, z), implicit_pullback_prepared
end

function unimplemented_tangent(_)
Expand Down
19 changes: 12 additions & 7 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,36 @@ module ImplicitDifferentiationForwardDiffExt

using ADTypes: AutoForwardDiff
using ForwardDiff: Dual, Partials, partials, value
using ImplicitDifferentiation: ImplicitFunction, build_A, build_B
using ImplicitDifferentiation:
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B

function (implicit::ImplicitFunction)(
x_and_dx::AbstractArray{Dual{T,R,N}}, args...
prep::ImplicitFunctionPreparation, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
) where {T,R,N}
x = value.(x_and_dx)
y, z = implicit(x, args...)
c = implicit.conditions(x, y, z, args...)

suggested_backend = AutoForwardDiff()
A = build_A(implicit, x, y, z, c, args...; suggested_backend)
B = build_B(implicit, x, y, z, c, args...; suggested_backend)
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend)

dX = ntuple(Val(N)) do k
partials.(x_and_dx, k)
end
dC_mat = mapreduce(hcat, dX) do dₖx
dC_vec = map(dX) do dₖx
dₖx_vec = vec(dₖx)
dₖc_vec = B(dₖx_vec)
return dₖc_vec
end
dY_mat = implicit.linear_solver(A, -dC_mat)
dY = map(dC_vec) do dₖc_vec
dₖy_vec = implicit.linear_solver(A, -dₖc_vec)
dₖy = reshape(dₖy_vec, size(y))
return dₖy
end

y_and_dy = map(y, LinearIndices(y)) do yi, i
Dual{T}(yi, Partials(ntuple(k -> dY_mat[i, k], Val(N))))
Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N))))
end

return y_and_dy, z
Expand Down
9 changes: 5 additions & 4 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,21 @@ using DifferentiationInterface:
pullback,
pushforward!,
pushforward
using Krylov: Krylov
using IterativeSolvers: IterativeSolvers
using Krylov: Krylov, krylov_workspace, krylov_solve!, solution
using LinearOperators: LinearOperator
using LinearMaps: FunctionMap
using LinearAlgebra: factorize

include("utils.jl")
include("settings.jl")
include("preparation.jl")
include("implicit_function.jl")
include("preparation.jl")
include("execution.jl")
include("callable.jl")

export MatrixRepresentation, OperatorRepresentation
export IterativeLinearSolver
export IterativeLinearSolver, DirectLinearSolver
export ImplicitFunction
export prepare_implicit

end
9 changes: 9 additions & 0 deletions src/callable.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function (implicit::ImplicitFunction)(x::AbstractArray, args::Vararg{Any,N}) where {N}
return implicit(ImplicitFunctionPreparation(), x, args...)
end

function (implicit::ImplicitFunction)(
::ImplicitFunctionPreparation, x::AbstractArray, args::Vararg{Any,N}
) where {N}
return implicit.solver(x, args...)
end
Loading
Loading