Skip to content

Commit 18ee995

Browse files
authored
Correct rrule for perturbed (#75)
* Correct rrule for perturbed * Replace sum with mean and add propert tests
1 parent b9f84b9 commit 18ee995

File tree

4 files changed

+54
-37
lines changed

4 files changed

+54
-37
lines changed

src/perturbed/additive.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ function ChainRulesCore.rrule(
6969
Z_samples = sample_perturbations(perturbed, θ)
7070
probadist = compute_probability_distribution(perturbed, θ, Z_samples; kwargs...)
7171
function perturbed_additive_probadist_pullback(probadist_tangent)
72-
weigths_tangent = probadist_tangent.weights
73-
= inv(ε) * sum(wt * Z for (wt, Z) in zip(weigths_tangent, Z_samples))
72+
weights_tangent = probadist_tangent.weights
73+
if length(weights_tangent) != length(Z_samples)
74+
throw(ArgumentError("Probadist tangent has invalid number of atoms"))
75+
end
76+
= inv(ε) * mean(wt * Z for (wt, Z) in zip(weights_tangent, Z_samples))
7477
return NoTangent(), NoTangent(), dθ
7578
end
7679
return probadist, perturbed_additive_probadist_pullback

src/perturbed/multiplicative.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ function ChainRulesCore.rrule(
7171
Z_samples = sample_perturbations(perturbed, θ)
7272
probadist = compute_probability_distribution(perturbed, θ, Z_samples; kwargs...)
7373
function perturbed_multiplicative_probadist_pullback(probadist_tangent)
74-
weigths_tangent = probadist_tangent.weights
75-
= inv.(ε .* θ) .* sum(wt * Z for (wt, Z) in zip(weigths_tangent, Z_samples))
74+
weights_tangent = probadist_tangent.weights
75+
if length(weights_tangent) != length(Z_samples)
76+
throw(ArgumentError("Probadist tangent has invalid number of atoms"))
77+
end
78+
= inv.(ε .* θ) .* mean(wt * Z for (wt, Z) in zip(weights_tangent, Z_samples))
7679
return NoTangent(), NoTangent(), dθ
7780
end
7881
return probadist, perturbed_multiplicative_probadist_pullback

src/utils/probability_distribution.jl

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -36,33 +36,6 @@ end
3636

3737
Base.rand(probadist::FixedAtomsProbabilityDistribution) = rand(GLOBAL_RNG, probadist)
3838

39-
"""
40-
compress_distribution!(probadist[; atol])
41-
42-
Remove duplicated atoms in `probadist` (up to a tolerance on equality).
43-
"""
44-
function compress_distribution!(
45-
probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
46-
) where {A,W}
47-
(; atoms, weights) = probadist
48-
to_delete = Int[]
49-
for i in length(probadist):-1:1
50-
ai = atoms[i]
51-
for j in 1:(i - 1)
52-
aj = atoms[j]
53-
if isapprox(ai, aj; atol=atol)
54-
weights[j] += weights[i]
55-
push!(to_delete, i)
56-
break
57-
end
58-
end
59-
end
60-
sort!(to_delete)
61-
deleteat!(atoms, to_delete)
62-
deleteat!(weights, to_delete)
63-
return probadist
64-
end
65-
6639
"""
6740
apply_on_atoms(post_processing, probadist)
6841
@@ -121,3 +94,32 @@ The following layer types are supported:
12194
- [`RegularizedGeneric`](@ref)
12295
"""
12396
function compute_probability_distribution end
97+
98+
"""
99+
compress_distribution!(probadist[; atol])
100+
101+
Remove duplicated atoms in `probadist` (up to a tolerance on equality).
102+
103+
This function can break probabilistic layers if used during training. It is only meant for analyzing outputs.
104+
"""
105+
function compress_distribution!(
106+
probadist::FixedAtomsProbabilityDistribution{A,W}; atol=0
107+
) where {A,W}
108+
(; atoms, weights) = probadist
109+
to_delete = Int[]
110+
for i in length(probadist):-1:1
111+
ai = atoms[i]
112+
for j in 1:(i - 1)
113+
aj = atoms[j]
114+
if isapprox(ai, aj; atol=atol)
115+
weights[j] += weights[i]
116+
push!(to_delete, i)
117+
break
118+
end
119+
end
120+
end
121+
sort!(to_delete)
122+
deleteat!(atoms, to_delete)
123+
deleteat!(weights, to_delete)
124+
return probadist
125+
end

test/jacobian_approx.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,38 @@
44
using Test
55
using Zygote
66

7-
Random.seed!(63)
7+
# Random.seed!(63)
88

99
θ = [3, 5, 4, 2]
1010

11-
perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1000, seed=0)
12-
perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1000, seed=0)
11+
perturbed1 = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=1_000, seed=0)
12+
perturbed1_big = PerturbedAdditive(one_hot_argmax; ε=2, nb_samples=10_000, seed=0)
13+
perturbed2 = PerturbedMultiplicative(one_hot_argmax; ε=0.5, nb_samples=1_000, seed=0)
14+
perturbed2_big = PerturbedMultiplicative(
15+
one_hot_argmax; ε=0.5, nb_samples=10_000, seed=0
16+
)
1317

1418
@testset "PerturbedAdditive" begin
1519
# Compute jacobian with reverse mode
1620
jac1 = Zygote.jacobian(perturbed1, θ)[1]
21+
jac1_big = Zygote.jacobian(perturbed1_big, θ)[1]
22+
@show jac1 jac1_big
1723
# Only diagonal should be positive
1824
@test all(diag(jac1) .>= 0)
1925
@test all(jac1 - Diagonal(jac1) .<= 0)
2026
# Order of diagonal coefficients should follow order of θ
2127
@test sortperm(diag(jac1)) == sortperm(θ)
28+
# No scaling with nb of samples
29+
@test norm(jac1) norm(jac1_big) rtol = 1e-2
2230
end
2331

2432
@testset "PerturbedMultiplicative" begin
2533
jac2 = Zygote.jacobian(perturbed2, θ)[1]
34+
jac2_big = Zygote.jacobian(perturbed2_big, θ)[1]
2635
@test all(diag(jac2) .>= 0)
2736
@test all(jac2 - Diagonal(jac2) .<= 0)
28-
@test_broken sortperm(diag(jac2)) == sortperm(θ)
29-
# This is broken because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5
30-
# Maybe because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax?
37+
@test sortperm(diag(jac2)) != sortperm(θ)
38+
# This is not equal because the diagonal coefficient for θ₃ = 4 is often larger than the one for θ₂ = 5. It happens because θ₃ has the opportunity to *become* the argmax (and hence switch from 0 to 1), whereas θ₂ already *is* the argmax.
39+
@test norm(jac2) norm(jac2_big) rtol = 1e-2
3140
end
3241
end

0 commit comments

Comments
 (0)