Skip to content

Commit 84f61a9

Browse files
authored
Merge pull request #234 from FluxML/develop
Correct GCNConv with normalized_adjacency_matrix
2 parents 12f3908 + d23ffe4 commit 84f61a9

File tree

27 files changed

+538
-649
lines changed

27 files changed

+538
-649
lines changed

Project.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ version = "0.7.7"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
11-
GraphLaplacians = "a1251efa-393a-423f-9d7b-faaecba535dc"
1212
GraphMLDatasets = "21828b05-d3b3-40ad-870e-a4bc2f52d5e8"
1313
GraphSignals = "3ebe565e-a4b5-49c6-aed2-300248c3a9c1"
14-
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
14+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1616
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1717
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
@@ -21,14 +21,14 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2121
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222

2323
[compat]
24-
CUDA = "3.3"
24+
CUDA = "3"
25+
ChainRulesCore = "1.7"
2526
DataStructures = "0.18"
26-
FillArrays = "0.11, 0.12"
27+
FillArrays = "0.12"
2728
Flux = "0.12"
28-
GraphLaplacians = "0.1"
2929
GraphMLDatasets = "0.1"
30-
GraphSignals = "0.2"
31-
LightGraphs = "1.3"
30+
GraphSignals = "0.3"
31+
Graphs = "1.4"
3232
NNlib = "0.7"
3333
NNlibCUDA = "0.1"
3434
Reexport = "1.1"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ GeometricFlux handles graph data (the topology plus node/vertex/graph features)
3434
thanks to the type `FeaturedGraph`.
3535

3636
A `FeaturedGraph` can be constructed out of
37-
adjacency matrices, adjacency lists, LightGraphs' types...
37+
adjacency matrices, adjacency lists, Graphs' types...
3838

3939
```julia
4040
fg = FeaturedGraph(adj_list)

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ makedocs(
1313
pages = ["Home" => "index.md",
1414
"Get started" => "start.md",
1515
"Basics" =>
16-
["Building layers" => "basics/layers.md",
16+
["Graph convolutions" => "basics/conv.md",
17+
"Building layers" => "basics/layers.md",
1718
"Graph passing" => "basics/passgraph.md"],
1819
"Cooperate with Flux layers" => "cooperate.md",
1920
"Abstractions" =>

docs/src/basics/conv.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Graph convolutions
2+
3+
Graph convolution can be classified into spectral-based graph convolution and spatial-based graph convolution. Spectral-based graph convolution, such as `GCNConv` and `ChebConv`, performs operation on features of *whole* graph at one time. Spatial-based graph convolution, such as `GraphConv` and `GATConv`, performs operation on features of *local* graph instead. Message-passing scheme is an abstraction for spatial-based graph convolutional layers. Any spatial-based graph convolutional layer can be implemented under the framework of message-passing scheme.

docs/src/basics/layers.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Building GNN is as simple as building neural network in Flux. The syntax here is the same as Flux. `Chain` is used to stack layers into a GNN. A simple example is shown here:
44

5-
```
5+
```julia
66
model = Chain(GCNConv(adj_mat, feat=>h1),
77
GCNConv(adj_mat, h1=>h2, relu))
88
```
@@ -21,7 +21,6 @@ When using GNN layers, the general guidelines are:
2121
* If you pass in a ``n \times d`` matrix of node features, and the layer maps node features ``\mathbb{R}^d \rightarrow \mathbb{R}^k`` then the output will be in matrix with dimensions ``n \times k``. The same ostensibly goes for edge features but as of now no layer type supports outputting new edge features.
2222
* If you pass in a `FeaturedGraph`, the output will be also be a `FeaturedGraph` with modified node (and/or edge) features. Add `node_feature` as the following entry in the Flux chain (or simply call `node_feature()` on the output) if you wish to subsequently convert them to matrix form.
2323

24-
25-
## Customize layers
24+
## Create custom layers
2625

2726
Customizing your own GNN layers are the same as customizing layers in Flux. You may want to reference [Flux documentation](https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Layers-1).

docs/src/basics/passgraph.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ A static graph is used to reduce redundant computation during passing through la
1010
GCNConv(adj_mat, feat=>h1, relu)
1111
```
1212

13-
`Simple(Di)Graph`, `SimpleWeighted(Di)Graph` or `Meta(Di)Graph` provided by the packages LightGraphs, SimpleWeightedGraphs and MetaGraphs, respectively, are valid arguments for passing as a static graph to this layer. An adjacency list is also accepted in the type of `Vector{Vector}` is also accepted.
13+
`Simple(Di)Graph`, `SimpleWeighted(Di)Graph` or `Meta(Di)Graph` provided by the packages Graphs, SimpleWeightedGraphs and MetaGraphs, respectively, are valid arguments for passing as a static graph to this layer. An adjacency list is also accepted in the type of `Vector{Vector}` is also accepted.
1414

1515
## Variable graph
1616

@@ -20,7 +20,7 @@ Variable graphs are supported through `FeaturedGraph`, which contains both the g
2020
FeaturedGraph(adj_mat, features)
2121
```
2222

