Skip to content

Commit 50f5c3f

Browse files
committed
add GAT tutorial
1 parent b31b348 commit 50f5c3f

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

docs/src/tutorials/gat.md

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,80 @@
11
# Graph Attention Network
2+
3+
Graph attention network (GAT) belongs to the message-passing network family, and it queries node feature over its neighbor features and generates result as layer output.
4+
5+
## Step 1: Load Dataset
6+
7+
We load dataset from Planetoid dataset. Here cora dataset is used.
8+
9+
```julia
10+
train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true))
11+
```
12+
13+
## Step 2: Batch up Features and Labels
14+
15+
Just batch up features as usual.
16+
17+
```julia
18+
add_all_self_loops!(g)
19+
fg = FeaturedGraph(g)
20+
train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats)))
21+
train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true)
22+
```
23+
24+
Notably, self loop for all nodes are needed for GAT model.
25+
26+
## Step 3: Build a GAT model
27+
28+
```julia
29+
model = Chain(
30+
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
31+
Dropout(0.6),
32+
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
33+
) |> device
34+
```
35+
36+
To note that a `GATConv` with `concat=true` will accumulates `heads` onto feature dimension. Thus, in the next layer, we should use `args.hidden_dim*args.heads`. In the final layer of a network, a `GATConv` layer should be assigned with `concat=false` to average over each heads.
37+
38+
39+
## Step 4: Loss Functions and Accuracy
40+
41+
Cross entropy loss is used as loss function and accuracy is used to evaluate the model.
42+
43+
```julia
44+
model_loss(model, X, y, idx) =
45+
logitcrossentropy(model(X)[:,idx,:], y[:,idx,:])
46+
```
47+
48+
```julia
49+
accuracy(model, X::AbstractArray, y::AbstractArray, idx) =
50+
mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:])
51+
```
52+
53+
54+
## Step 5: Training GAT Model
55+
56+
```julia
57+
# ADAM optimizer
58+
opt = ADAM(args.η)
59+
60+
# parameters
61+
ps = Flux.params(model)
62+
63+
# training
64+
@info "Start Training, total $(args.epochs) epochs"
65+
for epoch = 1:args.epochs
66+
@info "Epoch $(epoch)"
67+
68+
for (X, y) in train_loader
69+
loss, back = Flux.pullback(ps) do
70+
model_loss(model, X |> device, y |> device, train_idx |> device)
71+
end
72+
train_acc = accuracy(model, train_loader, device, train_idx)
73+
test_acc = accuracy(model, test_loader, device, test_idx)
74+
grad = back(1f0)
75+
Flux.Optimise.update!(opt, ps, grad)
76+
end
77+
end
78+
```
79+
80+
For a complete example, please check [examples/gat.jl](../../examples/gat.jl).

0 commit comments

Comments
 (0)