Skip to content

Commit 1dbd503

Browse files
authored
Merge pull request #347 from ReactiveBayes/dev-fix-335
Show variable name and suggestions if the resulting functional form is not supported by the inference backend
2 parents 244592f + 08b93ab commit 1dbd503

File tree

6 files changed

+139
-7
lines changed

6 files changed

+139
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ MacroTools = "0.5.6"
4141
Optim = "1.0.0"
4242
ProgressMeter = "1.0.0"
4343
Random = "1.9"
44-
ReactiveMP = "~4.3.0"
44+
ReactiveMP = "~4.4.0"
4545
Reexport = "1.2.0"
4646
Rocket = "1.8.0"
4747
TupleTools = "1.2.0"

src/RxInfer.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ include("model/plugins/reactivemp_free_energy.jl")
1616
include("model/plugins/initialization_plugin.jl")
1717
include("model/graphppl.jl")
1818

19+
include("constraints/form/form_ensure_supported.jl")
1920
include("constraints/form/form_fixed_marginal.jl")
2021
include("constraints/form/form_point_mass.jl")
2122
include("constraints/form/form_sample_list.jl")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import ReactiveMP: AbstractFormConstraint
2+
3+
# This is an internal functional form constraint that only checks that the result
4+
# is of a supported form. Displays a user-friendly error message if the form is not supported.
5+
struct EnsureSupportedFunctionalForm <: AbstractFormConstraint
6+
prefix::Symbol
7+
name::Symbol
8+
index::Any
9+
end
10+
11+
ReactiveMP.default_form_check_strategy(::EnsureSupportedFunctionalForm) = FormConstraintCheckLast()
12+
13+
ReactiveMP.default_prod_constraint(::EnsureSupportedFunctionalForm) = GenericProd()
14+
15+
function ReactiveMP.constrain_form(constraint::EnsureSupportedFunctionalForm, something)
16+
if typeof(something) <: ProductOf || typeof(something) <: LinearizedProductOf
17+
expr = string(constraint.prefix, '(', constraint.name, isnothing(constraint.index) ? "" : string('[', constraint.index, ']'), ')')
18+
expr_noindex = string(constraint.prefix, '(', constraint.name, ')')
19+
error(lazy"""
20+
The expression `$expr` has an undefined functional form of type `$(typeof(something))`.
21+
This is likely because the inference backend does not support the product of these distributions.
22+
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `$expr`.
23+
24+
Possible solutions:
25+
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
26+
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
27+
```julia
28+
using ExponentialFamilyProjection
29+
30+
@constraints begin
31+
$(expr_noindex) :: ProjectedTo(NormalMeanVariance)
32+
end
33+
```
34+
Refer to the documentation for more details on functional form constraints.
35+
""")
36+
end
37+
return something
38+
end

