|
1 | 1 | # Graph Attention Network
|
| 2 | + |
| 3 | +Graph attention network (GAT) belongs to the message-passing network family, and it queries node feature over its neighbor features and generates result as layer output. |
| 4 | + |
| 5 | +## Step 1: Load Dataset |
| 6 | + |
| 7 | +We load dataset from Planetoid dataset. Here cora dataset is used. |
| 8 | + |
| 9 | +```julia |
| 10 | +train_X, train_y = map(x -> Matrix(x), alldata(Planetoid(), dataset, padding=true)) |
| 11 | +``` |
| 12 | + |
| 13 | +## Step 2: Batch up Features and Labels |
| 14 | + |
| 15 | +Just batch up features as usual. |
| 16 | + |
| 17 | +```julia |
| 18 | +add_all_self_loops!(g) |
| 19 | +fg = FeaturedGraph(g) |
| 20 | +train_data = (repeat(train_X, outer=(1,1,train_repeats)), repeat(train_y, outer=(1,1,train_repeats))) |
| 21 | +train_loader = DataLoader(train_data, batchsize=batch_size, shuffle=true) |
| 22 | +``` |
| 23 | + |
| 24 | +Notably, self loop for all nodes are needed for GAT model. |
| 25 | + |
| 26 | +## Step 3: Build a GAT model |
| 27 | + |
| 28 | +```julia |
| 29 | +model = Chain( |
| 30 | + WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)), |
| 31 | + Dropout(0.6), |
| 32 | + WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)), |
| 33 | +) |> device |
| 34 | +``` |
| 35 | + |
| 36 | +To note that a `GATConv` with `concat=true` will accumulates `heads` onto feature dimension. Thus, in the next layer, we should use `args.hidden_dim*args.heads`. In the final layer of a network, a `GATConv` layer should be assigned with `concat=false` to average over each heads. |
| 37 | + |
| 38 | + |
| 39 | +## Step 4: Loss Functions and Accuracy |
| 40 | + |
| 41 | +Cross entropy loss is used as loss function and accuracy is used to evaluate the model. |
| 42 | + |
| 43 | +```julia |
| 44 | +model_loss(model, X, y, idx) = |
| 45 | + logitcrossentropy(model(X)[:,idx,:], y[:,idx,:]) |
| 46 | +``` |
| 47 | + |
| 48 | +```julia |
| 49 | +accuracy(model, X::AbstractArray, y::AbstractArray, idx) = |
| 50 | + mean(onecold(softmax(cpu(model(X))[:,idx,:])) .== onecold(cpu(y)[:,idx,:]) |
| 51 | +``` |
| 52 | +
|
| 53 | +
|
| 54 | +## Step 5: Training GAT Model |
| 55 | +
|
| 56 | +```julia |
| 57 | +# ADAM optimizer |
| 58 | +opt = ADAM(args.η) |
| 59 | + |
| 60 | +# parameters |
| 61 | +ps = Flux.params(model) |
| 62 | + |
| 63 | +# training |
| 64 | +@info "Start Training, total $(args.epochs) epochs" |
| 65 | +for epoch = 1:args.epochs |
| 66 | + @info "Epoch $(epoch)" |
| 67 | + |
| 68 | + for (X, y) in train_loader |
| 69 | + loss, back = Flux.pullback(ps) do |
| 70 | + model_loss(model, X |> device, y |> device, train_idx |> device) |
| 71 | + end |
| 72 | + train_acc = accuracy(model, train_loader, device, train_idx) |
| 73 | + test_acc = accuracy(model, test_loader, device, test_idx) |
| 74 | + grad = back(1f0) |
| 75 | + Flux.Optimise.update!(opt, ps, grad) |
| 76 | + end |
| 77 | +end |
| 78 | +``` |
| 79 | +
|
| 80 | +For a complete example, please check [examples/gat.jl](../../examples/gat.jl). |
0 commit comments