Skip to content

Commit d8d97ae

Browse files
authored
Merge pull request #403 from ReactiveBayes/binomial_regression
Add Binomial Regression Tests and Example Notebook
2 parents f771180 + 2891509 commit d8d97ae

File tree

4 files changed

+91
-5
lines changed

4 files changed

+91
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RxInfer"
22
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
33
authors = ["Bagaev Dmitry <d.v.bagaev@tue.nl> and contributors"]
4-
version = "3.8.4"
4+
version = "3.9.0"
55

66
[deps]
77
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
@@ -42,7 +42,7 @@ MacroTools = "0.5.6"
4242
Optim = "1.0.0"
4343
ProgressMeter = "1.0.0"
4444
Random = "1.9"
45-
ReactiveMP = "~4.4.4"
45+
ReactiveMP = "~4.5.0"
4646
Reexport = "1.2.0"
4747
Rocket = "1.8.0"
4848
Static = "0.8.10, 1"

codemeta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
"downloadUrl": "https://github.com/reactivebayes/RxInfer.jl/releases",
1010
"issueTracker": "https://github.com/reactivebayes/RxInfer.jl/issues",
1111
"name": "RxInfer.jl",
12-
"version": "3.8.4",
12+
"version": "3.9.0",
1313
"description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ",
1414
"applicationCategory": "Statistics",
1515
"developmentStatus": "active",
1616
"readme": "https://reactivebayes.github.io/RxInfer.jl/stable/",
17-
"softwareVersion": "3.8.4",
17+
"softwareVersion": "3.9.0",
1818
"keywords": [
1919
"Bayesian inference",
2020
"message passing",

scripts/format.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ s = ArgParseSettings()
1010
end
1111

1212
commandline_args = parse_args(s)
13-
folders_to_format = ["scripts", "src", "test"]
13+
folders_to_format = ["scripts", "src", "test", "ext"]
1414

1515
overwrite = commandline_args["overwrite"]
1616
formatted = all(map(folder -> JuliaFormatter.format(folder, overwrite = overwrite, verbose = true), folders_to_format))
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
@testitem "Linear regression with BinomialPolya node" begin
2+
using BenchmarkTools, Plots, Dates, LinearAlgebra, StableRNGs
3+
4+
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
5+
6+
function generate_synthetic_binomial_data(n_samples::Int, true_beta::Vector{Float64}; seed::Int = 42)
7+
rng = StableRNG(seed)
8+
n_features = length(true_beta)
9+
# Generate design matrix X
10+
X = randn(rng, n_samples, n_features)
11+
12+
# Generate number of trials for each observation
13+
n_trials = rand(rng, 5:20, n_samples)
14+
15+
# Compute logits and probabilities
16+
logits = X * true_beta
17+
probs = 1 ./ (1 .+ exp.(-logits))
18+
19+
# Generate binomial outcomes
20+
y = [rand(rng, Binomial(n_trials[i], probs[i])) for i in 1:n_samples]
21+
22+
return X, y, n_trials
23+
end
24+
n_samples = 1000
25+
n_features = 2
26+
true_beta = [-1.0, 0.6]
27+
n_iterations = 100
28+
n_sims = 20
29+
30+
@model function binomial_model(prior_xi, prior_precision, n_trials, X, y)
31+
β ~ MvNormalWeightedMeanPrecision(prior_xi, prior_precision)
32+
for i in eachindex(y)
33+
y[i] ~ BinomialPolya(X[i], n_trials[i], β) where {dependencies = RequireMessageFunctionalDependencies= MvNormalWeightedMeanPrecision(prior_xi, prior_precision))}
34+
end
35+
end
36+
37+
function binomial_inference(binomial_model, iterations, X, y, n_trials, n_features)
38+
return infer(
39+
model = binomial_model(prior_xi = zeros(n_features), prior_precision = diageye(n_features)),
40+
data = (X = X, y = y, n_trials = n_trials),
41+
iterations = iterations,
42+
free_energy = true,
43+
options = (limit_stack_depth = 100,)
44+
)
45+
end
46+
47+
function run_simulation(n_sims::Int, n_samples::Int, true_beta::Vector{Float64}; iterations = n_iterations)
48+
# Storage for results
49+
n_features = length(true_beta)
50+
coverage = Vector{Vector{Float64}}(undef, n_sims)
51+
fes = Vector{Vector{Float64}}(undef, n_sims)
52+
for sim in 1:n_sims
53+
# Generate new dataset
54+
X, y, n_trials = generate_synthetic_binomial_data(n_samples, true_beta, seed = sim)
55+
X = [collect(row) for row in eachrow(X)]
56+
57+
# Run inference
58+
results = binomial_inference(binomial_model, iterations, X, y, n_trials, n_features)
59+
# Extract posterior parameters
60+
post = results.posteriors[][end]
61+
m = mean(post)
62+
v = var(post)
63+
estimates = map((x, y) -> Normal(x, sqrt(y)), m, v)
64+
coverage[sim] = map((d, b) -> cdf(d, b), estimates, true_beta)
65+
fes[sim] = results.free_energy
66+
end
67+
68+
return coverage, fes
69+
end
70+
71+
function in_credible_interval(x, lwr = 0.025, upr = 0.975)
72+
return x >= lwr && x <= upr
73+
end
74+
75+
coverage, fes = run_simulation(n_sims, n_samples, true_beta)
76+
for i in 1:n_sims
77+
@test fes[i][end] < fes[i][1]
78+
end
79+
coverages = Vector{Float64}(undef, n_features)
80+
for i in 1:n_features
81+
coverages[i] = sum(in_credible_interval.(getindex.(coverage, i))) / n_sims
82+
@test coverages[i] >= 0.8
83+
end
84+
85+
@test_benchmark "models" "binomialreg" binomial_inference(binomial_model, $n_iterations, $X, $y, $n_trials)
86+
end

0 commit comments

Comments
 (0)