Skip to content

Commit 6858f37

Browse files
authored
Merge pull request #141 from yuehhua/gat
Correct type of negative slope in leaky relu again
2 parents 5765e50 + 841c50f commit 6858f37

File tree

3 files changed

+14
-14
lines changed

3 files changed

+14
-14
lines changed

src/graph/simplegraphs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ function GraphConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}, σ=ide
2828
end
2929

3030

31-
function GATConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}; heads=1,
32-
concat::Bool=true, negative_slope=0.2, init=glorot_uniform,
33-
T::DataType=Float32, bias::Bool=true)
31+
function GATConv(g::AbstractSimpleGraph, ch::Pair{<:Integer,<:Integer}; T::DataType=Float32,
32+
heads=1, concat::Bool=true, negative_slope=T(0.2),
33+
init=glorot_uniform, bias::Bool=true)
3434
w = T.(init(ch[2]*heads, ch[1]))
3535
b = bias ? T.(init(ch[2]*heads)) : zeros(T, ch[2]*heads)
3636
a = T.(init(2*ch[2], heads))

src/graph/weightedgraphs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ function GraphConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}
2929
end
3030

3131

32-
function GATConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}; heads=1,
33-
concat::Bool=true, negative_slope=0.2, init=glorot_uniform,
34-
T::DataType=Float32, bias::Bool=true)
32+
function GATConv(g::AbstractSimpleWeightedGraph, ch::Pair{<:Integer,<:Integer}; T::DataType=Float32,
33+
heads=1, concat::Bool=true, negative_slope=T(0.2),
34+
init=glorot_uniform, bias::Bool=true)
3535
w = T.(init(ch[2]*heads, ch[1]))
3636
b = bias ? T.(init(ch[2]*heads)) : zeros(T, ch[2]*heads)
3737
a = T.(init(2*ch[2], heads))

src/layers/conv.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,30 +255,30 @@ the layer instead of only the features.
255255
- `bias::Bool=true`: keyword argument, whether to learn the additive bias.
256256
- `negative_slope::Real=0.2`: keyword argument, the parameter of LeakyReLU.
257257
"""
258-
struct GATConv{V<:AbstractFeaturedGraph, T <: Real} <: MessagePassing
258+
struct GATConv{V<:AbstractFeaturedGraph,T<:Real} <: MessagePassing
259259
fg::V
260260
weight::AbstractMatrix{T}
261261
bias::AbstractVector{T}
262262
a::AbstractMatrix{T}
263-
negative_slope::Real
263+
negative_slope::T
264264
channel::Pair{<:Integer,<:Integer}
265265
heads::Integer
266266
concat::Bool
267267
end
268268

269-
function GATConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}; heads::Integer=1,
270-
concat::Bool=true, negative_slope::Real=0.2, init=glorot_uniform,
271-
bias::Bool=true, T::DataType=Float32)
269+
function GATConv(adj::AbstractMatrix, ch::Pair{<:Integer,<:Integer}; T::DataType=Float32,
270+
heads::Integer=1, concat::Bool=true, negative_slope::Real=T(0.2),
271+
init=glorot_uniform, bias::Bool=true)
272272
w = T.(init(ch[2]*heads, ch[1]))
273273
b = bias ? T.(init(ch[2]*heads)) : zeros(T, ch[2]*heads)
274274
a = T.(init(2*ch[2], heads))
275275
fg = FeaturedGraph(adjacency_list(adj))
276276
GATConv(fg, w, b, a, negative_slope, ch, heads, concat)
277277
end
278278

279-
function GATConv(ch::Pair{<:Integer,<:Integer}; heads::Integer=1,
280-
concat::Bool=true, negative_slope::Real=0.2, init=glorot_uniform,
281-
bias::Bool=true, T::DataType=Float32)
279+
function GATConv(ch::Pair{<:Integer,<:Integer}; T::DataType=Float32,
280+
heads::Integer=1, concat::Bool=true, negative_slope::Real=T(0.2),
281+
init=glorot_uniform, bias::Bool=true)
282282
w = T.(init(ch[2]*heads, ch[1]))
283283
b = bias ? T.(init(ch[2]*heads)) : zeros(T, ch[2]*heads)
284284
a = T.(init(2*ch[2], heads))

0 commit comments

Comments
 (0)