Skip to content

Commit dc71ab7

Browse files
committed
fix cuda
1 parent 1c41b77 commit dc71ab7

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

src/layers/graph_conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ function (l::GatedGraphConv)(el::NamedTuple, H::AbstractArray{T}) where {T<:Real
556556
H = vcat(H, Hpad)
557557
end
558558
for i = 1:l.num_layers
559-
M = _matmul(selectdim(l.weight, 3, i), H)
559+
M = _matmul(l.weight[:, :, i], H)
560560
_, M = propagate(l, el, nothing, M, nothing, l.aggr, nothing, nothing)
561561
H = apply_gru(l.gru, H, M)
562562
end

test/cuda/graph_conv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@
149149
@test size(Y) == (out_channel, N, batch_size)
150150

151151
g = gradient(() -> sum(ggc(X |> gpu)), Flux.params(ggc))
152-
@test length(g.grads) == 6
152+
@test length(g.grads) == 7
153153
end
154154
end
155155

0 commit comments

Comments
 (0)