-
Notifications
You must be signed in to change notification settings - Fork 100
Open
Description
The marginal log likelihoods from the recursive filter and parallel filter are different. The below example makes a small modification on the test case from parallel_inference_test.py
by setting µ0[0] = 100.
Reproduce
from jax import numpy as jnp
from jax import random as jr
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter, parallel_lgssm_filter, lgssm_joint_sample
def make_static_lgssm_params():
"""Create a static LGSSM with fixed parameters."""
dt = 0.1
F = jnp.eye(4) + dt * jnp.eye(4, k=2)
Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
[dt**2/2, dt]]),
jnp.eye(2))
H = jnp.eye(2, 4)
R = 0.5 ** 2 * jnp.eye(2)
μ0 = jnp.array([100.,0.,1.,-1.])
Σ0 = jnp.eye(4)
latent_dim = 4
observation_dim = 2
lgssm = LinearGaussianSSM(latent_dim, observation_dim)
params, _ = lgssm.initialize(jr.PRNGKey(0),
initial_mean=μ0,
initial_covariance= Σ0,
dynamics_weights=F,
dynamics_covariance=Q,
emission_weights=H,
emission_covariance=R)
return params, lgssm
num_timesteps = 50
key = jr.PRNGKey(1)
params, lgssm = make_static_lgssm_params()
_, emissions = lgssm_joint_sample(params, key, num_timesteps)
posterior = lgssm_filter(params, emissions)
parallel_posterior = parallel_lgssm_filter(params, emissions)
print(posterior.marginal_loglik)
print(parallel_posterior.marginal_loglik)
Output
-106.0021
-4143.4688
Have not done the math in detail, but I think it's due to this
logZ = _marginal_loglik_elem(P, H, R, innov) |
I'm guessing it should be replaced with something like
logZ = -MVN(H @ m, H @ P @ H.T + R).log_prob(y)
Metadata
Metadata
Assignees
Labels
No labels