Skip to content

Commit 5584bbe

Browse files
authored
Merge pull request #97 from axelparmentier/fix-fyl-gm
Fix missing kwargs in FenchelYoungLoss with generalized maximizer
2 parents 0493719 + 48aec0d commit 5584bbe

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

src/imitation/fenchel_young_loss.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ function fenchel_young_F_and_first_part_of_grad(
141141
η = θ .+ ε .* Z
142142
y = oracle(η; kwargs...)
143143
F = objective_value(oracle, η, y; kwargs...)
144-
return F, oracle.g(y)
144+
return F, oracle.g(y; kwargs...)
145145
end
146146

147147
function fenchel_young_F_and_first_part_of_grad(
@@ -165,5 +165,5 @@ function fenchel_young_F_and_first_part_of_grad(
165165
y = oracle(η; kwargs...)
166166
F = objective_value(oracle, η, y; kwargs...)
167167
y_scaled = y .* eZ
168-
return F, oracle.g(y_scaled)
168+
return F, oracle.g(y_scaled; kwargs...)
169169
end

test/InferOptTestUtils/src/maximizers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function max_pricing(θ::AbstractVector; instance::AbstractMatrix)
1212
return weights .>= 0
1313
end
1414

15-
g(y; kwargs...) = vec(sum(y; dims=2))
15+
g(y; instance, kwargs...) = vec(sum(y; dims=2))
1616
h(y; instance) = -sum(dij * yij for (dij, yij) in zip(instance, y))
1717

1818
identity_kw(x; kwargs...) = identity(x)

test/generalized_maximizer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
val = InferOpt.objective_value(generalized_maximizer, θ, y; instance)
2121

22-
@test val == θ' * g(y) + h(y; instance)
22+
@test val == θ' * g(y; instance) + h(y; instance)
2323
end
2424

2525
@testitem "Generalized maximizer - imit - MSE PerturbedAdditive" default_imports = false begin

0 commit comments

Comments
 (0)