Skip to content

Commit 839400a

Browse files
authored
Merge pull request #325 from ReactiveBayes/dev-ef-projection
Initial integration with ExponentialFamilyProjection
2 parents 34afa89 + b3c1ef6 commit 839400a

File tree

14 files changed

+1022
-43
lines changed

14 files changed

+1022
-43
lines changed

Project.toml

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RxInfer"
22
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
33
authors = ["Bagaev Dmitry <d.v.bagaev@tue.nl> and contributors"]
4-
version = "3.4.0"
4+
version = "3.5.0"
55

66
[deps]
77
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
@@ -21,20 +21,27 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2121
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
2222
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
2323

24+
[weakdeps]
25+
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
26+
27+
[extensions]
28+
ProjectionExt = "ExponentialFamilyProjection"
29+
2430
[compat]
2531
BayesBase = "1.1"
2632
DataStructures = "0.18"
2733
Distributions = "0.25"
2834
DomainSets = "0.5.2, 0.6, 0.7"
29-
ExponentialFamily = "1.2"
35+
ExponentialFamily = "1.5"
36+
ExponentialFamilyProjection = "1.1"
3037
FastCholesky = "1.3.0"
3138
GraphPPL = "~4.3.0"
3239
LinearAlgebra = "1.9"
3340
MacroTools = "0.5.6"
3441
Optim = "1.0.0"
3542
ProgressMeter = "1.0.0"
3643
Random = "1.9"
37-
ReactiveMP = "~4.2.0"
44+
ReactiveMP = "~4.3.0"
3845
Reexport = "1.2.0"
3946
Rocket = "1.8.0"
4047
TupleTools = "1.2.0"
@@ -49,6 +56,7 @@ CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
4956
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
5057
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
5158
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
59+
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
5260
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
5361
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
5462
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
@@ -62,4 +70,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6270
TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04"
6371

