Skip to content

Commit 6814650

Browse files
authored
Merge pull request #108 from yuehhua/develop
Refactor
2 parents e4f7dd7 + cc3f9de commit 6814650

File tree

4 files changed

+27
-28
lines changed

4 files changed

+27
-28
lines changed

src/layers/conv.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@ struct GCNConv{T,F,S<:AbstractFeaturedGraph}
2727
end
2828

2929
function GCNConv(ch::Pair{<:Integer,<:Integer}, σ = identity;
30-
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
30+
init=glorot_uniform, T::DataType=Float32, bias::Bool=true)
3131
b = bias ? T.(init(ch[2])) : zeros(T, ch[2])
32-
fg = cache ? FeaturedGraph() : NullGraph()
32+
fg = NullGraph()
3333
GCNConv(T.(init(ch[2], ch[1])), b, σ, fg)
3434
end
3535

3636
function GCNConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, σ = identity;
37-
init=glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
37+
init=glorot_uniform, T::DataType=Float32, bias::Bool=true)
3838
b = bias ? T.(init(ch[2])) : zeros(T, ch[2])
39-
fg = cache ? FeaturedGraph(adj) : NullGraph()
39+
fg = FeaturedGraph(adj)
4040
GCNConv(T.(init(ch[2], ch[1])), b, σ, fg)
4141
end
4242

@@ -56,8 +56,10 @@ end
5656

5757
function (g::GCNConv)(fg::FeaturedGraph)
5858
X = node_feature(fg)
59-
A = adjacency_matrix(fg)
60-
g.fg isa NullGraph || (g.fg.graph = A)
59+
A = adjacency_matrix(fg) # TODO: choose graph from g or fg
60+
Zygote.ignore() do
61+
g.fg isa NullGraph || (g.fg.graph = A)
62+
end
6163
X_ = g(A, X)
6264
FeaturedGraph(A, X_)
6365
end
@@ -97,16 +99,16 @@ struct ChebConv{T,S<:AbstractFeaturedGraph}
9799
end
98100

99101
function ChebConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}, k::Integer;
100-
init = glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
102+
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
101103
b = bias ? init(ch[2]) : zeros(T, ch[2])
102-
fg = cache ? FeaturedGraph(adj) : NullGraph()
104+
fg = FeaturedGraph(adj)
103105
ChebConv(init(ch[2], ch[1], k), b, fg, k, ch[1], ch[2])
104106
end
105107

106108
function ChebConv(ch::Pair{<:Integer,<:Integer}, k::Integer;
107-
init = glorot_uniform, T::DataType=Float32, bias::Bool=true, cache::Bool=true)
109+
init = glorot_uniform, T::DataType=Float32, bias::Bool=true)
108110
b = bias ? init(ch[2]) : zeros(T, ch[2])
109-
fg = cache ? FeaturedGraph() : NullGraph()
111+
fg = NullGraph()
110112
ChebConv(init(ch[2], ch[1], k), b, fg, k, ch[1], ch[2])
111113
end
112114

@@ -138,7 +140,9 @@ end
138140
function (c::ChebConv)(fg::FeaturedGraph)
139141
@assert has_graph(fg) "A given FeaturedGraph must contain a graph."
140142
g = graph(fg)
141-
c.fg isa NullGraph || (c.fg.graph = g)
143+
Zygote.ignore() do
144+
c.fg isa NullGraph || (c.fg.graph = g)
145+
end
142146
X = node_feature(fg)
143147
= scaled_laplacian(adjacency_matrix(fg))
144148
= convert(typeof(X), L̃)

src/layers/gn.jl

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,19 @@ abstract type GraphNet end
88
@inline update_global(gn::T, ē, v̄, u) where {T<:GraphNet} = u
99

1010
@inline function update_batch_edge(gn::T, adj, E, V, u) where {T<:GraphNet}
11+
n = size(adj, 1)
1112
edge_idx = edge_index_table(adj)
12-
E_ = Vector[]
13-
for (i, js) = enumerate(adj)
14-
for j = js
15-
k = edge_idx[(i,j)]
16-
e = update_edge(gn, get_feature(E, k), get_feature(V, i), get_feature(V, j), u)
17-
push!(E_, e)
18-
end
19-
end
13+
E_ = [_apply_batch_message(gn, i, adj[i], edge_idx, E, V, u) for i in 1:n]
14+
hcat(E_...)
15+
end
16+
17+
@inline function _apply_batch_message(gn::T, i, js, edge_idx, E, V, u) where {T<:GraphNet}
18+
E_ = [update_edge(gn, get_feature(E, edge_idx[(i,j)]), get_feature(V, i), get_feature(V, j), u) for j = js]
2019
hcat(E_...)
2120
end
2221

2322
@inline function update_batch_vertex(gn::T, Ē, V, u) where {T<:GraphNet}
24-
V_ = Vector[]
25-
for i = 1:size(V,2)
26-
v = update_vertex(gn, get_feature(Ē, i), get_feature(V, i), u)
27-
push!(V_, v)
28-
end
23+
V_ = [update_vertex(gn, get_feature(Ē, i), get_feature(V, i), u) for i = 1:size(V,2)]
2924
hcat(V_...)
3025
end
3126

src/models.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ end
9696

9797
function VariationalEncoder(nn, h_dim::Integer, z_dim::Integer)
9898
VariationalEncoder(nn,
99-
GCNConv(h_dim=>z_dim; cache=false),
100-
GCNConv(h_dim=>z_dim; cache=false),
99+
GCNConv(h_dim=>z_dim),
100+
GCNConv(h_dim=>z_dim),
101101
z_dim)
102102
end
103103

test/layers/conv.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ adj = [0. 1. 0. 1.;
2626
end
2727

2828
@testset "layer without graph" begin
29-
gc = GCNConv(in_channel=>out_channel, cache=false)
29+
gc = GCNConv(in_channel=>out_channel)
3030
@test size(gc.weight) == (out_channel, in_channel)
3131
@test size(gc.bias) == (out_channel,)
3232
@test !has_graph(gc.fg)
@@ -62,7 +62,7 @@ adj = [0. 1. 0. 1.;
6262
end
6363

6464
@testset "layer without graph" begin
65-
cc = ChebConv(in_channel=>out_channel, k, cache=false)
65+
cc = ChebConv(in_channel=>out_channel, k)
6666
@test size(cc.weight) == (out_channel, in_channel, k)
6767
@test size(cc.bias) == (out_channel,)
6868
@test !has_graph(cc.fg)

0 commit comments

Comments
 (0)