|
32 | 32 | fgt_ = gc(fgt)
|
33 | 33 | @test size(node_feature(fgt_)) == (out_channel, N)
|
34 | 34 |
|
35 |
| - g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
| 35 | + g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
36 | 36 | @test length(g.grads) == 4
|
37 | 37 | end
|
38 | 38 |
|
|
45 | 45 | Y = gc(Xt)
|
46 | 46 | @test size(Y) == (out_channel, N)
|
47 | 47 |
|
48 |
| - g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc)) |
| 48 | + g = gradient(() -> sum(gc(X)), Flux.params(gc)) |
49 | 49 | @test length(g.grads) == 3
|
50 | 50 | end
|
51 | 51 |
|
|
85 | 85 | fgt_ = cc(fgt)
|
86 | 86 | @test size(node_feature(fgt_)) == (out_channel, N)
|
87 | 87 |
|
88 |
| - g = Zygote.gradient(() -> sum(node_feature(cc(fg))), Flux.params(cc)) |
| 88 | + g = gradient(() -> sum(node_feature(cc(fg))), Flux.params(cc)) |
89 | 89 | @test length(g.grads) == 4
|
90 | 90 | end
|
91 | 91 |
|
|
98 | 98 | Y = cc(Xt)
|
99 | 99 | @test size(Y) == (out_channel, N)
|
100 | 100 |
|
101 |
| - g = Zygote.gradient(() -> sum(cc(X)), Flux.params(cc)) |
| 101 | + g = gradient(() -> sum(cc(X)), Flux.params(cc)) |
102 | 102 | @test length(g.grads) == 2
|
103 | 103 | end
|
104 | 104 |
|
|
121 | 121 | @test size(node_feature(fg_)) == (out_channel, N)
|
122 | 122 | @test_throws MethodError gc(X)
|
123 | 123 |
|
124 |
| - g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
| 124 | + g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
125 | 125 | @test length(g.grads) == 5
|
126 | 126 | end
|
127 | 127 |
|
|
131 | 131 | Y = gc(X)
|
132 | 132 | @test size(Y) == (out_channel, N, batch_size)
|
133 | 133 |
|
134 |
| - g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc)) |
| 134 | + g = gradient(() -> sum(gc(X)), Flux.params(gc)) |
135 | 135 | @test length(g.grads) == 3
|
136 | 136 | end
|
137 | 137 |
|
|
168 | 168 | @test size(node_feature(fg_)) == (concat ? (out_channel * heads, N) : (out_channel, N))
|
169 | 169 | @test_throws MethodError gat(X)
|
170 | 170 |
|
171 |
| - g = Zygote.gradient(() -> sum(node_feature(gat(fg_gat))), Flux.params(gat)) |
| 171 | + g = gradient(() -> sum(node_feature(gat(fg_gat))), Flux.params(gat)) |
172 | 172 | @test length(g.grads) == 5
|
173 | 173 | end
|
174 | 174 | end
|
|
181 | 181 | Y = gat(X)
|
182 | 182 | @test size(Y) == (concat ? (out_channel * heads, N, batch_size) : (out_channel, N, batch_size))
|
183 | 183 |
|
184 |
| - g = Zygote.gradient(() -> sum(gat(X)), Flux.params(gat)) |
| 184 | + g = gradient(() -> sum(gat(X)), Flux.params(gat)) |
185 | 185 | @test length(g.grads) == 3
|
186 | 186 | end
|
187 | 187 | end
|
|
221 | 221 | @test size(node_feature(fg_)) == (concat ? (out_channel * heads, N) : (out_channel, N))
|
222 | 222 | @test_throws MethodError gat2(X)
|
223 | 223 |
|
224 |
| - g = Zygote.gradient(() -> sum(node_feature(gat2(fg_gat))), Flux.params(gat2)) |
| 224 | + g = gradient(() -> sum(node_feature(gat2(fg_gat))), Flux.params(gat2)) |
225 | 225 | @test length(g.grads) == 7
|
226 | 226 | end
|
227 | 227 | end
|
|
234 | 234 | Y = gat2(X)
|
235 | 235 | @test size(Y) == (concat ? (out_channel * heads, N, batch_size) : (out_channel, N, batch_size))
|
236 | 236 |
|
237 |
| - g = Zygote.gradient(() -> sum(gat2(X)), Flux.params(gat2)) |
| 237 | + g = gradient(() -> sum(gat2(X)), Flux.params(gat2)) |
238 | 238 | @test length(g.grads) == 5
|
239 | 239 | end
|
240 | 240 | end
|
|
258 | 258 | @test size(node_feature(fg_)) == (out_channel, N)
|
259 | 259 | @test_throws MethodError ggc(X)
|
260 | 260 |
|
261 |
| - g = Zygote.gradient(() -> sum(node_feature(ggc(fg))), Flux.params(ggc)) |
| 261 | + g = gradient(() -> sum(node_feature(ggc(fg))), Flux.params(ggc)) |
262 | 262 | @test length(g.grads) == 8
|
263 | 263 | end
|
264 | 264 |
|
|
268 | 268 | @test_broken Y = ggc(X)
|
269 | 269 | @test_broken size(Y) == (out_channel, N, batch_size)
|
270 | 270 |
|
271 |
| - @test_broken g = Zygote.gradient(() -> sum(ggc(X)), Flux.params(ggc)) |
| 271 | + @test_broken g = gradient(() -> sum(ggc(X)), Flux.params(ggc)) |
272 | 272 | @test_broken length(g.grads) == 6
|
273 | 273 | end
|
274 | 274 | end
|
|
283 | 283 | @test size(node_feature(fg_)) == (out_channel, N)
|
284 | 284 | @test_throws MethodError ec(X)
|
285 | 285 |
|
286 |
| - g = Zygote.gradient(() -> sum(node_feature(ec(fg))), Flux.params(ec)) |
| 286 | + g = gradient(() -> sum(node_feature(ec(fg))), Flux.params(ec)) |
287 | 287 | @test length(g.grads) == 4
|
288 | 288 | end
|
289 | 289 |
|
|
293 | 293 | Y = ec(X)
|
294 | 294 | @test size(Y) == (out_channel, N, batch_size)
|
295 | 295 |
|
296 |
| - g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec)) |
| 296 | + g = gradient(() -> sum(ec(X)), Flux.params(ec)) |
297 | 297 | @test length(g.grads) == 2
|
298 | 298 | end
|
299 | 299 |
|
|
303 | 303 | Y = ec(X)
|
304 | 304 | @test size(Y) == (out_channel, N)
|
305 | 305 |
|
306 |
| - g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec)) |
| 306 | + g = gradient(() -> sum(ec(X)), Flux.params(ec)) |
307 | 307 | @test length(g.grads) == 2
|
308 | 308 | end
|
309 | 309 |
|
|
313 | 313 | Y = ec(X)
|
314 | 314 | @test size(Y) == (out_channel, N, batch_size)
|
315 | 315 |
|
316 |
| - g = Zygote.gradient(() -> sum(ec(X)), Flux.params(ec)) |
| 316 | + g = gradient(() -> sum(ec(X)), Flux.params(ec)) |
317 | 317 | @test length(g.grads) == 2
|
318 | 318 | end
|
319 | 319 | end
|
|
333 | 333 | @test size(node_feature(fg_)) == (out_channel, N)
|
334 | 334 | @test_throws MethodError gc(X)
|
335 | 335 |
|
336 |
| - g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
| 336 | + g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc)) |
337 | 337 | @test length(g.grads) == 5
|
338 | 338 | end
|
339 | 339 |
|
|
343 | 343 | Y = gc(X)
|
344 | 344 | @test size(Y) == (out_channel, N, batch_size)
|
345 | 345 |
|
346 |
| - g = Zygote.gradient(() -> sum(gc(X)), Flux.params(gc)) |
| 346 | + g = gradient(() -> sum(gc(X)), Flux.params(gc)) |
347 | 347 | @test length(g.grads) == 2
|
348 | 348 | end
|
349 | 349 | end
|
|
362 | 362 | fg_ = cgc(fg)
|
363 | 363 | @test_throws MethodError cgc(nf)
|
364 | 364 |
|
365 |
| - g = Zygote.gradient(() -> sum(node_feature(cgc(fg))), Flux.params(cgc)) |
| 365 | + g = gradient(() -> sum(node_feature(cgc(fg))), Flux.params(cgc)) |
366 | 366 | @test length(g.grads) == 6
|
367 | 367 | end
|
368 | 368 |
|
|
373 | 373 | Y = cgc(nf, ef)
|
374 | 374 | @test size(Y) == (in_channel, N, batch_size)
|
375 | 375 |
|
376 |
| - g = Zygote.gradient(() -> sum(cgc(nf, ef)), Flux.params(cgc)) |
| 376 | + g = gradient(() -> sum(cgc(nf, ef)), Flux.params(cgc)) |
377 | 377 | @test length(g.grads) == 4
|
378 | 378 | end
|
379 | 379 | end
|
|
391 | 391 | @test size(node_feature(fg_)) == (out_channel, N)
|
392 | 392 | @test_throws MethodError l(X)
|
393 | 393 |
|
394 |
| - g = Zygote.gradient(() -> sum(node_feature(l(fg))), Flux.params(l)) |
| 394 | + g = gradient(() -> sum(node_feature(l(fg))), Flux.params(l)) |
395 | 395 | if l.proj == identity
|
396 | 396 | if conv == LSTMAggregator
|
397 | 397 | @test length(g.grads) == 10
|
|
414 | 414 | Y = l(X)
|
415 | 415 | @test size(Y) == (out_channel, N, batch_size)
|
416 | 416 |
|
417 |
| - g = Zygote.gradient(() -> sum(l(X)), Flux.params(l)) |
| 417 | + g = gradient(() -> sum(l(X)), Flux.params(l)) |
418 | 418 | if l.layer.proj == identity
|
419 | 419 | @test length(g.grads) == 3
|
420 | 420 | else
|
|
0 commit comments