Skip to content

Commit 3635dba

Browse files
authored
Merge pull request #312 from FluxML/ignore
Replace Zygote.ignore as ChainRulesCore.ignore_derivatives
2 parents 1d069f5 + 21cf47c commit 3635dba

File tree

9 files changed

+44
-48
lines changed

9 files changed

+44
-48
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2525
Word2Vec = "c64b6f0f-98cd-51d1-af78-58ae84944834"
26-
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2726

2827
[compat]
2928
CUDA = "3"
@@ -40,7 +39,6 @@ Optimisers = "0.2"
4039
Reexport = "1.1"
4140
StatsBase = "0.33"
4241
Word2Vec = "0.5"
43-
Zygote = "0.6"
4442
julia = "1.6"
4543

4644
[extras]

src/GeometricFlux.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ using Flux: glorot_uniform, leakyrelu, GRUCell, @functor
1717
using Graphs
1818
using NNlib, NNlibCUDA
1919
using Optimisers
20-
using Zygote
2120

2221
import Word2Vec: word2vec, wordvectors, get_vector
2322

src/layers/conv.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ end
4545
# For variable graph
4646
function (l::GCNConv)(fg::AbstractFeaturedGraph)
4747
nf = node_feature(fg)
48-
= Zygote.ignore() do
48+
= ChainRulesCore.ignore_derivatives() do
4949
GraphSignals.normalized_adjacency_matrix(fg, eltype(nf); selfloop=true)
5050
end
5151
return ConcreteFeaturedGraph(fg, nf = l(Ã, nf))
@@ -127,7 +127,7 @@ function (l::ChebConv)(fg::AbstractFeaturedGraph)
127127
GraphSignals.check_num_nodes(fg, nf)
128128
@assert size(nf, 1) == size(l.weight, 2) "Input feature size must match input channel size."
129129

130-
= Zygote.ignore() do
130+
= ChainRulesCore.ignore_derivatives() do
131131
GraphSignals.scaled_laplacian(fg, eltype(nf))
132132
end
133133
return ConcreteFeaturedGraph(fg, nf = l(L̃, nf))
@@ -331,7 +331,7 @@ function (l::GATConv)(fg::AbstractFeaturedGraph)
331331
X = node_feature(fg)
332332
GraphSignals.check_num_nodes(fg, X)
333333
sg = graph(fg)
334-
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
334+
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
335335
el = to_namedtuple(sg)
336336
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
337337
return ConcreteFeaturedGraph(fg, nf=V)
@@ -459,7 +459,7 @@ function (l::GATv2Conv)(fg::AbstractFeaturedGraph)
459459
X = node_feature(fg)
460460
GraphSignals.check_num_nodes(fg, X)
461461
sg = graph(fg)
462-
@assert Zygote.ignore(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
462+
@assert ChainRulesCore.ignore_derivatives(() -> GraphSignals.has_all_self_loops(sg)) "a vertex must have self loop (receive a message from itself)."
463463
el = to_namedtuple(sg)
464464
_, V, _ = propagate(l, el, nothing, X, nothing, hcat, nothing, nothing)
465465
return ConcreteFeaturedGraph(fg, nf=V)
@@ -546,7 +546,7 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
546546
m, n = size(H)[1:2]
547547
@assert (m <= l.out_ch) "number of input features must less or equals to output features."
548548
if m < l.out_ch
549-
Hpad = Zygote.ignore() do
549+
Hpad = ChainRulesCore.ignore_derivatives() do
550550
fill!(similar(H, T, l.out_ch - m, n, size(H)[3:end]...), 0)
551551
end
552552
H = vcat(H, Hpad)

src/models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ function summarize(ve::VariationalGraphEncoder, X::AbstractArray)
134134
end
135135

136136
function sample::AbstractArray{T}, logσ::AbstractArray{T}) where {T<:Real}
137-
R = Zygote.ignore(() -> randn!(similar(logσ)))
137+
R = ChainRulesCore.ignore_derivatives(() -> randn!(similar(logσ)))
138138
return μ + exp.(logσ) .* R
139139
end
140140

test/cuda/conv.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
fg_ = gc(fg)
2525
@test size(node_feature(fg_)) == (out_channel, N)
2626

27-
g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
27+
g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
2828
@test length(g.grads) == 4
2929
end
3030

@@ -33,7 +33,7 @@
3333
Y = gc(X |> gpu)
3434
@test size(Y) == (out_channel, N)
3535

36-
g = Zygote.gradient(() -> sum(gc(X |> gpu)), Flux.params(gc))
36+
g = gradient(() -> sum(gc(X |> gpu)), Flux.params(gc))
3737
@test length(g.grads) == 3
3838
end
3939
end
@@ -53,7 +53,7 @@
5353
fg_ = cc(fg)
5454
@test size(node_feature(fg_)) == (out_channel, N)
5555

