Skip to content

Commit 2e93878

Browse files
committed
GCN example works
1 parent 6ba5cbf commit 2e93878

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

examples/gcn.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using GeometricFlux
22
using Flux
33
using Flux: onehotbatch, onecold, logitcrossentropy, throttle
4+
using Flux: @epochs
45
using JLD2 # use v0.1.2
56
using Statistics: mean
67
using SparseArrays
@@ -40,6 +41,4 @@ train_data = [(train_X, train_y)]
4041
opt = ADAM(0.01)
4142
evalcb() = @show(accuracy(train_X, train_y))
4243

43-
for i = 1:epochs
44-
Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))
45-
end
44+
@epochs epochs Flux.train!(loss, ps, train_data, opt, cb=throttle(evalcb, 10))

src/operations/linalg.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ Normalized Laplacian matrix of graph `g`.
120120
"""
121121
function normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj); selfloop::Bool=false)
122122
selfloop && (adj += I)
123+
_normalized_laplacian(adj, T)
124+
end
125+
126+
# nograd can only used without keyword arguments
127+
Zygote.@nograd function _normalized_laplacian(adj::AbstractMatrix, T::DataType=eltype(adj))
123128
inv_sqrtD = inv_sqrt_degree_matrix(adj, T, dir=:both)
124129
T.(I - inv_sqrtD * adj * inv_sqrtD)
125130
end

0 commit comments

Comments
 (0)