src/model/plugins/reactivemp_inference.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,12 +167,14 @@ function activate_rmp_variable!(plugin::ReactiveMPInferencePlugin, model::Model,
167167
# By default it is `UnspecifiedFormConstraint` which means that the form of the resulting distribution is not specified in advance
168168
# and follows from the computation, but users may override it with other form constraints, e.g. `PointMassFormConstraint`, which
169169
# constraints the resulting distribution to be of a point mass form
170-
messages_form_constraint = ReactiveMP.preprocess_form_constraints(
171-
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
172-
)
173-
marginal_form_constraint = ReactiveMP.preprocess_form_constraints(
174-
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
175-
)
170+
messages_form_constraint =
171+
ReactiveMP.preprocess_form_constraints(
172+
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMessagesFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
173+
) + EnsureSupportedFunctionalForm(, GraphPPL.getname(nodeproperties), GraphPPL.index(nodeproperties))
174+
marginal_form_constraint =
175+
ReactiveMP.preprocess_form_constraints(
176+
plugin, model, getextra(nodedata, GraphPPL.VariationalConstraintsMarginalFormConstraintKey, ReactiveMP.UnspecifiedFormConstraint())
177+
) + EnsureSupportedFunctionalForm(:q, GraphPPL.getname(nodeproperties), GraphPPL.index(nodeproperties))
176178
# Fetch "prod-constraint" for messages and marginals. The prod-constraint usually defines the constraints for a single product of messages
177179
# It can for example preserve a specific parametrization of distribution
178180
messages_prod_constraint = getextra(nodedata, :messages_prod_constraint, ReactiveMP.default_prod_constraint(messages_form_constraint))
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
@testitem "Tests for `EnsureSupportedFunctionalForm" begin
2+
import RxInfer: EnsureSupportedFunctionalForm
3+
import ReactiveMP: default_form_check_strategy, default_prod_constraint, constrain_form
4+
import BayesBase: PointMass, ProductOf, LinearizedProductOf
5+
6+
# In principle any object is supported except `ProductOf` and `LinearizedProductOf` from `BayesBase`
7+
# Those are supposed to be passed to the functional form constraint
8+
9+
for prefix in (:q, ), index in (nothing, (1,)), name in (:a, :b)
10+
@test default_form_check_strategy(EnsureSupportedFunctionalForm(prefix, name, index)) === FormConstraintCheckLast()
11+
@test default_prod_constraint(EnsureSupportedFunctionalForm(prefix, name, index)) === GenericProd()
12+
13+
@testset let constraint = EnsureSupportedFunctionalForm(prefix, name, index)
14+
@test constrain_form(constraint, PointMass(1)) === PointMass(1)
15+
@test_throws Exception constrain_form(constraint, ProductOf(PointMass(1), PointMass(2)))
16+
@test_throws Exception constrain_form(constraint, LinearizedProductOf([PointMass(1), PointMass(2)], 2))
17+
end
18+
end
19+
end

test/inference/inference_tests.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,3 +782,75 @@ end
782782
model = beta_bernoulli(), data = (y = 1,), initmessages = (t = Normal(0.0, 1.0))
783783
)
784784
end
785+
786+
@testitem "Unsupported functional forms (e.g. `ProductOf`) should display the name of the variable and suggestions" begin
787+
struct DistributionA
788+
a
789+
end
790+
struct DistributionB
791+
b
792+
end
793+
struct LikelihoodDistribution
794+
input
795+
end
796+
797+
@node DistributionA Stochastic [out, a]
798+
@node DistributionB Stochastic [out, b]
799+
@node LikelihoodDistribution Stochastic [out, input]
800+
801+
@rule DistributionA(:out, Marginalisation) (q_a::Any,) = DistributionA(mean(q_a))
802+
@rule DistributionB(:out, Marginalisation) (q_b::Any,) = DistributionB(mean(q_b))
803+
@rule LikelihoodDistribution(:input, Marginalisation) (q_out::Any,) = LikelihoodDistribution(mean(q_out))
804+
805+
@model function invalid_product_posterior(out)
806+
θ ~ DistributionA(1.0)
807+
out ~ LikelihoodDistribution(θ)
808+
end
809+
810+
# Product of `DistributionA` & `LikelihoodDistribution` in the posterior
811+
P = typeof(prod(GenericProd(), DistributionA(1.0), LikelihoodDistribution(1.0)))
812+
@test_throws """
813+
The expression `q(θ)` has an undefined functional form of type `$(P)`.
814+
This is likely because the inference backend does not support the product of these distributions.
815+
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `q(θ)`.
816+
817+
Possible solutions:
818+
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
819+
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
820+
```julia
821+
using ExponentialFamilyProjection
822+
823+
@constraints begin
824+
q(θ) :: ProjectedTo(NormalMeanVariance)
825+
end
826+
```
827+
Refer to the documentation for more details on functional form constraints.
828+
""" result = infer(model = invalid_product_posterior(), data = (out = 1.0,))
829+
830+
# Product of `DistributionA` & `DistributionB` in the message
831+
@model function invalid_product_message(out)
832+
input[1] ~ DistributionA(1.0)
833+
input[1] ~ DistributionB(1.0)
834+
θ ~ DistributionA(input[1])
835+
out ~ LikelihoodDistribution(θ)
836+
end
837+
838+
T = typeof(prod(GenericProd(), DistributionA(1.0), DistributionB(1.0)))
839+
@test_throws """
840+
The expression `μ(input[1])` has an undefined functional form of type `$(T)`.
841+
This is likely because the inference backend does not support the product of these distributions.
842+
As a result, `RxInfer` cannot compute key quantities such as the `mean` or `var` of `μ(input[1])`.
843+
844+
Possible solutions:
845+
- Implement the `BayesBase.prod` method (refer to the `BayesBase` documentation for guidance).
846+
- Use a functional form constraint to specify the posterior form with the `@constraints` macro. For example:
847+
```julia
848+
using ExponentialFamilyProjection
849+
850+
@constraints begin
851+
μ(input) :: ProjectedTo(NormalMeanVariance)
852+
end
853+
```
854+
Refer to the documentation for more details on functional form constraints.
855+
""" result = infer(model = invalid_product_message(), data = (out = 1.0,), returnvars == KeepEach(),))
856+
end

0 commit comments

Comments
 (0)