Skip to content

Commit 9d474e4

Browse files
authored
Merge pull request #119 from yuehhua/refactor
Refactor
2 parents dceaec8 + 37b76ee commit 9d474e4

File tree

8 files changed

+41
-56
lines changed

8 files changed

+41
-56
lines changed

src/GeometricFlux.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using SparseArrays: SparseMatrixCSC
55
using LinearAlgebra: Adjoint, norm
66
using Reexport
77

8+
using CUDA
89
using FillArrays: Fill
910
using Flux
1011
using Flux: glorot_uniform, leakyrelu, GRUCell
@@ -105,7 +106,6 @@ using .Datasets
105106

106107
function __init__()
107108
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin
108-
using CUDA
109109
include("cuda/msgpass.jl")
110110
include("cuda/conv.jl")
111111
include("cuda/pool.jl")

src/cuda/msgpass.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@ end
99
end
1010

1111
@inline function update_batch_edge(mp::T, adj, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
12+
n = size(adj, 1)
1213
edge_idx = edge_index_table(adj)
13-
E_ = Vector[]
14-
for (i, js) = enumerate(adj)
15-
for j = js
16-
k = edge_idx[(i,j)]
17-
m = message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, k))
18-
push!(E_, m)
19-
end
20-
end
21-
hcat(E_...)
14+
hcat([apply_batch_message(mp, i, adj[i], edge_idx, E, X) for i in 1:n]...)
15+
end
16+
17+
@inline function apply_batch_message(mp::T, i, js, edge_idx, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
18+
hcat([message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, edge_idx[(i,j)])) for j = js]...)
2219
end
2320

2421
@inline function update_batch_vertex(mp::T, M::AbstractMatrix, X::CuMatrix) where {T<:MessagePassing}
@@ -32,10 +29,6 @@ end
3229
end
3330

3431
@inline function update_batch_vertex(mp::T, M::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
35-
X_ = Vector[]
36-
for i = 1:size(X,2)
37-
x = update(mp, get_feature(M, i), get_feature(X, i))
38-
push!(X_, x)
39-
end
32+
X_ = [update(mp, get_feature(M, i), get_feature(X, i)) for i = 1:size(X,2)]
4033
hcat(X_...)
4134
end

src/graph/index.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,20 +53,31 @@ end
5353
Zygote.@nograd vertex_pair_table
5454

5555
"""
56-
edge_index_table(adj[, num_E])
56+
edge_index_table(adj[, directed])
5757
5858
Generate a mapping from vertex pair (i, j) to edge index. The edge indecies are determined by
5959
the sorted vertex indecies.
6060
"""
61-
function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}},
62-
num_E=sum(map(length, adj)))
61+
function edge_index_table(adj::AbstractVector{<:AbstractVector{<:Integer}}, directed::Bool=is_directed(adj))
6362
table = Dict{Tuple{UInt32,UInt32},UInt64}()
6463
e = one(UInt64)
65-
for (i, js) = enumerate(adj)
66-
js = sort(js)
67-
for j = js
68-
table[(i, j)] = e
69-
e += one(UInt64)
64+
if directed
65+
for (i, js) = enumerate(adj)
66+
js = sort(js)
67+
for j = js
68+
table[(i, j)] = e
69+
e += one(UInt64)
70+
end
71+
end
72+
else
73+
for (i, js) = enumerate(adj)
74+
js = sort(js)
75+
js = js[i .≤ js]
76+
for j = js
77+
table[(i, j)] = e
78+
table[(j, i)] = e
79+
e += one(UInt64)
80+
end
7081
end
7182
end
7283
table
@@ -80,6 +91,8 @@ function edge_index_table(vpair::AbstractVector{<:Tuple})
8091
table
8192
end
8293

94+
edge_index_table(fg::FeaturedGraph) = edge_index_table(fg.graph, fg.directed)
95+
8396
Zygote.@nograd edge_index_table
8497

8598
function transform(X::AbstractArray, vpair::AbstractVector{<:Tuple}, num_V)

src/layers/gn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ abstract type GraphNet end
1010
@inline function update_batch_edge(gn::T, adj, E, V, u) where {T<:GraphNet}
1111
n = size(adj, 1)
1212
edge_idx = edge_index_table(adj)
13-
E_ = [_apply_batch_message(gn, i, adj[i], edge_idx, E, V, u) for i in 1:n]
13+
E_ = [apply_batch_message(gn, i, adj[i], edge_idx, E, V, u) for i in 1:n]
1414
hcat(E_...)
1515
end
1616

