Skip to content

Commit 7e19a7b

Browse files
authored
Merge pull request #69 from yuehhua/develop
Graph API enhancements
2 parents 78fe005 + 3ffe4f6 commit 7e19a7b

File tree

12 files changed

+73
-27
lines changed

12 files changed

+73
-27
lines changed

.gitlab-ci.yml

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,36 @@
11
variables:
2-
CI_IMAGE_TAG: 'cuda'
2+
CI_APT_INSTALL: 'libgomp1'
33
JULIA_NUM_THREADS: '4'
4+
NVIDIA_VISIBLE_DEVICES: 'all'
5+
NVIDIA_DRIVER_CAPABILITIES: 'compute,utility'
6+
CI_THOROUGH: 'true'
7+
JULIA_CUDA_VERSION: '10.2'
8+
JULIA_CUDA_USE_BINARYBUILDER: 'true'
49

510
include:
611
- 'https://raw.githubusercontent.com/JuliaGPU/gitlab-ci/master/templates/v6.yml'
712

8-
image: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04
13+
image: ubuntu:bionic
914

1015
test:v1.4:
1116
extends:
1217
- .julia:1.4
1318
- .test
1419
variables:
1520
CI_VERSION_TAG: 'v1.4'
21+
tags:
22+
- nvidia
23+
- cuda_10.2
24+
25+
test:v1.5:
26+
extends:
27+
- .julia:1.5
28+
- .test
29+
variables:
30+
CI_VERSION_TAG: 'v1.5'
31+
tags:
32+
- nvidia
33+
- cuda_10.2
1634

1735
# test:dev:
1836
# extends:

src/GeometricFlux.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ export
8585
pool,
8686

8787
# graph/index
88-
neighbors,
88+
adjacency_list,
8989
generate_cluster,
9090

9191
# graph/featuredgraphs
@@ -102,9 +102,6 @@ export
102102
has_global_feature,
103103
nv,
104104

105-
# graph/simplegraphs
106-
adjlist,
107-
108105
# utils
109106
gather,
110107
topk_index
@@ -115,12 +112,12 @@ include("operations/scatter.jl")
115112
include("operations/pool.jl")
116113
include("operations/linalg.jl")
117114

118-
include("utils.jl")
119-
120115
include("graph/index.jl")
121116
include("graph/featuredgraphs.jl")
122117
include("graph/linalg.jl")
123118

119+
include("utils.jl")
120+
124121
include("layers/gn.jl")
125122
include("layers/msgpass.jl")
126123

src/graph/featuredgraphs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ has_global_feature(::NullGraph) = false
8181
has_global_feature(fg::FeaturedGraph) = fg.gf[] != zeros(0)
8282

8383
"""
84-
neighbors(::AbstractFeaturedGraph)
84+
adjacency_list(::AbstractFeaturedGraph)
8585
8686
Get adjacency list of graph.
8787
"""
88-
neighbors(::NullGraph) = [zeros(0)]
89-
neighbors(fg::FeaturedGraph) = neighbors(fg.graph[])
88+
adjacency_list(::NullGraph) = [zeros(0)]
89+
adjacency_list(fg::FeaturedGraph) = adjacency_list(fg.graph[])
9090

