Skip to content

Commit 534c8a2

Browse files
authored
Merge pull request #302 from ReactiveBayes/fix-tests
Fix broken tests
2 parents 238bf5c + 411b5e7 commit 534c8a2

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

test/models/mixtures/gamma_mixture_tests.jl

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
using Distributions
33
using BenchmarkTools, LinearAlgebra, StableRNGs, Plots
44

5-
# `include(test/utiltests.jl)`
65
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
76

87
@model function gamma_mixture_model(y, nmixtures, priors_as, priors_bs, prior_s)
@@ -74,8 +73,24 @@
7473
# create model from inferred parameters
7574
_mixture = MixtureModel(_dists, _mixing)
7675

77-
@test mean(_dists[1]) 0.32261559907078213
78-
@test mean(_dists[2]) 0.3346638123092099
79-
@test _mixing [0.7988294835972645, 0.20117051640273556]
80-
@test last(gresult.free_energy) -138.7724253019069
76+
@test mean(_dists[1]) 0.32 atol = 1e-2
77+
@test mean(_dists[2]) 0.33 atol = 1e-2
78+
@test _mixing [0.8, 0.2] atol = 1e-2
79+
@test last(gresult.free_energy) -146.8 atol = 1e-1
80+
81+
@test_plot "models" "gamma_mixture" begin
82+
# plot results
83+
p1 = histogram(dataset, ylim = (0, 13), xlim = (0, 1), normalize = :pdf, label = "data", title = "Generated mixtures", opacity = 0.3)
84+
p1 = plot!(range(0.0, 1.0, length = 100), (x) -> mixing[1] * pdf(mixtures[1], x), label = "component 1", linewidth = 3.0)
85+
p1 = plot!(range(0.0, 1.0, length = 100), (x) -> mixing[2] * pdf(mixtures[2], x), label = "component 2", linewidth = 3.0)
86+
87+
p2 = histogram(dataset, ylim = (0, 13), xlim = (0, 1), normalize = :pdf, label = "data", title = "Inferred mixtures", opacity = 0.3)
88+
p2 = plot!(range(0.0, 1.0, length = 100), (x) -> _mixing[1] * pdf(_dists[1], x), label = "component 1", linewidth = 3.0)
89+
p2 = plot!(range(0.0, 1.0, length = 100), (x) -> _mixing[2] * pdf(_dists[2], x), label = "component 2", linewidth = 3.0)
90+
91+
# evaluate the convergence of the algorithm by monitoring the BFE
92+
p3 = plot(gresult.free_energy, label = false, xlabel = "iterations", title = "Bethe FE")
93+
94+
plot(plot(p1, p2, layout = @layout([a; b])), plot(p3), layout = @layout([a b]), size = (800, 400))
95+
end
8196
end

test/models/mixtures/gmm_multivariate_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
@testitem "Multivariate Gaussian Mixture model" begin
22
using BenchmarkTools, Plots, StableRNGs, LinearAlgebra
33

4-
# `include(test/utiltests.jl)`
54
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
65

76
@model function multivariate_gaussian_mixture_model(rng, L, nmixtures, y)
@@ -120,14 +119,15 @@
120119
m = fresult.posteriors[:m]
121120
w = fresult.posteriors[:w]
122121
fe = fresult.free_energy
122+
123123
## -------------------------------------------- ##
124124
# Test inference results
125125
@test length(s) === 25
126126
@test length(m) === 25
127127
@test length(w) === 25
128128
@test length(fe) === 25
129129
@test all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0)
130-
@test abs(last(fe) - 3442.4015524445967) < 0.01
130+
@test last(fe) 3436.7 atol = 1e-1
131131

132132
ems = sort(mean.(last(m)), by = x -> atan(x[2] / x[1]))
133133
rms = sort(mean.(gaussians), by = x -> atan(x[2] / x[1]))

test/models/mixtures/gmm_univariate_tests.jl

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,13 @@
4747

