File tree Expand file tree Collapse file tree 3 files changed +14
-10
lines changed Expand file tree Collapse file tree 3 files changed +14
-10
lines changed Original file line number Diff line number Diff line change 97
97
Forward pass. Compute the expectation of the underlying distribution.
98
98
"""
99
99
function (perturbed:: AbstractPerturbed )(
100
- θ:: AbstractArray ; autodiff_variance_reduction:: Bool = false , kwargs...
100
+ θ:: AbstractArray ; autodiff_variance_reduction:: Bool = true , kwargs...
101
101
)
102
102
probadist = compute_probability_distribution (
103
103
perturbed, θ; autodiff_variance_reduction, kwargs...
@@ -118,7 +118,7 @@ function ChainRulesCore.rrule(
118
118
:: typeof (compute_probability_distribution),
119
119
perturbed:: AbstractPerturbed ,
120
120
θ:: AbstractArray ;
121
- autodiff_variance_reduction:: Bool = false ,
121
+ autodiff_variance_reduction:: Bool = true ,
122
122
kwargs... ,
123
123
)
124
124
η_samples = sample_perturbations (perturbed, θ)
Original file line number Diff line number Diff line change 15
15
16
16
@testset " PerturbedAdditive" begin
17
17
# Compute jacobian with reverse mode
18
- jac1 = Zygote. jacobian (perturbed1, θ)[1 ]
19
- jac1_big = Zygote. jacobian (perturbed1_big, θ)[1 ]
18
+ jac1 = Zygote. jacobian (θ -> perturbed1 (θ; autodiff_variance_reduction= false ), θ)[1 ]
19
+ jac1_big = Zygote. jacobian (
20
+ θ -> perturbed1_big (θ; autodiff_variance_reduction= false ), θ
21
+ )[1 ]
20
22
# Only diagonal should be positive
21
23
@test all (diag (jac1) .>= 0 )
22
24
@test all (jac1 - Diagonal (jac1) .<= 0 )
27
29
end
28
30
29
31
@testset " PerturbedMultiplicative" begin
30
- jac2 = Zygote. jacobian (perturbed2, θ)[1 ]
31
- jac2_big = Zygote. jacobian (perturbed2_big, θ)[1 ]
32
+ jac2 = Zygote. jacobian (θ -> perturbed2 (θ; autodiff_variance_reduction= false ), θ)[1 ]
33
+ jac2_big = Zygote. jacobian (
34
+ θ -> perturbed2_big (θ; autodiff_variance_reduction= false ), θ
35
+ )[1 ]
32
36
@test all (diag (jac2) .>= 0 )
33
37
@test all (jac2 - Diagonal (jac2) .<= 0 )
34
38
@test sortperm (diag (jac2)) != sortperm (θ)
Original file line number Diff line number Diff line change 31
31
n = 10
32
32
θ = randn (10 )
33
33
34
- Ja = jacobian (pa , θ)[1 ]
35
- Ja_reduced_variance = jacobian (x -> pa (x; autodiff_variance_reduction = true ) , θ)[1 ]
34
+ Ja = jacobian (θ -> pa (θ; autodiff_variance_reduction = false ) , θ)[1 ]
35
+ Ja_reduced_variance = jacobian (pa , θ)[1 ]
36
36
37
- Jm = jacobian (pm , θ)[1 ]
38
- Jm_reduced_variance = jacobian (x -> pm (x; autodiff_variance_reduction = true ) , θ)[1 ]
37
+ Jm = jacobian (x -> pm (x; autodiff_variance_reduction = false ) , θ)[1 ]
38
+ Jm_reduced_variance = jacobian (pm , θ)[1 ]
39
39
40
40
J_true = Matrix (I, n, n) # exact jacobian is the identity matrix
41
41
You can’t perform that action at this time.
0 commit comments