Skip to content

Marginal log likelihood of LGSSM parallel filter does not match recursive filter #423

@gorold

Description

@gorold

The marginal log likelihoods from the recursive filter and parallel filter are different. The below example makes a small modification on the test case from parallel_inference_test.py by setting µ0[0] = 100.

Reproduce

from jax import numpy as jnp
from jax import random as jr

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_smoother, lgssm_filter, parallel_lgssm_filter, lgssm_joint_sample


def make_static_lgssm_params():
    """Create a static LGSSM with fixed parameters."""
    dt = 0.1
    F = jnp.eye(4) + dt * jnp.eye(4, k=2)
    Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
                          [dt**2/2, dt]]),
                         jnp.eye(2))
                         
    H = jnp.eye(2, 4)
    R = 0.5 ** 2 * jnp.eye(2)
    μ0 = jnp.array([100.,0.,1.,-1.])
    Σ0 = jnp.eye(4)

    latent_dim = 4
    observation_dim = 2

    lgssm = LinearGaussianSSM(latent_dim, observation_dim)
    params, _ = lgssm.initialize(jr.PRNGKey(0),
                             initial_mean=μ0,
                             initial_covariance= Σ0,
                             dynamics_weights=F,
                             dynamics_covariance=Q,
                             emission_weights=H,
                             emission_covariance=R)
    return params, lgssm


num_timesteps = 50
key = jr.PRNGKey(1)
params, lgssm = make_static_lgssm_params()   
_, emissions = lgssm_joint_sample(params, key, num_timesteps)
posterior = lgssm_filter(params, emissions)
parallel_posterior = parallel_lgssm_filter(params, emissions)


print(posterior.marginal_loglik)
print(parallel_posterior.marginal_loglik)

Output

-106.0021
-4143.4688

Have not done the math in detail, but I think it's due to this

logZ = _marginal_loglik_elem(P, H, R, innov)

I'm guessing it should be replaced with something like

logZ = -MVN(H @ m, H @ P @ H.T + R).log_prob(y)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions