Skip to content

Commit 9f1e4b9

Browse files
authored
Merge pull request #283 from FluxML/doc
Update docs and fix GATv2Conv
2 parents 63706e7 + 50f5c3f commit 9f1e4b9

File tree

13 files changed

+508
-72
lines changed

13 files changed

+508
-72
lines changed

docs/bibliography.bib

Lines changed: 149 additions & 16 deletions
Large diffs are not rendered by default.

docs/make.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ makedocs(
2929
"Tutorials" => [
3030
"Semi-Supervised Learning with GCN" => "tutorials/semisupervised_gcn.md",
3131
"GCN with Fixed Graph" => "tutorials/gcn_fixed_graph.md",
32+
"Graph Attention Network" => "tutorials/gat.md",
33+
"DeepSet for Digit Sum" => "tutorials/deepset.md",
34+
"Variational Graph Autoencoder" => "tutorials/vgae.md",
35+
"Graph Embedding" => "tutorials/graph_embedding.md",
3236
],
3337
"Abstractions" => [
3438
"Message passing scheme" => "abstractions/msgpass.md",

docs/src/introduction.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,33 @@ Graph signals include node signals, edge signals and global (or graph) signals.
2727
<figcaption><em>Signals and graph signals.</em></figcaption>
2828
</figure>
2929
```
30+
31+
## Variable graph: `FeaturedGraph` as Container for Graph and Features
32+
33+
A GNN model accepts a graph and features as input. To this end, `FeaturedGraph` object is designed as a container for graph and various kinds of features. It can be passed to a GNN model directly.
34+
35+
```julia
36+
T = Float32
37+
fg = FeaturedGraph(g, nf=rand(10, 5), ef=rand(7, 11), gf=)
38+
```
39+
40+
It is worth noting that it is better to convert element type of graph to `Float32` explicitly. It can avoid some issues when training or inferring a GNN model.
41+
42+
```julia
43+
train_data = [(FeaturedGraph(g, nf=train_X), train_y) for _ in 1:N]
44+
```
45+
46+
A set of `FeaturedGraph` can include different graph structures `g` and different features `train_X` and then pass into the same GNN model in order to train/infer on variable graphs.
47+
48+
## Build GNN Model
49+
50+
```julia
51+
model = Chain(
52+
GCNConv(input_dim=>hidden_dim, relu),
53+
GraphParallel(node_layer=Dropout(0.5)),
54+
GCNConv(hidden_dim=>target_dim),
55+
node_feature,
56+
)
57+
```
58+
59+
A GNN model can be built by stacking GNN layers with or without regular Flux layers. Regular Flux layers should be wrapped in `GraphParallel` and specified as `node_layer` which is applied to node features.

docs/src/manual/conv.md

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ where ``\hat{A} = A + I``, ``A`` denotes the adjacency matrix, and
1313
GCNConv
1414
```
1515

16-
Reference: [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907)
16+
Reference: [Kipf2017](@cite)
1717

1818
---
1919

@@ -37,7 +37,7 @@ and ``\hat{L} = \frac{2}{\lambda_{max}} L - I``.
3737
ChebConv
3838
```
3939

40-
Reference: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375)
40+
Reference: [Defferrard2016](@cite)
4141

4242
---
4343

@@ -51,7 +51,7 @@ Reference: [Convolutional Neural Networks on Graphs with Fast Localized Spectral
5151
GraphConv
5252
```
5353

54-
Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244)
54+
Reference: [Morris2019](@cite)
5555

5656
---
5757

@@ -71,7 +71,7 @@ where the attention coefficient ``\alpha_{i,j}`` can be calculated from
7171
GATConv
7272
```
7373

74-
Reference: [Graph Attention Networks](https://arxiv.org/abs/1710.10903)
74+
Reference: [GAT2018](@cite)
7575

7676
---
7777

@@ -82,7 +82,8 @@ Reference: [Graph Attention Networks](https://arxiv.org/abs/1710.10903)
8282
GATv2Conv
8383
```
8484

85-
Reference: [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491)
85+
Reference: [Brody2022](@cite)
86+
8687
---
8788

8889
## Gated Graph Convolution Layer
@@ -98,7 +99,7 @@ Reference: [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2
9899
GatedGraphConv
99100
```
100101

101-
Reference: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493)
102+
Reference: [Li2016](@cite)
102103

103104
---
104105

@@ -114,7 +115,7 @@ where ``f_{\Theta}`` denotes a neural network parametrized by ``\Theta``, *i.e.*
114115
EdgeConv
115116
```
116117

117-
Reference: [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829)
118+
Reference: [Wang2019](@cite)
118119

119120
---
120121

@@ -130,7 +131,9 @@ where ``f_{\Theta}`` denotes a neural network parametrized by ``\Theta``, *i.e.*
130131
GINConv
131132
```
132133

133-
Reference: [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)
134+
Reference: [Xu2019](@cite)
135+
136+
---
134137

135138
## Crystal Graph Convolutional Network
136139

@@ -144,4 +147,4 @@ where ``\textbf{z}_{i,j} = [\textbf{x}_i, \textbf{x}_j}, \textbf{e}_{i,j}]`` den
144147
CGConv
145148
```
146149

147-
Reference: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://arxiv.org/pdf/1710.10324.pdf)
150+
Reference: [Xie2018](@cite)

docs/src/manual/embedding.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,5 @@
55
```@docs
66
GeometricFlux.node2vec
77
```
8+
9+
Reference: [Grover2016](@cite)

docs/src/manual/models.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ where ``A`` denotes the adjacency matrix.
1515
GeometricFlux.GAE
1616
```
1717

18-
Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
18+
Reference: [Kipf2016](@cite)
1919

2020
---
2121

@@ -33,7 +33,7 @@ where ``A`` denotes the adjacency matrix, ``X`` denotes node features.
3333
GeometricFlux.VGAE
3434
```
3535

36-
Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
36+
Reference: [Kipf2016](@cite)
3737

3838
---
3939

@@ -49,7 +49,7 @@ where ``\phi`` and ``\rho`` denote two neural networks and ``x_i`` is the node f
4949
GeometricFlux.DeepSet
5050
```
5151

52-
Reference: [Deep Sets](https://papers.nips.cc/paper/2017/hash/f22e4747da1aa27e363d86d40ff442fe-Abstract.html)
52+
Reference: [Zaheer2017](@cite)
5353

5454
---
5555

@@ -67,7 +67,7 @@ where ``Z`` denotes the input matrix from encoder.
6767
GeometricFlux.InnerProductDecoder
6868
```
6969

70-
Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
70+
Reference: [Kipf2016](@cite)
7171

7272
---
7373

@@ -82,4 +82,4 @@ Z_{\mu}, Z_{logσ} = GCN_{\mu}(H, A), GCN_{\sigma}(H, A)
8282
GeometricFlux.VariationalGraphEncoder
8383
```
8484

85-
Reference: [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308)
85+
Reference: [Kipf2016](@cite)

docs/src/manual/pool.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,25 @@
1-
# Pooling layers
1+
# Pooling Layers
2+
3+
## Global Pooling Layer
24

35
```@docs
46
GlobalPool
57
```
68

9+
---
10+
11+
## Local Pooling Layer
12+
713
```@docs
814
LocalPool
915
```
1016

17+
---
18+
19+
## Top-k Pooling Layer
20+
1121
```@docs
1222
TopKPool
1323
```
24+
25+
Reference: [Gao2019](@cite)

docs/src/tutorials/deepset.md

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Predicting Digits Sum from DeepSet model
2+
3+
Digits sum is a task of summing up digits in images or text. This example demonstrates summing up digits in arbitrary number of MNIST images. To accomplish such task, DeepSet model is suitable for this task. DeepSet model is excellent at the task which takes a set of objects and reduces them into single object.
4+
5+
## Step 1: Load MNIST Dataset
6+
7+
Since a DeepSet model predicts the summation from a set of images, we have to prepare training dataset composed of a random-sized set of images and a summed result.
8+
9+
First, the whole dataset is loaded from MLDatasets.jl and then shuffled before generating training dataset.
10+
11+
```julia
12+
train_X, train_y = MLDatasets.MNIST.traindata(Float32)
13+
train_X, train_y = shuffle_data(train_X, train_y)
14+
```
15+
16+
The `generate_featuredgraphs` here generates a set of pairs which contains a `FeaturedGraph` and a summed number for prediction target. In a `FeaturedGraph`, an arbitrary number of MNIST images are collected as node features and corresponding nodes are collected in a graph without edges.
17+
18+
```julia
19+
train_data = generate_featuredgraphs(train_X, train_y, num_train_examples, 1:train_max_length)
20+
```
21+
22+
`num_train_examples` is the parameter for assigning how many training example to generate. `1:train_max_length` specifies the range of number of images to contained in one example.
23+
24+
## Step 2: Build a DeepSet model
25+
26+
A DeepSet takes a set of objects and outputs single object. To make a model accept a set of objects, the model input must be invariant to permutation. The DeepSet model is simply composed of two parts: ``\phi`` network and ``\rho`` network.
27+
28+
```math
29+
Z = \rho ( \sum_{x_i \in \mathcal{V}} \phi (x_i) )
30+
```
31+
32+
``\phi`` network embeds every images and they are summed up to be a single embedding. Permutation invariance comes from the use of summation. In general, a commutative binary operator can be used to reduce a set of embeddings into one embedding. Finally, ``\rho`` network decodes the embedding to a number.
33+
34+
```julia
35+
ϕ = Chain(
36+
Dense(args.input_dim, args.hidden_dims[1], tanh),
37+
Dense(args.hidden_dims[1], args.hidden_dims[2], tanh),
38+
Dense(args.hidden_dims[2], args.hidden_dims[3], tanh),
39+
)
40+
ρ = Dense(args.hidden_dims[3], args.target_dim)
41+
model = DeepSet(ϕ, ρ) |> device
42+
```
43+
44+
## Step 3: Loss Functions
45+
46+
Mean absolute error is used as the loss function. Since the model outputs a `FeaturedGraph`, the prediction is placed as a global feature in `FeaturedGraph`.
47+
48+
```julia
49+
function model_loss(model, batch)
50+
= vcat(map(x -> global_feature(model(x[1])), batch)...)
51+
y = vcat(map(x -> x[2], batch)...)
52+
return mae(ŷ, y)
53+
end
54+
```
55+
56+
## Step 4: Training DeepSet Model
57+
58+
```julia
59+
# optimizer
60+
opt = ADAM(args.η)
61+
62+
# parameters
63+
ps = Flux.params(model)
64+
65+
# training
66+
@info "Start Training, total $(args.epochs) epochs"
67+
for epoch = 1:args.epochs
68+
@info "Epoch $(epoch)"
69+
70+
for batch in train_loader
71+
train_loss, back = Flux.pullback(ps) do
72+
model_loss(model, batch |> device)
73+
end
74+
test_loss = model_loss(model, test_loader, device)
75+
grad = back(1f0)
76+
Flux.Optimise.update!(opt, ps, grad)
77+
end
78+
end
79+
```
80+
81+
For a complete example, please check [examples/digitsum_deepsets.jl](../../examples/digitsum_deepsets.jl).

docs/src/tutorials/gat.md

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Graph Attention Network
2+
3+
Graph attention network (GAT) belongs to the message-passing network family, and it queries node feature over its neighbor features and generates result as layer output.
4+
5+
## Step 1: Load Dataset
6+
7+
We load dataset from Planetoid dataset. Here cora dataset is used.
8+
9+
```julia
10+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
11+
```
12+
13+
## Step 2: Batch up Features and Labels
14+
15+
Just batch up features as usual.
16+
17+
```julia
18+
add_all_self_loops!(g)
19+
fg = FeaturedGraph(g)
20+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
21+
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
22+
```
23+
24+
Notably, self loop for all nodes are needed for GAT model.
25+
26+
## Step 3: Build a GAT model
27+
28+
```julia
29+
model = Chain(
30+
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
31+
Dropout(0.6),
32+
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
33+
) |> device
34+
```
35+
36+
To note that a `GATConv` with `concat=true` will accumulates `heads` onto feature dimension. Thus, in the next layer, we should use `args.hidden_dim*args.heads`. In the final layer of a network, a `GATConv` layer should be assigned with `concat=false` to average over each heads.
37+
38+
39+
## Step 4: Loss Functions and Accuracy
40+
41+
Cross entropy loss is used as loss function and accuracy is used to evaluate the model.
42+
43+
```julia
44+
model_loss(model, X, y, idx) =
45+
logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
46+
```
47+
48+
```julia
49+
accuracy(model, X::AbstractArray, y::AbstractArray, idx) =
50+
mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:])
51+
```
52+
53+
54+
## Step 5: Training GAT Model
55+
56+
```julia
57+
# ADAM optimizer
58+
opt = ADAM(args.η)
59+
60+
# parameters
61+
ps = Flux.params(model)
62+
63+
# training
64+
@info "Start Training, total $(args.epochs) epochs"
65+
for epoch = 1:args.epochs
66+
@info "Epoch $(epoch)"
67+
68+
for (X, y) in train_loader
69+
loss, back = Flux.pullback(ps) do
70+
model_loss(model, X |> device, y |> device, train_idx |> device)
71+
end
72+
train_acc = accuracy(model, train_loader, device, train_idx)
73+
test_acc = accuracy(model, test_loader, device, test_idx)
74+
grad = back(1f0)
75+
Flux.Optimise.update!(opt, ps, grad)
76+
end
77+
end
78+
```
79+
80+
For a complete example, please check [examples/gat.jl](../../examples/gat.jl).

docs/src/tutorials/graph_embedding.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Graph Embedding Through Node2vec model

0 commit comments

Comments
 (0)