Skip to content

Commit 18c7e97

Browse files
github-actions[bot]CompatHelper Juliagdalle
authored
Bump compat for DifferentiationInterface to 0.4 (#146)
* CompatHelper: bump compat for DifferentiationInterface to 0.4, (keep existing compat) * Adapt to same point preparation * Remove DI 1.3 * Fix imports * Fix imports --------- Co-authored-by: CompatHelper Julia <compathelper_noreply@julialang.org> Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com>
1 parent 225eb46 commit 18c7e97

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
2323
[compat]
2424
ADTypes = "1.0"
2525
ChainRulesCore = "1.23.0"
26-
DifferentiationInterface = "0.3"
26+
DifferentiationInterface = "0.4"
2727
Enzyme = "0.11.20,0.12"
2828
ForwardDiff = "0.10.36"
2929
Krylov = "0.9.5"

src/ImplicitDifferentiation.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@ module ImplicitDifferentiation
99

1010
using ADTypes: AbstractADType
1111
using DifferentiationInterface:
12-
jacobian, prepare_pushforward, prepare_pullback, pushforward!, value_and_pullback!_split
12+
jacobian,
13+
prepare_pushforward_same_point,
14+
prepare_pullback_same_point,
15+
pullback!,
16+
pushforward!
1317
using Krylov: block_gmres, gmres
1418
using LinearOperators: LinearOperator
1519
using LinearAlgebra: factorize, lu

src/operators.jl

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,21 @@ function (po::PushforwardOperator!)(res, v, α, β)
9191
return res
9292
end
9393

94-
struct PullbackOperator!{PB,R}
95-
pullbackfunc!::PB
94+
struct PullbackOperator!{F,B,X,E,R}
95+
f::F
96+
backend::B
97+
x::X
98+
extras::E
9699
res_backup::R
97100
end
98101

99102
function (po::PullbackOperator!)(res, v, α, β)
100103
if iszero(β)
101-
po.pullbackfunc!(res, v)
104+
pullback!(po.f, res, po.backend, po.x, v, po.extras)
105+
res .= α .* res
102106
else
103107
po.res_backup .= res
104-
po.pullbackfunc!(res, v)
108+
pullback!(po.f, res, po.backend, po.x, v, po.extras)
105109
res .= α .* res .+ β .+ po.res_backup
106110
end
107111
return res
@@ -121,7 +125,7 @@ function build_A(
121125
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
122126
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
123127
if lazy
124-
extras = prepare_pushforward(cond_y, back_y, y, similar(y))
128+
extras = prepare_pushforward_same_point(cond_y, back_y, y, zero(y))
125129
A = LinearOperator(
126130
eltype(y),
127131
m,
@@ -152,15 +156,14 @@ function build_Aᵀ(
152156
back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend
153157
cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...)
154158
if lazy
155-
extras = prepare_pullback(cond_y, back_y, y, similar(y))
156-
_, pullbackfunc! = value_and_pullback!_split(cond_y, back_y, y, extras)
159+
extras = prepare_pullback_same_point(cond_y, back_y, y, zero(y))
157160
Aᵀ = LinearOperator(
158161
eltype(y),
159162
m,
160163
m,
161164
false,
162165
false,
163-
PullbackOperator!(pullbackfunc!, similar(y)),
166+
PullbackOperator!(cond_y, back_y, y, extras, similar(y)),
164167
typeof(y),
165168
)
166169
else
@@ -184,7 +187,7 @@ function build_B(
184187
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
185188
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
186189
if lazy
187-
extras = prepare_pushforward(cond_x, back_x, x, similar(x))
190+
extras = prepare_pushforward_same_point(cond_x, back_x, x, zero(x))
188191
B = LinearOperator(
189192
eltype(y),
190193
m,
@@ -214,15 +217,14 @@ function build_Bᵀ(
214217
back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend
215218
cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...)
216219
if lazy
217-
extras = prepare_pullback(cond_x, back_x, x, similar(y))
218-
_, pullbackfunc! = value_and_pullback!_split(cond_x, back_x, x, extras)
220+
extras = prepare_pullback_same_point(cond_x, back_x, x, zero(y))
219221
Bᵀ = LinearOperator(
220222
eltype(y),
221223
n,
222224
m,
223225
false,
224226
false,
225-
PullbackOperator!(pullbackfunc!, similar(y)),
227+
PullbackOperator!(cond_x, back_x, x, extras, similar(x)),
226228
typeof(x),
227229
)
228230
else

0 commit comments

Comments
 (0)