Skip to content

Commit 42ccbf2

Browse files
authored
Merge pull request #343 from FluxML/gatedgraphconv
Fix GatedGraphConv
2 parents f1e442e + dc71ab7 commit 42ccbf2

File tree

3 files changed

+23
-17
lines changed

3 files changed

+23
-17
lines changed

src/layers/graph_conv.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -556,17 +556,23 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
556556
H = vcat(H, Hpad)
557557
end
558558
for i = 1:l.num_layers
559-
M = _matmul(view(l.weight, :, :, i), H)
559+
M = _matmul(l.weight[:, :, i], H)
560560
_, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing)
561-
H, _ = l.gru(H, M)
561+
H = apply_gru(l.gru, H, M)
562562
end
563563
return H
564564
end
565565

566+
function apply_gru(gru, H::AbstractArray, M::AbstractArray)
567+
H′ = apply_gru(gru, reshape(H, size(H, 1), :), reshape(M, size(M, 1), :))
568+
return reshape(H′, size(H′, 1), size(H)[2:end]...)
569+
end
570+
571+
apply_gru(gru, H::AbstractMatrix, M::AbstractMatrix) = gru(H, M)[1]
572+
566573
function Base.show(io::IO, l::GatedGraphConv)
567574
print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)")
568-
print(io, ", aggr=", l.aggr)
569-
print(io, ")")
575+
print(io, ", aggr=", l.aggr, ")")
570576
end
571577

572578

test/cuda/graph_conv.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
in_channel = 3
44
out_channel = 5
55
batch_size = 10
6-
6+
77
N = 4
88
adj = T[0 1 0 1;
99
1 0 1 0;
@@ -48,7 +48,7 @@
4848
@test size(cc.weight) == (out_channel, in_channel, k)
4949
@test size(cc.bias) == (out_channel,)
5050
@test cc.k == k
51-
51+
5252
fg = FeaturedGraph(adj, nf=X) |> gpu
5353
fg_ = cc(fg)
5454
@test size(node_feature(fg_)) == (out_channel, N)
@@ -106,12 +106,12 @@
106106
@test size(gat.weight) == (out_channel * heads, in_channel)
107107
@test size(gat.bias) == (out_channel * heads,)
108108
@test size(gat.a) == (2*out_channel, heads)
109-
109+
110110
X = rand(T, in_channel, N)
111111
fg = FeaturedGraph(adj, nf=X) |> gpu
112112
fg_ = gat(fg)
113113
@test size(node_feature(fg_)) == (out_channel * heads, N)
114-
114+
115115
g = gradient(() -> sum(node_feature(gat(fg))), Flux.params(gat))
116116
@test length(g.grads) == 5
117117
end
@@ -121,7 +121,7 @@
121121
gat = WithGraph(fg, GATConv(in_channel=>out_channel, heads=2)) |> gpu
122122
Y = gat(X |> gpu)
123123
@test size(Y) == (out_channel * heads, N, batch_size)
124-
124+
125125
g = gradient(() -> sum(gat(X |> gpu)), Flux.params(gat))
126126
@test length(g.grads) == 4
127127
end
@@ -145,11 +145,11 @@
145145
@testset "layer with static graph" begin
146146
X = rand(T, in_channel, N, batch_size)
147147
ggc = WithGraph(fg, GatedGraphConv(out_channel, num_layers)) |> gpu
148-
@test_broken Y = ggc(X |> gpu)
149-
@test_broken size(Y) == (out_channel, N, batch_size)
148+
Y = ggc(X |> gpu)
149+
@test size(Y) == (out_channel, N, batch_size)
150150

151-
@test_broken g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
152-
@test_broken length(g.grads) == 6
151+
g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
152+
@test length(g.grads) == 7
153153
end
154154
end
155155

test/layers/graph_conv.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,11 @@
265265
@testset "layer with static graph" begin
266266
X = rand(T, in_channel, N, batch_size)
267267
ggc = WithGraph(fg, GatedGraphConv(out_channel, num_layers))
268-
@test_broken Y = ggc(X)
269-
@test_broken size(Y) == (out_channel, N, batch_size)
268+
Y = ggc(X)
269+
@test size(Y) == (out_channel, N, batch_size)
270270

271-
@test_broken g = gradient(() -> sum(ggc(X)), Flux.params(ggc))
272-
@test_broken length(g.grads) == 6
271+
g = gradient(() -> sum(ggc(X)), Flux.params(ggc))
272+
@test length(g.grads) == 6
273273
end
274274
end
275275

0 commit comments

Comments
 (0)