Skip to content

Commit b9f84b9

Browse files
authored
Remove all <:Real type constraints (#71)
1 parent 2036cf1 commit b9f84b9

File tree

18 files changed

+61
-97
lines changed

18 files changed

+61
-97
lines changed

ext/InferOptFrankWolfeExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ using LinearAlgebra: dot
88
## Forward pass
99

1010
function InferOpt.compute_probability_distribution(
11-
dfw::DiffFW, θ::AbstractArray{<:Real}; frank_wolfe_kwargs=NamedTuple()
11+
dfw::DiffFW, θ::AbstractArray; frank_wolfe_kwargs=NamedTuple()
1212
)
1313
weights, atoms = dfw.implicit(θ; frank_wolfe_kwargs=frank_wolfe_kwargs)
1414
probadist = FixedAtomsProbabilityDistribution(atoms, weights)
@@ -23,7 +23,7 @@ Construct a `DifferentiableFrankWolfe.DiffFW` struct and call `compute_probabili
2323
Keyword arguments are passed to the underlying linear maximizer.
2424
"""
2525
function InferOpt.compute_probability_distribution(
26-
regularized::RegularizedGeneric, θ::AbstractArray{<:Real}; kwargs...
26+
regularized::RegularizedGeneric, θ::AbstractArray; kwargs...
2727
)
2828
(; maximizer, Ω, Ω_grad, frank_wolfe_kwargs) = regularized
2929
f(y, θ) = Ω(y) - dot(θ, y)
@@ -41,7 +41,7 @@ Apply `compute_probability_distribution(regularized, θ)` and return the expecta
4141
4242
Keyword arguments are passed to the underlying linear maximizer.
4343
"""
44-
function (regularized::RegularizedGeneric)(θ::AbstractArray{<:Real}; kwargs...)
44+
function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...)
4545
probadist = compute_probability_distribution(regularized, θ; kwargs...)
4646
return compute_expectation(probadist)
4747
end

src/fenchel_young/fenchel_young.jl

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@ end
1919

2020
## Forward pass
2121

22-
function (fyl::FenchelYoungLoss)(
23-
θ::AbstractArray{<:Real}, y_true::AbstractArray{<:Real}; kwargs...
24-
)
22+
function (fyl::FenchelYoungLoss)(θ::AbstractArray, y_true::AbstractArray; kwargs...)
2523
l, _ = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...)
2624
return l
2725
end
2826

2927
@traitfn function fenchel_young_loss_and_grad(
30-
fyl::FenchelYoungLoss{P},
31-
θ::AbstractArray{<:Real},
32-
y_true::AbstractArray{<:Real};
33-
kwargs...,
28+
fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs...
3429
) where {P; IsRegularized{P}}
3530
(; predictor) = fyl
3631
= predictor(θ; kwargs...)
@@ -42,10 +37,7 @@ end
4237
end
4338

4439
function fenchel_young_loss_and_grad(
45-
fyl::FenchelYoungLoss{P},
46-
θ::AbstractArray{<:Real},
47-
y_true::AbstractArray{<:Real};
48-
kwargs...,
40+
fyl::FenchelYoungLoss{P}, θ::AbstractArray, y_true::AbstractArray; kwargs...
4941
) where {P<:AbstractPerturbed}
5042
(; predictor) = fyl
5143
F, almost_ŷ = fenchel_young_F_and_first_part_of_grad(predictor, θ; kwargs...)
@@ -57,10 +49,7 @@ end
5749
## Backward pass
5850

5951
function ChainRulesCore.rrule(
60-
fyl::FenchelYoungLoss,
61-
θ::AbstractArray{<:Real},
62-
y_true::AbstractArray{<:Real};
63-
kwargs...,
52+
fyl::FenchelYoungLoss, θ::AbstractArray, y_true::AbstractArray; kwargs...
6453
)
6554
l, g = fenchel_young_loss_and_grad(fyl, θ, y_true; kwargs...)
6655
fyl_pullback(dl) = NoTangent(), dl * g, NoTangent()

src/fenchel_young/perturbed.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
function compute_F_and_y_samples(
2-
perturbed::AbstractPerturbed{false}, θ::AbstractArray{<:Real}, Z_samples; kwargs...
2+
perturbed::AbstractPerturbed{false}, θ::AbstractArray, Z_samples; kwargs...
33
)
44
F_and_y_samples = [
55
fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...) for
@@ -9,26 +9,23 @@ function compute_F_and_y_samples(
99
end
1010

1111
function compute_F_and_y_samples(
12-
perturbed::AbstractPerturbed{true}, θ::AbstractArray{<:Real}, Z_samples; kwargs...
12+
perturbed::AbstractPerturbed{true}, θ::AbstractArray, Z_samples; kwargs...
1313
)
1414
return ThreadsX.map(
1515
Z -> fenchel_young_F_and_first_part_of_grad(perturbed, θ, Z; kwargs...), Z_samples
1616
)
1717
end
1818

1919
function fenchel_young_F_and_first_part_of_grad(
20-
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
20+
perturbed::AbstractPerturbed, θ::AbstractArray; kwargs...
2121
)
2222
Z_samples = sample_perturbations(perturbed, θ)
2323
F_and_y_samples = compute_F_and_y_samples(perturbed, θ, Z_samples; kwargs...)
2424
return mean(first, F_and_y_samples), mean(last, F_and_y_samples)
2525
end
2626

2727
function fenchel_young_F_and_first_part_of_grad(
28-
perturbed::PerturbedAdditive,
29-
θ::AbstractArray{<:Real},
30-
Z::AbstractArray{<:Real};
31-
kwargs...,
28+
perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs...
3229
)
3330
(; maximizer, ε) = perturbed
3431
θ_perturbed = θ .+ ε .* Z
@@ -38,10 +35,7 @@ function fenchel_young_F_and_first_part_of_grad(
3835
end
3936

4037
function fenchel_young_F_and_first_part_of_grad(
41-
perturbed::PerturbedMultiplicative,
42-
θ::AbstractArray{<:Real},
43-
Z::AbstractArray{<:Real};
44-
kwargs...,
38+
perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs...
4539
)
4640
(; maximizer, ε) = perturbed
4741
eZ = exp.(ε .* Z .- ε^2)

src/imitation_loss/imitation_loss.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Retrieve `y_true` from `t_true`. `t_true` must contain an `y_true` field.
5353
"""
5454
get_y_true(t_true::NamedTuple) = t_true.y_true
5555

56-
function prediction_and_loss(l::ImitationLoss, θ::AbstractArray{<:Real}, t_true; kwargs...)
56+
function prediction_and_loss(l::ImitationLoss, θ::AbstractArray, t_true; kwargs...)
5757
(; base_loss, Ω, maximizer, α) = l
5858
y_true = get_y_true(t_true)
5959
= maximizer(θ, t_true; kwargs...)
@@ -63,14 +63,14 @@ end
6363

6464
## Forward pass
6565

66-
function (l::ImitationLoss)(θ::AbstractArray{<:Real}, t_true; kwargs...)
66+
function (l::ImitationLoss)(θ::AbstractArray, t_true; kwargs...)
6767
_, l = prediction_and_loss(l, θ, t_true; kwargs...)
6868
return l
6969
end
7070

7171
## Backward pass
7272

73-
function ChainRulesCore.rrule(l::ImitationLoss, θ::AbstractArray{<:Real}, t_true; kwargs...)
73+
function ChainRulesCore.rrule(l::ImitationLoss, θ::AbstractArray, t_true; kwargs...)
7474
(; α) = l
7575
y_true = get_y_true(t_true)
7676
ŷ, l = prediction_and_loss(l, θ, t_true; kwargs...)

src/interpolation/interpolation.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@ end
2121

2222
Interpolation(maximizer; λ=1.0) = Interpolation(maximizer, float(λ))
2323

24-
function (interpolation::Interpolation)(θ::AbstractArray{<:Real}; kwargs...)
24+
function (interpolation::Interpolation)(θ::AbstractArray; kwargs...)
2525
return interpolation.maximizer(θ; kwargs...)
2626
end
2727

28-
function ChainRulesCore.rrule(
29-
interpolation::Interpolation, θ::AbstractArray{<:Real}; kwargs...
30-
)
28+
function ChainRulesCore.rrule(interpolation::Interpolation, θ::AbstractArray; kwargs...)
3129
(; maximizer, λ) = interpolation
3230
y = maximizer(θ; kwargs...)
3331
function interpolation_pullback(dy)

src/perturbed/abstract_perturbed.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ abstract type AbstractPerturbed{parallel} end
2929
3030
Draw random perturbations `Z` which will be applied to the objective direction `θ`.
3131
"""
32-
function sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray{<:Real})
32+
function sample_perturbations(perturbed::AbstractPerturbed, θ::AbstractArray)
3333
(; rng, seed, nb_samples) = perturbed
3434
seed!(rng, seed)
3535
Z_samples = [randn(rng, size(θ)) for _ in 1:nb_samples]
@@ -38,26 +38,26 @@ end
3838

3939
function compute_atoms(
4040
perturbed::AbstractPerturbed{false},
41-
θ::AbstractArray{<:Real},
42-
Z_samples::Vector{<:AbstractArray{<:Real}};
41+
θ::AbstractArray,
42+
Z_samples::Vector{<:AbstractArray};
4343
kwargs...,
4444
)
4545
return [perturb_and_optimize(perturbed, θ, Z; kwargs...) for Z in Z_samples]
4646
end
4747

4848
function compute_atoms(
4949
perturbed::AbstractPerturbed{true},
50-
θ::AbstractArray{<:Real},
51-
Z_samples::Vector{<:AbstractArray{<:Real}};
50+
θ::AbstractArray,
51+
Z_samples::Vector{<:AbstractArray};
5252
kwargs...,
5353
)
5454
return ThreadsX.map(Z -> perturb_and_optimize(perturbed, θ, Z; kwargs...), Z_samples)
5555
end
5656

5757
function compute_probability_distribution(
5858
perturbed::AbstractPerturbed,
59-
θ::AbstractArray{<:Real},
60-
Z_samples::Vector{<:AbstractArray{<:Real}};
59+
θ::AbstractArray,
60+
Z_samples::Vector{<:AbstractArray};
6161
kwargs...,
6262
)
6363
atoms = compute_atoms(perturbed, θ, Z_samples; kwargs...)
@@ -72,7 +72,7 @@ end
7272
Turn random perturbations of `θ` into a distribution on polytope vertices.
7373
"""
7474
function compute_probability_distribution(
75-
perturbed::AbstractPerturbed, θ::AbstractArray{<:Real}; kwargs...
75+
perturbed::AbstractPerturbed, θ::AbstractArray; kwargs...
7676
)
7777
Z_samples = sample_perturbations(perturbed, θ)
7878
return compute_probability_distribution(perturbed, θ, Z_samples; kwargs...)
@@ -83,7 +83,7 @@ end
8383
8484
Apply `compute_probability_distribution(perturbed, θ)` and return the expectation.
8585
"""
86-
function (perturbed::AbstractPerturbed)(θ::AbstractArray{<:Real}; kwargs...)
86+
function (perturbed::AbstractPerturbed)(θ::AbstractArray; kwargs...)
8787
probadist = compute_probability_distribution(perturbed, θ; kwargs...)
8888
return compute_expectation(probadist)
8989
end

src/perturbed/additive.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,7 @@ end
4949
## Forward pass
5050

5151
function perturb_and_optimize(
52-
perturbed::PerturbedAdditive,
53-
θ::AbstractArray{<:Real},
54-
Z::AbstractArray{<:Real};
55-
kwargs...,
52+
perturbed::PerturbedAdditive, θ::AbstractArray, Z::AbstractArray; kwargs...
5653
)
5754
(; maximizer, ε) = perturbed
5855
θ_perturbed = θ .+ ε .* Z
@@ -65,7 +62,7 @@ end
6562
function ChainRulesCore.rrule(
6663
::typeof(compute_probability_distribution),
6764
perturbed::PerturbedAdditive,
68-
θ::AbstractArray{<:Real};
65+
θ::AbstractArray;
6966
kwargs...,
7067
)
7168
(; ε) = perturbed

src/perturbed/multiplicative.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ end
5151
## Forward pass
5252

5353
function perturb_and_optimize(
54-
perturbed::PerturbedMultiplicative,
55-
θ::AbstractArray{<:Real},
56-
Z::AbstractArray{<:Real};
57-
kwargs...,
54+
perturbed::PerturbedMultiplicative, θ::AbstractArray, Z::AbstractArray; kwargs...
5855
)
5956
(; maximizer, ε) = perturbed
6057
θ_perturbed = θ .* exp.(ε .* Z .- ε^2)
@@ -67,7 +64,7 @@ end
6764
function ChainRulesCore.rrule(
6865
::typeof(compute_probability_distribution),
6966
perturbed::PerturbedMultiplicative,
70-
θ::AbstractArray{<:Real};
67+
θ::AbstractArray;
7168
kwargs...,
7269
)
7370
(; ε) = perturbed

src/plus_identity/plus_identity.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ function Base.show(io::IO, plusid::PlusIdentity)
1818
return print(io, "PlusIdentity($(plusid.maximizer)")
1919
end
2020

21-
function (plusid::PlusIdentity)(θ::AbstractArray{<:Real}; kwargs...)
21+
function (plusid::PlusIdentity)(θ::AbstractArray; kwargs...)
2222
return plusid.maximizer(θ; kwargs...)
2323
end
2424

25-
function ChainRulesCore.rrule(plusid::PlusIdentity, θ::AbstractArray{<:Real}; kwargs...)
25+
function ChainRulesCore.rrule(plusid::PlusIdentity, θ::AbstractArray; kwargs...)
2626
y = plusid.maximizer(θ; kwargs...)
2727
plusid_pullback(dy) = NoTangent(), dy
2828
return y, plusid_pullback

src/regularized/regularized_generic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646

4747
@traitimpl IsRegularized{RegularizedGeneric}
4848

49-
function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray{<:Real})
49+
function compute_regularization(regularized::RegularizedGeneric, y::AbstractArray)
5050
return regularized.Ω(y)
5151
end
5252

@@ -55,7 +55,7 @@ end
5555
5656
Apply `compute_probability_distribution(regularized, θ, kwargs...)` and return the expectation.
5757
"""
58-
function (regularized::RegularizedGeneric)(θ::AbstractArray{<:Real}; kwargs...)
58+
function (regularized::RegularizedGeneric)(θ::AbstractArray; kwargs...)
5959
probadist = compute_probability_distribution(regularized, θ; kwargs...)
6060
return compute_expectation(probadist)
6161
end

0 commit comments

Comments
 (0)