Skip to content

Commit c306989

Browse files
authored
Merge pull request #295 from FluxML/develop
Add SAGEConv
2 parents ed63093 + e15ef7b commit c306989

File tree

6 files changed

+239
-1
lines changed

6 files changed

+239
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2121
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2222
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
24+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2425
Word2Vec = "c64b6f0f-98cd-51d1-af78-58ae84944834"
2526
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2627

@@ -37,6 +38,7 @@ NNlib = "0.8"
3738
NNlibCUDA = "0.2"
3839
Optimisers = "0.2"
3940
Reexport = "1.1"
41+
StatsBase = "0.33"
4042
Word2Vec = "0.5"
4143
Zygote = "0.6"
4244
julia = "1.6"

src/GeometricFlux.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module GeometricFlux
22

33
using DelimitedFiles
44
using SparseArrays
5-
using Statistics: mean
5+
using Statistics, StatsBase
66
using LinearAlgebra
77
using Random
88
using Reexport
@@ -41,6 +41,9 @@ export
4141
EdgeConv,
4242
GINConv,
4343
CGConv,
44+
SAGEConv,
45+
MeanAggregator, MeanPoolAggregator, MaxPoolAggregator,
46+
LSTMAggregator,
4447

4548
# layer/pool
4649
GlobalPool,

