Skip to content

Commit ceac5ee

Browse files
authored
Merge pull request #277 from FluxML/develop
Update docstring and add deprecate
2 parents fb73d23 + 8f15492 commit ceac5ee

File tree

4 files changed

+114
-9
lines changed

4 files changed

+114
-9
lines changed

docs/make.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using GeometricFlux
44

55
bib = CitationBibliography(joinpath(@__DIR__, "bibliography.bib"), sorting=:nyt)
66

7+
DocMeta.setdocmeta!(GeometricFlux, :DocTestSetup, :(using GeometricFlux, Flux); recursive=true)
8+
79
makedocs(
810
bib,
911
sitename = "GeometricFlux.jl",

docs/src/manual/conv.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ where ``f_{\Theta}`` denotes a neural network parametrized by ``\Theta``, *i.e.*
119119
```@docs
120120
GINConv
121121
```
122+
122123
Reference: [How Powerful are Graph Neural Networks?](https://arxiv.org/pdf/1810.00826.pdf)
123124

124125
## Crystal Graph Convolutional Network

src/layers/conv.jl

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@ of size `(num_features, num_nodes)`.
1212
- `bias`: Add learnable bias.
1313
- `init`: Weights' initializer.
1414
15-
# Example
15+
# Examples
1616
1717
```jldoctest
18-
julia> using GeometricFlux, Flux
19-
2018
julia> gc = GCNConv(1024=>256, relu)
2119
GCNConv(1024 => 256, relu)
2220
```
@@ -83,7 +81,7 @@ Chebyshev spectral graph convolutional layer.
8381
- `bias`: Add learnable bias.
8482
- `init`: Weights' initializer.
8583
86-
# Example
84+
# Examples
8785
8886
```jldoctest
8987
julia> cc = ChebConv(1024=>256, 5, relu)
@@ -107,6 +105,8 @@ function ChebConv(ch::Pair{Int,Int}, k::Int, σ=identity;
107105
ChebConv(W, b, k, σ)
108106
end
109107

108+
@deprecate ChebConv(fg, args...; kwargs...) WithGraph(fg, ChebConv(args...; kwargs...))
109+
110110
@functor ChebConv
111111

112112
Flux.trainable(l::ChebConv) = (l.weight, l.bias)
@@ -167,6 +167,18 @@ Graph neural network layer.
167167
`*`, `/`, `max`, `min` and `mean` are available.
168168
- `bias`: Add learnable bias.
169169
- `init`: Weights' initializer.
170+
171+
# Examples
172+
173+
```jldoctest
174+
julia> GraphConv(1024=>256, relu)
175+
GraphConv(1024 => 256, relu, aggr=+)
176+
177+
julia> GraphConv(1024=>256, relu, *)
178+
GraphConv(1024 => 256, relu, aggr=*)
179+
```
180+
181+
See also [`WithGraph`](@ref) for training layer with static graph.
170182
"""
171183
struct GraphConv{A<:AbstractMatrix,B,F,O} <: MessagePassing
172184
weight1::A
@@ -185,6 +197,8 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
185197
GraphConv(W1, W2, b, σ, aggr)
186198
end
187199

200+
@deprecate GraphConv(fg, args...; kwargs...) WithGraph(fg, GraphConv(args...; kwargs...))
201+
188202
@functor GraphConv
189203

190204
Flux.trainable(l::GraphConv) = (l.weight1, l.weight2, l.bias)
@@ -234,6 +248,24 @@ Graph attentional layer.
234248
- `heads`: Number attention heads
235249
- `concat`: Concatenate layer output or not. If not, layer output is averaged.
236250
- `negative_slope::Real`: Keyword argument, the parameter of LeakyReLU.
251+
252+
# Examples
253+
254+
```jldoctest
255+
julia> GATConv(1024=>256, relu)
256+
GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.2))
257+
258+
julia> GATConv(1024=>256, relu, heads=4)
259+
GATConv(1024=>1024, heads=4, concat=true, LeakyReLU(λ=0.2))
260+
261+
julia> GATConv(1024=>256, relu, heads=4, concat=false)
262+
GATConv(1024=>1024, heads=4, concat=false, LeakyReLU(λ=0.2))
263+
264+
julia> GATConv(1024=>256, relu, negative_slope=0.1f0)
265+
GATConv(1024=>256, heads=1, concat=true, LeakyReLU(λ=0.1))
266+
```
267+
268+
See also [`WithGraph`](@ref) for training layer with static graph.
237269
"""
238270
struct GATConv{T,A<:AbstractMatrix{T},B,F} <: MessagePassing
239271
weight::A
@@ -255,6 +287,8 @@ function GATConv(ch::Pair{Int,Int}, σ=identity; heads::Int=1, concat::Bool=true
255287
GATConv(W, b, a, σ, negative_slope, ch, heads, concat)
256288
end
257289

290+
@deprecate GATConv(fg, args...; kwargs...) WithGraph(fg, GATConv(args...; kwargs...))
291+
258292
@functor GATConv
259293

260294
Flux.trainable(l::GATConv) = (l.weight, l.bias, l.a)
@@ -342,6 +376,8 @@ function Base.show(io::IO, l::GATConv)
342376
in_channel = size(l.weight, ndims(l.weight))
343377
out_channel = size(l.weight, ndims(l.weight)-1)
344378
print(io, "GATConv(", in_channel, "=>", out_channel)
379+
print(io, ", heads=", l.heads)
380+
print(io, ", concat=", l.concat)
345381
print(io, ", LeakyReLU(λ=", l.negative_slope)
346382
print(io, "))")
347383
end
@@ -358,6 +394,18 @@ Gated graph convolution layer.
358394
- `num_layers`: The number of gated recurrent unit.
359395
- `aggr`: An aggregate function applied to the result of message function. `+`, `-`,
360396
`*`, `/`, `max`, `min` and `mean` are available.
397+
398+
# Examples
399+
400+
```jldoctest
401+
julia> GatedGraphConv(256, 4)
402+
GatedGraphConv((256 => 256)^4, aggr=+)
403+
404+
julia> GatedGraphConv(256, 4, aggr=*)
405+
GatedGraphConv((256 => 256)^4, aggr=*)
406+
```
407+
408+
See also [`WithGraph`](@ref) for training layer with static graph.
361409
"""
362410
struct GatedGraphConv{A<:AbstractArray{<:Number,3},R,O} <: MessagePassing
363411
weight::A
@@ -373,6 +421,8 @@ function GatedGraphConv(out_ch::Int, num_layers::Int; aggr=+, init=glorot_unifor
373421
GatedGraphConv(w, gru, out_ch, num_layers, aggr)
374422
end
375423

424+
@deprecate GatedGraphConv(fg, args...; kwargs...) WithGraph(fg, GatedGraphConv(args...; kwargs...))
425+
376426
@functor GatedGraphConv
377427

378428
Flux.trainable(l::GatedGraphConv) = (l.weight, l.gru)
@@ -424,7 +474,20 @@ Edge convolutional layer.
424474
# Arguments
425475
426476
- `nn`: A neural network (e.g. a Dense layer or a MLP).
427-
- `aggr`: An aggregate function applied to the result of message function. `+`, `max` and `mean` are available.
477+
- `aggr`: An aggregate function applied to the result of message function.
478+
`+`, `max` and `mean` are available.
479+
480+
# Examples
481+
482+
```jldoctest
483+
julia> EdgeConv(Dense(1024, 256, relu))
484+
EdgeConv(Dense(1024, 256, relu), aggr=max)
485+
486+
julia> EdgeConv(Dense(1024, 256, relu), aggr=+)
487+
EdgeConv(Dense(1024, 256, relu), aggr=+)
488+
```
489+
490+
See also [`WithGraph`](@ref) for training layer with static graph.
428491
"""
429492
struct EdgeConv{N,O} <: MessagePassing
430493
nn::N
@@ -433,6 +496,8 @@ end
433496

434497
EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)
435498

499+
@deprecate EdgeConv(fg, args...; kwargs...) WithGraph(fg, EdgeConv(args...; kwargs...))
500+
436501
@functor EdgeConv
437502

438503
Flux.trainable(l::EdgeConv) = (l.nn,)
@@ -470,8 +535,17 @@ end
470535
- `nn`: A neural network/layer.
471536
- `eps`: Weighting factor.
472537
473-
The definition of this is as defined in the original paper,
474-
Xu et. al. (2018) https://arxiv.org/abs/1810.00826.
538+
# Examples
539+
540+
```jldoctest
541+
julia> GINConv(Dense(1024, 256, relu))
542+
GINConv(Dense(1024, 256, relu), ϵ=0.0)
543+
544+
julia> GINConv(Dense(1024, 256, relu), 1.f-6)
545+
GINConv(Dense(1024, 256, relu), ϵ=1.0e-6)
546+
```
547+
548+
See also [`WithGraph`](@ref) for training layer with static graph.
475549
"""
476550
struct GINConv{N,R<:Real} <: MessagePassing
477551
nn::N
@@ -480,6 +554,8 @@ end
480554

481555
GINConv(nn, eps=0f0) = GINConv(nn, eps)
482556

557+
@deprecate GINConv(fg, args...; kwargs...) WithGraph(fg, GINConv(args...; kwargs...))
558+
483559
@functor GINConv
484560

485561
Flux.trainable(g::GINConv) = (g.nn,)
@@ -502,19 +578,31 @@ function (l::GINConv)(el::NamedTuple, x::AbstractArray)
502578
return V
503579
end
504580

581+
function Base.show(io::IO, l::GINConv)
582+
print(io, "GINConv(", l.nn, ", ϵ=", l.eps, ")")
583+
end
584+
505585

506586
"""
507-
CGConv((node_dim, edge_dim), out, init, bias=true)
587+
CGConv((node_dim, edge_dim), init, bias=true)
508588
509589
Crystal Graph Convolutional network. Uses both node and edge features.
510590
511591
# Arguments
512592
513593
- `node_dim`: Dimensionality of the input node features. Also is necessarily the output dimensionality.
514594
- `edge_dim`: Dimensionality of the input edge features.
515-
- `out`: Dimensionality of the output features.
516595
- `init`: Initialization algorithm for each of the weight matrices
517596
- `bias`: Whether or not to learn an additive bias parameter.
597+
598+
# Examples
599+
600+
```jldoctest
601+
julia> CGConv((128, 32))
602+
CGConv(node dim=128, edge dim=32)
603+
```
604+
605+
See also [`WithGraph`](@ref) for training layer with static graph.
518606
"""
519607
struct CGConv{A<:AbstractMatrix,B} <: MessagePassing
520608
Wf::A
@@ -523,6 +611,8 @@ struct CGConv{A<:AbstractMatrix,B} <: MessagePassing
523611
bs::B
524612
end
525613

614+
@deprecate CGConv(fg, args...; kwargs...) WithGraph(fg, CGConv(args...; kwargs...))
615+
526616
@functor CGConv
527617

528618
Flux.trainable(l::CGConv) = (l.Wf, l.Ws, l.bf, l.bs)
@@ -560,3 +650,9 @@ function (l::CGConv)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
560650
_, V, _ = propagate(l, el, E, X, nothing, +, nothing, nothing)
561651
return V
562652
end
653+
654+
function Base.show(io::IO, l::CGConv)
655+
node_dim, d = size(l.Wf)
656+
edge_dim = d - 2*node_dim
657+
print(io, "CGConv(node dim=", node_dim, ", edge dim=", edge_dim, ")")
658+
end

src/layers/msgpass.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,9 @@ WithGraph(fg::AbstractFeaturedGraph, mp::MessagePassing) =
6767
WithGraph(to_namedtuple(fg), mp)
6868

6969
(wg::WithGraph{<:MessagePassing})(args...) = wg.layer(wg.graph, args...)
70+
71+
function Base.show(io::IO, l::WithGraph{<:MessagePassing})
72+
print(io, "WithGraph(Graph(#V=", l.graph.N)
73+
print(io, ", #E=", l.graph.E, "), ")
74+
print(io, l.layer, ")")
75+
end

0 commit comments

Comments
 (0)