Skip to content

Commit 757d2cc

Browse files
committed
add defining GNN layer to doc
1 parent b5dbe98 commit 757d2cc

File tree

5 files changed

+60
-6
lines changed

5 files changed

+60
-6
lines changed

docs/src/basics/layers.md

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,50 @@ When using GNN layers, the general guidelines are:
3636
* With static graph strategy: you should pass in a ``d \times n \times batch`` matrix for node features, and the layer maps node features ``\mathbb{R}^d \rightarrow \mathbb{R}^k`` then the output will be in matrix with dimensions ``k \times n \times batch``. The same ostensibly goes for edge features but as of now no layer type supports outputting new edge features.
3737
* With variable graph strategy: you should pass in a `FeaturedGraph`, the output will be also be a `FeaturedGraph` with modified node (and/or edge) features. Add `node_feature` as the following entry in the Flux chain (or simply call `node_feature()` on the output) if you wish to subsequently convert them to matrix form.
3838

39-
## Create Custom GNN Layers
39+
## Define Your Own GNN Layer
4040

41-
Customizing your own GNN layers are the same as customizing layers in Flux. You may want to reference [Flux documentation](https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Layers-1).
41+
Customizing your own GNN layers are the same as defining a layer in Flux. You may want to check [Flux documentation](https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Layers-1) first.
42+
43+
To define a customized GNN layer, for example, we take a simple `GCNConv` layer as example here.
44+
45+
```julia
46+
struct GCNConv <: AbstractGraphLayer
47+
weight
48+
bias
49+
σ
50+
end
51+
52+
@functor GCNConv
53+
```
54+
55+
We first should define a `GCNConv` type and let it be the subtype of `AbstractGraphLayer`. In this type, it holds parameters that a layer operate on. Don't forget to add `@functor` macro to `GCNConv` type.
56+
57+
```julia
58+
(l::GCNConv)(Ã::AbstractMatrix, x::AbstractMatrix) = l.σ.(l.weight * x *.+ l.bias)
59+
```
60+
61+
Then, we can define the operation for `GCNConv` layer.
62+
63+
```julia
64+
function (l::GCNConv)(fg::AbstractFeaturedGraph)
65+
nf = node_feature(fg)
66+
= Zygote.ignore() do
67+
GraphSignals.normalized_adjacency_matrix(fg, eltype(nf); selfloop=true)
68+
end
69+
return ConcreteFeaturedGraph(fg, nf = l(Ã, nf))
70+
end
71+
```
72+
73+
Here comes to the GNN-specific behaviors. A GNN layer should accept object of subtype of `AbstractFeaturedGraph` to support variable graph strategy. A variable graph strategy should fetch node/edge/global features from `fg` and transform graph in `fg` into required form for layer operation, e.g. `GCNConv` layer needs a normalized adjacency matrix with self loop. Then, normalized adjacency matrix `` and node features `nf` are pass through `GCNConv` layer `l(Ã, nf)` to give a new node feature. Finally, a `ConcreteFeaturedGraph` wrap graph in `fg` and new node features into a new object of subtype of `AbstractFeaturedGraph`.
74+
75+
```julia
76+
layer = GCNConv(10=>5, relu)
77+
new_fg = layer(fg)
78+
gradient(() -> sum(node_feature(layer(fg))), Flux.params(layer))
79+
```
80+
81+
Now we complete a simple version of `GCNConv` layer. One can test the forward pass and gradient if they work properly.
82+
83+
```@docs
84+
GeometricFlux.AbstractGraphLayer
85+
```

docs/src/basics/subgraph.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Subgraph
22

3+
## Subgraph of `FeaturedGraph`
4+
5+
A `FeaturedGraph` object can derive a subgraph from a selected subset of the vertices of the graph.
6+
37
```julia
48
train_idx = train_indices(Planetoid(), :cora)
59
fg = FeaturedGraph(g)
610
fsg = subgraph(fg, train_idx)
7-
layer = WithGraph(fsg, GCNConv(in_channel=>out_channel), ) |> gpu
8-
train_X = train_X |> Matrix |> gpu
9-
H = layer(train_X)
1011
```
12+
13+
A `FeaturedSubgraph` object is returned from `subgraph` by selected vertices `train_idx`.

docs/src/manual/featuredgraph.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@ GraphSignals.edge_feature
99
GraphSignals.has_edge_feature
1010
GraphSignals.global_feature
1111
GraphSignals.has_global_feature
12+
GraphSignals.subgraph
13+
GraphSignals.ConcreteFeaturedGraph
1214
```

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function GCNConv(ch::Pair{Int,Int}, σ=identity;
3434
in, out = ch
3535
W = init(out, in)
3636
b = Flux.create_bias(W, bias, out)
37-
GCNConv(W, b, σ)
37+
return GCNConv(W, b, σ)
3838
end
3939

4040
@functor GCNConv

src/layers/graphlayers.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
"""
2+
AbstractGraphLayer
3+
4+
An abstract type of graph neural network layer for GeometricFlux.
5+
"""
16
abstract type AbstractGraphLayer end
27

38
(l::AbstractGraphLayer)(x::AbstractMatrix) = l(l.fg, x)

0 commit comments

Comments
 (0)