|
21 | 21 | @testset "without aggregation" begin
|
22 | 22 | function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
|
23 | 23 | nf = node_feature(fg)
|
24 |
| - ef = edge_feature(fg) |
25 | 24 | GraphSignals.check_num_nodes(fg, nf)
|
26 |
| - GraphSignals.check_num_edges(fg, ef) |
27 |
| - return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), nothing, nothing, nothing) |
| 25 | + return GeometricFlux.propagate(l, graph(fg), nothing, nf, nothing, nothing, nothing, nothing) |
28 | 26 | end
|
29 | 27 |
|
30 | 28 | fg = FeaturedGraph(adj, nf=nf)
|
31 | 29 | l = NewGNLayer()
|
32 | 30 | ef_, nf_, gf_ = l(fg)
|
33 | 31 |
|
34 | 32 | @test nf_ == nf
|
35 |
| - @test size(ef_) == (0, 2E) |
36 |
| - @test size(gf_) == (0,) |
| 33 | + @test isnothing(ef_) |
| 34 | + @test isnothing(gf_) |
37 | 35 | end
|
38 | 36 |
|
39 | 37 | @testset "with neighbor aggregation" begin
|
|
42 | 40 | ef = edge_feature(fg)
|
43 | 41 | GraphSignals.check_num_nodes(fg, nf)
|
44 | 42 | GraphSignals.check_num_edges(fg, ef)
|
45 |
| - return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing) |
| 43 | + return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing) |
46 | 44 | end
|
47 | 45 |
|
48 | 46 | fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
|
49 | 47 | l = NewGNLayer()
|
50 | 48 | ef_, nf_, gf_ = l(fg)
|
51 | 49 |
|
52 | 50 | @test size(nf_) == (in_channel, V)
|
53 |
| - @test size(ef_) == (0, 2E) |
54 |
| - @test size(gf_) == (0,) |
| 51 | + @test size(ef_) == (in_channel, 2E) |
| 52 | + @test isnothing(gf_) |
55 | 53 | end
|
56 | 54 |
|
57 | 55 | GeometricFlux.update_edge(l::NewGNLayer, e, vi, vj, u) = similar(e, out_channel, size(e)[2:end]...)
|
|
61 | 59 | ef = edge_feature(fg)
|
62 | 60 | GraphSignals.check_num_nodes(fg, nf)
|
63 | 61 | GraphSignals.check_num_edges(fg, ef)
|
64 |
| - return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, nothing, nothing) |
| 62 | + return GeometricFlux.propagate(l, graph(fg), ef, nf, nothing, +, nothing, nothing) |
65 | 63 | end
|
66 | 64 |
|
67 | 65 | fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=zeros(0))
|
|
70 | 68 |
|
71 | 69 | @test size(nf_) == (in_channel, V)
|
72 | 70 | @test size(ef_) == (out_channel, 2E)
|
73 |
| - @test size(gf_) == (0,) |
| 71 | + @test isnothing(gf_) |
74 | 72 | end
|
75 | 73 |
|
76 | 74 | GeometricFlux.update_vertex(l::NewGNLayer, ē, vi, u) = similar(vi, out_channel, size(vi)[2:end]...)
|
77 | 75 | @testset "update edge/vertex with all aggregation" begin
|
78 | 76 | function (l::NewGNLayer)(fg::AbstractFeaturedGraph)
|
79 | 77 | nf = node_feature(fg)
|
80 | 78 | ef = edge_feature(fg)
|
| 79 | + gf = global_feature(fg) |
81 | 80 | GraphSignals.check_num_nodes(fg, nf)
|
82 | 81 | GraphSignals.check_num_edges(fg, ef)
|
83 |
| - return GeometricFlux.propagate(l, graph(fg), ef, nf, global_feature(fg), +, +, +) |
| 82 | + return GeometricFlux.propagate(l, graph(fg), ef, nf, gf, +, +, +) |
84 | 83 | end
|
85 | 84 |
|
86 | 85 | fg = FeaturedGraph(adj, nf=nf, ef=ef, gf=gf)
|
|
0 commit comments