|
| 1 | + |
| 2 | +@testitem "MultiAgentTrajectoryPlanning model should terminate and give results" begin |
| 3 | + # https://github.com/biaslab/MultiAgentTrajectoryPlanning/issues/4 |
| 4 | + using RxInfer, BenchmarkTools, Random, Plots, Dates, LinearAlgebra, StableRNGs |
| 5 | + |
| 6 | + # `include(test/utiltests.jl)` |
| 7 | + include(joinpath(@__DIR__, "..", "..", "utiltests.jl")) |
| 8 | + |
| 9 | + # half space specification |
| 10 | + struct Halfspace end |
| 11 | + @node Halfspace Stochastic [out, a, σ2, γ] |
| 12 | + |
| 13 | + # rule specification |
| 14 | + @rule Halfspace(:out, Marginalisation) (q_a::PointMass, q_σ2::PointMass, q_γ::PointMass) = begin |
| 15 | + return NormalMeanVariance(mean(q_a) + mean(q_γ) * mean(q_σ2), mean(q_σ2)) |
| 16 | + end |
| 17 | + |
| 18 | + struct ForcePointMass{V} |
| 19 | + v::V |
| 20 | + end |
| 21 | + |
| 22 | + @rule Halfspace(:σ2, Marginalisation) (q_out::UnivariateNormalDistributionsFamily, q_a::PointMass, q_γ::PointMass) = begin |
| 23 | + return ForcePointMass(1 / mean(q_γ) * sqrt(abs2(mean(q_out) - mean(q_a)) + var(q_out))) |
| 24 | + end |
| 25 | + |
| 26 | + BayesBase.prod(::GenericProd, p::ForcePointMass, any) = PointMass(p.v) |
| 27 | + BayesBase.prod(::GenericProd, any, p::ForcePointMass) = PointMass(p.v) |
| 28 | + |
| 29 | + ReactiveMP.to_marginal(p::ForcePointMass) = PointMass(p.v) |
| 30 | + |
| 31 | + function h(y1, y2) |
| 32 | + r1 = 15 |
| 33 | + r2 = 15 |
| 34 | + return norm(y1 - y2) - r1 - r2 |
| 35 | + end |
| 36 | + |
| 37 | + @model function switching_model(nr_steps, γ, ΔT, goals) |
| 38 | + |
| 39 | + # transition model |
| 40 | + A = [1 ΔT 0 0; 0 1 0 0; 0 0 1 ΔT; 0 0 0 1] |
| 41 | + B = [0 0; ΔT 0; 0 0; 0 ΔT] |
| 42 | + C = [1 0 0 0; 0 0 1 0] |
| 43 | + |
| 44 | + local y |
| 45 | + |
| 46 | + # single agent models |
| 47 | + for k in 1:2 |
| 48 | + |
| 49 | + # prior on state |
| 50 | + x[k, 1] ~ MvNormalMeanCovariance(zeros(4), 1e2I) |
| 51 | + |
| 52 | + for t in 1:nr_steps |
| 53 | + |
| 54 | + # prior on controls |
| 55 | + u[k, t] ~ MvNormalMeanCovariance(zeros(2), 1e-2I) |
| 56 | + |
| 57 | + # state transition |
| 58 | + x[k, t + 1] ~ A * x[k, t] + B * u[k, t] |
| 59 | + |
| 60 | + # observation model |
| 61 | + y[k, t] ~ C * x[k, t + 1] |
| 62 | + end |
| 63 | + |
| 64 | + # goal priors (indexing reverse due to definition) |
| 65 | + goals[1, k] ~ MvNormalMeanCovariance(x[k, 1], 1e-5I) |
| 66 | + goals[2, k] ~ MvNormalMeanCovariance(x[k, nr_steps + 1], 1e-5I) |
| 67 | + end |
| 68 | + |
| 69 | + # multi-agent models |
| 70 | + for t in 1:nr_steps |
| 71 | + |
| 72 | + # observation constraint |
| 73 | + σ2[t] ~ GammaShapeRate(3 / 2, γ^2 / 2) |
| 74 | + d[t] ~ h(y[1, t], y[2, t]) |
| 75 | + d[t] ~ Halfspace(0, σ2[t], γ) |
| 76 | + end |
| 77 | + end |
| 78 | + |
| 79 | + @constraints function switching_constraints() |
| 80 | + q(d, σ2) = q(d)q(σ2) |
| 81 | + end |
| 82 | + |
| 83 | + @meta function switching_meta() |
| 84 | + h() -> Linearization() |
| 85 | + end |
| 86 | + |
| 87 | + goals = hcat([ |
| 88 | + # agent 1: start at (0,0) with 0 velocity, end at (0, 50) with 0 velocity |
| 89 | + [[0, 0, 0, 0], [0, 0, 50, 0]], |
| 90 | + # agent 2: start at (0,50) with 0 velocity, end at (0, 0) with 0 velocity |
| 91 | + [[0, 0, 50, 0], [0, 0, 0, 0]] |
| 92 | + ]...) |
| 93 | + |
| 94 | + @initialization function switching_initialization_1() |
| 95 | + q(σ2) = PointMass(1) |
| 96 | + μ(x) = MvNormalMeanCovariance(randn(4), 100I) |
| 97 | + end |
| 98 | + |
| 99 | + @initialization function switching_initialization_2() |
| 100 | + q(σ2) = PointMass(1) |
| 101 | + μ(y) = MvNormalMeanCovariance(randn(2), 100I) |
| 102 | + end |
| 103 | + |
| 104 | + for nr_steps in (50, 100), init in [switching_initialization_1, switching_initialization_2] |
| 105 | + result = infer( |
| 106 | + model = switching_model(nr_steps = nr_steps, γ = 1, ΔT = 1), |
| 107 | + data = (goals = goals,), |
| 108 | + constraints = switching_constraints(), |
| 109 | + meta = switching_meta(), |
| 110 | + initialization = init(), |
| 111 | + iterations = 100, |
| 112 | + returnvars = KeepLast(), |
| 113 | + showprogress = false, |
| 114 | + options = (limit_stack_depth = 100,) |
| 115 | + ) |
| 116 | + @test mean(result.posteriors[:x][1, 1]) ≈ [0, 0, 0, 0] atol = 5e-1 |
| 117 | + @test mean(result.posteriors[:x][1, nr_steps]) ≈ [0, 0, 50, 0] atol = 5e-1 |
| 118 | + @test mean(result.posteriors[:x][2, 1]) ≈ [0, 0, 50, 0] atol = 5e-1 |
| 119 | + @test mean(result.posteriors[:x][2, nr_steps]) ≈ [0, 0, 0, 0] atol = 5e-1 |
| 120 | + end |
| 121 | +end |
0 commit comments