Skip to content

Commit 495baa7

Browse files
authored
Merge pull request #228 from ReactiveBayes/fix_infer_dicttype
Fix missing check for infer function
2 parents 23f95bb + 9cd0d84 commit 495baa7

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

src/inference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,6 +1799,7 @@ function infer(;
17991799
__infer_check_dicttype(:initmarginals, initmarginals)
18001800
__infer_check_dicttype(:initmessages, initmessages)
18011801
__infer_check_dicttype(:callbacks, callbacks)
1802+
__infer_check_dicttype(:data, data)
18021803

18031804
if isnothing(autoupdates)
18041805
__check_available_callbacks(warn, callbacks, available_callbacks(__inference))

test/inference_test.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,3 +1026,46 @@ end
10261026

10271027
@test all(result.predictions[:y] .== Bernoulli(mean(Beta(1.0, 1.0))))
10281028
end
1029+
1030+
@testitem "Test misspecified types in infer function" begin
1031+
@model function rolling_die(n)
1032+
y = datavar(Vector{Float64}, n)
1033+
1034+
θ ~ Dirichlet(ones(6))
1035+
for i in 1:n
1036+
y[i] ~ Categorical(θ)
1037+
end
1038+
end
1039+
1040+
observations = [[1.0; zeros(5)], [zeros(5); 1.0]]
1041+
1042+
@testset "Test misspecified data" begin
1043+
@test_throws "Keyword argument `data` expects either `Dict` or `NamedTuple` as an input" infer(model = rolling_die(2), data = (y = observations))
1044+
result = infer(model = rolling_die(2), data = (y = observations,))
1045+
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
1046+
end
1047+
1048+
@testset "Test misspecified initmarginals" begin
1049+
@test_throws "Keyword argument `initmarginals` expects either `Dict` or `NamedTuple` as an input" infer(
1050+
model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)))
1051+
)
1052+
result = infer(model = rolling_die(2), data = (y = observations,), initmarginals == Dirichlet(ones(6)),))
1053+
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
1054+
end
1055+
1056+
@testset "Test misspecified initmessages" begin
1057+
@test_throws "Keyword argument `initmessages` expects either `Dict` or `NamedTuple` as an input" infer(
1058+
model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)))
1059+
)
1060+
result = infer(model = rolling_die(2), data = (y = observations,), initmessages == Dirichlet(ones(6)),))
1061+
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
1062+
end
1063+
1064+
@testset "Test misspecified callbacks" begin
1065+
@test_throws "Keyword argument `callbacks` expects either `Dict` or `NamedTuple` as an input" infer(
1066+
model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing)
1067+
)
1068+
result = infer(model = rolling_die(2), data = (y = observations,), callbacks = (before_model_creation = (args...) -> nothing,))
1069+
@test isequal(first(mean(result.posteriors[])), last(mean(result.posteriors[])))
1070+
end
1071+
end

0 commit comments

Comments
 (0)