|
47 | 47 |
|
48 | 48 | μ1 = -10.0
|
49 | 49 | μ2 = 10.0
|
50 |
| - w = 1.777 |
| 50 | + w1 = 3.777 |
| 51 | + w2 = 0.333 |
51 | 52 |
|
52 | 53 | z = rand(rng, Categorical(switch), n)
|
53 | 54 | y = Vector{Float64}(undef, n)
|
54 | 55 |
|
55 |
| - dists = [Normal(μ1, sqrt(inv(w))), Normal(μ2, sqrt(inv(w)))] |
| 56 | + dists = [Normal(μ1, sqrt(inv(w1))), Normal(μ2, sqrt(inv(w2)))] |
56 | 57 |
|
57 | 58 | for i in 1:n
|
58 | 59 | y[i] = rand(rng, dists[z[i]])
|
|
89 | 90 | @test length(mm2) === 10
|
90 | 91 | @test length(mw1) === 10
|
91 | 92 | @test length(mw2) === 10
|
92 |
| - @test length(fe) === 10 && all(filter(e -> abs(e) > 1e-3, diff(fe)) .< 0) |
93 |
| - @test abs(last(fe) - 139.74362) < 0.01 |
| 93 | + @test length(fe) === 10 && all(e -> e <= 1e-10, diff(fe)) |
| 94 | + @test last(fe) ≈ 284.76 atol = 1e-1 |
94 | 95 |
|
95 | 96 | ms = mean(last(mswitch))
|
96 | 97 |
|
|
100 | 101 | rms = sort([μ1, μ2])
|
101 | 102 |
|
102 | 103 | foreach(zip(rms, ems)) do (r, e)
|
103 |
| - @test abs(r - mean(e)) < 0.19 |
| 104 | + @test abs(r - mean(e)) < 3std(e) |
104 | 105 | end
|
105 | 106 |
|
106 | 107 | ews = sort([last(mw1), last(mw2)], by = mean)
|
107 |
| - rws = sort([w, w]) |
| 108 | + rws = sort([w1, w2]) |
108 | 109 |
|
109 | 110 | foreach(zip(rws, ews)) do (r, e)
|
110 |
| - @test abs(r - mean(e)) < 0.15 |
| 111 | + @test abs(r - mean(e)) < 3std(e) |
111 | 112 | end
|
112 | 113 |
|
113 | 114 | @test_throws "must be the naive mean-field" inference_univariate(y, 10, BetheFactorization())
|
|
117 | 118 |
|
118 | 119 | @test_plot "models" "gmm_univariate" begin
|
119 | 120 | dim(d) = (a) -> map(r -> r[d], a)
|
120 |
| - mp = plot(mean.(mm1), ribbon = var.(mm1) .|> sqrt, label = "m1 prediction") |
121 |
| - mp = plot!(mean.(mm2), ribbon = var.(mm2) .|> sqrt, label = "m2 prediction") |
| 121 | + mp = plot(mean.(mm1), ribbon = std.(mm1), label = "m1 prediction") |
| 122 | + mp = plot!(mean.(mm2), ribbon = std.(mm2), label = "m2 prediction") |
122 | 123 | mp = plot!(mp, [μ1], seriestype = :hline, label = "real m1")
|
123 | 124 | mp = plot!(mp, [μ2], seriestype = :hline, label = "real m2")
|
124 | 125 |
|
125 |
| - wp = plot(mean.(mw1), ribbon = var.(mw1) .|> sqrt, label = "w1 prediction", legend = :bottomleft, ylim = (-1, 3)) |
126 |
| - wp = plot!(wp, [w], seriestype = :hline, label = "real w1") |
127 |
| - wp = plot!(wp, mean.(mw2), ribbon = var.(mw2) .|> sqrt, label = "w2 prediction") |
128 |
| - wp = plot!(wp, [w], seriestype = :hline, label = "real w2") |
| 126 | + wp = plot(mean.(mw1), ribbon = std.(mw1), label = "w1 prediction", legend = :bottomleft, ylim = (0, 5)) |
| 127 | + wp = plot!(wp, [w1], seriestype = :hline, label = "real w1") |
| 128 | + wp = plot!(wp, mean.(mw2), ribbon = var.(mw2), label = "w2 prediction") |
| 129 | + wp = plot!(wp, [w2], seriestype = :hline, label = "real w2") |
129 | 130 |
|
130 |
| - swp = plot(mean.(mswitch), ribbon = var.(mswitch) .|> sqrt, label = "Switch prediction") |
| 131 | + swp = plot(mean.(mswitch), ribbon = std.(mswitch), label = "Switch prediction") |
131 | 132 |
|
132 | 133 | swp = plot!(swp, [switch[1]], seriestype = :hline, label = "switch[1]")
|
133 | 134 | swp = plot!(swp, [switch[2]], seriestype = :hline, label = "switch[2]")
|
|
0 commit comments