Skip to content

Commit 87c2dcf

Browse files
authored
Merge pull request #269 from FluxML/gae
Fix GAE example
2 parents ceb6eab + caebf43 commit 87c2dcf

File tree

5 files changed

+157
-70
lines changed

5 files changed

+157
-70
lines changed

examples/gae.jl

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,112 @@
1-
using GeometricFlux
2-
using GraphSignals
1+
using CUDA
32
using Flux
4-
using Flux: throttle
3+
using Flux: onecold
54
using Flux.Losses: logitbinarycrossentropy
6-
using Flux: @epochs
7-
using JLD2
8-
using Statistics: mean
9-
using SparseArrays
10-
using Graphs.SimpleGraphs
11-
using CUDA
5+
using Flux.Data: DataLoader
6+
using GeometricFlux
7+
using GeometricFlux.Datasets
8+
using GraphSignals
9+
using Parameters: @with_kw
10+
using ProgressMeter: Progress, next!
11+
using Statistics
12+
using Random
13+
14+
function load_data(dataset, batch_size, train_repeats=128)
15+
# (train_X, train_y) dim: (num_features, target_dim) × 1708
16+
train_X, _ = map(x -> Matrix(x), alldata(Planetoid(), dataset))
17+
# (test_X, test_y) dim: (num_features, target_dim) × 1000
18+
test_X, _ = map(x -> Matrix(x), testdata(Planetoid(), dataset))
19+
g = graphdata(Planetoid(), dataset)
20+
21+
X = hcat(train_X, test_X)
22+
fg = FeaturedGraph(g)
23+
A = GraphSignals.adjacency_matrix(fg)
24+
data = (repeat(X, outer=(1,1,train_repeats)), repeat(A, outer=(1,1,train_repeats)))
25+
loader = DataLoader(data, batchsize=batch_size, shuffle=true)
26+
return loader, fg
27+
end
28+
29+
@with_kw mutable struct Args
30+
η = 0.01 # learning rate
31+
batch_size = 16 # batch size
32+
epochs = 200 # number of epochs
33+
seed = 0 # random seed
34+
cuda = true # use GPU
35+
input_dim = 1433 # input dimension
36+
hidden1_dim = 32 # hidden1 dimension
37+
hidden2_dim = 16 # hidden1 dimension
38+
end
39+
40+
## Loss: binary cross entropy
41+
model_loss(model, X, A) = logitbinarycrossentropy(model(X), A)
42+
43+
function precision(model, X::AbstractArray, A::AbstractArray)
44+
= onecold(softmax(cpu(model(X))))
45+
y = onecold(cpu(A))
46+
return mean(y[ŷ .== true])
47+
end
48+
49+
precision(model, loader::DataLoader, device) =
50+
mean(precision(model, X |> device, A |> device) for (X, A) in loader)
51+
52+
function train(; kws...)
53+
# load hyperparamters
54+
args = Args(; kws...)
55+
args.seed > 0 && Random.seed!(args.seed)
56+
57+
# GPU config
58+
if args.cuda && CUDA.has_cuda()
59+
device = gpu
60+
@info "Training on GPU"
61+
else
62+
device = cpu
63+
@info "Training on CPU"
64+
end
65+
66+
# load Cora from Planetoid dataset
67+
loader, fg = load_data(:cora, args.batch_size)
68+
69+
# build model
70+
encoder = Chain(
71+
WithGraph(fg, GCNConv(args.input_dim=>args.hidden1_dim, relu)),
72+
Dropout(0.5),
73+
WithGraph(fg, GCNConv(args.hidden1_dim=>args.hidden2_dim)),
74+
)
1275

13-
CUDA.allowscalar(false)
76+
model = GAE(encoder, σ) |> device
1477

15-
@load "data/cora_features.jld2" features
16-
@load "data/cora_graph.jld2" g
78+
# ADAM optimizer
79+
opt = ADAM(args.η)
80+
81+
# parameters
82+
ps = Flux.params(model)
1783

18-
num_nodes = 2708
19-
num_features = 1433
20-
hidden1 = 32
21-
hidden2 = 16
22-
target_catg = 7
23-
epochs = 200
84+
# training
85+
train_steps = 0
86+
@info "Start Training, total $(args.epochs) epochs"
87+
for epoch = 1:args.epochs
88+
@info "Epoch $(epoch)"
89+
progress = Progress(length(loader))
2490

25-
## Preprocessing data
26-
fg = FeaturedGraph(g) # pass to gpu together in model layers
27-
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
28-
train_y = fg |> GraphSignals.adjacency_matrix |> gpu # dim: num_nodes * num_nodes
91+
for (X, A) in loader
92+
loss, back = Flux.pullback(ps) do
93+
model_loss(model, X |> device, A |> device)
94+
end
95+
prec = precision(model, loader, device)
96+
grad = back(1f0)
97+
Flux.Optimise.update!(opt, ps, grad)
2998