23-
`Simple(Di)Graph`, `SimpleWeighted(Di)Graph` or `Meta(Di)Graph` provided by the packages LightGraphs, SimpleWeightedGraphs and MetaGraphs, respectively, are acceptable for constructing a `FeaturedGraph`. An adjacency list is also accepted, too.
23+
`Simple(Di)Graph`, `SimpleWeighted(Di)Graph` or `Meta(Di)Graph` provided by the packages Graphs, SimpleWeightedGraphs and MetaGraphs, respectively, are acceptable for constructing a `FeaturedGraph`. An adjacency list is also accepted, too.
2424

2525
## Cached graph in layers
2626

examples/gae.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ using Flux: @epochs
66
using JLD2
77
using Statistics: mean
88
using SparseArrays
9-
using LightGraphs.SimpleGraphs
10-
using LightGraphs: adjacency_matrix
9+
using Graphs.SimpleGraphs
1110
using CUDA
1211

1312
@load "data/cora_features.jld2" features
@@ -21,13 +20,13 @@ target_catg = 7
2120
epochs = 200
2221

2322
## Preprocessing data
24-
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu
23+
fg = FeaturedGraph(g) |> gpu
2524
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
26-
train_y = adj_mat # dim: num_nodes * num_nodes
25+
train_y = fg # dim: num_nodes * num_nodes
2726

2827
## Model
29-
encoder = Chain(GCNConv(adj_mat, num_features=>hidden1, relu),
30-
GCNConv(adj_mat, hidden1=>hidden2))
28+
encoder = Chain(GCNConv(fg, num_features=>hidden1, relu),
29+
GCNConv(fg, hidden1=>hidden2))
3130
model = Chain(GAE(encoder, σ)) |> gpu
3231

3332
## Loss

examples/gat.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ using Flux.Data: DataLoader
66
using JLD2
77
using Statistics: mean
88
using SparseArrays
9-
using LightGraphs.SimpleGraphs
10-
using LightGraphs: adjacency_matrix
9+
using Graphs.SimpleGraphs
10+
using Graphs: adjacency_matrix
1111
using CUDA
1212

1313
@load "data/cora_features.jld2" features

examples/gcn.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ using Flux: @epochs
55
using JLD2
66
using Statistics
77
using SparseArrays
8-
using LightGraphs.SimpleGraphs
9-
using LightGraphs: adjacency_matrix
8+
using Graphs.SimpleGraphs
109
using CUDA
1110
using Random
1211

@@ -25,12 +24,12 @@ epochs = 100
2524
## Preprocessing data
2625
train_X = Matrix{Float32}(features) |> gpu # dim: num_features * num_nodes
2726
train_y = Matrix{Float32}(labels) |> gpu # dim: target_catg * num_nodes
28-
adj_mat = Matrix{Float32}(adjacency_matrix(g)) |> gpu
27+
fg = FeaturedGraph(g) |> gpu
2928

3029
## Model
31-
model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),
30+
model = Chain(GCNConv(fg, num_features=>hidden, relu),
3231
Dropout(0.5),
33-
GCNConv(adj_mat, hidden=>target_catg),
32+
GCNConv(fg, hidden=>target_catg),
3433
) |> gpu
3534

3635
## Loss

examples/gde.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using GeometricFlux, Flux, JLD2, SparseArrays, DiffEqFlux, DifferentialEquations
22
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
33
using Flux: @epochs
44
using Statistics: mean
5-
using LightGraphs: adjacency_matrix
65

76
# Load the dataset
87
@load "data/cora_features.jld2" features
@@ -19,22 +18,22 @@ epochs = 40
1918
# Preprocess the data and compute adjacency matrix
2019
train_X = Matrix{Float32}(features) # dim: num_features * num_nodes
2120
train_y = Float32.(labels) # dim: target_catg * num_nodes
22-
adj_mat = Matrix{Float32}(adjacency_matrix(g))
21+
fg = FeaturedGraph(g)
2322

2423
# Define the Neural GDE
2524
diffeqarray_to_array(x) = reshape(cpu(x), size(x)[1:2])
2625

2726
node = NeuralODE(
28-
GCNConv(adj_mat, hidden=>hidden),
27+
GCNConv(fg, hidden=>hidden),
2928
(0.f0, 1.f0), Tsit5(), save_everystep = false,
3029
reltol = 1e-3, abstol = 1e-3, save_start = false
3130
)
3231

33-
model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),
32+
model = Chain(GCNConv(fg, num_features=>hidden, relu),
3433
Dropout(0.5),
3534
node,
3635
diffeqarray_to_array,
37-
GCNConv(adj_mat, hidden=>target_catg),
36+
GCNConv(fg, hidden=>target_catg),
3837
softmax)
3938

4039
# Loss

0 commit comments

Comments
 (0)