Skip to content

Commit 120bee2

Browse files
committed
simplify MessagePassing methods
fix bug for GATConv and ChebConv
1 parent dc53610 commit 120bee2

File tree

17 files changed

+224
-327
lines changed

17 files changed

+224
-327
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.7.7"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
89
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
910
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1011
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
@@ -20,9 +21,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

2223
[compat]
23-
CUDA = "3.3"
24+
CUDA = "3"
25+
ChainRulesCore = "1.7"
2426
DataStructures = "0.18"
25-
FillArrays = "0.11, 0.12"
27+
FillArrays = "0.12"
2628
Flux = "0.12"
2729
GraphMLDatasets = "0.1"
2830
GraphSignals = "0.3"

docs/make.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ makedocs(
1313
pages = ["Home" => "index.md",
1414
"Get started" => "start.md",
1515
"Basics" =>
16-
["Building layers" => "basics/layers.md",
16+
["Graph convolutions" => "basics/conv.md",
17+
"Building layers" => "basics/layers.md",
1718
"Graph passing" => "basics/passgraph.md"],
1819
"Cooperate with Flux layers" => "cooperate.md",
1920
"Abstractions" =>

docs/src/basics/conv.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Graph convolutions
2+
3+
Graph convolution can be classified into spectral-based graph convolution and spatial-based graph convolution. Spectral-based graph convolution, such as `GCNConv` and `ChebConv`, performs operation on features of *whole* graph at one time. Spatial-based graph convolution, such as `GraphConv` and `GATConv`, performs operation on features of *local* graph instead. Message-passing scheme is an abstraction for spatial-based graph convolutional layers. Any spatial-based graph convolutional layer can be implemented under the framework of message-passing scheme.

docs/src/basics/layers.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Building GNN is as simple as building neural network in Flux. The syntax here is the same as Flux. `Chain` is used to stack layers into a GNN. A simple example is shown here:
44

5-
```
5+
```julia
66
model = Chain(GCNConv(adj_mat, feat=>h1),
77
GCNConv(adj_mat, h1=>h2, relu))
88
```
@@ -21,7 +21,6 @@ When using GNN layers, the general guidelines are:
2121
* If you pass in a ``n \times d`` matrix of node features, and the layer maps node features ``\mathbb{R}^d \rightarrow \mathbb{R}^k`` then the output will be in matrix with dimensions ``n \times k``. The same ostensibly goes for edge features but as of now no layer type supports outputting new edge features.
2222
* If you pass in a `FeaturedGraph`, the output will be also be a `FeaturedGraph` with modified node (and/or edge) features. Add `node_feature` as the following entry in the Flux chain (or simply call `node_feature()` on the output) if you wish to subsequently convert them to matrix form.
2323

24-
25-
## Customize layers
24+
## Create custom layers
2625

2726
Customizing your own GNN layers are the same as customizing layers in Flux. You may want to reference [Flux documentation](https://fluxml.ai/Flux.jl/stable/models/basics/#Building-Layers-1).

src/GeometricFlux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using LinearAlgebra: Adjoint, norm, Transpose
55
using Reexport
66

77
using CUDA
8+
using ChainRulesCore: @non_differentiable
89
using FillArrays: Fill
910
using Flux
1011
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
@@ -66,7 +67,6 @@ include("layers/pool.jl")
6667
include("models.jl")
6768
include("layers/misc.jl")
6869

69-
include("cuda/msgpass.jl")
7070
include("cuda/conv.jl")
7171

7272
using .Datasets

src/cuda/msgpass.jl

Lines changed: 0 additions & 41 deletions
This file was deleted.

src/layers/conv.jl

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
3737
@functor GCNConv
3838

3939
function (l::GCNConv)(fg::FeaturedGraph, x::AbstractMatrix)
40-
= normalized_laplacian(fg, eltype(x); selfloop=true)
40+
= Zygote.ignore() do
41+
GraphSignals.normalized_laplacian(fg, eltype(x); selfloop=true)
42+
end
4143
l.σ.(l.weight * x *.+ l.bias)
4244
end
4345

@@ -86,10 +88,12 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
8688
@functor ChebConv
8789

8890
function (c::ChebConv)(fg::FeaturedGraph, X::AbstractMatrix{T}) where T
89-
check_num_nodes(fg, X)
91+
GraphSignals.check_num_nodes(fg, X)
9092
@assert size(X, 1) == size(c.weight, 2) "Input feature size must match input channel size."
9193

92-
= scaled_laplacian(fg, eltype(X))
94+
= Zygote.ignore() do
95+
GraphSignals.scaled_laplacian(fg, eltype(X))
96+
end
9397

9498
Z_prev = X
9599
Z = X *
@@ -156,12 +160,13 @@ message(gc::GraphConv, x_i, x_j::AbstractVector, e_ij) = gc.weight2 * x_j
156160
update(gc::GraphConv, m::AbstractVector, x::AbstractVector) = gc.σ.(gc.weight1*x .+ m .+ gc.bias)
157161

158162
function (gc::GraphConv)(fg::FeaturedGraph, x::AbstractMatrix)
159-
check_num_nodes(fg, x)
160-
_, x = propagate(gc, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), x, +)
163+
GraphSignals.check_num_nodes(fg, x)
164+
_, x, _ = propagate(gc, graph(fg), edge_feature(fg), x, global_feature(fg), +)
161165
x
162166
end
163167

164168
(l::GraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
169+
# (l::GraphConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
165170

166171
function Base.show(io::IO, l::GraphConv)
167172
in_channel = size(l.weight1, ndims(l.weight1))
@@ -239,21 +244,21 @@ function apply_batch_message(gat::GATConv, i, js, X::AbstractMatrix)
239244
reshape(msgs, (n-1)*gat.heads, :)
240245
end
241246

242-
update_batch_edge(gat::GATConv, adj, E::AbstractMatrix, X::AbstractMatrix, u) = update_batch_edge(gat, adj, X)
247+
function update_batch_edge(gat::GATConv, sg::SparseGraph, E::AbstractMatrix, X::AbstractMatrix, u)
248+
@assert check_self_loops(sg) "a vertex must have self loop (receive a message from itself)."
249+
mapreduce(i -> apply_batch_message(gat, i, neighbors(sg, i), X), hcat, 1:nv(sg))
250+
end
243251

244-
function update_batch_edge(gat::GATConv, adj, X::AbstractMatrix)
245-
n = size(adj, 1)
246-
# a vertex must always receive a message from itself
247-
Zygote.ignore() do
248-
GraphLaplacians.add_self_loop!(adj, n)
252+
function check_self_loops(sg::SparseGraph)
253+
for i in 1:nv(sg)
254+
if !(i in GraphSignals.rowvalview(sg.S, i))
255+
return false
256+
end
249257
end
250-
mapreduce(i -> apply_batch_message(gat, i, adj[i], X), hcat, 1:n)
258+
return true
251259
end
252260

253-
# The same as update function in batch manner
254-
update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u) = update_batch_vertex(gat, M)
255-
256-
function update_batch_vertex(gat::GATConv, M::AbstractMatrix)
261+
function update_batch_vertex(gat::GATConv, M::AbstractMatrix, X::AbstractMatrix, u)
257262
M = M .+ gat.bias
258263
if !gat.concat
259264
N = size(M, 2)
@@ -263,12 +268,13 @@ function update_batch_vertex(gat::GATConv, M::AbstractMatrix)
263268
end
264269

265270
function (gat::GATConv)(fg::FeaturedGraph, X::AbstractMatrix)
266-
check_num_nodes(fg, X)
267-
_, X = propagate(gat, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, +)
268-
X
271+
GraphSignals.check_num_nodes(fg, X)
272+
_, X, _ = propagate(gat, graph(fg), edge_feature(fg), X, global_feature(fg), +)
273+
return X
269274
end
270275

271276
(l::GATConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
277+
# (l::GATConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
272278

273279
function Base.show(io::IO, l::GATConv)
274280
in_channel = size(l.weight, ndims(l.weight))
@@ -319,10 +325,9 @@ update(ggc::GatedGraphConv, m::AbstractVector, x) = m
319325

320326

321327
function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
322-
check_num_nodes(fg, H)
328+
GraphSignals.check_num_nodes(fg, H)
323329
m, n = size(H)
324330
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
325-
adj = adjacency_list(fg)
326331
if m < ggc.out_ch
327332
Hpad = Zygote.ignore() do
328333
fill!(similar(H, S, ggc.out_ch - m, n), 0)
@@ -331,13 +336,14 @@ function (ggc::GatedGraphConv)(fg::FeaturedGraph, H::AbstractMatrix{S}) where {T
331336
end
332337
for i = 1:ggc.num_layers
333338
M = view(ggc.weight, :, :, i) * H
334-
_, M = propagate(ggc, adj, Fill(0.f0, 0, ne(fg)), M, +)
339+
_, M = propagate(ggc, graph(fg), edge_feature(fg), M, global_feature(fg), +)
335340
H, _ = ggc.gru(H, M)
336341
end
337342
H
338343
end
339344

340345
(l::GatedGraphConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
346+
# (l::GatedGraphConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
341347

342348

343349
function Base.show(io::IO, l::GatedGraphConv)
@@ -374,12 +380,13 @@ message(ec::EdgeConv, x_i::AbstractVector, x_j::AbstractVector, e_ij) = ec.nn(vc
374380
update(ec::EdgeConv, m::AbstractVector, x) = m
375381

376382
function (ec::EdgeConv)(fg::FeaturedGraph, X::AbstractMatrix)
377-
check_num_nodes(fg, X)
378-
_, X = propagate(ec, adjacency_list(fg), Fill(0.f0, 0, ne(fg)), X, ec.aggr)
383+
GraphSignals.check_num_nodes(fg, X)
384+
_, X, _ = propagate(ec, graph(fg), edge_feature(fg), X, global_feature(fg), ec.aggr)
379385
X
380386
end
381387

382388
(l::EdgeConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
389+
# (l::EdgeConv)(fg::FeaturedGraph) = propagate(l, fg, l.aggr) # edge number check break this
383390

384391
function Base.show(io::IO, l::EdgeConv)
385392
print(io, "EdgeConv(", l.nn)
@@ -425,12 +432,13 @@ update(g::GINConv, m::AbstractVector, x) = g.nn((1 + g.eps) * x + m)
425432

426433
function (g::GINConv)(fg::FeaturedGraph, X::AbstractMatrix)
427434
gf = graph(fg)
428-
GraphSignals.check_num_node(gf, X)
429-
_, X = propagate(g, adjacency_list(gf), Fill(0.f0, 0, ne(gf)), X, +)
435+
GraphSignals.check_num_nodes(gf, X)
436+
_, X, _ = propagate(g, graph(fg), edge_feature(fg), X, global_feature(fg), +)
430437
X
431438
end
432439

433-
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg.graph, nf = l(fg, node_feature(fg)))
440+
(l::GINConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf = l(fg, node_feature(fg)))
441+
# (l::GINConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
434442

435443

436444
"""
@@ -490,15 +498,16 @@ end
490498
update(c::CGConv, m::AbstractVector, x) = x + m
491499

492500
function (c::CGConv)(fg::FeaturedGraph, X::AbstractMatrix, E::AbstractMatrix)
493-
check_num_nodes(fg, X)
494-
check_num_edges(fg, E)
495-
_, Y = propagate(c, adjacency_list(fg), E, X, +)
501+
GraphSignals.check_num_nodes(fg, X)
502+
GraphSignals.check_num_edges(fg, E)
503+
_, Y, _ = propagate(c, graph(fg), E, X, global_feature(fg), +)
496504
Y
497505
end
498506

499-
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg, nf=l(fg, node_feature(fg),
500-
edge_feature(fg)),
507+
(l::CGConv)(fg::FeaturedGraph) = FeaturedGraph(fg,
508+
nf=l(fg, node_feature(fg), edge_feature(fg)),
501509
ef=edge_feature(fg))
510+
# (l::CGConv)(fg::FeaturedGraph) = propagate(l, fg, +) # edge number check break this
502511

503512
(l::CGConv)(X::AbstractMatrix, E::AbstractMatrix) = l(l.fg, X, E)
504513

src/layers/gn.jl

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,52 +16,32 @@ abstract type GraphNet <: AbstractGraphLayer end
1616
@inline update_vertex(gn::GraphNet, ē, vi, u) = vi
1717
@inline update_global(gn::GraphNet, ē, v̄, u) = u
1818

19-
@inline function update_batch_edge(gn::GraphNet, adj, E, V, u)
20-
n = size(adj, 1)
21-
edge_idx = edge_index_table(adj)
22-
mapreduce(i -> apply_batch_message(gn, i, adj[i], edge_idx, E, V, u), hcat, 1:n)
23-
end
19+
@inline update_batch_edge(gn::GraphNet, sg::SparseGraph, E, V, u) =
20+
mapreduce(i -> apply_batch_message(gn, sg, i, neighbors(sg, i), E, V, u), hcat, vertices(sg))
2421

25-
@inline apply_batch_message(gn::GraphNet, i, js, edge_idx, E, V, u) =
26-
mapreduce(j -> update_edge(gn, _view(E, edge_idx[(i,j)]), _view(V, i), _view(V, j), u), hcat, js)
22+
@inline apply_batch_message(gn::GraphNet, sg::SparseGraph, i, js, E, V, u) =
23+
mapreduce(j -> update_edge(gn, _view(E, edge_index(sg, i, j)), _view(V, i), _view(V, j), u), hcat, js)
2724

2825
@inline update_batch_vertex(gn::GraphNet, Ē, V, u) =
2926
mapreduce(i -> update_vertex(gn, _view(Ē, i), _view(V, i), u), hcat, 1:size(V,2))
3027

31-
@inline function aggregate_neighbors(gn::GraphNet, aggr, E, accu_edge)
32-
@assert !iszero(accu_edge) "accumulated edge must not be zero."
33-
cluster = generate_cluster(E, accu_edge)
34-
NNlib.scatter(aggr, E, cluster)
35-
end
36-
37-
@inline function aggregate_neighbors(gn::GraphNet, aggr::Nothing, E, accu_edge)
38-
@nospecialize E accu_edge
39-
return nothing
40-
end
28+
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr, E) = neighbor_scatter(aggr, E, sg)
29+
@inline aggregate_neighbors(gn::GraphNet, sg::SparseGraph, aggr::Nothing, @nospecialize E) = nothing
4130

4231
@inline aggregate_edges(gn::GraphNet, aggr, E) = aggregate(aggr, E)
43-
44-
@inline function aggregate_edges(gn::GraphNet, aggr::Nothing, E)
45-
@nospecialize E
46-
return nothing
47-
end
32+
@inline aggregate_edges(gn::GraphNet, aggr::Nothing, @nospecialize E) = nothing
4833

4934
@inline aggregate_vertices(gn::GraphNet, aggr, V) = aggregate(aggr, V)
50-
51-
@inline function aggregate_vertices(gn::GraphNet, aggr::Nothing, V)
52-
@nospecialize V
53-
return nothing
54-
end
35+
@inline aggregate_vertices(gn::GraphNet, aggr::Nothing, @nospecialize V) = nothing
5536

5637
function propagate(gn::GraphNet, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing)
57-
E, V, u = propagate(gn, adjacency_list(fg), fg.ef, fg.nf, fg.gf, naggr, eaggr, vaggr)
38+
E, V, u = propagate(gn, graph(fg), edge_feature(fg), node_feature(fg), global_feature(fg), naggr, eaggr, vaggr)
5839
FeaturedGraph(fg, nf=V, ef=E, gf=u)
5940
end
6041

61-
function propagate(gn::GraphNet, adj::AbstractVector{S}, E::R, V::Q, u::P,
62-
naggr=nothing, eaggr=nothing, vaggr=nothing) where {S<:AbstractVector,R,Q,P}
63-
E = update_batch_edge(gn, adj, E, V, u)
64-
= aggregate_neighbors(gn, naggr, E, accumulated_edges(adj))
42+
function propagate(gn::GraphNet, sg::SparseGraph, E, V, u, naggr=nothing, eaggr=nothing, vaggr=nothing)
43+
E = update_batch_edge(gn, sg, E, V, u)
44+
= aggregate_neighbors(gn, sg, naggr, E)
6545
V = update_batch_vertex(gn, Ē, V, u)
6646
= aggregate_edges(gn, eaggr, E)
6747
= aggregate_vertices(gn, vaggr, V)

src/layers/msgpass.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ specialize this method with custom behavior.
2222
2323
See also [`update`](@ref).
2424
"""
25+
function message end
26+
2527
@inline message(mp::MessagePassing, x_i, x_j, e_ij) = x_j
2628
@inline message(mp::MessagePassing, i::Integer, j::Integer, x_i, x_j, e_ij) = x_j
2729

@@ -45,21 +47,10 @@ specialize this method with custom behavior.
4547
4648
See also [`message`](@ref).
4749
"""
50+
function update end
51+
4852
@inline update(mp::MessagePassing, m, x) = m
4953
@inline update(mp::MessagePassing, i::Integer, m, x) = m
5054

51-
@inline apply_batch_message(mp::MessagePassing, i, js, edge_idx, E::AbstractMatrix, X::AbstractMatrix, u) =
52-
mapreduce(j -> GeometricFlux.message(mp, _view(X, i), _view(X, j), _view(E, edge_idx[(i,j)])), hcat, js)
53-
54-
@inline update_batch_vertex(mp::MessagePassing, M::AbstractMatrix, X::AbstractMatrix, u) =
55-
mapreduce(i -> GeometricFlux.update(mp, _view(M, i), _view(X, i)), hcat, 1:size(X,2))
56-
57-
function propagate(mp::MessagePassing, fg::FeaturedGraph, aggr=+)
58-
E, X = propagate(mp, adjacency_list(fg), fg.ef, fg.nf, aggr)
59-
FeaturedGraph(fg, nf=X, ef=E, gf=Fill(0.f0, 0))
60-
end
61-
62-
function propagate(mp::MessagePassing, adj::AbstractVector{S}, E::R, X::Q, aggr) where {S<:AbstractVector,R,Q}
63-
E, X, u = propagate(mp, adj, E, X, Fill(0.f0, 0), aggr, nothing, nothing)
64-
E, X
65-
end
55+
@inline update_edge(mp::MessagePassing, e, vi, vj, u) = GeometricFlux.message(mp, vi, vj, e)
56+
@inline update_vertex(mp::MessagePassing, ē, vi, u) = GeometricFlux.update(mp, ē, vi)

0 commit comments

Comments
 (0)