|
| 1 | +@testitem "Multinomial regression with MultinomialPolya (offline inference) node" begin |
| 2 | + using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions |
| 3 | + |
| 4 | + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) |
| 5 | + |
| 6 | + N = 20 |
| 7 | + k = 10 |
| 8 | + nsamples = 1000 |
| 9 | + X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples) |
| 10 | + |
| 11 | + @model function multinomial_model(y, N, ξ_ψ, W_ψ) |
| 12 | + ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ) |
| 13 | + for i in eachindex(y) |
| 14 | + y[i] ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))} |
| 15 | + end |
| 16 | + end |
| 17 | + |
| 18 | + result = infer( |
| 19 | + model = multinomial_model(ξ_ψ = zeros(k - 1), W_ψ = rand(Wishart(k, diageye(k - 1))), N = N), |
| 20 | + data = (y = X,), |
| 21 | + iterations = 100, |
| 22 | + free_energy = true, |
| 23 | + showprogress = false, |
| 24 | + returnvars = KeepLast(), |
| 25 | + options = (limit_stack_depth = 100,) |
| 26 | + ) |
| 27 | + |
| 28 | + m = mean(result.posteriors[:ψ]) |
| 29 | + pest = logistic_stic_breaking(m) |
| 30 | + |
| 31 | + mse = mean((pest - p) .^ 2) |
| 32 | + @test mse < 2e-5 |
| 33 | + |
| 34 | + @test result.free_energy[end] < result.free_energy[1] |
| 35 | + @test result.free_energy[end] <= result.free_energy[end - 1] |
| 36 | + @test abs(result.free_energy[end - 1] - result.free_energy[end]) < 1e-8 |
| 37 | +end |
| 38 | + |
| 39 | +@testitem "Multinomial regression - online inference" begin |
| 40 | + using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions |
| 41 | + |
| 42 | + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) |
| 43 | + |
| 44 | + N = 50 |
| 45 | + k = 40 |
| 46 | + nsamples = 5000 |
| 47 | + X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples) |
| 48 | + |
| 49 | + @model function multinomial_model(y, N, ξ_ψ, W_ψ, k) |
| 50 | + ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ) |
| 51 | + y ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies(ψ = MvNormalWeightedMeanPrecision(zeros(k - 1), diageye(k - 1)))} |
| 52 | + end |
| 53 | + |
| 54 | + @autoupdates function auto() |
| 55 | + ξ_ψ, W_ψ = weightedmean_precision(q(ψ)) |
| 56 | + end |
| 57 | + init = @initialization begin |
| 58 | + q(ψ) = MvNormalWeightedMeanPrecision(zeros(k - 1), rand(Wishart(k, diageye(k - 1)))) |
| 59 | + end |
| 60 | + |
| 61 | + result = infer( |
| 62 | + model = multinomial_model(N = N, k = k), |
| 63 | + data = (y = X,), |
| 64 | + initialization = init, |
| 65 | + iterations = 1, |
| 66 | + autoupdates = auto(), |
| 67 | + keephistory = length(X), |
| 68 | + free_energy = true, |
| 69 | + showprogress = false |
| 70 | + ) |
| 71 | + |
| 72 | + m = result.history[:ψ][end] |
| 73 | + |
| 74 | + pest = logistic_stic_breaking(mean(m)) |
| 75 | + mse = mean((pest - p) .^ 2) |
| 76 | + @test mse < 1e-3 |
| 77 | + |
| 78 | + @test result.free_energy_final_only_history[end] < result.free_energy_final_only_history[1] |
| 79 | + #Free energy over time decreases in a noisy way. It is not a monotonic decrease. |
| 80 | + |
| 81 | +end |
0 commit comments