@@ -765,3 +765,177 @@ function Base.show(io::IO, l::CGConv)
765
765
edge_dim = d - 2 * node_dim
766
766
print (io, " CGConv(node dim=" , node_dim, " , edge dim=" , edge_dim, " )" )
767
767
end
768
+
769
+ """
770
+ SAGEConv(in => out, σ=identity, aggr=mean; normalize=true, project=false,
771
+ bias=true, num_sample=10, init=glorot_uniform)
772
+
773
+ SAmple and aggreGatE convolutional layer for GraphSAGE network.
774
+
775
+ # Arguments
776
+
777
+ - `in`: The dimension of input features.
778
+ - `out`: The dimension of output features.
779
+ - `σ`: Activation function.
780
+ - `aggr`: An aggregate function applied to the result of message function. `mean`, `max`,
781
+ `LSTM` and `GCNConv` are available.
782
+ - `normalize::Bool`: Whether to normalize features across all nodes or not.
783
+ - `project::Bool`: Whether to project, i.e. `Dense(in, in)`, before aggregation.
784
+ - `bias`: Add learnable bias.
785
+ - `num_sample::Int`: Number of samples for each node from their neighbors.
786
+ - `init`: Weights' initializer.
787
+
788
+ # Examples
789
+
790
+ ```jldoctest
791
+ julia> SAGEConv(1024=>256, relu)
792
+ SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=10)
793
+
794
+ julia> SAGEConv(1024=>256, relu, num_sample=5)
795
+ SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=5)
796
+
797
+ julia> MeanAggregator(1024=>256, relu, normalize=false)
798
+ SAGEConv(1024 => 256, relu, aggr=mean, normalize=false, #sample=10)
799
+
800
+ julia> MeanPoolAggregator(1024=>256, relu)
801
+ SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=mean, normalize=true, #sample=10)
802
+
803
+ julia> MaxPoolAggregator(1024=>256, relu)
804
+ SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=max, normalize=true, #sample=10)
805
+
806
+ julia> LSTMAggregator(1024=>256, relu)
807
+ SAGEConv(1024 => 256, relu, aggr=LSTMCell(1024 => 1024), normalize=true, #sample=10)
808
+ ```
809
+
810
+ See also [`WithGraph`](@ref) for training layer with static graph and [`MeanAggregator`](@ref),
811
+ [`MeanPoolAggregator`](@ref), [`MaxPoolAggregator`](@ref) and [`LSTMAggregator`](@ref).
812
+ """
813
+ struct SAGEConv{A,B,F,P,O} <: MessagePassing
814
+ weight1:: A
815
+ weight2:: A
816
+ bias:: B
817
+ σ:: F
818
+ proj:: P
819
+ aggr:: O
820
+ normalize:: Bool
821
+ num_sample:: Int
822
+ end
823
+
824
+ function SAGEConv (ch:: Pair{Int,Int} , σ= identity, aggr= mean;
825
+ normalize:: Bool = true , project:: Bool = false , bias:: Bool = true ,
826
+ num_sample:: Int = 10 , init= glorot_uniform)
827
+ in, out = ch
828
+ weight1 = init (out, in)
829
+ weight2 = init (out, in)
830
+ bias = Flux. create_bias (weight1, bias, out)
831
+ proj = project ? Dense (in, in) : identity
832
+ return SAGEConv (weight1, weight2, bias, σ, proj, aggr, normalize, num_sample)
833
+ end
834
+
835
+ @functor SAGEConv
836
+
837
+ message (l:: SAGEConv , x_i, x_j:: AbstractArray , e) = l. proj (x_j)
838
+
839
+ function aggregate_neighbors (l:: SAGEConv , el:: NamedTuple , aggr, E)
840
+ batch_size = size (E)[end ]
841
+ sample_idx = sample_node_index (E, l. num_sample; dims= 2 )
842
+ idx = ntuple (i -> (i == 2 ) ? sample_idx : Colon (), ndims (E))
843
+ dstsize = (size (E, 1 ), el. N, batch_size) # ensure outcome has the same dimension as x in update
844
+ xs = batched_index (el. xs[sample_idx], batch_size)
845
+ Ē = _scatter (aggr, E[idx... ], xs, dstsize)
846
+ return Ē
847
+ end
848
+
849
+ function aggregate_neighbors (l:: SAGEConv , el:: NamedTuple , aggr, E:: AbstractMatrix )
850
+ sample_idx = sample_node_index (E, l. num_sample; dims= 2 )
851
+ idx = ntuple (i -> (i == 2 ) ? sample_idx : Colon (), ndims (E))
852
+ dstsize = (size (E, 1 ), el. N) # ensure outcome has the same dimension as x in update
853
+ Ē = _scatter (aggr, E[idx... ], el. xs[sample_idx], dstsize)
854
+ return Ē
855
+ end
856
+
857
+ aggregate_neighbors (:: SAGEConv , el:: NamedTuple , lstm:: Flux.LSTMCell , E:: AbstractArray ) =
858
+ throw (ArgumentError (" SAGEConv with LSTM aggregator does not support batch learning." ))
859
+
860
+ function aggregate_neighbors (:: SAGEConv , el:: NamedTuple , lstm:: Flux.LSTMCell , E:: AbstractMatrix )
861
+ sample_idx = sample_node_index (E, el. N; dims= 2 )
862
+ idx = ntuple (i -> (i == 2 ) ? sample_idx : Colon (), ndims (E))
863
+ state, Ē = lstm (lstm. state0, E[idx... ])
864
+ return Ē
865
+ end
866
+
867
+ function update (l:: SAGEConv , m:: AbstractArray , x:: AbstractArray )
868
+ y = l. σ .(_matmul (l. weight1, x) + _matmul (l. weight2, m) .+ l. bias)
869
+ l. normalize && (y = l2normalize (y; dims= 2 )) # across all nodes
870
+ return y
871
+ end
872
+
873
+ # For variable graph
874
+ function (l:: SAGEConv )(fg:: AbstractFeaturedGraph )
875
+ nf = node_feature (fg)
876
+ GraphSignals. check_num_nodes (fg, nf)
877
+ _, V, _ = propagate (l, graph (fg), nothing , nf, nothing , l. aggr, nothing , nothing )
878
+ return ConcreteFeaturedGraph (fg, nf= V)
879
+ end
880
+
881
+ # For static graph
882
+ function (l:: SAGEConv )(el:: NamedTuple , x:: AbstractArray )
883
+ GraphSignals. check_num_nodes (el. N, x)
884
+ _, V, _ = propagate (l, el, nothing , x, nothing , l. aggr, nothing , nothing )
885
+ return V
886
+ end
887
+
888
+ function Base. show (io:: IO , l:: SAGEConv )
889
+ out_channel, in_channel = size (l. weight1)
890
+ print (io, " SAGEConv(" , in_channel, " => " , out_channel)
891
+ l. σ == identity || print (io, " , " , l. σ)
892
+ l. proj == identity || print (io, " , project=" , l. proj)
893
+ print (io, " , aggr=" , l. aggr)
894
+ print (io, " , normalize=" , l. normalize)
895
+ print (io, " , #sample=" , l. num_sample)
896
+ print (io, " )" )
897
+ end
898
+
899
+ """
900
+ MeanAggregator(in => out, σ=identity; normalize=true, project=false,
901
+ bias=true, num_sample=10, init=glorot_uniform)
902
+
903
+ SAGEConv with mean aggregator.
904
+
905
+ See also [`SAGEConv`](@ref).
906
+ """
907
+ MeanAggregator (args... ; kwargs... ) = SAGEConv (args... , mean; kwargs... )
908
+
909
+ """
910
+ MeanAggregator(in => out, σ=identity; normalize=true,
911
+ bias=true, num_sample=10, init=glorot_uniform)
912
+
913
+ SAGEConv with meanpool aggregator.
914
+
915
+ See also [`SAGEConv`](@ref).
916
+ """
917
+ MeanPoolAggregator (args... ; kwargs... ) = SAGEConv (args... , mean; project= true , kwargs... )
918
+
919
+ """
920
+ MeanAggregator(in => out, σ=identity; normalize=true,
921
+ bias=true, num_sample=10, init=glorot_uniform)
922
+
923
+ SAGEConv with maxpool aggregator.
924
+
925
+ See also [`SAGEConv`](@ref).
926
+ """
927
+ MaxPoolAggregator (args... ; kwargs... ) = SAGEConv (args... , max; project= true , kwargs... )
928
+
929
+
930
+ """
931
+ LSTMAggregator(in => out, σ=identity; normalize=true, project=false,
932
+ bias=true, num_sample=10, init=glorot_uniform)
933
+
934
+ SAGEConv with LSTM aggregator.
935
+
936
+ See also [`SAGEConv`](@ref).
937
+ """
938
+ function LSTMAggregator (args... ; kwargs... )
939
+ in_ch = args[1 ][1 ]
940
+ return SAGEConv (args... , Flux. LSTMCell (in_ch, in_ch); kwargs... )
941
+ end
0 commit comments