src/layers/conv.jl

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -765,3 +765,177 @@ function Base.show(io::IO, l::CGConv)
765765
edge_dim = d - 2*node_dim
766766
print(io, "CGConv(node dim=", node_dim, ", edge dim=", edge_dim, ")")
767767
end
768+
769+
"""
770+
SAGEConv(in => out, σ=identity, aggr=mean; normalize=true, project=false,
771+
bias=true, num_sample=10, init=glorot_uniform)
772+
773+
SAmple and aggreGatE convolutional layer for GraphSAGE network.
774+
775+
# Arguments
776+
777+
- `in`: The dimension of input features.
778+
- `out`: The dimension of output features.
779+
- `σ`: Activation function.
780+
- `aggr`: An aggregate function applied to the result of message function. `mean`, `max`,
781+
`LSTM` and `GCNConv` are available.
782+
- `normalize::Bool`: Whether to normalize features across all nodes or not.
783+
- `project::Bool`: Whether to project, i.e. `Dense(in, in)`, before aggregation.
784+
- `bias`: Add learnable bias.
785+
- `num_sample::Int`: Number of samples for each node from their neighbors.
786+
- `init`: Weights' initializer.
787+
788+
# Examples
789+
790+
```jldoctest
791+
julia> SAGEConv(1024=>256, relu)
792+
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=10)
793+
794+
julia> SAGEConv(1024=>256, relu, num_sample=5)
795+
SAGEConv(1024 => 256, relu, aggr=mean, normalize=true, #sample=5)
796+
797+
julia> MeanAggregator(1024=>256, relu, normalize=false)
798+
SAGEConv(1024 => 256, relu, aggr=mean, normalize=false, #sample=10)
799+
800+
julia> MeanPoolAggregator(1024=>256, relu)
801+
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=mean, normalize=true, #sample=10)
802+
803+
julia> MaxPoolAggregator(1024=>256, relu)
804+
SAGEConv(1024 => 256, relu, project=Dense(1024 => 1024), aggr=max, normalize=true, #sample=10)
805+
806+
julia> LSTMAggregator(1024=>256, relu)
807+
SAGEConv(1024 => 256, relu, aggr=LSTMCell(1024 => 1024), normalize=true, #sample=10)
808+
```
809+
810+
See also [`WithGraph`](@ref) for training layer with static graph and [`MeanAggregator`](@ref),
811+
[`MeanPoolAggregator`](@ref), [`MaxPoolAggregator`](@ref) and [`LSTMAggregator`](@ref).
812+
"""
813+
struct SAGEConv{A,B,F,P,O} <: MessagePassing
814+
weight1::A
815+
weight2::A
816+
bias::B
817+
σ::F
818+
proj::P
819+
aggr::O
820+
normalize::Bool
821+
num_sample::Int
822+
end
823+
824+
function SAGEConv(ch::Pair{Int,Int}, σ=identity, aggr=mean;
825+
normalize::Bool=true, project::Bool=false, bias::Bool=true,
826+
num_sample::Int=10, init=glorot_uniform)
827+
in, out = ch
828+
weight1 = init(out, in)
829+
weight2 = init(out, in)
830+
bias = Flux.create_bias(weight1, bias, out)
831+
proj = project ? Dense(in, in) : identity
832+
return SAGEConv(weight1, weight2, bias, σ, proj, aggr, normalize, num_sample)
833+
end
834+
835+
@functor SAGEConv
836+
837+
message(l::SAGEConv, x_i, x_j::AbstractArray, e) = l.proj(x_j)
838+
839+
function aggregate_neighbors(l::SAGEConv, el::NamedTuple, aggr, E)
840+
batch_size = size(E)[end]
841+
sample_idx = sample_node_index(E, l.num_sample; dims=2)
842+
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
843+
dstsize = (size(E, 1), el.N, batch_size) # ensure outcome has the same dimension as x in update
844+
xs = batched_index(el.xs[sample_idx], batch_size)
845+
= _scatter(aggr, E[idx...], xs, dstsize)
846+
return
847+
end
848+
849+
function aggregate_neighbors(l::SAGEConv, el::NamedTuple, aggr, E::AbstractMatrix)
850+
sample_idx = sample_node_index(E, l.num_sample; dims=2)
851+
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
852+
dstsize = (size(E, 1), el.N) # ensure outcome has the same dimension as x in update
853+
= _scatter(aggr, E[idx...], el.xs[sample_idx], dstsize)
854+
return
855+
end
856+
857+
aggregate_neighbors(::SAGEConv, el::NamedTuple, lstm::Flux.LSTMCell, E::AbstractArray) =
858+
throw(ArgumentError("SAGEConv with LSTM aggregator does not support batch learning."))
859+
860+
function aggregate_neighbors(::SAGEConv, el::NamedTuple, lstm::Flux.LSTMCell, E::AbstractMatrix)
861+
sample_idx = sample_node_index(E, el.N; dims=2)
862+
idx = ntuple(i -> (i == 2) ? sample_idx : Colon(), ndims(E))
863+
state, Ē = lstm(lstm.state0, E[idx...])
864+
return
865+
end
866+
867+
function update(l::SAGEConv, m::AbstractArray, x::AbstractArray)
868+
y = l.σ.(_matmul(l.weight1, x) + _matmul(l.weight2, m) .+ l.bias)
869+
l.normalize && (y = l2normalize(y; dims=2)) # across all nodes
870+
return y
871+
end
872+
873+
# For variable graph
874+
function (l::SAGEConv)(fg::AbstractFeaturedGraph)
875+
nf = node_feature(fg)
876+
GraphSignals.check_num_nodes(fg, nf)
877+
_, V, _ = propagate(l, graph(fg), nothing, nf, nothing, l.aggr, nothing, nothing)
878+
return ConcreteFeaturedGraph(fg, nf=V)
879+
end
880+
881+
# For static graph
882+
function (l::SAGEConv)(el::NamedTuple, x::AbstractArray)
883+
GraphSignals.check_num_nodes(el.N, x)
884+
_, V, _ = propagate(l, el, nothing, x, nothing, l.aggr, nothing, nothing)
885+
return V
886+
end
887+
888+
function Base.show(io::IO, l::SAGEConv)
889+
out_channel, in_channel = size(l.weight1)
890+
print(io, "SAGEConv(", in_channel, " => ", out_channel)
891+
l.σ == identity || print(io, ", ", l.σ)
892+
l.proj == identity || print(io, ", project=", l.proj)
893+
print(io, ", aggr=", l.aggr)
894+
print(io, ", normalize=", l.normalize)
895+
print(io, ", #sample=", l.num_sample)
896+
print(io, ")")
897+
end
898+
899+
"""
900+
MeanAggregator(in => out, σ=identity; normalize=true, project=false,
901+
bias=true, num_sample=10, init=glorot_uniform)
902+
903+
SAGEConv with mean aggregator.
904+
905+
See also [`SAGEConv`](@ref).
906+
"""
907+
MeanAggregator(args...; kwargs...) = SAGEConv(args..., mean; kwargs...)
908+
909+
"""
910+
MeanAggregator(in => out, σ=identity; normalize=true,
911+
bias=true, num_sample=10, init=glorot_uniform)
912+
913+
SAGEConv with meanpool aggregator.
914+
915+
See also [`SAGEConv`](@ref).
916+
"""
917+
MeanPoolAggregator(args...; kwargs...) = SAGEConv(args..., mean; project=true, kwargs...)
918+
919+
"""
920+
MeanAggregator(in => out, σ=identity; normalize=true,
921+
bias=true, num_sample=10, init=glorot_uniform)
922+
923+
SAGEConv with maxpool aggregator.
924+
925+
See also [`SAGEConv`](@ref).
926+
"""
927+
MaxPoolAggregator(args...; kwargs...) = SAGEConv(args..., max; project=true, kwargs...)
928+
929+
930+
"""
931+
LSTMAggregator(in => out, σ=identity; normalize=true, project=false,
932+
bias=true, num_sample=10, init=glorot_uniform)
933+
934+
SAGEConv with LSTM aggregator.
935+
936+
See also [`SAGEConv`](@ref).
937+
"""
938+
function LSTMAggregator(args...; kwargs...)
939+
in_ch = args[1][1]
940+
return SAGEConv(args..., Flux.LSTMCell(in_ch, in_ch); kwargs...)
941+
end

