Skip to content

Commit 63706e7

Browse files
authored
Merge pull request #282 from FluxML/gat
Memory pre-allocation fix for GAT example
2 parents b6b6717 + 07b8b56 commit 63706e7

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

examples/gat.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ function train(; kws...)
7979
# build model
8080
model = Chain(
8181
WithGraph(fg, GATConv(args.input_dim=>args.hidden_dim, heads=args.heads)),
82+
Dropout(0.6),
8283
WithGraph(fg, GATConv(args.hidden_dim*args.heads=>args.target_dim, heads=args.heads, concat=false)),
8384
) |> device
8485

src/operation.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,12 @@ function incidence_matrix(xs::AbstractVector{T}, N) where {T}
2929
end
3030

3131
function indexed_softmax(x::AbstractArray, xs, N; dims=1)
32-
# memory pre-allocation approach leads to loss fluctuation but not drop anyway
33-
# be aware of model loss while optimizing this code snippet
34-
as = map(1:N) do i
35-
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(x))
36-
NNlib.softmax(x[idx...]; dims)
32+
y = copy(x)
33+
for i in 1:N
34+
idx = ntuple(j -> (j == dims) ? (xs .== i) : Colon(), ndims(y))
35+
NNlib.softmax!(view(y, idx...); dims)
3736
end
38-
return cat(as...; dims)
37+
return y
3938
end
4039

4140
function ∇indexed_softmax(dy::AbstractArray{T}, y::AbstractArray{S}, xs, N; dims=1) where {T,S}

0 commit comments

Comments
 (0)