Skip to content

Commit 343c95e

Browse files
authored
Merge pull request #116 from JuliaDecisionFocusedLearning/frankwolfe-tests
Replace Agnostic with Adaptive in FrankWolfe tests
2 parents 59ab4fe + 9e04343 commit 343c95e

File tree

4 files changed

+13
-12
lines changed

4 files changed

+13
-12
lines changed

examples/tutorial.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,15 @@ Thanks to this smoothing, we can now train our model with a standard gradient op
113113

114114
encoder = deepcopy(initial_encoder)
115115
opt = Flux.Adam();
116+
opt_state = Flux.setup(opt, encoder)
116117
losses = Float64[]
117118
for epoch in 1:100
118119
l = 0.0
119120
for (x, y) in zip(X_train, Y_train)
120-
grads = gradient(Flux.params(encoder)) do
121-
l += loss(encoder(x), y; directions=queen_directions)
121+
grads = Flux.gradient(encoder) do m
122+
l += loss(m(x), y; directions=queen_directions)
122123
end
123-
Flux.update!(opt, Flux.params(encoder), grads)
124+
Flux.update!(opt_state, encoder, grads[1])
124125
end
125126
push!(losses, l)
126127
end;

test/argmax.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ end
116116
one_hot_argmax;
117117
Ω=half_square_norm,
118118
Ω_grad=identity_kw,
119-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
119+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
120120
),
121121
loss=mse_kw,
122122
error_function=hamming_distance,
@@ -198,7 +198,7 @@ end
198198
one_hot_argmax;
199199
Ω=half_square_norm,
200200
Ω_grad=identity_kw,
201-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
201+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
202202
),
203203
),
204204
error_function=hamming_distance,
@@ -263,7 +263,7 @@ end
263263
one_hot_argmax;
264264
Ω=half_square_norm,
265265
Ω_grad=identity_kw,
266-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
266+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
267267
),
268268
cost,
269269
),

test/paths.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ end
101101
shortest_path_maximizer;
102102
Ω=half_square_norm,
103103
Ω_grad=identity_kw,
104-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
104+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
105105
),
106106
loss=mse_kw,
107107
error_function=mse_kw,
@@ -177,7 +177,7 @@ end
177177
shortest_path_maximizer;
178178
Ω=half_square_norm,
179179
Ω_grad=identity_kw,
180-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
180+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
181181
),
182182
),
183183
error_function=mse_kw,
@@ -247,7 +247,7 @@ end
247247
shortest_path_maximizer;
248248
Ω=half_square_norm,
249249
Ω_grad=identity_kw,
250-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
250+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
251251
),
252252
cost,
253253
),

test/ranking.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ end
101101
ranking;
102102
Ω=half_square_norm,
103103
Ω_grad=identity_kw,
104-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
104+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
105105
),
106106
loss=mse_kw,
107107
error_function=hamming_distance,
@@ -170,7 +170,7 @@ end
170170
ranking;
171171
Ω=half_square_norm,
172172
Ω_grad=identity_kw,
173-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
173+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
174174
),
175175
),
176176
error_function=hamming_distance,
@@ -303,7 +303,7 @@ end
303303
ranking;
304304
Ω=half_square_norm,
305305
Ω_grad=identity_kw,
306-
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Agnostic()),
306+
frank_wolfe_kwargs=(; max_iteration=10, line_search=FrankWolfe.Adaptive()),
307307
),
308308
cost,
309309
),

0 commit comments

Comments
 (0)