Skip to content

Commit 9facd12

Browse files
authored
Merge pull request #354 from ReactiveBayes/dev-eqnodes
Add MultiAgentTrajectoryPlanning to tests
2 parents 55dda63 + ddc26a1 commit 9facd12

File tree

4 files changed

+127
-6
lines changed

4 files changed

+127
-6
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "RxInfer"
22
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
33
authors = ["Bagaev Dmitry <d.v.bagaev@tue.nl> and contributors"]
4-
version = "3.6.0"
4+
version = "3.6.1"
55

66
[deps]
77
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
@@ -41,7 +41,7 @@ MacroTools = "0.5.6"
4141
Optim = "1.0.0"
4242
ProgressMeter = "1.0.0"
4343
Random = "1.9"
44-
ReactiveMP = "~4.4.0"
44+
ReactiveMP = "~4.4.1"
4545
Reexport = "1.2.0"
4646
Rocket = "1.8.0"
4747
TupleTools = "1.2.0"

codemeta.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
"downloadUrl": "https://github.com/reactivebayes/RxInfer.jl/releases",
1010
"issueTracker": "https://github.com/reactivebayes/RxInfer.jl/issues",
1111
"name": "RxInfer.jl",
12-
"version": "3.6.0",
12+
"version": "3.6.1",
1313
"description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ",
1414
"applicationCategory": "Statistics",
1515
"developmentStatus": "active",
1616
"readme": "https://reactivebayes.github.io/RxInfer.jl/stable/",
17-
"softwareVersion": "3.6.0",
17+
"softwareVersion": "3.6.1",
1818
"keywords": [
1919
"Bayesian inference",
2020
"message passing",

test/inference/inference_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ end
808808
end
809809

810810
# Product of `DistributionA` & `LikelihoodDistribution` in the posterior
811-
P = typeof(prod(GenericProd(), DistributionA(1.0), LikelihoodDistribution(1.0)))
811+
P = typeof(prod(GenericProd(), DistributionA(1.0), LikelihoodDistribution(1.0))) # the actual order may change though
812812
@test_throws """
813813
The expression `q(θ)` has an undefined functional form of type `$(P)`.
814814
This is likely because the inference backend does not support the product of these distributions.
@@ -835,7 +835,7 @@ end
835835
out ~ LikelihoodDistribution(θ)
836836
end
837837

838-
T = typeof(prod(GenericProd(), DistributionA(1.0), DistributionB(1.0)))
838+
T = typeof(prod(GenericProd(), DistributionB(1.0), DistributionA(1.0))) # the actual order may change though
839839
@test_throws """
840840
The expression `μ(input[1])` has an undefined functional form of type `$(T)`.
841841
This is likely because the inference backend does not support the product of these distributions.
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
2+
@testitem "MultiAgentTrajectoryPlanning model should terminate and give results" begin
3+
# https://github.com/biaslab/MultiAgentTrajectoryPlanning/issues/4
4+
using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs
5+
6+
# `include(test/utiltests.jl)`
7+
include(joinpath(@__DIR__, "..", "..", "utiltests.jl"))
8+
9+
# half space specification
10+
struct Halfspace end
11+
@node Halfspace Stochastic [out, a, σ2, γ]
12+
13+
# rule specification
14+
@rule Halfspace(:out, Marginalisation) (q_a::PointMass, q_σ2::PointMass, q_γ::PointMass) = begin
15+
return NormalMeanVariance(mean(q_a) + mean(q_γ) * mean(q_σ2), mean(q_σ2))
16+
end
17+
18+
struct ForcePointMass{V}
19+
v::V
20+
end
21+
22+
@rule Halfspace(:σ2, Marginalisation) (q_out::UnivariateNormalDistributionsFamily, q_a::PointMass, q_γ::PointMass) = begin
23+
return ForcePointMass(1 / mean(q_γ) * sqrt(abs2(mean(q_out) - mean(q_a)) + var(q_out)))
24+
end
25+
26+
BayesBase.prod(::GenericProd, p::ForcePointMass, any) = PointMass(p.v)
27+
BayesBase.prod(::GenericProd, any, p::ForcePointMass) = PointMass(p.v)
28+
29+
ReactiveMP.to_marginal(p::ForcePointMass) = PointMass(p.v)
30+
31+
function h(y1, y2)
32+
r1 = 15
33+
r2 = 15
34+
return norm(y1 - y2) - r1 - r2
35+
end
36+
37+
@model function switching_model(nr_steps, γ, ΔT, goals)
38+
39+
# transition model
40+
A = [1 ΔT 0 0; 0 1 0 0; 0 0 1 ΔT; 0 0 0 1]
41+
B = [0 0; ΔT 0; 0 0; 0 ΔT]
42+
C = [1 0 0 0; 0 0 1 0]
43+
44+
local y
45+
46+
# single agent models
47+
for k in 1:2
48+
49+
# prior on state
50+
x[k, 1] ~ MvNormalMeanCovariance(zeros(4), 1e2I)
51+
52+
for t in 1:nr_steps
53+
54+
# prior on controls
55+
u[k, t] ~ MvNormalMeanCovariance(zeros(2), 1e-2I)
56+
57+
# state transition
58+
x[k, t + 1] ~ A * x[k, t] + B * u[k, t]
59+
60+
# observation model
61+
y[k, t] ~ C * x[k, t + 1]
62+
end
63+
64+
# goal priors (indexing reverse due to definition)
65+
goals[1, k] ~ MvNormalMeanCovariance(x[k, 1], 1e-5I)
66+
goals[2, k] ~ MvNormalMeanCovariance(x[k, nr_steps + 1], 1e-5I)
67+
end
68+
69+
# multi-agent models
70+
for t in 1:nr_steps
71+
72+
# observation constraint
73+
σ2[t] ~ GammaShapeRate(3 / 2, γ^2 / 2)
74+
d[t] ~ h(y[1, t], y[2, t])
75+
d[t] ~ Halfspace(0, σ2[t], γ)
76+
end
77+
end
78+
79+
@constraints function switching_constraints()
80+
q(d, σ2) = q(d)q(σ2)
81+
end
82+
83+
@meta function switching_meta()
84+
h() -> Linearization()
85+
end
86+
87+
goals = hcat([
88+
# agent 1: start at (0,0) with 0 velocity, end at (0, 50) with 0 velocity
89+
[[0, 0, 0, 0], [0, 0, 50, 0]],
90+
# agent 2: start at (0,50) with 0 velocity, end at (0, 0) with 0 velocity
91+
[[0, 0, 50, 0], [0, 0, 0, 0]]
92+
]...)
93+
94+
@initialization function switching_initialization_1()
95+
q(σ2) = PointMass(1)
96+
μ(x) = MvNormalMeanCovariance(randn(4), 100I)
97+
end
98+
99+
@initialization function switching_initialization_2()
100+
q(σ2) = PointMass(1)
101+
μ(y) = MvNormalMeanCovariance(randn(2), 100I)
102+
end
103+
104+
for nr_steps in (50, 100), init in [switching_initialization_1, switching_initialization_2]
105+
result = infer(
106+
model = switching_model(nr_steps = nr_steps, γ = 1, ΔT = 1),
107+
data = (goals = goals,),
108+
constraints = switching_constraints(),
109+
meta = switching_meta(),
110+
initialization = init(),
111+
iterations = 100,
112+
returnvars = KeepLast(),
113+
showprogress = false,
114+
options = (limit_stack_depth = 100,)
115+
)
116+
@test mean(result.posteriors[:x][1, 1]) [0, 0, 0, 0] atol = 5e-1
117+
@test mean(result.posteriors[:x][1, nr_steps]) [0, 0, 50, 0] atol = 5e-1
118+
@test mean(result.posteriors[:x][2, 1]) [0, 0, 50, 0] atol = 5e-1
119+
@test mean(result.posteriors[:x][2, nr_steps]) [0, 0, 0, 0] atol = 5e-1
120+
end
121+
end

0 commit comments

Comments
 (0)