Skip to content

Commit 5380155

Browse files
authored
feat: add least-squares linear solver (#185)
* feat: add least-squares solver * Fixes * Kwargs * Fix type-stability * Codecov v5 * Better tests * Allow Factorization
1 parent a884a90 commit 5380155

11 files changed

+115
-38
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
- uses: julia-actions/julia-buildpkg@v1
3333
- uses: julia-actions/julia-runtest@v1
3434
- uses: julia-actions/julia-processcoverage@v1
35-
- uses: codecov/codecov-action@v4
35+
- uses: codecov/codecov-action@v5
3636
with:
3737
files: lcov.info
3838
token: ${{ secrets.CODECOV_TOKEN }}

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ ImplicitFunction
2222
MatrixRepresentation
2323
OperatorRepresentation
2424
IterativeLinearSolver
25+
IterativeLeastSquaresSolver
2526
DirectLinearSolver
2627
```
2728

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
module ImplicitDifferentiationChainRulesCoreExt
22

3-
using ADTypes: AutoChainRules
3+
using ADTypes: AutoChainRules, AutoForwardDiff
44
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig
55
using ChainRulesCore: unthunk, @not_implemented
66
using ImplicitDifferentiation:
77
ImplicitDifferentiation,
88
ImplicitFunction,
99
ImplicitFunctionPreparation,
10+
IterativeLeastSquaresSolver,
11+
build_A,
1012
build_Aᵀ,
1113
build_Bᵀ,
12-
chainrules_suggested_backend
14+
suggested_forward_backend,
15+
suggested_reverse_backend
1316

1417
# not covered by Codecov for now
15-
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
18+
ImplicitDifferentiation.suggested_forward_backend(rc::RuleConfig) = AutoForwardDiff()
19+
ImplicitDifferentiation.suggested_reverse_backend(rc::RuleConfig) = AutoChainRules(rc)
1620

17-
struct ImplicitPullback{TA,TB,TL,TC,TP,Nargs}
21+
struct ImplicitPullback{Nargs,TA,TB,TA2,TL,TC,TP}
1822
Aᵀ::TA
1923
Bᵀ::TB
24+
A::TA2
2025
linear_solver::TL
2126
c0::TC
2227
project_x::TP
2328
_Nargs::Val{Nargs}
2429
end
2530

26-
function (pb::ImplicitPullback{TA,TB,TL,TC,TP,Nargs})((dy, dz)) where {TA,TB,TL,TP,TC,Nargs}
27-
(; Aᵀ, Bᵀ, linear_solver, c0, project_x) = pb
28-
dc = linear_solver(Aᵀ, -unthunk(dy), c0)
31+
function (pb::ImplicitPullback{Nargs})((dy, dz)) where {Nargs}
32+
(; Aᵀ, Bᵀ, A, linear_solver, c0, project_x) = pb
33+
dc = linear_solver(Aᵀ, A, -unthunk(dy), c0)
2934
dx = Bᵀ(dc)
3035
df = NoTangent()
3136
dargs = ntuple(unimplemented_tangent, Val(Nargs))
@@ -40,13 +45,19 @@ function ChainRulesCore.rrule(
4045
c = conditions(x, y, z, args...)
4146
c0 = zero(c)
4247

43-
suggested_backend = chainrules_suggested_backend(rc)
48+
forward_backend = suggested_forward_backend(rc)
49+
reverse_backend = suggested_reverse_backend(rc)
4450
prep = ImplicitFunctionPreparation(eltype(x))
45-
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
46-
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
51+
Aᵀ = build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend)
52+
Bᵀ = build_Bᵀ(implicit, prep, x, y, z, c, args...; suggested_backend=reverse_backend)
53+
if linear_solver isa IterativeLeastSquaresSolver
54+
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend=forward_backend)
55+
else
56+
A = nothing
57+
end
4758
project_x = ProjectTo(x)
4859

49-
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, linear_solver, c0, project_x, Val(N))
60+
implicit_pullback = ImplicitPullback(Aᵀ, Bᵀ, A, linear_solver, c0, project_x, Val(N))
5061
return (y, z), implicit_pullback
5162
end
5263

ext/ImplicitDifferentiationForwardDiffExt.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,36 @@ module ImplicitDifferentiationForwardDiffExt
33
using ADTypes: AutoForwardDiff
44
using ForwardDiff: Dual, Partials, partials, value
55
using ImplicitDifferentiation:
6-
ImplicitFunction, ImplicitFunctionPreparation, build_A, build_B
6+
ImplicitFunction,
7+
ImplicitFunctionPreparation,
8+
IterativeLeastSquaresSolver,
9+
build_A,
10+
build_Aᵀ,
11+
build_B
712

813
function (implicit::ImplicitFunction)(
914
prep::ImplicitFunctionPreparation{R}, x_and_dx::AbstractArray{Dual{T,R,N}}, args...
1015
) where {T,R,N}
16+
(; conditions, linear_solver) = implicit
1117
x = value.(x_and_dx)
1218
y, z = implicit(x, args...)
13-
c = implicit.conditions(x, y, z, args...)
19+
c = conditions(x, y, z, args...)
1420
y0 = zero(y)
1521

1622
suggested_backend = AutoForwardDiff()
1723
A = build_A(implicit, prep, x, y, z, c, args...; suggested_backend)
1824
B = build_B(implicit, prep, x, y, z, c, args...; suggested_backend)
19-
20-
dX = ntuple(Val(N)) do k
21-
partials.(x_and_dx, k)
25+
Aᵀ = if linear_solver isa IterativeLeastSquaresSolver
26+
build_Aᵀ(implicit, prep, x, y, z, c, args...; suggested_backend)
27+
else
28+
nothing
2229
end
30+
31+
dX = ntuple(k -> partials.(x_and_dx, k), Val(N))
2332
dC = map(B, dX)
2433
dY = map(dC) do dₖc
25-
dₖy = implicit.linear_solver(A, -dₖc, y0)
26-
return dₖy
34+
linear_solver(A, Aᵀ, -dₖc, y0)
2735
end
28-
2936
y_and_dy = map(y, LinearIndices(y)) do yi, i
3037
Dual{T}(yi, Partials(ntuple(k -> dY[k][i], Val(N))))
3138
end
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
module ImplicitDifferentiationZygoteExt
22

3-
using ADTypes: AutoZygote
3+
using ADTypes: AutoForwardDiff, AutoZygote
44
using ImplicitDifferentiation: ImplicitDifferentiation
55
using Zygote: ZygoteRuleConfig
66

7-
ImplicitDifferentiation.chainrules_suggested_backend(::ZygoteRuleConfig) = AutoZygote()
7+
ImplicitDifferentiation.suggested_forward_backend(::ZygoteRuleConfig) = AutoForwardDiff()
8+
ImplicitDifferentiation.suggested_reverse_backend(::ZygoteRuleConfig) = AutoZygote()
89

910
end

src/ImplicitDifferentiation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ using DifferentiationInterface:
1919
prepare_pushforward_same_point,
2020
pullback,
2121
pushforward
22-
using KrylovKit: linsolve
23-
using LinearAlgebra: factorize
22+
using KrylovKit: linsolve, lssolve
23+
using LinearAlgebra: Factorization, factorize
2424

2525
include("utils.jl")
2626
include("settings.jl")
@@ -30,7 +30,7 @@ include("execution.jl")
3030
include("callable.jl")
3131

3232
export MatrixRepresentation, OperatorRepresentation
33-
export IterativeLinearSolver, DirectLinearSolver
33+
export IterativeLinearSolver, IterativeLeastSquaresSolver, DirectLinearSolver
3434
export ImplicitFunction
3535

3636
end

src/implicit_function.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@ This requires solving a linear system `A * J = -B` where `A = ∂₂c`, `B = ∂
3434
## Keyword arguments
3535
3636
- `representation`: defines how the partial Jacobian `A` of the conditions with respect to the output is represented. It can be either [`MatrixRepresentation`](@ref) or [`OperatorRepresentation`](@ref).
37-
- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref) or [`IterativeLinearSolver`](@ref).
37+
- `linear_solver`: specifies how the linear system `A * J = -B` will be solved in the implicit function theorem. It can be either [`DirectLinearSolver`](@ref), [`IterativeLinearSolver`](@ref) or [`IterativeLeastSquaresSolver`](@ref).
3838
- `backends::AbstractADType`: specifies how the `conditions` will be differentiated with respect to `x` and `y`. It can be either, `nothing`, which means that the external autodiff system will be used, or a named tuple `(; x=AutoSomething(), y=AutoSomethingElse())` of backend objects from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
3939
- `strict::Val`: specifies whether preparation inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl) should enforce a strict match between the primal variables and the provided tangents.
4040
"""
4141
struct ImplicitFunction{
4242
F,
4343
C,
44-
L,
44+
L<:AbstractSolver,
4545
R<:AbstractRepresentation,
4646
B<:Union{
4747
Nothing, #

src/settings.jl

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,96 @@
11
## Linear solver
22

3+
abstract type AbstractSolver end
4+
35
"""
46
DirectLinearSolver
57
68
Specify that linear systems `Ax = b` should be solved with a direct method.
79
10+
!!! warning
11+
Can only be used when the `solver` and the `conditions` both output an `AbstractVector`.
12+
813
# See also
914
1015
- [`ImplicitFunction`](@ref)
1116
- [`IterativeLinearSolver`](@ref)
17+
- [`IterativeLeastSquaresSolver`](@ref)
1218
"""
13-
struct DirectLinearSolver end
19+
struct DirectLinearSolver <: AbstractSolver end
1420

15-
function (solver::DirectLinearSolver)(A, b::AbstractVector, x0::AbstractVector)
21+
function (solver::DirectLinearSolver)(
22+
A::Union{AbstractMatrix,Factorization}, _Aᵀ, b::AbstractVector, x0::AbstractVector
23+
)
1624
return A \ b
1725
end
1826

27+
abstract type AbstractIterativeSolver <: AbstractSolver end
28+
1929
"""
2030
IterativeLinearSolver
2131
2232
Specify that linear systems `Ax = b` should be solved with an iterative method.
2333
34+
!!! warning
35+
Can only be used when the `solver` and the `conditions` both output `AbstractArray`s with the same type and length.
36+
2437
# See also
2538
2639
- [`ImplicitFunction`](@ref)
2740
- [`DirectLinearSolver`](@ref)
41+
- [`IterativeLeastSquaresSolver`](@ref)
2842
"""
29-
struct IterativeLinearSolver{K}
43+
struct IterativeLinearSolver{K} <: AbstractIterativeSolver
3044
kwargs::K
3145
function IterativeLinearSolver(; kwargs...)
3246
return new{typeof(kwargs)}(kwargs)
3347
end
3448
end
3549

36-
function (solver::IterativeLinearSolver)(A, b, x0)
50+
function (solver::IterativeLinearSolver)(A, _Aᵀ, b, x0)
3751
sol, info = linsolve(A, b, x0; solver.kwargs...)
3852
@assert info.converged == 1
3953
return sol
4054
end
4155

42-
function Base.show(io::IO, linear_solver::IterativeLinearSolver)
56+
"""
57+
IterativeLeastSquaresSolver
58+
59+
Specify that linear systems `Ax = b` should be solved with an iterative least-squares method.
60+
61+
!!! tip
62+
Can be used when the `solver` and the `conditions` output `AbstractArray`s with different types or different lengths.
63+
64+
!!! warning
65+
To ensure performance, remember to specify both `backends` used to differentiate `condtions`.
66+
67+
# See also
68+
69+
- [`ImplicitFunction`](@ref)
70+
- [`DirectLinearSolver`](@ref)
71+
- [`IterativeLinearSolver`](@ref)
72+
"""
73+
struct IterativeLeastSquaresSolver{K} <: AbstractIterativeSolver
74+
kwargs::K
75+
function IterativeLeastSquaresSolver(; kwargs...)
76+
return new{typeof(kwargs)}(kwargs)
77+
end
78+
end
79+
80+
function (solver::IterativeLeastSquaresSolver)(A, Aᵀ, b, x0)
81+
sol, info = lssolve((A, Aᵀ), b; solver.kwargs...)
82+
@assert info.converged == 1
83+
return sol
84+
end
85+
86+
function Base.show(io::IO, linear_solver::AbstractIterativeSolver)
4387
(; kwargs) = linear_solver
44-
print(io, repr(IterativeLinearSolver; context=io), "(;")
88+
T = if linear_solver isa IterativeLinearSolver
89+
IterativeLinearSolver
90+
else
91+
IterativeLeastSquaresSolver
92+
end
93+
print(io, repr(T; context=io), "(;")
4594
for p in pairs(kwargs)
4695
print(io, " ", p[1], "=", repr(p[2]; context=io), ",")
4796
end
@@ -76,4 +125,7 @@ Specify that the matrix `A` involved in the implicit function theorem should be
76125
"""
77126
struct OperatorRepresentation <: AbstractRepresentation end
78127