4848
μ1 = -10.0
4949
μ2 = 10.0
50-
w = 1.777
50+
w1 = 3.777
51+
w2 = 0.333
5152

5253
z = rand(rng, Categorical(switch), n)
5354
y = Vector{Float64}(undef, n)
5455

55-
dists = [Normal(μ1, sqrt(inv(w))), Normal(μ2, sqrt(inv(w)))]
56+
dists = [Normal(μ1, sqrt(inv(w1))), Normal(μ2, sqrt(inv(w2)))]
5657

5758
for i in 1:n
5859
y[i] = rand(rng, dists[z[i]])
@@ -89,8 +90,8 @@
8990
@test length(mm2) === 10
9091
@test length(mw1) === 10
9192
@test length(mw2) === 10
92-
@test length(fe) === 10 && all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0)
93-
@test abs(last(fe) - 139.74362) < 0.01
93+
@test length(fe) === 10 && all(e -> e <= 1e-10, diff(fe))
94+
@test last(fe) 284.76 atol = 1e-1
9495

9596
ms = mean(last(mswitch))
9697

@@ -100,14 +101,14 @@
100101
rms = sort([μ1, μ2])
101102

102103
foreach(zip(rms, ems)) do (r, e)
103-
@test abs(r - mean(e)) < 0.19
104+
@test abs(r - mean(e)) < 3std(e)
104105
end
105106

106107
ews = sort([last(mw1), last(mw2)], by = mean)
107-
rws = sort([w, w])
108+
rws = sort([w1, w2])
108109

109110
foreach(zip(rws, ews)) do (r, e)
110-
@test abs(r - mean(e)) < 0.15
111+
@test abs(r - mean(e)) < 3std(e)
111112
end
112113

113114
@test_throws "must be the naive mean-field" inference_univariate(y, 10, BetheFactorization())
@@ -117,17 +118,17 @@
117118

118119
@test_plot "models" "gmm_univariate" begin
119120
dim(d) = (a) -> map(r -> r[d], a)
120-
mp = plot(mean.(mm1), ribbon = var.(mm1) .|> sqrt, label = "m1 prediction")
121-
mp = plot!(mean.(mm2), ribbon = var.(mm2) .|> sqrt, label = "m2 prediction")
121+
mp = plot(mean.(mm1), ribbon = std.(mm1), label = "m1 prediction")
122+
mp = plot!(mean.(mm2), ribbon = std.(mm2), label = "m2 prediction")
122123
mp = plot!(mp, [μ1], seriestype = :hline, label = "real m1")
123124
mp = plot!(mp, [μ2], seriestype = :hline, label = "real m2")
124125

125-
wp = plot(mean.(mw1), ribbon = var.(mw1) .|> sqrt, label = "w1 prediction", legend = :bottomleft, ylim = (-1, 3))
126-
wp = plot!(wp, [w], seriestype = :hline, label = "real w1")
127-
wp = plot!(wp, mean.(mw2), ribbon = var.(mw2) .|> sqrt, label = "w2 prediction")
128-
wp = plot!(wp, [w], seriestype = :hline, label = "real w2")
126+
wp = plot(mean.(mw1), ribbon = std.(mw1), label = "w1 prediction", legend = :bottomleft, ylim = (0, 5))
127+
wp = plot!(wp, [w1], seriestype = :hline, label = "real w1")
128+
wp = plot!(wp, mean.(mw2), ribbon = var.(mw2), label = "w2 prediction")
129+
wp = plot!(wp, [w2], seriestype = :hline, label = "real w2")
129130

130-
swp = plot(mean.(mswitch), ribbon = var.(mswitch) .|> sqrt, label = "Switch prediction")
131+
swp = plot(mean.(mswitch), ribbon = std.(mswitch), label = "Switch prediction")
131132

132133
swp = plot!(swp, [switch[1]], seriestype = :hline, label = "switch[1]")
133134
swp = plot!(swp, [switch[2]], seriestype = :hline, label = "switch[2]")

0 commit comments

Comments
 (0)