Skip to content

Commit f1e442e

Browse files
authored
Merge pull request #340 from FluxML/refactor
Fix tests and refactor
2 parents a864b61 + bda4e84 commit f1e442e

File tree

2 files changed

+25
-33
lines changed

2 files changed

+25
-33
lines changed

src/layers/graph_conv.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function (l::ChebConv)(fg::AbstractFeaturedGraph)
128128
nf = node_feature(fg)
129129
GraphSignals.check_num_nodes(fg, nf)
130130
@assert size(nf, 1) == size(l.weight, 2) "Input feature size must match input channel size."
131-
131+
132132
= ChainRulesCore.ignore_derivatives() do
133133
GraphSignals.scaled_laplacian(fg, eltype(nf))
134134
end
@@ -245,7 +245,7 @@ Graph attentional layer.
245245
- `out`: The dimension of output features.
246246
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
247247
- `σ`: Activation function.
248-
- `heads`: Number attention heads
248+
- `heads`: Number attention heads
249249
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
250250
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
251251
@@ -280,7 +280,7 @@ end
280280

281281
function GATConv(ch::Pair{Int,Int}, σ=identity; heads::Int=1, concat::Bool=true,
282282
negative_slope=0.2f0, init=glorot_uniform, bias::Bool=true)
283-
in, out = ch
283+
in, out = ch
284284
W = init(out*heads, in)
285285
b = Flux.create_bias(W, bias, out*heads)
286286
a = init(2*out, heads)
@@ -372,7 +372,7 @@ Graph attentional layer v2.
372372
- `in`: The dimension of input features.
373373
- `out`: The dimension of output features.
374374
- `σ`: Activation function.
375-
- `heads`: Number attention heads
375+
- `heads`: Number attention heads
376376
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
377377
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
378378
@@ -661,7 +661,7 @@ GINConv(nn, eps=0f0) = GINConv(nn, eps)
661661

662662
Flux.trainable(g::GINConv) = (g.nn,)
663663

664-
message(g::GINConv, x_i::AbstractArray, x_j::AbstractArray) = x_j
664+
message(g::GINConv, x_i::AbstractArray, x_j::AbstractArray) = x_j
665665
update(g::GINConv, m::AbstractArray, x::AbstractArray) = g.nn((1 + g.eps) * x + m)
666666