56-
g = Zygote.gradient(() -> sum(node_feature(cc(fg))), Flux.params(cc))
56+
g = gradient(() -> sum(node_feature(cc(fg))), Flux.params(cc))
5757
@test length(g.grads) == 4
5858
end
5959

@@ -62,7 +62,7 @@
6262
Y = cc(X |> gpu)
6363
@test size(Y) == (out_channel, N)
6464

65-
g = Zygote.gradient(() -> sum(cc(X |> gpu)), Flux.params(cc))
65+
g = gradient(() -> sum(cc(X |> gpu)), Flux.params(cc))
6666
@test length(g.grads) == 3
6767
end
6868
end
@@ -79,7 +79,7 @@
7979
fg_ = gc(fg)
8080
@test size(node_feature(fg_)) == (out_channel, N)
8181

82-
g = Zygote.gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
82+
g = gradient(() -> sum(node_feature(gc(fg))), Flux.params(gc))
8383
@test length(g.grads) == 5
8484
end
8585

@@ -89,7 +89,7 @@
8989
Y = gc(X |> gpu)
9090
@test size(Y) == (out_channel, N, batch_size)
9191

92-
g = Zygote.gradient(() -> sum(gc(X |> gpu)), Flux.params(gc))
92+
g = gradient(() -> sum(gc(X |> gpu)), Flux.params(gc))
9393
@test length(g.grads) == 4
9494
end
9595
end
@@ -112,7 +112,7 @@
112112
fg_ = gat(fg)
113113
@test size(node_feature(fg_)) == (out_channel * heads, N)
114114

115-
g = Zygote.gradient(() -> sum(node_feature(gat(fg))), Flux.params(gat))
115+
g = gradient(() -> sum(node_feature(gat(fg))), Flux.params(gat))
116116
@test length(g.grads) == 5
117117
end
118118

@@ -122,7 +122,7 @@
122122
Y = gat(X |> gpu)
123123
@test size(Y) == (out_channel * heads, N, batch_size)
124124

125-
g = Zygote.gradient(() -> sum(gat(X |> gpu)), Flux.params(gat))
125+
g = gradient(() -> sum(gat(X |> gpu)), Flux.params(gat))
126126
@test length(g.grads) == 4
127127
end
128128
end
@@ -138,7 +138,7 @@
138138
fg_ = ggc(fg)
139139
@test size(node_feature(fg_)) == (out_channel, N)
140140

141-
g = Zygote.gradient(() -> sum(node_feature(ggc(fg))), Flux.params(ggc))
141+
g = gradient(() -> sum(node_feature(ggc(fg))), Flux.params(ggc))
142142
@test length(g.grads) == 8
143143
end
144144

@@ -148,7 +148,7 @@
148148
@test_broken Y = ggc(X |> gpu)
149149
@test_broken size(Y) == (out_channel, N, batch_size)
150150

151-
@test_broken g = Zygote.gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
151+
@test_broken g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
152152
@test_broken length(g.grads) == 6
153153
end
154154
end
@@ -162,7 +162,7 @@
162162
fg_ = ec(fg)
163163
@test size(node_feature(fg_)) == (out_channel, N)
164164

165-
g = Zygote.gradient(() -> sum(node_feature(ec(fg))), Flux.params(ec))
165+
g = gradient(() -> sum(node_feature(ec(fg))), Flux.params(ec))
166166
@test length(g.grads) == 4
167167
end
168168

@@ -172,7 +172,7 @@
172172
Y = ec(X |> gpu)
173173
@test size(Y) == (out_channel, N, batch_size)
174174

175-
g = Zygote.gradient(() -> sum(ec(X |> gpu)), Flux.params(ec))
175+
g = gradient(() -> sum(ec(X |> gpu)), Flux.params(ec))
176176
@test length(g.grads) == 3
177177
end
178178
end

test/cuda/msgpass.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
fg_ = l(fg)
4545
@test size(node_feature(fg_)) == (out_channel, N)
4646

47-
g = Zygote.gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
47+
g = gradient(() -> sum(node_feature(l(fg))), Flux.params(l))
4848
@test length(g.grads) == 3
4949
end
5050

@@ -56,7 +56,7 @@
5656
Y = l(X |> gpu)
5757
@test size(Y) == (out_channel, N, batch_size)
5858

59-
g = Zygote.gradient(() -> sum(l(X |> gpu)), Flux.params(l))
59+
g = gradient(() -> sum(l(X |> gpu)), Flux.params(l))
6060
@test length(g.grads) == 2
6161
end
6262
end

0 commit comments

Comments
 (0)