Skip to content

Commit 2801a05

Browse files
authored
Merge pull request #333 from FluxML/compact
Support GraphSignals to 0.7
2 parents cb1642c + f46cc74 commit 2801a05

File tree

3 files changed

+14
-15
lines changed

3 files changed

+14
-15
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ DataStructures = "0.18"
3131
FillArrays = "0.13"
3232
Flux = "0.12 - 0.13"
3333
GraphMLDatasets = "0.1"
34-
GraphSignals = "0.6"
34+
GraphSignals = "0.7"
3535
Graphs = "1"
3636
NNlib = "0.8"
3737
NNlibCUDA = "0.2"

test/layers/gn.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,17 @@
2121
@testset "without aggregation" begin
2222
function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
2323
nf = node_feature(fg)
24-
ef = edge_feature(fg)
2524
GraphSignals.check_num_nodes(fg, nf)
26-
GraphSignals.check_num_edges(fg, ef)
27-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), nothing, nothing, nothing)
25+
return GeometricFlux.propagate(l, graph(fg), nothing, nf, nothing, nothing, nothing, nothing)
2826
end
2927

3028
fg = FeaturedGraph(adj, nf=nf)
3129
l = NewGNLayer()
3230
ef_, nf_, gf_ = l(fg)
3331

3432
@test nf_ == nf
35-
@test size(ef_) == (0, 2E)
36-
@test size(gf_) == (0,)
33+
@test isnothing(ef_)
34+
@test isnothing(gf_)
3735
end
3836

3937
@testset "with neighbor aggregation" begin
@@ -42,16 +40,16 @@
4240
ef = edge_feature(fg)
4341
GraphSignals.check_num_nodes(fg, nf)
4442
GraphSignals.check_num_edges(fg, ef)
45-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing)
43+
return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing)
4644
end
4745

4846
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
4947
l = NewGNLayer()
5048
ef_, nf_, gf_ = l(fg)
5149

5250
@test size(nf_) == (in_channel, V)
53-
@test size(ef_) == (0, 2E)
54-
@test size(gf_) == (0,)
51+
@test size(ef_) == (in_channel, 2E)
52+
@test isnothing(gf_)
5553
end
5654

5755
GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = similar(e, out_channel, size(e)[2:end]...)
@@ -61,7 +59,7 @@
6159
ef = edge_feature(fg)
6260
GraphSignals.check_num_nodes(fg, nf)
6361
GraphSignals.check_num_edges(fg, ef)
64-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing)
62+
return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing)
6563
end
6664

6765
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
@@ -70,17 +68,18 @@
7068

7169
@test size(nf_) == (in_channel, V)
7270
@test size(ef_) == (out_channel, 2E)
73-
@test size(gf_) == (0,)
71+
@test isnothing(gf_)
7472
end
7573

7674
GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = similar(vi, out_channel, size(vi)[2:end]...)
7775
@testset "update edge/vertex with all aggregation" begin
7876
function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
7977
nf = node_feature(fg)
8078
ef = edge_feature(fg)
79+
gf = global_feature(fg)
8180
GraphSignals.check_num_nodes(fg, nf)
8281
GraphSignals.check_num_edges(fg, ef)
83-
return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, +, +)
82+
return GeometricFlux.propagate(l, graph(fg), ef, nf, gf, +, +, +)
8483
end
8584

8685
fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf)

test/layers/msgpass.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
@test GraphSignals.adjacency_matrix(fg_) == adj
3838
@test size(node_feature(fg_)) == (in_channel, num_V)
3939
@test size(edge_feature(fg_)) == (0, num_E)
40-
@test size(global_feature(fg_)) == (0,)
40+
@test !has_global_feature(fg_)
4141
end
4242

4343
GeometricFlux.message(l::NewLayer, x_i, x_j::AbstractMatrix, e_ij) = l.weight * x_j
@@ -47,7 +47,7 @@
4747
@test GraphSignals.adjacency_matrix(fg_) == adj
4848
@test size(node_feature(fg_)) == (out_channel, num_V)
4949
@test size(edge_feature(fg_)) == (0, num_E)
50-
@test size(global_feature(fg_)) == (0,)
50+
@test !has_global_feature(fg_)
5151
end
5252

5353
GeometricFlux.update(l::NewLayer, m::AbstractMatrix, x) = l.weight * x + m
@@ -57,6 +57,6 @@
5757
@test GraphSignals.adjacency_matrix(fg_) == adj
5858
@test size(node_feature(fg_)) == (out_channel, num_V)
5959
@test size(edge_feature(fg_)) == (0, num_E)
60-
@test size(global_feature(fg_)) == (0,)
60+
@test !has_global_feature(fg_)
6161
end
6262
end

0 commit comments

Comments
 (0)