Skip to content

Commit 85bbc5d

Browse files
committed
gcn with fized graph example
1 parent b783ad7 commit 85bbc5d

File tree

4 files changed

+118
-20
lines changed

4 files changed

+118
-20
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ makedocs(
2424
"Tutorials" =>
2525
[
2626
"Semi-supervised learning with GCN" => "tutorials/semisupervised_gcn.md",
27+
"GCN with Fixed Graph" => "tutorials/gcn_fixed_graph.md",
2728
],
2829
"Abstractions" =>
2930
["Message passing scheme" => "abstractions/msgpass.md",

docs/src/tutorials/gcn_fixed_graph.md

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# GCN with Fixed Graph
2+
3+
In the tutorial for semi-supervised learning with GCN, variable graphs are provided to GNN from `FeaturedGraph`, which contains a graph and node features. Each `FeaturedGraph` object can contain different graph and different node features, and can be train on the same GNN model. However, variable graph doesn't have the proper form of graph structure with respect to GNN layers and this lead to inefficient training/inference process. Fixed graph strategy can be used to train a GNN model with the same graph structure in GeometricFlux.
4+
5+
## Fixed Graph
6+
7+
A fixed graph is given to a layer by `WithGraph` syntax. `WithGraph` wrap a `FeaturedGraph` object and a GNN layer as first and second arguments, respectively.
8+
9+
```julia
10+
fg = FeaturedGraph(graph)
11+
WithGraph(fg, GCNConv(1024=>256, relu))
12+
```
13+
14+
This way, we can customize by binding different graph to certain layer and the layer will specialize graph to a required form. For example, a `GCNConv` layer requires graph in the form of normalized adjacency matrix. Once the graph is bound to a `GCNConv` layer, it transforms graph into normalized adjacency matrix and stores in `WithGraph` object. It accelerates training or inference by avoiding calculating transformations. The features in `FeaturedGraph` object in `WithGraph` are not used in any layer or model training or inference.
15+
16+
## Array in, Array out
17+
18+
With this approach, a GNN layer accepts features in array. It takes an array as input and outputs array. Thus, a GNN layer wrapped with `WithGraph` should accept a feature array, just like regular deep learning model.
19+
20+
## Batch Learning
21+
22+
Since features are in the form of array, they can be batched up for batched learning. We will demonstrate how to achieve these goals.
23+
24+
## Step 1: Load Dataset
25+
26+
Different from loading datasets in semi-supervised learning example, we use `alldata` for supervised learning here and `padding=true` is added in order to padding features from partial nodes to pseudo-full nodes. A padded features contains zeros in the nodes that are not supposed to be train on.
27+
28+
```julia
29+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
30+
```
31+
32+
We need graph and node indices for training as well.
33+
34+
```julia
35+
g = graphdata(Planetoid(), dataset)
36+
train_idx = 1:size(train_X, 2)
37+
```
38+
39+
## Step 2: Batch up Features and Labels
40+
41+
In order to make batch learning available, we separate graph and node features. We don't subgraph here. Node features are batched up by repeating node features here for demonstration, since planetoid dataset doesn't have batched settings. Different repeat numbers can be specified by `train_repeats` and `train_repeats`.
42+
43+
```julia
44+
fg = FeaturedGraph(g)
45+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
46+
```
47+
48+
## Step 3: Build a GCN model
49+
50+
Here comes to building a GCN model. We build a model as building a regular Flux model but just wrap `GCNConv` layer with `WithGraph`.
51+
52+
```julia
53+
model = Chain(
54+
WithGraph(fg, GCNConv(args.input_dim=>args.hidden_dim, relu)),
55+
Dropout(0.5),
56+
WithGraph(fg, GCNConv(args.hidden_dim=>args.target_dim)),
57+
)
58+
```
59+
60+
## Step 4: Loss Functions and Accuracy
61+
62+
Almost all codes are the same as in semi-supervised learning example, except that indices for subgraphing are needed to get partial features out for calculating loss.
63+
64+
```julia
65+
l2norm(x) = sum(abs2, x)
66+
67+
function model_loss(model, λ, X, y, idx)
68+
loss = logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
69+
loss += λ*sum(l2norm, Flux.params(model[1]))
70+
return loss
71+
end
72+
```
73+
74+
And the accuracy measurement also needs indices.
75+
76+
```julia
77+
function accuracy(model, X::AbstractArray, y::AbstractArray, idx)
78+
return mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]))
79+
end
80+
81+
accuracy(model, loader::DataLoader, device, idx) = mean(accuracy(model, X |> device, y |> device, idx) for (X, y) in loader)
82+
```
83+
84+
## Step 5: Training GCN Model
85+
86+
```julia
87+
train_loader, test_loader, fg, train_idx, test_idx = load_data(:cora, args.batch_size)
88+
89+
# optimizer
90+
opt = ADAM(args.η)
91+
92+
# parameters
93+
ps = Flux.params(model)
94+
95+
# training
96+
train_steps = 0
97+
@info "Start Training, total $(args.epochs) epochs"
98+
for epoch = 1:args.epochs
99+
@info "Epoch $(epoch)"
100+
101+
for (X, y) in train_loader
102+
grad = gradient(() -> model_loss(model, args.λ, X |> device, y |> device, train_idx |> device), ps)
103+
Flux.Optimise.update!(opt, ps, grad)
104+
train_steps += 1
105+
end
106+
end
107+
```
108+
109+
Now we could just train the GCN model directly!

examples/gcn_with_fixed_graph.jl

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,18 @@ using ProgressMeter: Progress, next!
1212
using Statistics
1313
using Random
1414

15-
function load_data(dataset, batch_size)
16-
# (train_X, train_y) dim: (num_features, target_dim) × 1708
17-
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset))
18-
# (test_X, test_y) dim: (num_features, target_dim) × 1000
19-
test_X, test_y = map(x -> Matrix(x), testdata(Planetoid(), dataset))
15+
function load_data(dataset, batch_size, train_repeats=512, test_repeats=32)
16+
# (train_X, train_y) dim: (num_features, target_dim) × 2708
17+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
18+
# (test_X, test_y) dim: (num_features, target_dim) × 2708
19+
test_X, test_y = map(x -> Matrix(x), testdata(Planetoid(), dataset, padding=true))
2020
g = graphdata(Planetoid(), dataset)
2121
train_idx = 1:size(train_X, 2)
2222
test_idx = test_indices(Planetoid(), dataset)
2323

