@@ -91,17 +91,21 @@ function (po::PushforwardOperator!)(res, v, α, β)
91
91
return res
92
92
end
93
93
94
- struct PullbackOperator!{PB,R}
95
- pullbackfunc!:: PB
94
+ struct PullbackOperator!{F,B,X,E,R}
95
+ f:: F
96
+ backend:: B
97
+ x:: X
98
+ extras:: E
96
99
res_backup:: R
97
100
end
98
101
99
102
function (po:: PullbackOperator! )(res, v, α, β)
100
103
if iszero (β)
101
- po. pullbackfunc! (res, v)
104
+ pullback! (po. f, res, po. backend, po. x, v, po. extras)
105
+ res .= α .* res
102
106
else
103
107
po. res_backup .= res
104
- po . pullbackfunc! ( res, v )
108
+ pullback! (po . f, res, po . backend, po . x, v, po . extras )
105
109
res .= α .* res .+ β .+ po. res_backup
106
110
end
107
111
return res
@@ -121,7 +125,7 @@ function build_A(
121
125
back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
122
126
cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
123
127
if lazy
124
- extras = prepare_pushforward (cond_y, back_y, y, similar (y))
128
+ extras = prepare_pushforward_same_point (cond_y, back_y, y, zero (y))
125
129
A = LinearOperator (
126
130
eltype (y),
127
131
m,
@@ -152,15 +156,14 @@ function build_Aᵀ(
152
156
back_y = isnothing (conditions_y_backend) ? suggested_backend : conditions_y_backend
153
157
cond_y = ConditionsY (conditions, x, y_or_yz, args... ; kwargs... )
154
158
if lazy
155
- extras = prepare_pullback (cond_y, back_y, y, similar (y))
156
- _, pullbackfunc! = value_and_pullback!_split (cond_y, back_y, y, extras)
159
+ extras = prepare_pullback_same_point (cond_y, back_y, y, zero (y))
157
160
Aᵀ = LinearOperator (
158
161
eltype (y),
159
162
m,
160
163
m,
161
164
false ,
162
165
false ,
163
- PullbackOperator! (pullbackfunc! , similar (y)),
166
+ PullbackOperator! (cond_y, back_y, y, extras , similar (y)),
164
167
typeof (y),
165
168
)
166
169
else
@@ -184,7 +187,7 @@ function build_B(
184
187
back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
185
188
cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
186
189
if lazy
187
- extras = prepare_pushforward (cond_x, back_x, x, similar (x))
190
+ extras = prepare_pushforward_same_point (cond_x, back_x, x, zero (x))
188
191
B = LinearOperator (
189
192
eltype (y),
190
193
m,
@@ -214,15 +217,14 @@ function build_Bᵀ(
214
217
back_x = isnothing (conditions_x_backend) ? suggested_backend : conditions_x_backend
215
218
cond_x = ConditionsX (conditions, x, y_or_yz, args... ; kwargs... )
216
219
if lazy
217
- extras = prepare_pullback (cond_x, back_x, x, similar (y))
218
- _, pullbackfunc! = value_and_pullback!_split (cond_x, back_x, x, extras)
220
+ extras = prepare_pullback_same_point (cond_x, back_x, x, zero (y))
219
221
Bᵀ = LinearOperator (
220
222
eltype (y),
221
223
n,
222
224
m,
223
225
false ,
224
226
false ,
225
- PullbackOperator! (pullbackfunc!, similar (y )),
227
+ PullbackOperator! (cond_x, back_x, x, extras, similar (x )),
226
228
typeof (x),
227
229
)
228
230
else
0 commit comments