|
| 1 | +@testitem "Linear regression with BinomialPolya node" begin |
| 2 | + using BenchmarkTools, Plots, Dates, LinearAlgebra, StableRNGs |
| 3 | + |
| 4 | + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) |
| 5 | + |
| 6 | + function generate_synthetic_binomial_data(n_samples::Int, true_beta::Vector{Float64}; seed::Int = 42) |
| 7 | + rng = StableRNG(seed) |
| 8 | + n_features = length(true_beta) |
| 9 | + # Generate design matrix X |
| 10 | + X = randn(rng, n_samples, n_features) |
| 11 | + |
| 12 | + # Generate number of trials for each observation |
| 13 | + n_trials = rand(rng, 5:20, n_samples) |
| 14 | + |
| 15 | + # Compute logits and probabilities |
| 16 | + logits = X * true_beta |
| 17 | + probs = 1 ./ (1 .+ exp.(-logits)) |
| 18 | + |
| 19 | + # Generate binomial outcomes |
| 20 | + y = [rand(rng, Binomial(n_trials[i], probs[i])) for i in 1:n_samples] |
| 21 | + |
| 22 | + return X, y, n_trials |
| 23 | + end |
| 24 | + n_samples = 1000 |
| 25 | + n_features = 2 |
| 26 | + true_beta = [-1.0, 0.6] |
| 27 | + n_iterations = 100 |
| 28 | + n_sims = 20 |
| 29 | + |
| 30 | + @model function binomial_model(prior_xi, prior_precision, n_trials, X, y) |
| 31 | + β ~ MvNormalWeightedMeanPrecision(prior_xi, prior_precision) |
| 32 | + for i in eachindex(y) |
| 33 | + y[i] ~ BinomialPolya(X[i], n_trials[i], β) where {dependencies = RequireMessageFunctionalDependencies(β = MvNormalWeightedMeanPrecision(prior_xi, prior_precision))} |
| 34 | + end |
| 35 | + end |
| 36 | + |
| 37 | + function binomial_inference(binomial_model, iterations, X, y, n_trials, n_features) |
| 38 | + return infer( |
| 39 | + model = binomial_model(prior_xi = zeros(n_features), prior_precision = diageye(n_features)), |
| 40 | + data = (X = X, y = y, n_trials = n_trials), |
| 41 | + iterations = iterations, |
| 42 | + free_energy = true, |
| 43 | + options = (limit_stack_depth = 100,) |
| 44 | + ) |
| 45 | + end |
| 46 | + |
| 47 | + function run_simulation(n_sims::Int, n_samples::Int, true_beta::Vector{Float64}; iterations = n_iterations) |
| 48 | + # Storage for results |
| 49 | + n_features = length(true_beta) |
| 50 | + coverage = Vector{Vector{Float64}}(undef, n_sims) |
| 51 | + fes = Vector{Vector{Float64}}(undef, n_sims) |
| 52 | + for sim in 1:n_sims |
| 53 | + # Generate new dataset |
| 54 | + X, y, n_trials = generate_synthetic_binomial_data(n_samples, true_beta, seed = sim) |
| 55 | + X = [collect(row) for row in eachrow(X)] |
| 56 | + |
| 57 | + # Run inference |
| 58 | + results = binomial_inference(binomial_model, iterations, X, y, n_trials, n_features) |
| 59 | + # Extract posterior parameters |
| 60 | + post = results.posteriors[:β][end] |
| 61 | + m = mean(post) |
| 62 | + v = var(post) |
| 63 | + estimates = map((x, y) -> Normal(x, sqrt(y)), m, v) |
| 64 | + coverage[sim] = map((d, b) -> cdf(d, b), estimates, true_beta) |
| 65 | + fes[sim] = results.free_energy |
| 66 | + end |
| 67 | + |
| 68 | + return coverage, fes |
| 69 | + end |
| 70 | + |
| 71 | + function in_credible_interval(x, lwr = 0.025, upr = 0.975) |
| 72 | + return x >= lwr && x <= upr |
| 73 | + end |
| 74 | + |
| 75 | + coverage, fes = run_simulation(n_sims, n_samples, true_beta) |
| 76 | + for i in 1:n_sims |
| 77 | + @test fes[i][end] < fes[i][1] |
| 78 | + end |
| 79 | + coverages = Vector{Float64}(undef, n_features) |
| 80 | + for i in 1:n_features |
| 81 | + coverages[i] = sum(in_credible_interval.(getindex.(coverage, i))) / n_sims |
| 82 | + @test coverages[i] >= 0.8 |
| 83 | + end |
| 84 | + |
| 85 | + @test_benchmark "models" "binomialreg" binomial_inference(binomial_model, $n_iterations, $X, $y, $n_trials) |
| 86 | +end |
0 commit comments