667667
# For variable graph
@@ -705,30 +705,24 @@ CGConv(node dim=128, edge dim=32)
705705
706706
See also [`WithGraph`](@ref) for training layer with static graph.
707707
"""
708-
struct CGConv{A<:AbstractMatrix,B} <: MessagePassing
709-
Wf::A
710-
Ws::A
711-
bf::B
712-
bs::B
708+
struct CGConv{A,B} <: MessagePassing
709+
f::A
710+
s::B
713711
end
714712

715713
@functor CGConv
716714

717-
Flux.trainable(l::CGConv) = (l.Wf, l.Ws, l.bf, l.bs)
718-
719715
function CGConv(dims::NTuple{2,Int}; init=glorot_uniform, bias=true)
720716
node_dim, edge_dim = dims
721-
Wf = init(node_dim, 2*node_dim + edge_dim)
722-
Ws = init(node_dim, 2*node_dim + edge_dim)
723-
bf = Flux.create_bias(Wf, bias, node_dim)
724-
bs = Flux.create_bias(Ws, bias, node_dim)
725-
return CGConv(Wf, Ws, bf, bs)
717+
f = Dense(2*node_dim + edge_dim, node_dim; bias=bias, init=init)
718+
s = Dense(2*node_dim + edge_dim, node_dim; bias=bias, init=init)
719+
return CGConv(f, s)
726720
end
727721

728-
function message(c::CGConv, x_i::AbstractArray, x_j::AbstractArray, e::AbstractArray)
722+
function message(l::CGConv, x_i::AbstractArray, x_j::AbstractArray, e::AbstractArray)
729723
z = vcat(x_i, x_j, e)
730724

731-
return σ.(_matmul(c.Wf, z) .+ c.bf) .* softplus.(_matmul(c.Ws, z) .+ c.bs)
725+
return σ.(l.f(z)) .* softplus.(l.s(z))
732726
end
733727

734728
update(c::CGConv, m::AbstractArray, x) = x + m
@@ -752,7 +746,7 @@ function (l::CGConv)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
752746
end
753747

754748
function Base.show(io::IO, l::CGConv)
755-
node_dim, d = size(l.Wf)
749+
node_dim, d = size(l.f.weight)
756750
edge_dim = d - 2*node_dim
757751
print(io, "CGConv(node dim=", node_dim, ", edge dim=", edge_dim, ")")
758752
end

test/layers/graph_conv.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
@test size(cc.weight) == (out_channel, in_channel, k)
7575
@test size(cc.bias) == (out_channel,)
7676
@test cc.k == k
77-
77+
7878
fg = FeaturedGraph(adj, nf=X)
7979
fg_ = cc(fg)
8080
@test size(node_feature(fg_)) == (out_channel, N)
@@ -90,7 +90,7 @@
9090
end
9191

9292
@testset "layer with static graph" begin
93-
cc = WithGraph(fg, ChebConv(in_channel=>out_channel, k))
93+
cc = WithGraph(fg, ChebConv(in_channel=>out_channel, k))
9494
Y = cc(X)
9595
@test size(Y) == (out_channel, N)
9696

@@ -286,7 +286,7 @@
286286
g = gradient(() -> sum(node_feature(ec(fg))), Flux.params(ec))
287287
@test length(g.grads) == 4
288288
end
289-
289+
290290
@testset "layer with static graph" begin
291291
X = rand(T, in_channel, N, batch_size)
292292
ec = WithGraph(fg, EdgeConv(Dense(2*in_channel, out_channel)))
@@ -322,10 +322,8 @@
322322
nn = Flux.Chain(Dense(in_channel, out_channel))
323323
eps = 0.001
324324
@testset "layer without graph" begin
325-
gc = GraphConv(in_channel=>out_channel)
326-
@test size(gc.weight1) == (out_channel, in_channel)
327-
@test size(gc.weight2) == (out_channel, in_channel)
328-
@test size(gc.bias) == (out_channel,)
325+
gc = GINConv(nn, eps)
326+
@test gc.nn == nn
329327

330328
X = rand(T, in_channel, N)
331329
fg = FeaturedGraph(adj, nf=X)
@@ -334,7 +332,7 @@
334332
@test_throws MethodError gc(X)
335333

336334
g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
337-
@test length(g.grads) == 5
335+
@test length(g.grads) == 4
338336
end
339337

340338
@testset "layer with static graph" begin
@@ -351,10 +349,10 @@
351349
@testset "CGConv" begin
352350
@testset "layer without graph" begin
353351
cgc = CGConv((in_channel, in_channel_edge))
354-
@test size(cgc.Wf) == (in_channel, 2 * in_channel + in_channel_edge)
355-
@test size(cgc.Ws) == (in_channel, 2 * in_channel + in_channel_edge)
356-
@test size(cgc.bf) == (in_channel,)
357-
@test size(cgc.bs) == (in_channel,)
352+
@test size(cgc.f.weight) == (in_channel, 2 * in_channel + in_channel_edge)
353+
@test size(cgc.s.weight) == (in_channel, 2 * in_channel + in_channel_edge)
354+
@test size(cgc.f.bias) == (in_channel,)
355+
@test size(cgc.s.bias) == (in_channel,)
358356

359357
nf = rand(T, in_channel, N)
360358
ef = rand(T, in_channel_edge, E)
@@ -403,7 +401,7 @@
403401
end
404402
end
405403
end
406-
404+
407405
@testset "layer with static graph" begin
408406
for conv in aggregators
409407
X = rand(T, in_channel, N, batch_size)

0 commit comments

Comments
 (0)