Skip to content

Commit ef92ff4

Browse files
committed
enable variance reduction by default for
1 parent c9c68b1 commit ef92ff4

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

src/perturbed/abstract_perturbed.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ end
9797
Forward pass. Compute the expectation of the underlying distribution.
9898
"""
9999
function (perturbed::AbstractPerturbed)(
100-
θ::AbstractArray; autodiff_variance_reduction::Bool=false, kwargs...
100+
θ::AbstractArray; autodiff_variance_reduction::Bool=true, kwargs...
101101
)
102102
probadist = compute_probability_distribution(
103103
perturbed, θ; autodiff_variance_reduction, kwargs...
@@ -118,7 +118,7 @@ function ChainRulesCore.rrule(
118118
::typeof(compute_probability_distribution),
119119
perturbed::AbstractPerturbed,
120120
θ::AbstractArray;
121-
autodiff_variance_reduction::Bool=false,
121+
autodiff_variance_reduction::Bool=true,
122122
kwargs...,
123123
)
124124
η_samples = sample_perturbations(perturbed, θ)

test/jacobian_approx.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
@testset "PerturbedAdditive" begin
1717
# Compute jacobian with reverse mode
18-
jac1 = Zygote.jacobian(perturbed1, θ)[1]
19-
jac1_big = Zygote.jacobian(perturbed1_big, θ)[1]
18+
jac1 = Zygote.jacobian-> perturbed1(θ; autodiff_variance_reduction=false), θ)[1]
19+
jac1_big = Zygote.jacobian(
20+
θ -> perturbed1_big(θ; autodiff_variance_reduction=false), θ
21+
)[1]
2022
# Only diagonal should be positive
2123
@test all(diag(jac1) .>= 0)
2224
@test all(jac1 - Diagonal(jac1) .<= 0)
@@ -27,8 +29,10 @@
2729
end
2830

2931
@testset "PerturbedMultiplicative" begin
30-
jac2 = Zygote.jacobian(perturbed2, θ)[1]
31-
jac2_big = Zygote.jacobian(perturbed2_big, θ)[1]
32+
jac2 = Zygote.jacobian-> perturbed2(θ; autodiff_variance_reduction=false), θ)[1]
33+
jac2_big = Zygote.jacobian(
34+
θ -> perturbed2_big(θ; autodiff_variance_reduction=false), θ
35+
)[1]
3236
@test all(diag(jac2) .>= 0)
3337
@test all(jac2 - Diagonal(jac2) .<= 0)
3438
@test sortperm(diag(jac2)) != sortperm(θ)

test/perturbed_oracle.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ end
3131
n = 10
3232
θ = randn(10)
3333

34-
Ja = jacobian(pa, θ)[1]
35-
Ja_reduced_variance = jacobian(x -> pa(x; autodiff_variance_reduction=true), θ)[1]
34+
Ja = jacobian(θ -> pa(θ; autodiff_variance_reduction=false), θ)[1]
35+
Ja_reduced_variance = jacobian(pa, θ)[1]
3636

37-
Jm = jacobian(pm, θ)[1]
38-
Jm_reduced_variance = jacobian(x -> pm(x; autodiff_variance_reduction=true), θ)[1]
37+
Jm = jacobian(x -> pm(x; autodiff_variance_reduction=false), θ)[1]
38+
Jm_reduced_variance = jacobian(pm, θ)[1]
3939

4040
J_true = Matrix(I, n, n) # exact jacobian is the identity matrix
4141

0 commit comments

Comments
 (0)