Skip to content

Commit 5d55c53

Browse files
authored
Merge pull request #418 from ReactiveBayes/multinomial_polya_model
Multinomial polya model
2 parents c07338a + ab2337d commit 5d55c53

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RxInfer"
22
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
33
authors = ["Bagaev Dmitry <d.v.bagaev@tue.nl> and contributors"]
4-
version = "4.0.1"
4+
version = "4.1.0"
55

66
[deps]
77
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
@@ -56,7 +56,7 @@ Preferences = "1.4.3"
5656
PrettyTables = "2"
5757
ProgressMeter = "1.0.0"
5858
Random = "1.9"
59-
ReactiveMP = "~5.0.0"
59+
ReactiveMP = "~5.1.0"
6060
Reexport = "1.2.0"
6161
Rocket = "1.8.0"
6262
Static = "0.8.10, 1"

codemeta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
"downloadUrl": "https://github.com/reactivebayes/RxInfer.jl/releases",
1010
"issueTracker": "https://github.com/reactivebayes/RxInfer.jl/issues",
1111
"name": "RxInfer.jl",
12-
"version": "4.0.1",
12+
"version": "4.1.0",
1313
"description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ",
1414
"applicationCategory": "Statistics",
1515
"developmentStatus": "active",
1616
"readme": "https://rxinfer.ml",
17-
"softwareVersion": "4.0.1",
17+
"softwareVersion": "4.1.0",
1818
"keywords": [
1919
"Bayesian inference",
2020
"message passing",
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
@testitem "Multinomial regression with MultinomialPolya (offline inference) node" begin
2+
using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions
3+
4+
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
5+
6+
N = 20
7+
k = 10
8+
nsamples = 1000
9+
X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples)
10+
11+
@model function multinomial_model(y, N, ξ_ψ, W_ψ)
12+
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
13+
for i in eachindex(y)
14+
y[i] ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies= MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ))}
15+
end
16+
end
17+
18+
result = infer(
19+
model = multinomial_model(ξ_ψ = zeros(k - 1), W_ψ = rand(Wishart(k, diageye(k - 1))), N = N),
20+
data = (y = X,),
21+
iterations = 100,
22+
free_energy = true,
23+
showprogress = false,
24+
returnvars = KeepLast(),
25+
options = (limit_stack_depth = 100,)
26+
)
27+
28+
m = mean(result.posteriors[])
29+
pest = logistic_stic_breaking(m)
30+
31+
mse = mean((pest - p) .^ 2)
32+
@test mse < 2e-5
33+
34+
@test result.free_energy[end] < result.free_energy[1]
35+
@test result.free_energy[end] <= result.free_energy[end - 1]
36+
@test abs(result.free_energy[end - 1] - result.free_energy[end]) < 1e-8
37+
end
38+
39+
@testitem "Multinomial regression - online inference" begin
40+
using BenchmarkTools, Plots, Distributions, LinearAlgebra, StableRNGs, ExponentialFamily.LogExpFunctions
41+
42+
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
43+
44+
N = 50
45+
k = 40
46+
nsamples = 5000
47+
X, ψ, p = generate_multinomial_data(; N = N, k = k, nsamples = nsamples)
48+
49+
@model function multinomial_model(y, N, ξ_ψ, W_ψ, k)
50+
ψ ~ MvNormalWeightedMeanPrecision(ξ_ψ, W_ψ)
51+
y ~ MultinomialPolya(N, ψ) where {dependencies = RequireMessageFunctionalDependencies= MvNormalWeightedMeanPrecision(zeros(k - 1), diageye(k - 1)))}
52+
end
53+
54+
@autoupdates function auto()
55+
ξ_ψ, W_ψ = weightedmean_precision(q(ψ))
56+
end
57+
init = @initialization begin
58+
q(ψ) = MvNormalWeightedMeanPrecision(zeros(k - 1), rand(Wishart(k, diageye(k - 1))))
59+
end
60+
61+
result = infer(
62+
model = multinomial_model(N = N, k = k),
63+
data = (y = X,),
64+
initialization = init,
65+
iterations = 1,
66+
autoupdates = auto(),
67+
keephistory = length(X),
68+
free_energy = true,
69+
showprogress = false
70+
)
71+
72+
m = result.history[][end]
73+
74+
pest = logistic_stic_breaking(mean(m))
75+
mse = mean((pest - p) .^ 2)
76+
@test mse < 1e-3
77+
78+
@test result.free_energy_final_only_history[end] < result.free_energy_final_only_history[1]
79+
#Free energy over time decreases in a noisy way. It is not a monotonic decrease.
80+
81+
end

test/utiltests.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,24 @@ macro test_expression_generating(lhs, rhs)
7979
end
8080
)
8181
end
82+
83+
function generate_multinomial_data(rng = StableRNG(123); N = 3, k = 3, nsamples = 5000)
84+
ψ = randn(rng, k)
85+
p = ReactiveMP.softmax(ψ)
86+
87+
X = rand(rng, Multinomial(N, p), nsamples)
88+
X = [X[:, i] for i in axes(X, 2)]
89+
return X, ψ, p
90+
end
91+
92+
function logistic_stic_breaking(m)
93+
Km1 = length(m)
94+
95+
p = Array{Float64}(undef, Km1 + 1)
96+
p[1] = logistic(m[1])
97+
for i in 2:Km1
98+
p[i] = logistic(m[i]) * (1 - sum(p[1:(i - 1)]))
99+
end
100+
p[end] = 1 - sum(p[1:(end - 1)])
101+
return p
102+
end

0 commit comments

Comments
 (0)