Skip to content

Commit 2763c22

Browse files
authored
Merge pull request #251 from FluxML/examples
Update GCN model
2 parents 84f61a9 + 3faac7b commit 2763c22

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

examples/gcn.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using GeometricFlux
2+
using GraphSignals
23
using Flux
34
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
45
using Flux: @epochs
@@ -7,9 +8,6 @@ using Statistics
78
using SparseArrays
89
using Graphs.SimpleGraphs
910
using CUDA
10-
using Random
11-
12-
Random.seed!([0x6044b4da, 0xd873e4f9, 0x59d90c0a, 0xde01aa81])
1311

1412
@load "data/cora_features.jld2" features
1513
@load "data/cora_labels.jld2" labels
@@ -19,21 +17,25 @@ num_nodes = 2708
1917
num_features = 1433
2018
hidden = 16
2119
target_catg = 7
22-
epochs = 100
20+
epochs = 200
21+
λ = 5e-4
2322

2423
## Preprocessing data
2524
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2625
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
27-
fg = FeaturedGraph(g) |> gpu
26+
fg = FeaturedGraph(g) # pass to gpu together in model layers
2827

2928
## Model
3029
model = Chain(GCNConv(fg, num_features=>hidden, relu),
3130
Dropout(0.5),
3231
GCNConv(fg, hidden=>target_catg),
33-
) |> gpu
32+
) |> gpu;
33+
# do not show model architecture, showing CuSparseMatrix will trigger errors
3434

3535
## Loss
36-
loss(x, y) = logitcrossentropy(model(x), y)
36+
l2norm(x) = sum(abs2, x)
37+
# cross entropy with first layer L2 regularization
38+
loss(x, y) = logitcrossentropy(model(x), y) + λ*sum(l2norm, Flux.params(model[1]))
3739
accuracy(x, y) = mean(onecold(softmax(cpu(model(x)))) .== onecold(cpu(y)))
3840

3941

src/layers/conv.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
3636

3737
@functor GCNConv
3838

39+
Flux.trainable(l::GCNConv) = (l.weight, l.bias)
40+
3941
function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
4042
= Zygote.ignore() do
4143
GraphSignals.normalized_adjacency_matrix(fg, eltype(x); selfloop=true)
@@ -87,6 +89,8 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
8789

8890
@functor ChebConv
8991

92+
Flux.trainable(l::ChebConv) = (l.weight, l.bias)
93+
9094
function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
9195
GraphSignals.check_num_nodes(fg, X)
9296
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
@@ -155,6 +159,8 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) =
155159

156160
@functor GraphConv
157161

162+
Flux.trainable(l::GraphConv) = (l.weight1, l.weight2, l.bias)
163+
158164
message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
159165

160166
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)
@@ -224,6 +230,8 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...)
224230

225231
@functor GATConv
226232

233+
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
234+
227235
# Here the α that has not been softmaxed is the first number of the output message
228236
function message(gat::GATConv, x_i::AbstractVector, x_j::AbstractVector)
229237
x_i = reshape(gat.weight*x_i, :, gat.heads)
@@ -319,6 +327,8 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) =
319327

320328
@functor GatedGraphConv
321329

330+
Flux.trainable(l::GatedGraphConv) = (l.weight, l.gru)
331+
322332
message(ggc::GatedGraphConv, x_i, x_j::AbstractVector, e_ij) = x_j
323333

324334
update(ggc::GatedGraphConv, m::AbstractVector, x) = m
@@ -376,6 +386,8 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...)
376386

377387
@functor EdgeConv
378388

389+
Flux.trainable(l::EdgeConv) = (l.nn,)
390+
379391
message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
380392
update(ec::EdgeConv, m::AbstractVector, x) = m
381393

@@ -423,13 +435,13 @@ function GINConv(nn, eps::Real=0f0)
423435
GINConv(NullGraph(), nn, eps)
424436
end
425437

438+
@functor GINConv
439+
426440
Flux.trainable(g::GINConv) = (fg=g.fg, nn=g.nn)
427441

428442
message(g::GINConv, x_i::AbstractVector, x_j::AbstractVector) = x_j
429443
update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
430444

431-
@functor GINConv
432-
433445
function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
434446
gf = graph(fg)
435447
GraphSignals.check_num_nodes(gf, X)
@@ -474,6 +486,8 @@ end
474486

475487
@functor CGConv
476488

489+
Flux.trainable(l::CGConv) = (l.Wf, l.Ws, l.bf, l.bs)
490+
477491
function CGConv(fg::G, dims::NTuple{2,Int};
478492
init=glorot_uniform, bias=true, as_edge=false) where {G<:AbstractFeaturedGraph}
479493
node_dim, edge_dim = dims

0 commit comments

Comments
 (0)