24-
# padding zeros
25-
tr_X = zeros(Float32, size(train_X, 1), size(train_X, 2) + size(test_X, 2))
26-
te_X = zeros(Float32, size(test_X, 1), size(train_X, 2) + size(test_X, 2))
27-
tr_y = zeros(Float32, size(train_y, 1), size(train_y, 2) + size(test_y, 2))
28-
te_y = zeros(Float32, size(test_y, 1), size(train_y, 2) + size(test_y, 2))
29-
tr_X[:, train_idx] .= train_X
30-
te_X[:, test_idx] .= test_X
31-
tr_y[:, train_idx] .= train_y
32-
te_y[:, test_idx] .= test_y
33-
3424
fg = FeaturedGraph(g)
35-
train_data = (repeat(tr_X, outer=(1,1,256)), repeat(tr_y, outer=(1,1,256)))
36-
test_data = (repeat(te_X, outer=(1,1,32)), repeat(te_y, outer=(1,1,32)))
25+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
26+
test_data = (repeat(test_X, outer=(1,1,test_repeats)), repeat(test_y, outer=(1,1,test_repeats)))
3727
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
3828
test_loader = DataLoader(test_data, batchsize=batch_size, shuffle=true)
3929
return train_loader, test_loader, fg, train_idx, test_idx
@@ -42,8 +32,7 @@ end
4232
@with_kw mutable struct Args
4333
η = 0.01 # learning rate
4434
λ = 5f-4 # regularization paramater
45-
batch_size = 32 # batch size
46-
num_nodes = 2708 # number of nodes for graph
35+
batch_size = 64 # batch size
4736
epochs = 200 # number of epochs
4837
seed = 0 # random seed
4938
cuda = true # use GPU

examples/semisupervised_gcn.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ end
3535
η = 0.01 # learning rate
3636
λ = 5f-4 # regularization paramater
3737
batch_size = 32 # batch size
38-
num_nodes = 2708 # number of nodes for graph
3938
epochs = 200 # number of epochs
4039
seed = 0 # random seed
4140
cuda = true # use GPU

0 commit comments

Comments
 (0)