Skip to content

Commit 80d5fe1

Browse files
authored
Merge pull request #31 from axelparmentier/clean-parallel-perturbed
Parallelize perturbed with ThreadsX
2 parents 9629f21 + 8103978 commit 80d5fe1

File tree

7 files changed

+130
-16
lines changed

7 files changed

+130
-16
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1515
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1717
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
18+
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
1819

1920
[compat]
2021
ChainRulesCore = "1"
@@ -23,6 +24,7 @@ Krylov = "0.8"
2324
LinearOperators = "2.3"
2425
SimpleTraits = "0.9"
2526
StatsBase = "0.33"
27+
ThreadsX = "0.1.11"
2628
julia = "1.7"
2729

2830
[extras]

src/InferOpt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using SparseArrays
1414
using Statistics
1515
using StatsBase: StatsBase, sample
1616
using Test
17+
using ThreadsX
1718

1819
include("utils/probability_distribution.jl")
1920
include("utils/pushforward.jl")

src/fenchel_young/perturbed.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,26 @@
1-
function fenchel_young_F_and_first_part_of_grad(
2-
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
1+
function compute_F_and_y_samples(
2+
perturbed::AbstractPerturbed{false}, θ::AbstractArray{<:Real}, Z_samples; kwargs...
33
)
4-
Z_samples = sample_perturbations(perturbed, θ)
54
F_and_y_samples = [
65
fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for
76
Z in Z_samples
87
]
8+
return F_and_y_samples
9+
end
10+
11+
function compute_F_and_y_samples(
12+
perturbed::AbstractPerturbed{true}, θ::AbstractArray{<:Real}, Z_samples; kwargs...
13+
)
14+
return ThreadsX.map(
15+
Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples
16+
)
17+
end
18+
19+
function fenchel_young_F_and_first_part_of_grad(
20+
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
21+
)
22+
Z_samples = sample_perturbations(perturbed, θ)
23+
F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...)
924
return mean(first, F_and_y_samples), mean(last, F_and_y_samples)
1025
end
1126

src/perturbed/abstract_perturbed.jl

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
2-
AbstractPerturbed
2+
AbstractPerturbed{B}
33
44
Differentiable perturbation of a black box optimizer.
5+
The parameter `parallel` is a boolean value, equal to true if the perturbations are run in parallel.
56
67
# Applicable functions
78
@@ -21,7 +22,7 @@ These subtypes share the following fields:
2122
- `rng::AbstractRNG`: random number generator
2223
- `seed::Union{Nothing,Int}`: random seed
2324
"""
24-
abstract type AbstractPerturbed end
25+
abstract type AbstractPerturbed{parallel} end
2526

2627
"""
2728
sample_perturbations(perturbed::AbstractPerturbed, θ)
@@ -35,13 +36,31 @@ function sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray{<:
3536
return Z_samples
3637
end
3738

39+
function compute_atoms(
40+
perturbed::AbstractPerturbed{false},
41+
θ::AbstractArray{<:Real},
42+
Z_samples::Vector{<:AbstractArray{<:Real}};
43+
kwargs...,
44+
)
45+
return [perturb_and_optimize(perturbed, θ, Z; kwargs...) for Z in Z_samples]
46+
end
47+
48+
function compute_atoms(
49+
perturbed::AbstractPerturbed{true},
50+
θ::AbstractArray{<:Real},
51+
Z_samples::Vector{<:AbstractArray{<:Real}};
52+
kwargs...,
53+
)
54+
return ThreadsX.map(Z -> perturb_and_optimize(perturbed, θ, Z; kwargs...), Z_samples)
55+
end
56+
3857
function compute_probability_distribution(
3958
perturbed::AbstractPerturbed,
4059
θ::AbstractArray{<:Real},
4160
Z_samples::Vector{<:AbstractArray{<:Real}};
4261
kwargs...,
4362
)
44-
atoms = [perturb_and_optimize(perturbed, θ, Z; kwargs...) for Z in Z_samples]
63+
atoms = compute_atoms(perturbed, θ, Z_samples; kwargs...)
4564
weights = ones(length(atoms)) ./ length(atoms)
4665
probadist = FixedAtomsProbabilityDistribution(atoms, weights)
4766
return probadist

src/perturbed/additive.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@ See also: [`AbstractPerturbed`](@ref).
77
88
Reference: <https://arxiv.org/abs/2002.08676>
99
"""
10-
struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
10+
struct PerturbedAdditive{F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <:
11+
AbstractPerturbed{parallel}
1112
maximizer::F
1213
ε::Float64
1314
nb_samples::Int
1415
rng::R
1516
seed::S
17+
18+
function PerturbedAdditive{F,R,S,parallel}(
19+
maximizer::F, ε::Float64, nb_samples::Int, rng::R, seed::S
20+
) where {F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel}
21+
@assert parallel isa Bool
22+
return new{F,R,S,parallel}(maximizer, ε, nb_samples, rng, seed)
23+
end
1624
end
1725

1826
function Base.show(io::IO, perturbed::PerturbedAdditive)
@@ -28,12 +36,22 @@ end
2836
Shorter constructor with defaults.
2937
"""
3038
function PerturbedAdditive(
31-
maximizer; ε=1.0, epsilon=nothing, nb_samples=1, rng=MersenneTwister(0), seed=nothing
32-
)
39+
maximizer::F;
40+
ε=1.0,
41+
epsilon=nothing,
42+
nb_samples=1,
43+
rng::R=MersenneTwister(0),
44+
seed::S=nothing,
45+
is_parallel=false,
46+
) where {F,R,S}
3347
if isnothing(epsilon)
34-
return PerturbedAdditive(maximizer, float(ε), nb_samples, rng, seed)
48+
return PerturbedAdditive{F,R,S,is_parallel}(
49+
maximizer, float(ε), nb_samples, rng, seed
50+
)
3551
else
36-
return PerturbedAdditive(maximizer, float(epsilon), nb_samples, rng, seed)
52+
return PerturbedAdditive{F,R,S,is_parallel}(
53+
maximizer, float(epsilon), nb_samples, rng, seed
54+
)
3755
end
3856
end
3957

