@@ -27,16 +27,16 @@ struct GCNConv{T,F,S<:AbstractFeaturedGraph}
27
27
end
28
28
29
29
function GCNConv (ch:: Pair{<:Integer,<:Integer} , σ = identity;
30
- init= glorot_uniform, T:: DataType = Float32, bias:: Bool = true , cache :: Bool = true )
30
+ init= glorot_uniform, T:: DataType = Float32, bias:: Bool = true )
31
31
b = bias ? T .(init (ch[2 ])) : zeros (T, ch[2 ])
32
- fg = cache ? FeaturedGraph () : NullGraph ()
32
+ fg = NullGraph ()
33
33
GCNConv (T .(init (ch[2 ], ch[1 ])), b, σ, fg)
34
34
end
35
35
36
36
function GCNConv (adj:: AbstractMatrix , ch:: Pair{<:Integer,<:Integer} , σ = identity;
37
- init= glorot_uniform, T:: DataType = Float32, bias:: Bool = true , cache :: Bool = true )
37
+ init= glorot_uniform, T:: DataType = Float32, bias:: Bool = true )
38
38
b = bias ? T .(init (ch[2 ])) : zeros (T, ch[2 ])
39
- fg = cache ? FeaturedGraph (adj) : NullGraph ( )
39
+ fg = FeaturedGraph (adj)
40
40
GCNConv (T .(init (ch[2 ], ch[1 ])), b, σ, fg)
41
41
end
42
42
56
56
57
57
function (g:: GCNConv )(fg:: FeaturedGraph )
58
58
X = node_feature (fg)
59
- A = adjacency_matrix (fg)
60
- g. fg isa NullGraph || (g. fg. graph = A)
59
+ A = adjacency_matrix (fg) # TODO : choose graph from g or fg
60
+ Zygote. ignore () do
61
+ g. fg isa NullGraph || (g. fg. graph = A)
62
+ end
61
63
X_ = g (A, X)
62
64
FeaturedGraph (A, X_)
63
65
end
@@ -97,16 +99,16 @@ struct ChebConv{T,S<:AbstractFeaturedGraph}
97
99
end
98
100
99
101
function ChebConv (adj:: AbstractMatrix , ch:: Pair{<:Integer,<:Integer} , k:: Integer ;
100
- init = glorot_uniform, T:: DataType = Float32, bias:: Bool = true , cache :: Bool = true )
102
+ init = glorot_uniform, T:: DataType = Float32, bias:: Bool = true )
101
103
b = bias ? init (ch[2 ]) : zeros (T, ch[2 ])
102
- fg = cache ? FeaturedGraph (adj) : NullGraph ( )
104
+ fg = FeaturedGraph (adj)
103
105
ChebConv (init (ch[2 ], ch[1 ], k), b, fg, k, ch[1 ], ch[2 ])
104
106
end
105
107
106
108
function ChebConv (ch:: Pair{<:Integer,<:Integer} , k:: Integer ;
107
- init = glorot_uniform, T:: DataType = Float32, bias:: Bool = true , cache :: Bool = true )
109
+ init = glorot_uniform, T:: DataType = Float32, bias:: Bool = true )
108
110
b = bias ? init (ch[2 ]) : zeros (T, ch[2 ])
109
- fg = cache ? FeaturedGraph () : NullGraph ()
111
+ fg = NullGraph ()
110
112
ChebConv (init (ch[2 ], ch[1 ], k), b, fg, k, ch[1 ], ch[2 ])
111
113
end
112
114
138
140
function (c:: ChebConv )(fg:: FeaturedGraph )
139
141
@assert has_graph (fg) " A given FeaturedGraph must contain a graph."
140
142
g = graph (fg)
141
- c. fg isa NullGraph || (c. fg. graph = g)
143
+ Zygote. ignore () do
144
+ c. fg isa NullGraph || (c. fg. graph = g)
145
+ end
142
146
X = node_feature (fg)
143
147
L̃ = scaled_laplacian (adjacency_matrix (fg))
144
148
L̃ = convert (typeof (X), L̃)
0 commit comments