@@ -36,6 +36,8 @@ GCNConv(ch::Pair{Int,Int}, σ = identity; kwargs...) =
36
36
37
37
@functor GCNConv
38
38
39
+ Flux. trainable (l:: GCNConv ) = (l. weight, l. bias)
40
+
39
41
function (l:: GCNConv )(fg:: FeaturedGraph , x:: AbstractMatrix )
40
42
à = Zygote. ignore () do
41
43
GraphSignals. normalized_adjacency_matrix (fg, eltype (x); selfloop= true )
@@ -87,6 +89,8 @@ ChebConv(ch::Pair{Int,Int}, k::Int; kwargs...) =
87
89
88
90
@functor ChebConv
89
91
92
+ Flux. trainable (l:: ChebConv ) = (l. weight, l. bias)
93
+
90
94
function (c:: ChebConv )(fg:: FeaturedGraph , X:: AbstractMatrix{T} ) where T
91
95
GraphSignals. check_num_nodes (fg, X)
92
96
@assert size (X, 1 ) == size (c. weight, 2 ) " Input feature size must match input channel size."
@@ -155,6 +159,8 @@ GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; kwargs...) =
155
159
156
160
@functor GraphConv
157
161
162
+ Flux. trainable (l:: GraphConv ) = (l. weight1, l. weight2, l. bias)
163
+
158
164
message (gc:: GraphConv , x_i, x_j:: AbstractVector , e_ij) = gc. weight2 * x_j
159
165
160
166
update (gc:: GraphConv , m:: AbstractVector , x:: AbstractVector ) = gc. σ .(gc. weight1* x .+ m .+ gc. bias)
@@ -224,6 +230,8 @@ GATConv(ch::Pair{Int,Int}; kwargs...) = GATConv(NullGraph(), ch; kwargs...)
224
230
225
231
@functor GATConv
226
232
233
+ Flux. trainable (l:: GATConv ) = (l. weight, l. bias, l. a)
234
+
227
235
# Here the α that has not been softmaxed is the first number of the output message
228
236
function message (gat:: GATConv , x_i:: AbstractVector , x_j:: AbstractVector )
229
237
x_i = reshape (gat. weight* x_i, :, gat. heads)
@@ -319,6 +327,8 @@ GatedGraphConv(out_ch::Int, num_layers::Int; kwargs...) =
319
327
320
328
@functor GatedGraphConv
321
329
330
+ Flux. trainable (l:: GatedGraphConv ) = (l. weight, l. gru)
331
+
322
332
message (ggc:: GatedGraphConv , x_i, x_j:: AbstractVector , e_ij) = x_j
323
333
324
334
update (ggc:: GatedGraphConv , m:: AbstractVector , x) = m
@@ -376,6 +386,8 @@ EdgeConv(nn; kwargs...) = EdgeConv(NullGraph(), nn; kwargs...)
376
386
377
387
@functor EdgeConv
378
388
389
+ Flux. trainable (l:: EdgeConv ) = (l. nn,)
390
+
379
391
message (ec:: EdgeConv , x_i:: AbstractVector , x_j:: AbstractVector , e_ij) = ec. nn (vcat (x_i, x_j .- x_i))
380
392
update (ec:: EdgeConv , m:: AbstractVector , x) = m
381
393
@@ -423,13 +435,13 @@ function GINConv(nn, eps::Real=0f0)
423
435
GINConv (NullGraph (), nn, eps)
424
436
end
425
437
438
+ @functor GINConv
439
+
426
440
Flux. trainable (g:: GINConv ) = (fg= g. fg, nn= g. nn)
427
441
428
442
message (g:: GINConv , x_i:: AbstractVector , x_j:: AbstractVector ) = x_j
429
443
update (g:: GINConv , m:: AbstractVector , x) = g. nn ((1 + g. eps) * x + m)
430
444
431
- @functor GINConv
432
-
433
445
function (g:: GINConv )(fg:: FeaturedGraph , X:: AbstractMatrix )
434
446
gf = graph (fg)
435
447
GraphSignals. check_num_nodes (gf, X)
474
486
475
487
@functor CGConv
476
488
489
+ Flux. trainable (l:: CGConv ) = (l. Wf, l. Ws, l. bf, l. bs)
490
+
477
491
function CGConv (fg:: G , dims:: NTuple{2,Int} ;
478
492
init= glorot_uniform, bias= true , as_edge= false ) where {G<: AbstractFeaturedGraph }
479
493
node_dim, edge_dim = dims
0 commit comments