@@ -159,8 +159,12 @@ function activate_rmp_variable!(plugin::ReactiveMPInferencePlugin, model::Model,
159
159
# By default it is `UnspecifiedFormConstraint` which means that the form of the resulting distribution is not specified in advance
160
160
# and follows from the computation, but users may override it with other form constraints, e.g. `PointMassFormConstraint`, which
161
161
# constraints the resulting distribution to be of a point mass form
162
- messages_form_constraint = getextra (nodedata, GraphPPL. VariationalConstraintsMessagesFormConstraintKey, ReactiveMP. UnspecifiedFormConstraint ())
163
- marginal_form_constraint = getextra (nodedata, GraphPPL. VariationalConstraintsMarginalFormConstraintKey, ReactiveMP. UnspecifiedFormConstraint ())
162
+ messages_form_constraint = ReactiveMP. preprocess_form_constraints (
163
+ plugin, model, getextra (nodedata, GraphPPL. VariationalConstraintsMessagesFormConstraintKey, ReactiveMP. UnspecifiedFormConstraint ())
164
+ )
165
+ marginal_form_constraint = ReactiveMP. preprocess_form_constraints (
166
+ plugin, model, getextra (nodedata, GraphPPL. VariationalConstraintsMarginalFormConstraintKey, ReactiveMP. UnspecifiedFormConstraint ())
167
+ )
164
168
# Fetch "prod-constraint" for messages and marginals. The prod-constraint usually defines the constraints for a single product of messages
165
169
# It can for example preserve a specific parametrization of distribution
166
170
messages_prod_constraint = getextra (nodedata, :messages_prod_constraint , ReactiveMP. default_prod_constraint (messages_form_constraint))
@@ -301,3 +305,11 @@ ReactiveMP.setmarginals!(collection::AbstractArray{GraphVariableRef}, marginal)
301
305
302
306
ReactiveMP. setmessage! (ref:: GraphVariableRef , marginal) = setmessage! (ref. variable, marginal)
303
307
ReactiveMP. setmessages! (collection:: AbstractArray{GraphVariableRef} , marginal) = ReactiveMP. setmessages! (map (ref -> ref. variable, collection), marginal)
308
+
309
+ # Form constraint preprocessing
310
+
311
+ function ReactiveMP. preprocess_form_constraints (backend:: ReactiveMPInferencePlugin , model:: Model , constraints)
312
+ # It is a simple pass-through for now, but can be extended in the future to preprocess constraints that
313
+ # are defined in other packages, e.g. in `Distributions` and to support constraints, such as `q(x) :: Normal`
314
+ return ReactiveMP. preprocess_form_constraints (constraints)
315
+ end
0 commit comments