9191
"""
9292
nv(::AbstractFeaturedGraph)

src/graph/index.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
"""
2-
neighbors(adj)
2+
adjacency_list(adj)
33
44
Transform a adjacency matrix into a adjacency list.
55
"""
6-
function neighbors(adj::AbstractMatrix{T}) where {T}
6+
function adjacency_list(adj::AbstractMatrix{T}) where {T}
77
n = size(adj,1)
88
@assert n == size(adj,2) "adjacency matrix is not a square matrix."
99
A = (adj .!= zero(T))
@@ -15,7 +15,9 @@ function neighbors(adj::AbstractMatrix{T}) where {T}
1515
return ne
1616
end
1717

18-
neighbors(adj::AbstractVector{<:AbstractVector{<:Integer}}) = adj
18+
adjacency_list(adj::AbstractVector{<:AbstractVector{<:Integer}}) = adj
19+
20+
Zygote.@nograd adjacency_list
1921

2022
"""
2123
accumulated_edges(adj[, num_V])
@@ -30,6 +32,8 @@ function accumulated_edges(adj::AbstractVector{<:AbstractVector{<:Integer}},
3032
y
3133
end
3234

35+
Zygote.@nograd accumulated_edges
36+
3337
Zygote.@nograd function generate_cluster(M::AbstractArray{T,N}, accu_edge, V, E) where {T,N}
3438
cluster = similar(M, Int, E)
3539
@inbounds for i = 1:V
@@ -68,6 +72,8 @@ function vertex_pair_table(eidx::Dict)
6872
table
6973
end
7074

75+
Zygote.@nograd vertex_pair_table
76+
7177
"""
7278
edge_index_table(adj[, num_E])
7379
@@ -96,6 +102,8 @@ function edge_index_table(vpair::AbstractVector{<:Tuple})
96102
table
97103
end
98104

105+
Zygote.@nograd edge_index_table
106+
99107
function transform(X::AbstractArray, vpair::AbstractVector{<:Tuple}, num_V)
100108
dims = size(X)[1:end-1]..., num_V, num_V
101109
Y = similar(X, dims)

src/graph/simplegraphs.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1-
using LightGraphs: AbstractSimpleGraph, nv, adjacency_matrix, inneighbors,
1+
using LightGraphs: AbstractSimpleGraph, nv, adjacency_matrix, inneighbors, outneighbors,
22
all_neighbors
33

4+
function adjacency_list(g::AbstractSimpleGraph)
5+
N = nv(g)
6+
Vector{Int}[outneighbors(g, i) for i = 1:N]
7+
end
8+
49
## Convolution layers accepting AbstractSimpleGraph
510

611
function GCNConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}, σ = identity;

src/graph/weightedgraphs.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
using SimpleWeightedGraphs: AbstractSimpleWeightedGraph, nv
1+
using SimpleWeightedGraphs: AbstractSimpleWeightedGraph, nv, outneighbors
2+
3+
function adjacency_list(g::AbstractSimpleWeightedGraph)
4+
N = nv(g)
5+
Vector{Int}[outneighbors(g, i) for i = 1:N]
6+
end
27

38
## Convolution layers accepting AbstractSimpleWeightedGraph
49

src/layers/conv.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ function GraphConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, aggr=:add
198198
w1 = T.(init(ch[2], ch[1]))
199199
w2 = T.(init(ch[2], ch[1]))
200200
b = bias ? T.(init(ch[2])) : zeros(T, ch[2])
201-
fg = FeaturedGraph(neighbors(adj))
201+
fg = FeaturedGraph(adjacency_list(adj))
202202
GraphConv(fg, w1, w2, b, aggr)
203203
end
204204

@@ -264,7 +264,7 @@ function GATConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}; heads::Inte
264264
w = T.(init(ch[2]*heads, ch[1]))
265265
b = bias ? T.(init(ch[2]*heads)) : zeros(T, ch[2]*heads)
266266
a = T.(init(2*ch[2], heads, 1))
267-
fg = FeaturedGraph(neighbors(adj))
267+
fg = FeaturedGraph(adjacency_list(adj))
268268
GATConv(fg, w, b, a, negative_slope, ch, heads, concat)
269269
end
270270

@@ -349,7 +349,7 @@ function GatedGraphConv(adj::AbstractMatrix, out_ch::Integer, num_layers::Intege
349349
aggr=:add, init=glorot_uniform, T::DataType=Float32)
350350
w = T.(init(out_ch, out_ch, num_layers))
351351
gru = GRUCell(out_ch, out_ch)
352-
fg = FeaturedGraph(neighbors(adj))
352+
fg = FeaturedGraph(adjacency_list(adj))
353353
GatedGraphConv(fg, w, gru, out_ch, num_layers, aggr)
354354
end
355355

@@ -415,7 +415,7 @@ struct EdgeConv{V<:AbstractFeaturedGraph} <: MessagePassing
415415
end
416416

417417
function EdgeConv(adj::AbstractMatrix, nn; aggr::Symbol=:max)
418-
fg = FeaturedGraph(neighbors(adj))
418+
fg = FeaturedGraph(adjacency_list(adj))
419419
EdgeConv(fg, nn, aggr)
420420
end
421421

src/layers/gn.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
end
6363

6464
function propagate(gn::T, fg::FeaturedGraph, naggr=nothing, eaggr=nothing, vaggr=nothing) where {T<:GraphNet}
65-
adj = neighbors(fg)
65+
adj = adjacency_list(fg)
6666
num_V = nv(fg)
6767
accu_edge = accumulated_edges(adj)
6868
num_E = accu_edge[end]

src/layers/msgpass.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ end
5656
end
5757

5858
function propagate(mp::T, fg::FeaturedGraph, aggr::Symbol=:add) where {T<:MessagePassing}
59-
adj = neighbors(fg)
59+
adj = adjacency_list(fg)
6060
num_V = nv(fg)
6161
accu_edge = accumulated_edges(adj)
6262
num_E = accu_edge[end]

src/operations/linalg.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Zygote.@nograd issymmetric
44

5-
function adjacency_matrix(adj::AbstractMatrix, T::DataType=eltype(adj))
5+
Zygote.@nograd function adjacency_matrix(adj::AbstractMatrix, T::DataType=eltype(adj))
66
m, n = size(adj)
77
(m == n) || throw(DimensionMismatch("adjacency matrix is not a square matrix: ($m, $n)"))
88
T.(adj)
@@ -33,6 +33,11 @@ julia> GeometricFlux.degrees(m)
3333
```
3434
"""
3535
function degrees(adj::AbstractMatrix, T::DataType=eltype(adj); dir::Symbol=:out)
36+
_degrees(T.(adj), dir)
37+
end
38+
39+
# nograd can only used without keyword arguments
40+
Zygote.@nograd function _degrees(adj::AbstractMatrix, dir::Symbol=:out)
3641
if issymmetric(adj)
3742
d = vec(sum(adj, dims=1))
3843
else
@@ -90,7 +95,7 @@ The values other than diagonal are zeros.
9095
- `dir`: direction of degree; should be `:in`, `:out`, or `:both` (optional).
9196
"""
9297
function inv_sqrt_degree_matrix(adj::AbstractMatrix, T::DataType=eltype(adj); dir::Symbol=:out)
93-
d = inv.(sqrt.(degrees(adj, T, dir=dir)))
98+
d = inv.(sqrt.(degrees(adj, T, dir=dir)))
9499
return Diagonal(d)
95100
end
96101

0 commit comments

Comments
 (0)