1
- struct JVP!{F,P,B,I,C}
1
+ struct JVP!{F,P,B,I,V, C}
2
2
f:: F
3
3
prep:: P
4
4
backend:: B
5
5
input:: I
6
+ v_buffer:: V
6
7
contexts:: C
7
8
end
8
9
9
- struct VJP!{F,P,B,I,C}
10
+ struct VJP!{F,P,B,I,V, C}
10
11
f:: F
11
12
prep:: P
12
13
backend:: B
13
14
input:: I
15
+ v_buffer:: V
14
16
contexts:: C
15
17
end
16
18
17
- function (po:: JVP! )(res:: AbstractVector , v:: AbstractVector )
18
- (; f, backend, input, contexts, prep) = po
19
- pushforward! (f, (res,), prep, backend, input, (v,), contexts... )
19
+ function (po:: JVP! )(res:: AbstractVector , v_wrongtype:: AbstractVector )
20
+ (; f, backend, input, v_buffer, contexts, prep) = po
21
+ copyto! (v_buffer, v_wrongtype)
22
+ pushforward! (f, (res,), prep, backend, input, (v_buffer,), contexts... )
20
23
return res
21
24
end
22
25
23
- function (po:: VJP! )(res:: AbstractVector , v:: AbstractVector )
24
- (; f, backend, input, contexts, prep) = po
25
- pullback! (f, (res,), prep, backend, input, (v,), contexts... )
26
+ function (po:: VJP! )(res:: AbstractVector , v_wrongtype:: AbstractVector )
27
+ (; f, backend, input, v_buffer, contexts, prep) = po
28
+ copyto! (v_buffer, v_wrongtype)
29
+ pullback! (f, (res,), prep, backend, input, (v_buffer,), contexts... )
26
30
return res
27
31
end
28
32
29
33
# # A
30
34
31
35
function build_A (
32
36
implicit:: ImplicitFunction ,
37
+ prep:: ImplicitFunctionPreparation ,
33
38
x:: AbstractArray ,
34
39
y:: AbstractArray ,
35
40
z,
@@ -38,14 +43,15 @@ function build_A(
38
43
suggested_backend:: AbstractADType ,
39
44
)
40
45
return build_A_aux (
41
- implicit. representation, implicit, x, y, z, c, args... ; suggested_backend
46
+ implicit. representation, implicit, prep, x, y, z, c, args... ; suggested_backend
42
47
)
43
48
end
44
49
45
50
function build_A_aux (
46
- :: MatrixRepresentation , implicit, x, y, z, c, args... ; suggested_backend
51
+ :: MatrixRepresentation , implicit, prep, x, y, z, c, args... ; suggested_backend
47
52
)
48
- (; conditions, backends, prep_A) = implicit
53
+ (; conditions, backends) = implicit
54
+ (; prep_A) = prep
49
55
actual_backend = isnothing (backends) ? suggested_backend : backends. y
50
56
contexts = (Constant (x), Constant (z), map (Constant, args)... )
51
57
if isnothing (prep_A)
59
65
function build_A_aux (
60
66
:: OperatorRepresentation{package,symmetric,hermitian,posdef} ,
61
67
implicit,
68
+ prep,
62
69
x,
63
70
y,
64
71
z,
@@ -67,7 +74,8 @@ function build_A_aux(
67
74
suggested_backend,
68
75
) where {package,symmetric,hermitian,posdef}
69
76
T = Base. promote_eltype (x, y, c)
70
- (; conditions, backends, prep_A) = implicit
77
+ (; conditions, backends) = implicit
78
+ (; prep_A) = prep
71
79
actual_backend = isnothing (backends) ? suggested_backend : backends. y
72
80
contexts = (Constant (x), Constant (z), map (Constant, args)... )
73
81
f_vec = VecToVec (Switch12 (conditions), y)
@@ -82,9 +90,9 @@ function build_A_aux(
82
90
f_vec, prep_A, actual_backend, y_vec, (dy_vec,), contexts...
83
91
)
84
92
end
85
- prod! = JVP! (f_vec, prep_A_same, actual_backend, y_vec, contexts)
93
+ prod! = JVP! (f_vec, prep_A_same, actual_backend, y_vec, dy_vec, contexts)
86
94
if package == :LinearOperators
87
- return LinearOperator (T, length (c), length (y), symmetric, hermitian, prod!; )
95
+ return LinearOperator (T, length (c), length (y), symmetric, hermitian, prod!)
88
96
elseif package == :LinearMaps
89
97
return FunctionMap {T} (
90
98
prod!,
102
110
103
111
function build_Aᵀ (
104
112
implicit:: ImplicitFunction ,
113
+ prep:: ImplicitFunctionPreparation ,
105
114
x:: AbstractArray ,
106
115
y:: AbstractArray ,
107
116
z,
@@ -110,14 +119,15 @@ function build_Aᵀ(
110
119
suggested_backend:: AbstractADType ,
111
120
)
112
121
return build_Aᵀ_aux (
113
- implicit. representation, implicit, x, y, z, c, args... ; suggested_backend
122
+ implicit. representation, implicit, prep, x, y, z, c, args... ; suggested_backend
114
123
)
115
124
end
116
125
117
126
function build_Aᵀ_aux (
118
- :: MatrixRepresentation , implicit, x, y, z, c, args... ; suggested_backend
127
+ :: MatrixRepresentation , implicit, prep, x, y, z, c, args... ; suggested_backend
119
128
)
120
- (; conditions, backends, prep_Aᵀ) = implicit
129
+ (; conditions, backends) = implicit
130
+ (; prep_Aᵀ) = prep
121
131
actual_backend = isnothing (backends) ? suggested_backend : backends. y
122
132
contexts = (Constant (x), Constant (z), map (Constant, args)... )
123
133
if isnothing (prep_Aᵀ)
133
143
function build_Aᵀ_aux (
134
144
:: OperatorRepresentation{package,symmetric,hermitian,posdef} ,
135
145
implicit,
146
+ prep,
136
147
x,
137
148
y,
138
149
z,
@@ -141,7 +152,8 @@ function build_Aᵀ_aux(
141
152
suggested_backend,
142
153
) where {package,symmetric,hermitian,posdef}
143
154
T = Base. promote_eltype (x, y, c)
144
- (; conditions, backends, prep_Aᵀ) = implicit
155
+ (; conditions, backends) = implicit
156
+ (; prep_Aᵀ) = prep
145
157
actual_backend = isnothing (backends) ? suggested_backend : backends. y
146
158
contexts = (Constant (x), Constant (z), map (Constant, args)... )
147
159
f_vec = VecToVec (Switch12 (conditions), y)
@@ -156,9 +168,9 @@ function build_Aᵀ_aux(
156
168
f_vec, prep_Aᵀ, actual_backend, y_vec, (dc_vec,), contexts...
157
169
)
158
170
end
159
- prod! = VJP! (f_vec, prep_Aᵀ_same, actual_backend, y_vec, contexts)
171
+ prod! = VJP! (f_vec, prep_Aᵀ_same, actual_backend, y_vec, dc_vec, contexts)
160
172
if package == :LinearOperators
161
- return LinearOperator (T, length (y), length (c), symmetric, hermitian, prod!; )
173
+ return LinearOperator (T, length (y), length (c), symmetric, hermitian, prod!)
162
174
elseif package == :LinearMaps
163
175
return FunctionMap {T} (
164
176
prod!,
@@ -176,14 +188,16 @@ end
176
188
177
189
function build_B (
178
190
implicit:: ImplicitFunction ,
191
+ prep:: ImplicitFunctionPreparation ,
179
192
x:: AbstractArray ,
180
193
y:: AbstractArray ,
181
194
z,
182
195
c,
183
196
args... ;
184
197
suggested_backend:: AbstractADType ,
185
198
)
186
- (; conditions, backends, prep_B) = implicit
199
+ (; conditions, backends) = implicit
200
+ (; prep_B) = prep
187
201
actual_backend = isnothing (backends) ? suggested_backend : backends. x
188
202
contexts = (Constant (y), Constant (z), map (Constant, args)... )
189
203
f_vec = VecToVec (conditions, x)
@@ -198,7 +212,8 @@ function build_B(
198
212
f_vec, prep_B, actual_backend, x_vec, (dx_vec,), contexts...
199
213
)
200
214
end
201
- function B_fun (dx_vec)
215
+ function B_fun (dx_vec_wrongtype)
216
+ copyto! (dx_vec, dx_vec_wrongtype)
202
217
return pushforward (
203
218
f_vec, prep_B_same, actual_backend, x_vec, (dx_vec,), contexts...
204
219
)[1 ]
@@ -210,14 +225,16 @@ end
210
225
211
226
function build_Bᵀ (
212
227
implicit:: ImplicitFunction ,
228
+ prep:: ImplicitFunctionPreparation ,
213
229
x:: AbstractArray ,
214
230
y:: AbstractArray ,
215
231
z,
216
232
c,
217
233
args... ;
218
234
suggested_backend:: AbstractADType ,
219
235
)
220
- (; conditions, backends, prep_Bᵀ) = implicit
236
+ (; conditions, backends) = implicit
237
+ (; prep_Bᵀ) = prep
221
238
actual_backend = isnothing (backends) ? suggested_backend : backends. x
222
239
contexts = (Constant (y), Constant (z), map (Constant, args)... )
223
240
f_vec = VecToVec (conditions, x)
@@ -232,7 +249,8 @@ function build_Bᵀ(
232
249
f_vec, prep_Bᵀ, actual_backend, x_vec, (dc_vec,), contexts...
233
250
)
234
251
end
235
- function Bᵀ_fun (dc_vec)
252
+ function Bᵀ_fun (dc_vec_wrongtype)
253
+ copyto! (dc_vec, dc_vec_wrongtype)
236
254
return pullback (f_vec, prep_Bᵀ_same, actual_backend, x_vec, (dc_vec,), contexts... )[1 ]
237
255
end
238
256
return Bᵀ_fun
0 commit comments