Skip to content

Commit 12f3908

Browse files
authored
Merge pull request #250 from FluxML/fix
Resolve gradient bug for GatedGraphConv
2 parents c8396e4 + fbc699d commit 12f3908

File tree

3 files changed

+15
-15
lines changed

3 files changed

+15
-15
lines changed

src/layers/conv.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,15 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
324324
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
325325
adj = adjacency_list(fg)
326326
if m < ggc.out_ch
327-
Hpad = similar(H, S, ggc.out_ch - m, n)
328-
H = vcat(H, fill!(Hpad, 0))
327+
Hpad = Zygote.ignore() do
328+
fill!(similar(H, S, ggc.out_ch - m, n), 0)
329+
end
330+
H = vcat(H, Hpad)
329331
end
330332
for i = 1:ggc.num_layers
331333
M = view(ggc.weight, :, :, i) * H
332334
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +)
333-
H, _ = ggc.gru(H, M) # BUG: FluxML/Flux.jl#1381
335+
H, _ = ggc.gru(H, M)
334336
end
335337
H
336338
end

test/layers/conv.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
6060
@test size(node_feature(fgt_)) == (out_channel, N)
6161

6262
g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1]
63-
@test size(g[].nf) == size(X)
63+
@test size(g.nf) == size(X)
6464

6565
g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1]
6666
@test size(g.weight) == size(gc.weight)
@@ -118,7 +118,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
118118
@test size(node_feature(fgt_)) == (out_channel, N)
119119

120120
g = Zygote.gradient(x -> sum(node_feature(cc(x))), fg)[1]
121-
@test size(g[].nf) == size(X)
121+
@test size(g.nf) == size(X)
122122

123123
g = Zygote.gradient(model -> sum(node_feature(model(fg))), cc)[1]
124124
@test size(g.weight) == size(cc.weight)
@@ -174,7 +174,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
174174
@test size(node_feature(fgt_)) == (out_channel, N)
175175

176176
g = Zygote.gradient(x -> sum(node_feature(gc(x))), fg)[1]
177-
@test size(g[].nf) == size(X)
177+
@test size(g.nf) == size(X)
178178

179179
g = Zygote.gradient(model -> sum(node_feature(model(fg))), gc)[1]
180180
@test size(g.weight1) == size(gc.weight1)
@@ -245,7 +245,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
245245
@test size(node_feature(fgt_)) == (concat ? (out_channel*heads, N) : (out_channel, N))
246246

247247
g = Zygote.gradient(x -> sum(node_feature(gat(x))), fg_gat)[1]
248-
@test size(g[].nf) == size(X)
248+
@test size(g.nf) == size(X)
249249

250250
g = Zygote.gradient(model -> sum(node_feature(model(fg_gat))), gat)[1]
251251
@test size(g.weight) == size(gat.weight)
@@ -299,7 +299,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
299299
@test size(node_feature(fgt_)) == (out_channel, N)
300300

301301
g = Zygote.gradient(x -> sum(node_feature(ggc(x))), fg)[1]
302-
@test size(g[].nf) == size(X)
302+
@test size(g.nf) == size(X)
303303

304304
g = Zygote.gradient(model -> sum(node_feature(model(fg))), ggc)[1]
305305
@test size(g.weight) == size(ggc.weight)
@@ -342,7 +342,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
342342
@test size(node_feature(fgt_)) == (out_channel, N)
343343

344344
g = Zygote.gradient(x -> sum(node_feature(ec(x))), fg)[1]
345-
@test size(g[].nf) == size(X)
345+
@test size(g.nf) == size(X)
346346

347347
g = Zygote.gradient(model -> sum(node_feature(model(fg))), ec)[1]
348348
@test size(g.nn.weight) == size(ec.nn.weight)
@@ -371,7 +371,7 @@ fg_single_vertex = FeaturedGraph(adj_single_vertex)
371371

372372
g = Zygote.gradient(x -> sum(node_feature(gc(x))),
373373
FeaturedGraph(adj, nf=X))[1]
374-
@test size(g.x.nf) == size(X)
374+
@test size(g.nf) == size(X)
375375

376376
g = Zygote.gradient(model -> sum(node_feature(model(FeaturedGraph(adj, nf=X)))),
377377
gc)[1]

test/runtests.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
using GeometricFlux
22
using GeometricFlux.Datasets
3+
using CUDA
34
using Flux
45
using Flux: @functor
56
using FillArrays
67
using GraphSignals
78
using LightGraphs: SimpleGraph, SimpleDiGraph, add_edge!, nv, ne
89
using LinearAlgebra
9-
using NNlib
10+
using NNlib, NNlibCUDA
1011
using SparseArrays: SparseMatrixCSC
1112
using Statistics: mean
1213
using Zygote
@@ -26,10 +27,7 @@ tests = [
2627
"models",
2728
]
2829

29-
if Flux.use_cuda[]
30-
using CUDA
31-
using Flux: gpu
32-
using NNlibCUDA
30+
if CUDA.functional()
3331
append!(tests, cuda_tests)
3432
else
3533
@warn "CUDA unavailable, not testing GPU support"

0 commit comments

Comments
 (0)