src/operation.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ aggregate(::typeof(max), X) = maximum(X, dims=2)
2222
aggregate(::typeof(min), X) = minimum(X, dims=2)
2323
aggregate(::typeof(mean), X) = mean(X, dims=2)
2424

25+
function l2normalize(X::AbstractArray; dims=1)
26+
l2norm = .√(sum(abs2, X, dims=dims))
27+
return X ./ l2norm
28+
end
29+
2530
function incidence_matrix(xs::AbstractVector{T}, N) where {T}
2631
A = similar(xs, T, size(xs, 1), N)
2732
copyto!(A, Array(I(N))[Array(xs), :])

src/sampling.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,10 @@ function alias_sample(J::AbstractVector{<:Integer}, q::AbstractVector{<:Real})
5353
return J[small_index]
5454
end
5555
end
56+
57+
function sample_node_index(X::AbstractArray, num_sample::Int; dims::Int=1)
58+
n = size(X, dims)
59+
return StatsBase.sample(1:n, num_sample, replace=false)
60+
end
61+
62+
@non_differentiable sample_node_index(x...)

test/layers/conv.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,51 @@
377377
@test length(g.grads) == 4
378378
end
379379
end
380+
381+
@testset "SAGEConv" begin
382+
aggregators = [MeanAggregator, MeanPoolAggregator, MaxPoolAggregator,
383+
LSTMAggregator]
384+
@testset "layer without graph" begin
385+
for conv in aggregators
386+
l = conv(in_channel=>out_channel, relu, num_sample=3)
387+
388+
X = rand(T, in_channel, N)
389+
fg = FeaturedGraph(adj, nf=X)
390+
fg_ = l(fg)
391+
@test size(node_feature(fg_)) == (out_channel, N)
392+
@test_throws MethodError l(X)
393+
394+
g = Zygote.gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
395+
if l.proj == identity
396+
if conv == LSTMAggregator
397+
@test length(g.grads) == 10
398+
else
399+
@test length(g.grads) == 5
400+
end
401+
else
402+
@test length(g.grads) == 7
403+
end
404+
end
405+
end
406+
407+
@testset "layer with static graph" begin
408+
for conv in aggregators
409+
X = rand(T, in_channel, N, batch_size)
410+
l = WithGraph(fg, conv(in_channel=>out_channel, relu, num_sample=3))
411+
if conv == LSTMAggregator
412+
@test_throws ArgumentError l(X)
413+
else
414+
Y = l(X)
415+
@test size(Y) == (out_channel, N, batch_size)
416+
417+
g = Zygote.gradient(() -> sum(l(X)), Flux.params(l))
418+
if l.layer.proj == identity
419+
@test length(g.grads) == 3
420+
else
421+
@test length(g.grads) == 5
422+
end
423+
end
424+
end
425+
end
426+
end
380427
end

0 commit comments

Comments
 (0)