src/perturbed/multiplicative.jl

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@ See also: [`AbstractPerturbed`](@ref).
77
88
Reference: preprint coming soon.
99
"""
10-
struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int}} <: AbstractPerturbed
10+
struct PerturbedMultiplicative{F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel} <:
11+
AbstractPerturbed{parallel}
1112
maximizer::F
1213
ε::Float64
1314
nb_samples::Int
1415
rng::R
1516
seed::S
17+
18+
function PerturbedMultiplicative{F,R,S,parallel}(
19+
maximizer::F, ε::Float64, nb_samples::Int, rng::R, seed::S
20+
) where {F,R<:AbstractRNG,S<:Union{Nothing,Int},parallel}
21+
@assert parallel isa Bool
22+
return new{F,R,S,parallel}(maximizer, ε, nb_samples, rng, seed)
23+
end
1624
end
1725

1826
function Base.show(io::IO, perturbed::PerturbedMultiplicative)
@@ -28,12 +36,22 @@ end
2836
Shorter constructor with defaults.
2937
"""
3038
function PerturbedMultiplicative(
31-
maximizer; ε=1.0, epsilon=nothing, nb_samples=1, rng=MersenneTwister(0), seed=nothing
32-
)
39+
maximizer::F;
40+
ε=1.0,
41+
epsilon=nothing,
42+
nb_samples=1,
43+
rng::R=MersenneTwister(0),
44+
seed::S=nothing,
45+
is_parallel=false,
46+
) where {F,R,S}
3347
if isnothing(epsilon)
34-
return PerturbedMultiplicative(maximizer, float(ε), nb_samples, rng, seed)
48+
return PerturbedMultiplicative{F,R,S,is_parallel}(
49+
maximizer, float(ε), nb_samples, rng, seed
50+
)
3551
else
36-
return PerturbedMultiplicative(maximizer, float(epsilon), nb_samples, rng, seed)
52+
return PerturbedMultiplicative{F,R,S,is_parallel}(
53+
maximizer, float(epsilon), nb_samples, rng, seed
54+
)
3755
end
3856
end
3957

test/paths.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,20 @@ pipelines_imitation_y = [
4848
maximizer=identity,
4949
loss=FenchelYoungLoss(PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5)),
5050
),
51+
(
52+
encoder=encoder_factory(),
53+
maximizer=identity,
54+
loss=FenchelYoungLoss(
55+
PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=5, is_parallel=true)
56+
),
57+
),
58+
(
59+
encoder=encoder_factory(),
60+
maximizer=identity,
61+
loss=FenchelYoungLoss(
62+
PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=5, is_parallel=true)
63+
),
64+
),
5165
# Perturbed + other loss
5266
(
5367
encoder=encoder_factory(),
@@ -59,6 +73,18 @@ pipelines_imitation_y = [
5973
maximizer=PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10),
6074
loss=Flux.Losses.mse,
6175
),
76+
(
77+
encoder=encoder_factory(),
78+
maximizer=PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10, is_parallel=true),
79+
loss=Flux.Losses.mse,
80+
),
81+
(
82+
encoder=encoder_factory(),
83+
maximizer=PerturbedMultiplicative(
84+
true_maximizer; ε=1.0, nb_samples=10, is_parallel=true
85+
),
86+
loss=Flux.Losses.mse,
87+
),
6288
# Generic regularized + FYL
6389
(
6490
encoder=encoder_factory(),
@@ -88,6 +114,21 @@ pipelines_experience = [
88114
PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10), cost
89115
),
90116
),
117+
(
118+
encoder=encoder_factory(),
119+
maximizer=identity,
120+
loss=Pushforward(
121+
PerturbedAdditive(true_maximizer; ε=1.0, nb_samples=10, is_parallel=true), cost
122+
),
123+
),
124+
(
125+
encoder=encoder_factory(),
126+
maximizer=identity,
127+
loss=Pushforward(
128+
PerturbedMultiplicative(true_maximizer; ε=1.0, nb_samples=10, is_parallel=true),
129+
cost,
130+
),
131+
),
91132
(
92133
encoder=encoder_factory(),
93134
maximizer=identity,

0 commit comments

Comments
 (0)