Skip to content

Commit 09ddd17

Browse files
committed
Refactor CUDA message-passing
1 parent e0282b0 commit 09ddd17

File tree

2 files changed

+10
-15
lines changed

2 files changed

+10
-15
lines changed

src/cuda/msgpass.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,13 @@ end
99
end
1010

1111
@inline function update_batch_edge(mp::T, adj, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
12+
n = size(adj, 1)
1213
edge_idx = edge_index_table(adj)
13-
E_ = Vector[]
14-
for (i, js) = enumerate(adj)
15-
for j = js
16-
k = edge_idx[(i,j)]
17-
m = message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, k))
18-
push!(E_, m)
19-
end
20-
end
21-
hcat(E_...)
14+
hcat([apply_batch_message(mp, i, adj[i], edge_idx, E, X) for i in 1:n]...)
15+
end
16+
17+
@inline function apply_batch_message(mp::T, i, js, edge_idx, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
18+
hcat([message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, edge_idx[(i,j)])) for j = js]...)
2219
end
2320

2421
@inline function update_batch_vertex(mp::T, M::AbstractMatrix, X::CuMatrix) where {T<:MessagePassing}
@@ -32,10 +29,6 @@ end
3229
end
3330

3431
@inline function update_batch_vertex(mp::T, M::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
35-
X_ = Vector[]
36-
for i = 1:size(X,2)
37-
x = update(mp, get_feature(M, i), get_feature(X, i))
38-
push!(X_, x)
39-
end
32+
X_ = [update(mp, get_feature(M, i), get_feature(X, i)) for i = 1:size(X,2)]
4033
hcat(X_...)
4134
end

test/cuda/msgpass.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
in_channel = 10
22
out_channel = 5
33
N = 6
4+
T = Float32
45
adj = [0. 1. 0. 0. 0. 0.;
56
1. 0. 0. 1. 1. 1.;
67
0. 0. 0. 0. 0. 1.;
@@ -18,8 +19,9 @@ NewCudaLayer(m, n) = NewCudaLayer(randn(m,n))
1819
GeometricFlux.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
1920
GeometricFlux.update(::NewCudaLayer, m, x) = m
2021

21-
X = rand(Float32, in_channel, N) |> gpu
22+
X = rand(T, in_channel, N) |> gpu
2223
fg = FeaturedGraph(adj, X)
24+
fg.ef = Fill(zero(T), 0, 2num_E)
2325
l = NewCudaLayer(out_channel, in_channel) |> gpu
2426

2527
@testset "cuda/msgpass" begin

0 commit comments

Comments
 (0)