@@ -128,7 +128,7 @@ function (l::ChebConv)(fg::AbstractFeaturedGraph)
128
128
nf = node_feature (fg)
129
129
GraphSignals. check_num_nodes (fg, nf)
130
130
@assert size (nf, 1 ) == size (l. weight, 2 ) " Input feature size must match input channel size."
131
-
131
+
132
132
L̃ = ChainRulesCore. ignore_derivatives () do
133
133
GraphSignals. scaled_laplacian (fg, eltype (nf))
134
134
end
@@ -245,7 +245,7 @@ Graph attentional layer.
245
245
- `out`: The dimension of output features.
246
246
- `bias::Bool`: Keyword argument, whether to learn the additive bias.
247
247
- `σ`: Activation function.
248
- - `heads`: Number attention heads
248
+ - `heads`: Number attention heads
249
249
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
250
250
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
251
251
280
280
281
281
function GATConv (ch:: Pair{Int,Int} , σ= identity; heads:: Int = 1 , concat:: Bool = true ,
282
282
negative_slope= 0.2f0 , init= glorot_uniform, bias:: Bool = true )
283
- in, out = ch
283
+ in, out = ch
284
284
W = init (out* heads, in)
285
285
b = Flux. create_bias (W, bias, out* heads)
286
286
a = init (2 * out, heads)
@@ -372,7 +372,7 @@ Graph attentional layer v2.
372
372
- `in`: The dimension of input features.
373
373
- `out`: The dimension of output features.
374
374
- `σ`: Activation function.
375
- - `heads`: Number attention heads
375
+ - `heads`: Number attention heads
376
376
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
377
377
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
378
378
@@ -661,7 +661,7 @@ GINConv(nn, eps=0f0) = GINConv(nn, eps)
661
661
662
662
Flux. trainable (g:: GINConv ) = (g. nn,)
663
663
664
- message (g:: GINConv , x_i:: AbstractArray , x_j:: AbstractArray ) = x_j
664
+ message (g:: GINConv , x_i:: AbstractArray , x_j:: AbstractArray ) = x_j
665
665
update (g:: GINConv , m:: AbstractArray , x:: AbstractArray ) = g. nn ((1 + g. eps) * x + m)
666
666
667
667
# For variable graph
@@ -705,30 +705,24 @@ CGConv(node dim=128, edge dim=32)
705
705
706
706
See also [`WithGraph`](@ref) for training layer with static graph.
707
707
"""
708
- struct CGConv{A<: AbstractMatrix ,B} <: MessagePassing
709
- Wf:: A
710
- Ws:: A
711
- bf:: B
712
- bs:: B
708
+ struct CGConv{A,B} <: MessagePassing
709
+ f:: A
710
+ s:: B
713
711
end
714
712
715
713
@functor CGConv
716
714
717
- Flux. trainable (l:: CGConv ) = (l. Wf, l. Ws, l. bf, l. bs)
718
-
719
715
function CGConv (dims:: NTuple{2,Int} ; init= glorot_uniform, bias= true )
720
716
node_dim, edge_dim = dims
721
- Wf = init (node_dim, 2 * node_dim + edge_dim)
722
- Ws = init (node_dim, 2 * node_dim + edge_dim)
723
- bf = Flux. create_bias (Wf, bias, node_dim)
724
- bs = Flux. create_bias (Ws, bias, node_dim)
725
- return CGConv (Wf, Ws, bf, bs)
717
+ f = Dense (2 * node_dim + edge_dim, node_dim; bias= bias, init= init)
718
+ s = Dense (2 * node_dim + edge_dim, node_dim; bias= bias, init= init)
719
+ return CGConv (f, s)
726
720
end
727
721
728
- function message (c :: CGConv , x_i:: AbstractArray , x_j:: AbstractArray , e:: AbstractArray )
722
+ function message (l :: CGConv , x_i:: AbstractArray , x_j:: AbstractArray , e:: AbstractArray )
729
723
z = vcat (x_i, x_j, e)
730
724
731
- return σ .(_matmul (c . Wf, z) .+ c . bf ) .* softplus .(_matmul (c . Ws, z) .+ c . bs )
725
+ return σ .(l . f (z) ) .* softplus .(l . s (z) )
732
726
end
733
727
734
728
update (c:: CGConv , m:: AbstractArray , x) = x + m
@@ -752,7 +746,7 @@ function (l::CGConv)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
752
746
end
753
747
754
748
function Base. show (io:: IO , l:: CGConv )
755
- node_dim, d = size (l. Wf )
749
+ node_dim, d = size (l. f . weight )
756
750
edge_dim = d - 2 * node_dim
757
751
print (io, " CGConv(node dim=" , node_dim, " , edge dim=" , edge_dim, " )" )
758
752
end
0 commit comments