17-
@inline function _apply_batch_message(gn::T, i, js, edge_idx, E, V, u) where {T<:GraphNet}
17+
@inline function apply_batch_message(gn::T, i, js, edge_idx, E, V, u) where {T<:GraphNet}
1818
E_ = [update_edge(gn, get_feature(E, edge_idx[(i,j)]), get_feature(V, i), get_feature(V, j), u) for j = js]
1919
hcat(E_...)
2020
end

src/utils.jl

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,3 @@
1-
## Indexing
2-
3-
function range_indecies(idx::Tuple)
4-
x = Vector{Any}(undef, length(idx))
5-
for (i,n) in enumerate(idx)
6-
x[i] = 1:n
7-
end
8-
x
9-
end
10-
11-
replace_last_index!(idx::Vector, x) = (idx[end] = x; idx)
12-
13-
function assign!(A::AbstractArray, B::AbstractArray{T,N}; last_dim=1:size(B,N)) where {T,N}
14-
A_dims, B_dims = size(A), size(B)
15-
@assert A_dims[1:end-1] == B_dims[1:end-1] "Inconsistent dimensions with $(A_dims[1:end-1]) and $(B_dims[1:end-1])"
16-
A_dims = replace_last_index!(range_indecies(A_dims), last_dim)
17-
B_dims = range_indecies(B_dims)
18-
A_idxs = CartesianIndices(Tuple(A_dims))
19-
B_idxs = CartesianIndices(Tuple(B_dims))
20-
for (Aidx, Bidx) = zip(A_idxs, B_idxs)
21-
A[Aidx] = B[Bidx]
22-
end
23-
A
24-
end
25-
26-
27-
281
## Top-k pooling
292

303
function topk_index(y::AbstractVector, k::Integer)
@@ -38,8 +11,9 @@ topk_index(y::Adjoint, k::Integer) = topk_index(y', k)
3811

3912
## Get feature with defaults
4013

41-
get_feature(::Nothing, i::Integer) = zeros(0)
42-
get_feature(A::AbstractMatrix, i::Integer) = (i size(A,2)) ? view(A, :, i) : zeros(0)
14+
get_feature(::Nothing, i) = nothing
15+
get_feature(A::Fill{T,2,Axes}, i::Integer) where {T,Axes} = view(A, :, 1)
16+
get_feature(A::AbstractMatrix, i::Integer) = view(A, :, i)
4317

4418
"""
4519
bypass_graph(nf_func, ef_func, gf_func)

test/cuda/msgpass.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
in_channel = 10
22
out_channel = 5
33
N = 6
4+
T = Float32
45
adj = [0. 1. 0. 0. 0. 0.;
56
1. 0. 0. 1. 1. 1.;
67
0. 0. 0. 0. 0. 1.;
@@ -18,8 +19,9 @@ NewCudaLayer(m, n) = NewCudaLayer(randn(m,n))
1819
GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
1920
GeometricFlux.update(::NewCudaLayer, m, x) = m
2021

21-
X = rand(Float32, in_channel, N) |> gpu
22+
X = rand(T, in_channel, N) |> gpu
2223
fg = FeaturedGraph(adj, X)
24+
fg.ef = Fill(zero(T), 0, 2num_E)
2325
l = NewCudaLayer(out_channel, in_channel) |> gpu
2426

2527
@testset "cuda/msgpass" begin

test/layers/msgpass.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ in_channel = 10
22
out_channel = 5
33
num_V = 6
44
num_E = 7
5+
T = Float32
56
adj = [0. 1. 0. 0. 0. 0.;
67
1. 0. 0. 1. 1. 1.;
78
0. 0. 0. 0. 0. 1.;
@@ -16,8 +17,9 @@ NewLayer(m, n) = NewLayer(randn(m,n))
1617

1718
(l::NewLayer)(fg) = propagate(l, fg, :add)
1819

19-
X = Array(reshape(1:num_V*in_channel, in_channel, num_V))
20+
X = Array{T}(reshape(1:num_V*in_channel, in_channel, num_V))
2021
fg = FeaturedGraph(adj, X)
22+
fg.ef = Fill(zero(T), 0, 2num_E)
2123

2224
l = NewLayer(out_channel, in_channel)
2325

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using GeometricFlux
22
using GeometricFlux.Datasets
33
using Flux
44
using Flux: @functor
5+
using FillArrays
56
using GraphSignals
67
using StaticArrays: @MMatrix, @MArray
78
using LightGraphs: SimpleGraph, SimpleDiGraph, add_edge!, nv, ne

0 commit comments

Comments
 (0)