Skip to content

Commit 4d8c3ff

Browse files
authored
Merge pull request #120 from JuliaDecisionFocusedLearning/perturbed-overhaul
General overhaul
2 parents 9419111 + cd594ee commit 4d8c3ff

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+1012
-1225
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
.DS_Store
12
*.jl.*.cov
23
*.jl.cov
34
*.jl.mem

CITATION.bib

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ @misc{InferOpt.jl
22
author = {Guillaume Dalle, Léo Baty, Louis Bouvier and Axel Parmentier},
33
title = {InferOpt.jl},
44
url = {https://github.com/axelparmentier/InferOpt.jl},
5-
version = {v0.6.1},
6-
year = {2023},
7-
month = {9}
5+
version = {v0.7.0},
6+
year = {2025},
7+
month = {4}
88
}

Project.toml

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,47 @@
11
name = "InferOpt"
22
uuid = "4846b161-c94e-4150-8dac-c7ae193c601f"
33
authors = ["Guillaume Dalle", "Léo Baty", "Louis Bouvier", "Axel Parmentier"]
4-
version = "0.6.1"
4+
version = "0.7.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d"
9+
DifferentiableExpectations = "fc55d66b-b2a8-4ccc-9d64-c0c2166ceb36"
910
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
11+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
12+
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
14+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
1015
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1116
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1217
RequiredInterfaces = "97f35ef4-7bc5-4ec1-a41a-dcc69c7308c6"
1318
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1419
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1520
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
16-
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1721

1822
[weakdeps]
1923
DifferentiableFrankWolfe = "b383313e-5450-4164-a800-befbd27b574d"
24+
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
25+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
2026

2127
[extensions]
22-
InferOptFrankWolfeExt = "DifferentiableFrankWolfe"
28+
InferOptFrankWolfeExt = ["DifferentiableFrankWolfe", "FrankWolfe", "ImplicitDifferentiation"]
2329

2430
[compat]
2531
ChainRulesCore = "1"
2632
DensityInterface = "0.4.0"
27-
DifferentiableFrankWolfe = "0.2"
33+
DifferentiableExpectations = "0.2"
34+
DifferentiableFrankWolfe = "0.3"
35+
Distributions = "0.25"
36+
DocStringExtensions = "0.9"
37+
FrankWolfe = "0.3"
38+
ImplicitDifferentiation = "0.6"
2839
LinearAlgebra = "1"
2940
Random = "1"
3041
RequiredInterfaces = "0.1.3"
3142
Statistics = "1"
3243
StatsBase = "0.33, 0.34"
3344
StatsFuns = "1.3"
34-
ThreadsX = "0.1.11"
3545
julia = "1.10"
3646

3747
[extras]
@@ -45,6 +55,7 @@ FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
4555
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
4656
GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb"
4757
HiGHS = "87dc4568-4c63-4d18-b0c0-bb2238e4078b"
58+
ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207"
4859
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4960
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
5061
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
@@ -61,4 +72,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
6172
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
6273

6374
[targets]
64-
test = ["Aqua", "DifferentiableFrankWolfe", "Distributions", "Documenter", "FiniteDifferences", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "HiGHS", "JET", "JuliaFormatter", "JuMP", "LinearAlgebra", "Literate", "Pkg", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "TestItemRunner", "UnicodePlots", "Zygote"]
75+
test = ["Aqua", "DifferentiableFrankWolfe", "Distributions", "Documenter", "FiniteDifferences", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "HiGHS", "ImplicitDifferentiation", "JET", "JuliaFormatter", "JuMP", "LinearAlgebra", "Literate", "Pkg", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "TestItemRunner", "UnicodePlots", "Zygote"]

docs/make.jl

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,20 @@ DocMeta.setdocmeta!(InferOpt, :DocTestSetup, :(using InferOpt); recursive=true)
66

77
# Copy README.md into docs/src/index.md (overwriting)
88

9-
cp(
10-
joinpath(dirname(@__DIR__), "README.md"),
11-
joinpath(@__DIR__, "src", "index.md");
12-
force=true,
13-
)
9+
open(joinpath(@__DIR__, "src", "index.md"), "w") do io
10+
println(
11+
io,
12+
"""
13+
```@meta
14+
EditURL = "https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/main/README.md"
15+
```
16+
""",
17+
)
18+
# Write the contents out below the meta bloc
19+
for line in eachline(joinpath(dirname(@__DIR__), "README.md"))
20+
println(io, line)
21+
end
22+
end
1423

1524
# Parse test/tutorial.jl into docs/src/tutorial.md (overwriting)
1625

@@ -21,8 +30,14 @@ Literate.markdown(tuto_jl_file, tuto_md_dir; documenter=true, execute=false)
2130
makedocs(;
2231
modules=[InferOpt],
2332
authors="Guillaume Dalle, Léo Baty, Louis Bouvier, Axel Parmentier",
33+
repo="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl/blob/{commit}{path}#{line}",
2434
sitename="InferOpt.jl",
25-
format=Documenter.HTML(),
35+
format=Documenter.HTML(;
36+
prettyurls=get(ENV, "CI", "false") == "true",
37+
canonical="https://juliadecisionfocusedlearning.github.io/InferOpt.jl",
38+
assets=String[],
39+
repolink="https://github.com/JuliaDecisionFocusedLearning/InferOpt.jl",
40+
),
2641
pages=[
2742
"Home" => "index.md",
2843
"Background" => "background.md",

ext/InferOptFrankWolfeExt.jl

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
11
module InferOptFrankWolfeExt
22

3+
using DifferentiableExpectations:
4+
DifferentiableExpectations, FixedAtomsProbabilityDistribution
35
using DifferentiableFrankWolfe: DifferentiableFrankWolfe, DiffFW
4-
using DifferentiableFrankWolfe: LinearMinimizationOracle # from FrankWolfe
5-
using DifferentiableFrankWolfe: IterativeLinearSolver # from ImplicitDifferentiation
6-
using InferOpt: InferOpt, RegularizedFrankWolfe, FixedAtomsProbabilityDistribution
7-
using InferOpt: compute_expectation, compute_probability_distribution
6+
using FrankWolfe: LinearMinimizationOracle
7+
using ImplicitDifferentiation: KrylovLinearSolver
8+
using InferOpt: InferOpt, RegularizedFrankWolfe
89
using LinearAlgebra: dot
910

11+
"""
12+
RegularizedFrankWolfe(linear_maximizer; Ω, Ω_grad, frank_wolfe_kwargs=(;), implicit_kwargs=(; linear_solver=KrylovLinearSolver(; verbose=false)))
13+
14+
Construct a `RegularizedFrankWolfe` struct with a linear maximizer and the necessary components for the Frank-Wolfe algorithm.
15+
Set `implicit_kwargs` to `(; linear_solver=KrylovLinearSolver(; verbose=true))` if you want to see the solver potential warnings.
16+
"""
17+
function RegularizedFrankWolfe(
18+
linear_maximizer;
19+
Ω,
20+
Ω_grad,
21+
frank_wolfe_kwargs=NamedTuple(),
22+
implicit_kwargs=(; linear_solver=KrylovLinearSolver(; verbose=false)),
23+
)
24+
return RegularizedFrankWolfe(
25+
linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs
26+
)
27+
end
28+
1029
"""
1130
LinearMaximizationOracleWithKwargs{F,K}
1231
Wraps a linear maximizer as a `FrankWolfe.LinearMinimizationOracle` with a sign switch and predefined keyword arguments.
@@ -40,14 +59,17 @@ Keyword arguments are passed to the underlying linear maximizer.
4059
function InferOpt.compute_probability_distribution(
4160
regularized::RegularizedFrankWolfe, θ::AbstractArray; kwargs...
4261
)
43-
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
62+
shape = size(θ)
63+
(; linear_maximizer, Ω, Ω_grad, frank_wolfe_kwargs, implicit_kwargs) = regularized
4464
f(y, θ) = Ω(y) - dot(θ, y)
4565
f_grad1(y, θ) = Ω_grad(y) - θ
46-
lmo = LinearMaximizationOracleWithKwargs(linear_maximizer, kwargs)
47-
implicit_kwargs = (; linear_solver=IterativeLinearSolver(; accept_inconsistent=true))
66+
maximizer(θ; shape, kwargs...) = vec(linear_maximizer(reshape(θ, shape); kwargs...))
67+
lmo = LinearMaximizationOracleWithKwargs(maximizer, (; shape, kwargs...))
4868
dfw = DiffFW(f, f_grad1, lmo; implicit_kwargs)
49-
weights, atoms = dfw.implicit(θ; frank_wolfe_kwargs=frank_wolfe_kwargs)
50-
probadist = FixedAtomsProbabilityDistribution(atoms, weights)
69+
weights, atoms = dfw.implicit(vec(θ); frank_wolfe_kwargs=frank_wolfe_kwargs)
70+
probadist = FixedAtomsProbabilityDistribution(
71+
map(atom -> reshape(atom, shape), atoms), weights
72+
)
5173
return probadist
5274
end
5375

src/InferOpt.jl

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,71 +10,80 @@ module InferOpt
1010
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, Tangent, ZeroTangent
1111
using ChainRulesCore: rrule, rrule_via_ad, unthunk
1212
using DensityInterface: logdensityof
13+
using DifferentiableExpectations:
14+
DifferentiableExpectations, Reinforce, empirical_predistribution, empirical_distribution
15+
using Distributions:
16+
Distributions,
17+
ContinuousUnivariateDistribution,
18+
LogNormal,
19+
Normal,
20+
product_distribution,
21+
logpdf
22+
using DocStringExtensions: TYPEDEF, TYPEDFIELDS, TYPEDSIGNATURES
1323
using LinearAlgebra: dot
14-
using Random: AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
24+
using Random: Random, AbstractRNG, GLOBAL_RNG, MersenneTwister, rand, seed!
1525
using Statistics: mean
1626
using StatsBase: StatsBase, sample
1727
using StatsFuns: logaddexp, softmax
18-
using ThreadsX: ThreadsX
1928
using RequiredInterfaces
2029

2130
include("interface.jl")
2231

32+
include("utils/utils.jl")
2333
include("utils/some_functions.jl")
24-
include("utils/probability_distribution.jl")
2534
include("utils/pushforward.jl")
26-
include("utils/generalized_maximizer.jl")
35+
include("utils/linear_maximizer.jl")
2736
include("utils/isotonic_regression/isotonic_l2.jl")
2837
include("utils/isotonic_regression/isotonic_kl.jl")
2938
include("utils/isotonic_regression/projection.jl")
3039

31-
include("simple/interpolation.jl")
32-
include("simple/identity.jl")
40+
# Layers
41+
include("layers/simple/interpolation.jl")
42+
include("layers/simple/identity.jl")
3343

34-
include("regularized/abstract_regularized.jl")
35-
include("regularized/soft_argmax.jl")
36-
include("regularized/sparse_argmax.jl")
37-
include("regularized/soft_rank.jl")
38-
include("regularized/regularized_frank_wolfe.jl")
44+
include("layers/perturbed/utils.jl")
45+
include("layers/perturbed/perturbation.jl")
46+
include("layers/perturbed/perturbed.jl")
3947

40-
include("perturbed/abstract_perturbed.jl")
41-
include("perturbed/additive.jl")
42-
include("perturbed/multiplicative.jl")
43-
include("perturbed/perturbed_oracle.jl")
44-
45-
include("imitation/spoplus_loss.jl")
46-
include("imitation/ssvm_loss.jl")
47-
include("imitation/fenchel_young_loss.jl")
48-
include("imitation/imitation_loss.jl")
49-
include("imitation/zero_one_loss.jl")
48+
include("layers/regularized/abstract_regularized.jl")
49+
include("layers/regularized/soft_argmax.jl")
50+
include("layers/regularized/sparse_argmax.jl")
51+
include("layers/regularized/soft_rank.jl")
52+
include("layers/regularized/regularized_frank_wolfe.jl")
5053

5154
if !isdefined(Base, :get_extension)
5255
include("../ext/InferOptFrankWolfeExt.jl")
5356
end
5457

58+
# Losses
59+
include("losses/fenchel_young_loss.jl")
60+
include("losses/spoplus_loss.jl")
61+
include("losses/ssvm_loss.jl")
62+
include("losses/zero_one_loss.jl")
63+
include("losses/imitation_loss.jl")
64+
65+
export compute_probability_distribution
66+
5567
export half_square_norm
5668
export shannon_entropy, negative_shannon_entropy
5769
export one_hot_argmax, ranking
58-
export GeneralizedMaximizer, objective_value
70+
export LinearMaximizer, apply_g, objective_value
5971

60-
export FixedAtomsProbabilityDistribution
61-
export compute_expectation
62-
export compute_probability_distribution
6372
export Pushforward
6473

6574
export IdentityRelaxation
6675
export Interpolation
6776

68-
export AbstractRegularized, AbstractRegularizedGeneralizedMaximizer
77+
export AbstractRegularized
6978
export SoftArgmax, soft_argmax
7079
export SparseArgmax, sparse_argmax
7180
export SoftRank, soft_rank, soft_rank_l2, soft_rank_kl
7281
export SoftSort, soft_sort, soft_sort_l2, soft_sort_kl
7382
export RegularizedFrankWolfe
7483

84+
export PerturbedOracle
7585
export PerturbedAdditive
7686
export PerturbedMultiplicative
77-
export PerturbedOracle
7887

7988
export FenchelYoungLoss
8089
export StructuredSVMLoss

0 commit comments

Comments
 (0)