79-
function chainrules_suggested_backend end
128+
## Backends
129+
130+
function suggested_forward_backend end
131+
function suggested_reverse_backend end

test/printing.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@ using TestItems
44
@test contains(string(ImplicitFunction(nothing, nothing)), "ImplicitFunction")
55
@test contains(string(IterativeLinearSolver()), "IterativeLinearSolver")
66
@test contains(string(IterativeLinearSolver(; rtol=1e-3)), "IterativeLinearSolver")
7+
@test contains(
8+
string(IterativeLeastSquaresSolver(; rtol=1e-3)), "IterativeLeastSquaresSolver"
9+
)
710
end

test/systematic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end;
2828
[
2929
IterativeLinearSolver(),
3030
IterativeLinearSolver(; rtol=1e-8),
31-
IterativeLinearSolver(; issymmetric=true, isposdef=true),
31+
IterativeLeastSquaresSolver(),
3232
],
3333
[nothing, (; x=AutoForwardDiff(), y=AutoZygote())],
3434
[float.(1:3), reshape(float.(1:6), 3, 2)],
@@ -53,7 +53,7 @@ end;
5353
solver=default_solver,
5454
conditions=default_conditions,
5555
x=x,
56-
implicit_kwargs=(; strict=Val(false)),
56+
implicit_kwargs=(; linear_solver=IterativeLeastSquaresSolver()),
5757
)
5858
scen2 = add_arg_mult(scen)
5959
test_implicit(scen)

0 commit comments

Comments
 (0)