30-
## Model
31-
encoder = Chain(GCNConv(fg, num_features=>hidden1, relu),
32-
GCNConv(fg, hidden1=>hidden2))
33-
model = Chain(GAE(encoder, σ)) |> gpu;
34-
# do not show model architecture, showing CuSparseMatrix will trigger errors
99+
# progress meter
100+
next!(progress; showvalues=[
101+
(:loss, loss),
102+
(:precision, prec),
103+
])
35104

36-
## Loss
37-
loss(x, y) = logitbinarycrossentropy(model(x), y)
105+
train_steps += 1
106+
end
107+
end
38108

39-
## Training
40-
ps = Flux.params(model)
41-
train_data = [(train_X, train_y)]
42-
opt = ADAM(0.01)
43-
evalcb() = @show(loss(train_X, train_y))
109+
return model, args
110+
end
44111

45-
@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
112+
model, args = train()

src/GeometricFlux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ export
5050
GAE,
5151
VGAE,
5252
InnerProductDecoder,
53-
VariationalEncoder,
53+
VariationalGraphEncoder,
5454

5555
# layer/utils
5656
WithGraph,

src/models.jl

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""
2-
GAE(enc[, σ])
2+
GAE(enc, [σ=identity])
33
44
Graph autoencoder.
55
66
# Arguments
77
- `enc`: encoder. It can be any graph convolutional layer.
8+
- `σ`: Activation function for decoder.
89
910
Encoder is specified by user and decoder will be `InnerProductDecoder` layer.
1011
"""
11-
struct GAE{T,S}
12+
struct GAE{T,S} <: AbstractGraphLayer
1213
encoder::T
1314
decoder::S
1415
end
@@ -17,11 +18,14 @@ GAE(enc, σ::Function=identity) = GAE(enc, InnerProductDecoder(σ))
1718

1819
@functor GAE
1920

20-
function (g::GAE)(X::AbstractMatrix)
21-
Z = g.encoder(X)
22-
A = g.decoder(Z)
23-
A
24-
end
21+
# For variable graph
22+
(l::GAE)(fg::AbstractFeaturedGraph) = fg |> l.encoder |> l.decoder
23+
24+
# For static graph
25+
WithGraph(fg::AbstractFeaturedGraph, l::GAE) = GAE(WithGraph(fg, l.encoder), l.decoder)
26+
27+
(l::GAE)(X::AbstractMatrix) = X |> l.encoder |> l.decoder
28+
(l::GAE)(X::AbstractArray) = X |> l.encoder |> l.decoder
2529

2630

2731
"""
@@ -34,22 +38,24 @@ Variational graph autoencoder.
3438
3539
Encoder is specified by user and decoder will be `InnerProductDecoder` layer.
3640
"""
37-
struct VGAE{T,S}
41+
struct VGAE{T,S} <: AbstractGraphLayer
3842
encoder::T
3943
decoder::S
4044
end
4145

4246
function VGAE(enc, h_dim::Integer, z_dim::Integer, σ::Function=identity)
43-
VGAE(VariationalEncoder(enc, h_dim, z_dim), InnerProductDecoder(σ))
47+
VGAE(VariationalGraphEncoder(enc, h_dim, z_dim), InnerProductDecoder(σ))
4448
end
4549

4650
@functor VGAE
4751

48-
function (g::VGAE)(fg::FeaturedGraph)
49-
fg_ = g.encoder(fg)
50-
fg_ = g.decoder(fg_)
51-
fg_
52-
end
52+
# For variable graph
53+
(l::VGAE)(fg::AbstractFeaturedGraph) = fg |> l.encoder |> l.decoder
54+
55+
# For static graph
56+
WithGraph(fg::AbstractFeaturedGraph, l::VGAE) = VGAE(WithGraph(fg, l.encoder), l.decoder)
57+
58+
(l::VGAE)(X::AbstractArray) = X |> l.encoder |> l.decoder
5359

5460

5561
"""
@@ -60,13 +66,15 @@ Inner-product decoder layer.
6066
# Arguments
6167
- `σ`: activation function.
6268
"""
63-
struct InnerProductDecoder
64-
σ
69+
struct InnerProductDecoder{F}
70+
σ::F
6571
end
6672

6773
@functor InnerProductDecoder
6874

6975
(i::InnerProductDecoder)(Z::AbstractMatrix)::AbstractMatrix = i.σ.(Z'*Z)
76+
(i::InnerProductDecoder)(Z::AbstractArray)::AbstractArray =
77+
i.σ.(NNlib.batched_mul(NNlib.batched_transpose(Z), Z))
7078

7179
function (i::InnerProductDecoder)(fg::FeaturedGraph)
7280
Z = node_feature(fg)
@@ -76,9 +84,9 @@ end
7684

7785

7886
"""
79-
VariationalEncoder(nn, h_dim, z_dim)
87+
VariationalGraphEncoder(nn, h_dim, z_dim)
8088
81-
Variational encoder layer.
89+
Variational graph encoder layer.
8290
8391
# Arguments
8492
- `nn`: neural network. It can be any graph convolutional layer.
@@ -87,33 +95,44 @@ Variational encoder layer.
8795
8896
Encoder can be any graph convolutional layer.
8997
"""
90-
struct VariationalEncoder
91-
nn
92-
μ
93-
logσ
94-
z_dim::Integer
98+
struct VariationalGraphEncoder{L,M,S,T<:Integer} <: AbstractGraphLayer
99+
nn::L
100+
μ::M
101+
logσ::S
102+
z_dim::T
95103
end
96104

97-
function VariationalEncoder(nn, h_dim::Integer, z_dim::Integer)
98-
VariationalEncoder(nn,
105+
function VariationalGraphEncoder(nn, h_dim::Integer, z_dim::Integer)
106+
VariationalGraphEncoder(nn,
99107
GCNConv(h_dim=>z_dim),
100108
GCNConv(h_dim=>z_dim),
101109
z_dim)
102110
end
103111

104-
@functor VariationalEncoder
112+
@functor VariationalGraphEncoder
105113

106-
function (ve::VariationalEncoder)(fg::FeaturedGraph)::FeaturedGraph
114+
function (ve::VariationalGraphEncoder)(fg::FeaturedGraph)::FeaturedGraph
107115
μ, logσ = summarize(ve, fg)
108116
Z = sample(μ, logσ)
109117
FeaturedGraph(fg, nf=Z)
110118
end
111119

112-
function summarize(ve::VariationalEncoder, fg::FeaturedGraph)
120+
function summarize(ve::VariationalGraphEncoder, fg::FeaturedGraph)
113121
fg_ = ve.nn(fg)
114122
fg_μ, fg_logσ = ve.μ(fg_), ve.logσ(fg_)
115123
node_feature(fg_μ), node_feature(fg_logσ)
116124
end
117125

118126
sample::AbstractArray{T}, logσ::AbstractArray{T}) where {T<:Real} =
119127
μ + exp.(logσ) .* randn(T, size(logσ))
128+
129+
# For static graph
130+
WithGraph(fg::AbstractFeaturedGraph, l::VariationalGraphEncoder) =
131+
VariationalGraphEncoder(
132+
WithGraph(fg, l.nn),
133+
WithGraph(fg, l.μ),
134+
WithGraph(fg, l.logσ),
135+
l.z_dim
136+
)
137+
138+
# (l::VariationalGraphEncoder)(X::AbstractArray) = X |> l.encoder |> l.decoder

test/models.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
in_channel = 3
33
out_channel = 5
44
N = 4
5+
batch = 4
56
T = Float32
67
adj = T[0. 1. 0. 1.;
78
1. 0. 1. 0.;
@@ -11,7 +12,7 @@
1112
fg = FeaturedGraph(adj)
1213

1314
@testset "GAE" begin
14-
gc = GCNConv(fg, in_channel=>out_channel)
15+
gc = WithGraph(fg, GCNConv(in_channel=>out_channel))
1516
gae = GAE(gc)
1617
X = rand(T, in_channel, N)
1718
Y = gae(X)
@@ -21,9 +22,9 @@
2122
@testset "VGAE" begin
2223
@testset "InnerProductDecoder" begin
2324
ipd = InnerProductDecoder(identity)
24-
X = rand(T, 1, N)
25+
X = rand(T, 1, N, batch)
2526
Y = ipd(X)
26-
@test size(Y) == (N, N)
27+
@test size(Y) == (N, N, batch)
2728

2829
X = rand(T, 1, N)
2930
fg = FeaturedGraph(adj, nf=X)
@@ -38,10 +39,10 @@
3839
@test size(Y) == (N, N)
3940
end
4041

41-
@testset "VariationalEncoder" begin
42+
@testset "VariationalGraphEncoder" begin
4243
z_dim = 2
4344
gc = GCNConv(in_channel=>out_channel)
44-
ve = VariationalEncoder(gc, out_channel, z_dim)
45+
ve = VariationalGraphEncoder(gc, out_channel, z_dim)
4546
X = rand(T, in_channel, N)
4647
fg = FeaturedGraph(adj, nf=X)
4748
fg_ = ve(fg)

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ tests = [
2727
"layers/utils",
2828
"sampling",
2929
"embedding/node2vec",
30-
# "models",
30+
"models",
3131
]
3232

3333
if CUDA.functional()

0 commit comments

Comments
 (0)