Skip to content

Commit 23ba06b

Browse files
authored
Merge pull request #64 from yuehhua/develop
GCN example works again
2 parents 493b4bb + 2e93878 commit 23ba06b

File tree

4 files changed

+32
-15
lines changed

4 files changed

+32
-15
lines changed

examples/gcn.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GeometricFlux
22
using Flux
3-
using Flux: onehotbatch, onecold, crossentropy, throttle
3+
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
4+
using Flux: @epochs
45
using JLD2 # use v0.1.2
56
using Statistics: mean
67
using SparseArrays
@@ -31,7 +32,7 @@ model = Chain(GCNConv(adj_mat, num_features=>hidden, relu),
3132
softmax) |> gpu
3233

3334
## Loss
34-
loss(x, y) = crossentropy(model(x), y)
35+
loss(x, y) = logitcrossentropy(model(x), y)
3536
accuracy(x, y) = mean(onecold(model(x)) .== onecold(y))
3637

3738
## Training
@@ -40,6 +41,4 @@ train_data = [(train_X, train_y)]
4041
opt = ADAM(0.01)
4142
evalcb() = @show(accuracy(train_X, train_y))
4243

43-
for i = 1:epochs
44-
Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
45-
end
44+
@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))

src/layers/conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ end
6262
function Base.show(io::IO, l::GCNConv)
6363
in_channel = size(l.weight, ndims(l.weight))
6464
out_channel = size(l.weight, ndims(l.weight)-1)
65-
print(io, "GCNConv(G(V=", nv(l.graph))
65+
print(io, "GCNConv(G(V=", nv(l.fg))
6666
print(io, ", E), ", in_channel, "=>", out_channel)
6767
print(io, "GCNConv(", in_channel, "=>", out_channel)
6868
l.σ == identity || print(io, ", ", l.σ)

src/operations/linalg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ Normalized Laplacian matrix of graph `g`.
120120
"""
121121
function normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj); selfloop::Bool=false)
122122
selfloop && (adj += I)
123+
_normalized_laplacian(adj, T)
124+
end
125+
126+
# nograd can only used without keyword arguments
127+
Zygote.@nograd function _normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj))
123128
inv_sqrtD = inv_sqrt_degree_matrix(adj, T, dir=:both)
124129
T.(I - inv_sqrtD * adj * inv_sqrtD)
125130
end

test/operations/linalg.jl

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,27 @@
1616
-.5 1. -.5 0.;
1717
0. -.5 1. -.5;
1818
-.5 0. -.5 1.]
19-
scaled_lap = [0 -0.5 0 -0.5;
20-
-0.5 0 -0.5 -0;
21-
0 -0.5 0 -0.5;
19+
scaled_lap = [0 -0.5 0 -0.5;
20+
-0.5 0 -0.5 -0;
21+
0 -0.5 0 -0.5;
2222
-0.5 0 -0.5 0]
2323

24-
for T in [Int8, Float64]
24+
for T in [Int8, Int16, Int32, Int64, Float16, Float32, Float64]
2525
@test degree_matrix(adj, T, dir=:out) == T.(deg)
2626
@test degree_matrix(adj, T, dir=:out) == degree_matrix(adj, T, dir=:in)
2727
@test degree_matrix(adj, T, dir=:out) == degree_matrix(adj, T, dir=:both)
28+
@test eltype(degree_matrix(adj, T, dir=:out)) == T
29+
2830
@test laplacian_matrix(adj, T) == T.(lap)
31+
@test eltype(laplacian_matrix(adj, T)) == T
32+
end
33+
for T in [Float16, Float32, Float64]
34+
@test normalized_laplacian(adj, T) T.(norm_lap)
35+
@test eltype(normalized_laplacian(adj, T)) == T
36+
37+
@test scaled_laplacian(adj, T) T.(scaled_lap)
38+
@test eltype(scaled_laplacian(adj, T)) == T
2939
end
30-
@test normalized_laplacian(adj, Float64) norm_lap
31-
@test eltype(normalized_laplacian(adj, Float32)) == Float32
32-
@test scaled_laplacian(adj, Float64) scaled_lap
33-
@test eltype(scaled_laplacian(adj, Float32)) == Float32
3440
@test neighbors(adj) == [[2,4], [1,3], [2,4], [1,3]]
3541
end
3642

@@ -52,14 +58,21 @@
5258
0 0 7 0;
5359
0 0 0 4]
5460

55-
for T in [Int8, Float64]
61+
for T in [Int8, Int16, Int32, Int64, Float16, Float32, Float64]
5662
@test degree_matrix(adj, T, dir=:out) == T.(deg_out)
5763
@test degree_matrix(adj, T, dir=:in) == T.(deg_in)
5864
@test degree_matrix(adj, T, dir=:both) == T.(deg_both)
65+
@test eltype(degree_matrix(adj, T, dir=:out)) == T
66+
@test eltype(degree_matrix(adj, T, dir=:in)) == T
67+
@test eltype(degree_matrix(adj, T, dir=:both)) == T
5968
@test_throws DomainError degree_matrix(adj, dir=:other)
69+
6070
@test laplacian_matrix(adj, T, dir=:out) == T.(deg_out .- adj)
6171
@test laplacian_matrix(adj, T, dir=:in) == T.(deg_in .- adj)
6272
@test laplacian_matrix(adj, T, dir=:both) == T.(deg_both .- adj)
73+
@test eltype(laplacian_matrix(adj, T, dir=:out)) == T
74+
@test eltype(laplacian_matrix(adj, T, dir=:in)) == T
75+
@test eltype(laplacian_matrix(adj, T, dir=:both)) == T
6376
end
6477
@test neighbors(adj) == [[2,3,4], [1,3], [1,2,4], [1,3]]
6578
end

0 commit comments

Comments
 (0)