@@ -1026,3 +1026,46 @@ end
10261026
10271027 @test all (result. predictions[:y ] .== Bernoulli (mean (Beta (1.0 , 1.0 ))))
10281028end
1029+
1030+ @testitem " Test misspecified types in infer function" begin
1031+ @model function rolling_die (n)
1032+ y = datavar (Vector{Float64}, n)
1033+
1034+ θ ~ Dirichlet (ones (6 ))
1035+ for i in 1 : n
1036+ y[i] ~ Categorical (θ)
1037+ end
1038+ end
1039+
1040+ observations = [[1.0 ; zeros (5 )], [zeros (5 ); 1.0 ]]
1041+
1042+ @testset " Test misspecified data" begin
1043+ @test_throws " Keyword argument `data` expects either `Dict` or `NamedTuple` as an input" infer (model = rolling_die (2 ), data = (y = observations))
1044+ result = infer (model = rolling_die (2 ), data = (y = observations,))
1045+ @test isequal (first (mean (result. posteriors[:θ ])), last (mean (result. posteriors[:θ ])))
1046+ end
1047+
1048+ @testset " Test misspecified initmarginals" begin
1049+ @test_throws " Keyword argument `initmarginals` expects either `Dict` or `NamedTuple` as an input" infer (
1050+ model = rolling_die (2 ), data = (y = observations,), initmarginals = (θ = Dirichlet (ones (6 )))
1051+ )
1052+ result = infer (model = rolling_die (2 ), data = (y = observations,), initmarginals = (θ = Dirichlet (ones (6 )),))
1053+ @test isequal (first (mean (result. posteriors[:θ ])), last (mean (result. posteriors[:θ ])))
1054+ end
1055+
1056+ @testset " Test misspecified initmessages" begin
1057+ @test_throws " Keyword argument `initmessages` expects either `Dict` or `NamedTuple` as an input" infer (
1058+ model = rolling_die (2 ), data = (y = observations,), initmessages = (θ = Dirichlet (ones (6 )))
1059+ )
1060+ result = infer (model = rolling_die (2 ), data = (y = observations,), initmessages = (θ = Dirichlet (ones (6 )),))
1061+ @test isequal (first (mean (result. posteriors[:θ ])), last (mean (result. posteriors[:θ ])))
1062+ end
1063+
1064+ @testset " Test misspecified callbacks" begin
1065+ @test_throws " Keyword argument `callbacks` expects either `Dict` or `NamedTuple` as an input" infer (
1066+ model = rolling_die (2 ), data = (y = observations,), callbacks = (before_model_creation = (args... ) -> nothing )
1067+ )
1068+ result = infer (model = rolling_die (2 ), data = (y = observations,), callbacks = (before_model_creation = (args... ) -> nothing ,))
1069+ @test isequal (first (mean (result. posteriors[:θ ])), last (mean (result. posteriors[:θ ])))
1070+ end
1071+ end
0 commit comments