6472
[targets]
65-
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"]
73+
test = ["Test", "Pkg", "Logging", "InteractiveUtils", "TestSetExtensions", "Coverage", "CpuId", "Dates", "Distributed", "Documenter", "ExponentialFamilyProjection", "Plots", "BenchmarkCI", "BenchmarkTools", "PkgBenchmark", "Aqua", "StableRNGs", "StatsFuns", "Optimisers", "ReTestItems"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ Our high-level project roadmap outlines the key milestones and focus areas for t
167167
| Q1/Q2 2024 | Q3/Q4 2024 | 2025 |
168168
|---------------------|---------------------------|--------------------|
169169
| 🧩 **Nested models with [GraphPPL.jl](https://github.com/reactivebayes/GraphPPL.jl)**| 🌐 **Graph structure visualization** | 🔀 **Stochastic Processes** |
170-
| 🔄 **Development of [ExponentialFamilyProjection.jl]()** | 🧠 **Automated inference with [ExponentialFamilyProjection.jl](https://github.com/reactivebayes/ExponentialFamilyProjection.jl)** | 🚀 **Robustness & Memory-efficiency** |
170+
| 🔄 **Development of [ExponentialFamilyProjection.jl]()** | 🧠 **Automated inference with [ExponentialFamilyProjection.jl](https://github.com/reactivebayes/ExponentialFamilyProjection.jl)** | 🚀 **Robustness & Memory-efficiency** |
171171

172172
For a more granular view of our progress and ongoing tasks, check out our [project board](https://github.com/orgs/reactivebayes/projects/2/views/4) or join our 4-weekly [public meetings](https://dynalist.io/d/F4aA-Z2c8X-M1iWTn9hY_ndN).
173173

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
44
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
55
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
ExponentialFamily = "62312e5e-252a-4322-ace9-a5f4bf9b357b"
7+
ExponentialFamilyProjection = "17f509fa-9a96-44ba-99b2-1c5f01f0931b"
78
GraphPPL = "b3f8163a-e979-4e85-b43e-1f63d8c8b42c"
89
GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
910
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
@@ -12,6 +13,7 @@ ReactiveMP = "a194aa59-28ba-4574-a09c-4a745416d6e3"
1213
Rocket = "df971d30-c9d6-4b37-b8ff-e965b2cb3a40"
1314
RxInfer = "86711068-29c9-4ff7-b620-ae75d7495b3d"
1415
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
16+
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
1517

1618
[compat]
1719
Documenter = "1.0.0"

docs/make.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ foreach(vcat(ExamplesOverviewPath, ExamplesCategoriesOverviewPaths)) do path
6969
@warn "`$(path)` does not exist. Generating an empty overview. Use the `make examples` command to generate the overview and all examples."
7070
mkpath(dirname(path))
7171
open(path, "w") do f
72-
write(f, "The overview is missing. Use the `make examples` command to generate the overview and all examples.")
72+
write(f, """
73+
$(isequal(path, ExamplesOverviewPath) ? "# [Examples overview](@id examples-overview)" : "")
74+
The overview is missing. Use the `make examples` command to generate the overview and all examples.
75+
""")
7376
end
7477
end
7578
end
@@ -108,7 +111,9 @@ makedocs(;
108111
"Streamline inference" => "manuals/inference/streamlined.md",
109112
"Initialization" => "manuals/inference/initialization.md",
110113
"Auto-updates" => "manuals/inference/autoupdates.md",
111-
"Deterministic nodes" => "manuals/inference/delta-node.md"
114+
"Deterministic nodes" => "manuals/inference/delta-node.md",
115+
"Non-conjugate inference" => "manuals/inference/nonconjugate.md",
116+
"Undefined message update rules" => "manuals/inference/undefinedrules.md"
112117
],
113118
"Inference customization" => [
114119
"Defining a custom node and rules" => "manuals/customization/custom-node.md",

docs/src/manuals/customization/custom-node.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Welcome to the `RxInfer` documentation on creating custom factor graph nodes. In `RxInfer`, factor nodes represent functional relationships between variables, also known as factors. Together, these factors define your probabilistic model. Quite often these factors represent distributions, denoting how a certain parameter affects another. However, other factors are also possible, such as ones specifying linear or non-linear relationships. `RxInfer` already supports a lot of factor nodes, however, depending on the problem that you are trying to solve, you may need to create a custom node that better fits the specific requirements of your model. This tutorial will guide you through the process of defining a custom node in `RxInfer`, step by step. By the end of this tutorial, you will be able to create your own custom node and integrate it into your model.
44

5+
In addition, read another section on a different way of running inference with custom stochastic nodes without explicit rule specification [here](@ref inference-undefinedrules).
6+
57
---
68

79
To create a custom node in `RxInfer`, 4 steps are required:

docs/src/manuals/getting-started.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,4 @@ result.posteriors[:θ]
305305

306306
## Where to go next?
307307

308-
There are a set of [examples](@ref examples-overview) available in `RxInfer` repository that demonstrate the more advanced features of the package for various problems. Alternatively, you can head to the [Model specification](@ref user-guide-model-specification) which provides more detailed information of how to use `RxInfer` to specify probabilistic models. [Inference execution](@ref user-guide-inference-execution) section provides a documentation about `RxInfer` API for running reactive Bayesian inference. Also read the [Comparison](@ref comparison) to compare `RxInfer` with other probabilistic programming libraries.
308+
There are a set of [examples](@ref examples-overview) available in `RxInfer` repository that demonstrate the more advanced features of the package for various problems. Alternatively, you can head to the [Model specification](@ref user-guide-model-specification) which provides more detailed information of how to use `RxInfer` to specify probabilistic models. [Inference execution](@ref user-guide-inference-execution) section provides a documentation about `RxInfer` API for running reactive Bayesian inference. Also read the [Comparison](@ref comparison) to compare `RxInfer` with other probabilistic programming libraries. For advances use cases refer to the [Non-conjugate inference](@ref inference-nonconjugate) tutorial and inference [without defining the message update rules explicitly](@ref inference-undefinedrules).

docs/src/manuals/inference/delta-node.md

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,20 @@ RxInfer.jl offers a comprehensive set of stochastic nodes, primarily emphasizing
66

77
The delta node supports several approximation methods for probabilistic inference. The desired approximation method depends on the nodes connected to the delta node. We differentiate the following deterministic transformation scenarios:
88

9-
1. **Gaussian Nodes**: For delta nodes linked to strictly multivariate or univariate Gaussian distributions, the recommended methods are Linearization or Unscented transforms.
10-
2. **Exponential Family Nodes**: For the delta node connected to nodes from the exponential family, the CVI (Conjugate Variational Inference) is the method of choice.
11-
3. **Stacking Delta Nodes**: For scenarios where delta nodes are stacked, either Linearization or Unscented transforms are suitable.
9+
1. **Gaussian Nodes**: For delta nodes linked to strictly multivariate or univariate Gaussian distributions, the recommended methods are `Linearization` or `Unscented` transforms.
10+
2. **Exponential Family Nodes**: For the delta node connected to nodes from the exponential family, the `CVIProjection` (Conjugate Variational Inference) is the method of choice.
11+
3. **Stacking Delta Nodes**: For scenarios where delta nodes are stacked, either `Linearization`, `Unscented` or `CVIProjection` are suitable.
12+
4. **Support for Inverse Functions**: For scenarious, where an inverse function is available
1213

1314
The table below summarizes the features of the delta node in RxInfer.jl, categorized by the approximation method:
1415

15-
| Methods | Gaussian Nodes | Exponential Family Nodes | Stacking Delta Nodes
16-
|---------------|----------------|--------------------------|----------------------
17-
| Linearization | ✓ | ✗ | ✓
18-
| Unscented | ✓ | ✗ | ✓
19-
| CVI | ✓ | ✓ | ✗
16+
| Methods | Gaussian Nodes | Exponential Family Nodes | Stacking Delta Nodes | Inverse functions
17+
|------------------|----------------|--------------------------|----------------------|----------------------
18+
| Linearization | ✓ | ✗ | ✓ | ✓
19+
| Unscented | ✓ | ✗ | ✓ | ✓
20+
| CVI (deprecated) | ✓ | ✓ | ✗ | ✗
21+
| CVI Projection | ✓ | ✓ | ✓ | ✗
22+
2023

2124
## Gaussian Case
2225

@@ -29,12 +32,15 @@ For clarity, consider the following example:
2932
using RxInfer
3033
3134
@model function delta_node_example(z)
32-
x ~ Normal(mean=0.0, var=1.0)
35+
x ~ Normal(mean = 0.0, var = 1.0)
3336
y := tanh(x)
34-
z ~ Normal(mean=y, var=1.0)
37+
z ~ Normal(mean = y, var = 1.0)
3538
end
3639
```
3740

41+
!!! note
42+
While not strictly required, it is advised to use `:=` to define a deterministic relationship within the `@model` macro.
43+
3844
To perform inference on this model, designate the approximation method for the delta node (here, the `tanh` function) using the `@meta` specification:
3945

4046
```@example delta_node_example
@@ -62,21 +68,25 @@ end
6268
To execute the inference procedure:
6369

6470
```@example delta_node_example
65-
infer(model = delta_node_example(), meta=delta_meta, data = (z = 1.0,))
71+
result = infer(
72+
model = delta_node_example(),
73+
meta = delta_meta,
74+
data = (z = 1.0,)
75+
)
6676
```
6777

68-
This methodology is consistent even when the delta node is associated with multiple nodes. For instance:
78+
This methodology is consistent even when the delta node is associated with multiple inputs. For instance:
6979

7080
```@example delta_node_example
7181
f(x, g) = x*tanh(g)
7282
```
7383

7484
```@example delta_node_example
7585
@model function delta_node_example(z)
76-
x ~ Normal(mean=1.0, var=1.0)
77-
g ~ Normal(mean=1.0, var=1.0)
86+
x ~ Normal(mean = 1.0, var = 1.0)
87+
g ~ Normal(mean = 1.0, var = 1.0)
7888
y := f(x, g)
79-
z ~ Normal(mean=y, var=0.1)
89+
z ~ Normal(mean = y, var = 0.1)
8090
end
8191
```
8292

@@ -112,11 +122,14 @@ end
112122

113123
When the delta node is associated with nodes from the exponential family (excluding Gaussians), the `Linearization` and `Unscented` methods are not applicable. In such cases, the CVI (Conjugate Variational Inference) is available. Here's a modified example:
114124

125+
!!! note
126+
The `CVIProjection` method is available only if `ExponentialFamilyProjection` package is installed in the current environment.
127+
115128
```@example delta_node_example_cvi
116-
using RxInfer
129+
using RxInfer, ExponentialFamilyProjection
117130
118131
@model function delta_node_example1(z)
119-
x ~ Gamma(shape=1.0, rate=1.0)
132+
x ~ Gamma(shape = 1.0, rate = 1.0)
120133
y := tanh(x)
121134
z ~ Bernoulli(y)
122135
end
@@ -125,12 +138,16 @@ end
125138
The corresponding meta specification can be represented as:
126139

127140
```@example delta_node_example_cvi
128-
using StableRNGs
129-
using Optimisers
130-
131141
delta_meta = @meta begin
132-
tanh() -> DeltaMeta(method = CVI(StableRNG(42), 100, 100, Optimisers.Descent(0.01)))
142+
tanh() -> CVIProjection()
133143
end
134144
```
135145

136-
Consult the `ProdCVI` docstrings for a detailed explanation of these parameters.
146+
Consult the `CVIProjection` docstrings for a detailed explanation of its hyper-parameters. Additionally, read the [Non-conjugate Inference](@ref inference-nonconjugate) section.
147+
148+
!!! note
149+
The `CVIProjection` method is an improved version of the now-deprecated `CVI` method. This new implementation features different hyperparameters, better accuracy, and improved stability.
150+
151+
## Fuse deterministic nodes with stochastic nodes
152+
153+
Read how to circumvent the need to define the meta structure and, instead, fuse the deterministic relation with a neighboring stochastic factor node in [this section](@ref inference-undefinedrules-fusedelta).

0 commit comments

Comments
 (0)