|
9 | 9 | end
|
10 | 10 |
|
11 | 11 | @inline function update_batch_edge(mp::T, adj, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
|
| 12 | + n = size(adj, 1) |
12 | 13 | edge_idx = edge_index_table(adj)
|
13 |
| - E_ = Vector[] |
14 |
| - for (i, js) = enumerate(adj) |
15 |
| - for j = js |
16 |
| - k = edge_idx[(i,j)] |
17 |
| - m = message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, k)) |
18 |
| - push!(E_, m) |
19 |
| - end |
20 |
| - end |
21 |
| - hcat(E_...) |
| 14 | + hcat([apply_batch_message(mp, i, adj[i], edge_idx, E, X) for i in 1:n]...) |
| 15 | +end |
| 16 | + |
| 17 | +@inline function apply_batch_message(mp::T, i, js, edge_idx, E::CuMatrix, X::CuMatrix) where {T<:MessagePassing} |
| 18 | + hcat([message(mp, get_feature(X, i), get_feature(X, j), get_feature(E, edge_idx[(i,j)])) for j = js]...) |
22 | 19 | end
|
23 | 20 |
|
24 | 21 | @inline function update_batch_vertex(mp::T, M::AbstractMatrix, X::CuMatrix) where {T<:MessagePassing}
|
|
32 | 29 | end
|
33 | 30 |
|
34 | 31 | @inline function update_batch_vertex(mp::T, M::CuMatrix, X::CuMatrix) where {T<:MessagePassing}
|
35 |
| - X_ = Vector[] |
36 |
| - for i = 1:size(X,2) |
37 |
| - x = update(mp, get_feature(M, i), get_feature(X, i)) |
38 |
| - push!(X_, x) |
39 |
| - end |
| 32 | + X_ = [update(mp, get_feature(M, i), get_feature(X, i)) for i = 1:size(X,2)] |
40 | 33 | hcat(X_...)
|
41 | 34 | end
|
0 commit comments