Skip to content

Commit 9e04343

Browse files
committed
fix tutorial
1 parent b90258c commit 9e04343

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

examples/tutorial.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op
113113

114114
encoder = deepcopy(initial_encoder)
115115
opt = Flux.Adam();
116+
opt_state = Flux.setup(opt, encoder)
116117
losses = Float64[]
117118
for epoch in 1:100
118119
l = 0.0
119120
for (x, y) in zip(X_train, Y_train)
120-
grads = gradient(Flux.params(encoder)) do
121-
l += loss(encoder(x), y; directions=queen_directions)
121+
grads = Flux.gradient(encoder) do m
122+
l += loss(m(x), y; directions=queen_directions)
122123
end
123-
Flux.update!(opt, Flux.params(encoder), grads)
124+
Flux.update!(opt_state, encoder, grads[1])
124125
end
125126
push!(losses, l)
126127
end;

0 commit comments

Comments
 (0)