Skip to content

Commit 00f6ee1

Browse files
author
Hylke Donker
committed
Add support for inhomogeneous parameters
1 parent b02d892 commit 00f6ee1

File tree

2 files changed

+161
-36
lines changed

2 files changed

+161
-36
lines changed

dynamax/linear_gaussian_ssm/models.py

Lines changed: 107 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77

88
from fastprogress.fastprogress import progress_bar
99
from functools import partial
10-
from jax import jit, vmap
10+
from jax import jit, tree, vmap
1111
from jax.tree_util import tree_map
1212
from jaxtyping import Array, Float
1313
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
1414
from typing import Any, Optional, Tuple, Union, runtime_checkable
15-
from typing_extensions import Protocol
15+
from typing_extensions import Protocol
1616

1717
from dynamax.ssm import SSM
1818
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
@@ -206,7 +206,7 @@ def sample(self,
206206
key: PRNGKeyT,
207207
num_timesteps: int,
208208
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None) \
209-
-> Tuple[Float[Array, "num_timesteps state_dim"],
209+
-> Tuple[Float[Array, "num_timesteps state_dim"],
210210
Float[Array, "num_timesteps emission_dim"]]:
211211
"""Sample from the model.
212212
@@ -357,7 +357,7 @@ def forecast(self,
357357
input_weights=params.emissions.input_weights,
358358
cov=1e8 * jnp.ones(self.emission_dim)) # ignore dummy observatiosn
359359
)
360-
360+
361361
dummy_emissions = jnp.zeros((num_forecast_timesteps, self.emission_dim))
362362
forecast_inputs = forecast_inputs if forecast_inputs is not None else \
363363
jnp.zeros((num_forecast_timesteps, 0))
@@ -367,7 +367,7 @@ def forecast(self,
367367
H = params.emissions.weights
368368
b = params.emissions.bias
369369
R = params.emissions.cov if params.emissions.cov.ndim == 2 else jnp.diag(params.emissions.cov)
370-
370+
371371
forecast_emissions = forecast_states.filtered_means @ H.T + b
372372
forecast_emissions_cov = H @ forecast_states.filtered_covariances @ H.T + R
373373
return forecast_states.filtered_means, \
@@ -643,6 +643,47 @@ def m_step(self,
643643
)
644644
return params, m_step_state
645645

646+
def _check_params(self, params: ParamsLGSSM, num_timesteps: int) -> ParamsLGSSM:
647+
"""Replace None parameters with zeros."""
648+
dynamics, emissions = params.dynamics, params.emissions
649+
is_inhomogeneous = dynamics.weights.ndim == 3
650+
651+
def _zeros_if_none(x, shape):
652+
if x is None:
653+
return jnp.zeros(shape)
654+
return x
655+
656+
shape_prefix = ()
657+
if is_inhomogeneous:
658+
shape_prefix = (num_timesteps - 1,)
659+
660+
clean_dynamics = ParamsLGSSMDynamics(
661+
weights=dynamics.weights,
662+
bias=_zeros_if_none(dynamics.bias, shape=shape_prefix + (self.state_dim,)),
663+
input_weights=_zeros_if_none(
664+
dynamics.input_weights, shape=shape_prefix + (self.state_dim, self.input_dim)
665+
),
666+
cov=dynamics.cov
667+
)
668+
shape_prefix = ()
669+
if is_inhomogeneous:
670+
shape_prefix = (num_timesteps,)
671+
672+
clean_emissions = ParamsLGSSMEmissions(
673+
weights=emissions.weights,
674+
bias=_zeros_if_none(emissions.bias, shape=shape_prefix + (self.emission_dim,)),
675+
input_weights=_zeros_if_none(
676+
emissions.input_weights, shape=shape_prefix + (self.emission_dim, self.input_dim)
677+
),
678+
cov=emissions.cov
679+
)
680+
return ParamsLGSSM(
681+
initial=params.initial,
682+
dynamics=clean_dynamics,
683+
emissions=clean_emissions,
684+
)
685+
686+
646687
def fit_blocked_gibbs(self,
647688
key: PRNGKeyT,
648689
initial_params: ParamsLGSSM,
@@ -654,7 +695,8 @@ def fit_blocked_gibbs(self,
654695
655696
Args:
656697
key: random number key.
657-
initial_params: starting parameters.
698+
initial_params: starting parameters. Include a leading time axis for
699+
the dynamics and emissions parameters in inhomogeneous models.
658700
sample_size: how many samples to draw.
659701
emissions: set of observation sequences.
660702
inputs: optional set of input sequences.
@@ -664,67 +706,97 @@ def fit_blocked_gibbs(self,
664706
"""
665707
num_timesteps = len(emissions)
666708

709+
# Inhomogeneous models have a leading time dimension.
710+
is_inhomogeneous = initial_params.dynamics.weights.ndim == 3
711+
667712
if inputs is None:
668713
inputs = jnp.zeros((num_timesteps, 0))
669714

715+
initial_params = self._check_params(initial_params, num_timesteps)
716+
670717
def sufficient_stats_from_sample(states):
671718
"""Convert samples of states to sufficient statistics."""
672719
inputs_joint = jnp.concatenate((inputs, jnp.ones((num_timesteps, 1))), axis=1)
673720
# Let xn[t] = x[t+1] for t = 0...T-2
674-
x, xp, xn = states, states[:-1], states[1:]
675-
u, up = inputs_joint, inputs_joint[:-1]
721+
x, xn = states, states[1:]
722+
u = inputs_joint
723+
# Let z[t] = [x[t], u[t]] for t = 0...T-1
724+
z = jnp.concatenate([x, u], axis=-1)
725+
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
726+
zp = z[:-1]
676727
y = emissions
677728

678729
init_stats = (x[0], jnp.outer(x[0], x[0]), 1)
679730

680731
# Quantities for the dynamics distribution
681-
# Let zp[t] = [x[t], u[t]] for t = 0...T-2
682-
sum_zpzpT = jnp.block([[xp.T @ xp, xp.T @ up], [up.T @ xp, up.T @ up]])
683-
sum_zpxnT = jnp.block([[xp.T @ xn], [up.T @ xn]])
684-
sum_xnxnT = xn.T @ xn
685-
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, num_timesteps - 1)
732+
sum_zpzpT = jnp.einsum('ti,tj->tij', zp, zp)
733+
sum_zpxnT = jnp.einsum('ti,tj->tij', zp, xn)
734+
sum_xnxnT = jnp.einsum('ti,tj->tij', xn, xn)
735+
z_is_observed = jnp.ones(num_timesteps - 1)
736+
# The dynamics stats have a leading time dimension.
737+
dynamics_stats = (sum_zpzpT, sum_zpxnT, sum_xnxnT, z_is_observed)
686738
if not self.has_dynamics_bias:
687-
dynamics_stats = (sum_zpzpT[:-1, :-1], sum_zpxnT[:-1, :], sum_xnxnT,
688-
num_timesteps - 1)
739+
dynamics_stats = (sum_zpzpT[:, :-1, :-1], sum_zpxnT[:, :-1, :], sum_xnxnT,
740+
z_is_observed)
689741

690742
# Quantities for the emissions
691-
# Let z[t] = [x[t], u[t]] for t = 0...T-1
692-
sum_zzT = jnp.block([[x.T @ x, x.T @ u], [u.T @ x, u.T @ u]])
693-
sum_zyT = jnp.block([[x.T @ y], [u.T @ y]])
694-
sum_yyT = y.T @ y
695-
emission_stats = (sum_zzT, sum_zyT, sum_yyT, num_timesteps)
743+
sum_zzT = jnp.einsum('ti,tj->tij', z, z)
744+
sum_zyT = jnp.einsum('ti,tj->tij', z, y)
745+
sum_yyT = jnp.einsum('ti,tj->tij', y, y)
746+
y_is_observed = jnp.ones(num_timesteps)
747+
# The emissions stats have a leading time dimension.
748+
emission_stats = (sum_zzT, sum_zyT, sum_yyT, y_is_observed)
696749
if not self.has_emissions_bias:
697-
emission_stats = (sum_zzT[:-1, :-1], sum_zyT[:-1, :], sum_yyT, num_timesteps)
750+
emission_stats = (sum_zzT[:, :-1, :-1], sum_zyT[:, :-1, :], sum_yyT, y_is_observed)
698751

699752
return init_stats, dynamics_stats, emission_stats
700753

701-
def lgssm_params_sample(rng, stats):
702-
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
703-
init_stats, dynamics_stats, emission_stats = stats
704-
rngs = iter(jr.split(rng, 3))
705-
706-
# Sample the initial params
754+
def _sample_initial_params(rng, init_stats):
707755
initial_posterior = niw_posterior_update(self.initial_prior, init_stats)
708-
S, m = initial_posterior.sample(seed=next(rngs))
756+
S, m = initial_posterior.sample(seed=rng)
757+
return ParamsLGSSMInitial(mean=m, cov=S)
709758

710-
# Sample the dynamics params
759+
def _sample_dynamics_params(rng, dynamics_stats):
711760
dynamics_posterior = mniw_posterior_update(self.dynamics_prior, dynamics_stats)
712-
Q, FB = dynamics_posterior.sample(seed=next(rngs))
761+
Q, FB = dynamics_posterior.sample(seed=rng)
713762
F = FB[:, :self.state_dim]
714763
B, b = (FB[:, self.state_dim:-1], FB[:, -1]) if self.has_dynamics_bias \
715764
else (FB[:, self.state_dim:], jnp.zeros(self.state_dim))
765+
return ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q)
716766

717-
# Sample the emission params
767+
def _sample_emission_params(rng, emission_stats):
718768
emission_posterior = mniw_posterior_update(self.emission_prior, emission_stats)
719-
R, HD = emission_posterior.sample(seed=next(rngs))
769+
R, HD = emission_posterior.sample(seed=rng)
720770
H = HD[:, :self.state_dim]
721771
D, d = (HD[:, self.state_dim:-1], HD[:, -1]) if self.has_emissions_bias \
722772
else (HD[:, self.state_dim:], jnp.zeros(self.emission_dim))
773+
return ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
774+
775+
def lgssm_params_sample(rng, stats):
776+
"""Sample parameters of the model given sufficient statistics from observed states and emissions."""
777+
init_stats, dynamics_stats, emission_stats = stats
778+
rngs = iter(jr.split(rng, 3))
779+
780+
# Sample the initial params
781+
initial_params = _sample_initial_params(next(rngs), init_stats)
782+
783+
# Sample the dynamics and emission params.
784+
if not is_inhomogeneous:
785+
# Aggregate summary statistics across time for homogeneous model.
786+
dynamics_stats = tree.map(lambda x: jnp.sum(x, axis=0), dynamics_stats)
787+
emission_stats = tree.map(lambda x: jnp.sum(x, axis=0), emission_stats)
788+
dynamics_params = _sample_dynamics_params(next(rngs), dynamics_stats)
789+
emission_params = _sample_emission_params(next(rngs), emission_stats)
790+
else:
791+
keys_dynamics = jr.split(next(rngs), num_timesteps - 1)
792+
keys_emission = jr.split(next(rngs), num_timesteps)
793+
dynamics_params = vmap(_sample_dynamics_params)(keys_dynamics, dynamics_stats)
794+
emission_params = vmap(_sample_emission_params)(keys_emission, emission_stats)
723795

724796
params = ParamsLGSSM(
725-
initial=ParamsLGSSMInitial(mean=m, cov=S),
726-
dynamics=ParamsLGSSMDynamics(weights=F, bias=b, input_weights=B, cov=Q),
727-
emissions=ParamsLGSSMEmissions(weights=H, bias=d, input_weights=D, cov=R)
797+
initial=initial_params,
798+
dynamics=dynamics_params,
799+
emissions=emission_params,
728800
)
729801
return params
730802

dynamax/linear_gaussian_ssm/models_test.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""
22
Tests for the linear Gaussian SSM models.
33
"""
4+
from itertools import count, product
45

5-
import pytest
6+
import jax.numpy as jnp
67
import jax.random as jr
8+
from jax import tree
9+
import pytest
710

811
from dynamax.linear_gaussian_ssm import LinearGaussianSSM
912
from dynamax.linear_gaussian_ssm import LinearGaussianConjugateSSM
13+
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM
1014
from dynamax.utils.utils import monotonically_increasing
1115

1216
NUM_TIMESTEPS = 100
@@ -29,3 +33,52 @@ def test_sample_and_fit(cls, kwargs, inputs):
2933
fitted_params, lps = model.fit_em(params, param_props, emissions, inputs=inputs, num_iters=3)
3034
assert monotonically_increasing(lps)
3135
fitted_params, lps = model.fit_sgd(params, param_props, emissions, inputs=inputs, num_epochs=3)
36+
37+
@pytest.mark.parametrize(["has_dynamics_bias", "has_emissions_bias"], product([True, False], repeat=2))
38+
def test_inhomogeneous_lgcssm(has_dynamics_bias, has_emissions_bias):
39+
"""
40+
Test a LinearGaussianConjugateSSM with time-varying dynamics and emission model.
41+
"""
42+
state_dim = 2
43+
emission_dim = 3
44+
num_timesteps = 4
45+
keys = map(jr.PRNGKey, count())
46+
kwargs = {
47+
"state_dim": state_dim,
48+
"emission_dim": emission_dim,
49+
"has_dynamics_bias": has_dynamics_bias,
50+
"has_emissions_bias": has_emissions_bias,
51+
}
52+
model = LinearGaussianConjugateSSM(**kwargs)
53+
params, param_props = model.initialize(jr.PRNGKey(0))
54+
# Repeat the parameters for each timestep.
55+
inhomogeneous_dynamics = tree.map(
56+
lambda x: jnp.repeat(x[None], num_timesteps - 1, axis=0), params.dynamics,
57+
)
58+
inhomogeneous_emissions = tree.map(
59+
lambda x: jnp.repeat(x[None], num_timesteps, axis=0), params.emissions,
60+
)
61+
62+
_, emissions = model.sample(params, next(keys), num_timesteps=num_timesteps)
63+
inhomogeneous_params = ParamsLGSSM(
64+
initial=params.initial,
65+
dynamics=inhomogeneous_dynamics,
66+
emissions=inhomogeneous_emissions,
67+
)
68+
params_trace = model.fit_blocked_gibbs(
69+
next(keys),
70+
inhomogeneous_params,
71+
sample_size=5,
72+
emissions=emissions,
73+
)
74+
75+
# Arbitrarily check the last set of parameters from the Markov chain.
76+
last_params = tree.map(lambda x: x[-1], params_trace)
77+
assert last_params.initial.mean.shape == (state_dim,)
78+
assert last_params.initial.cov.shape == (state_dim, state_dim)
79+
assert last_params.dynamics.weights.shape == (num_timesteps - 1, state_dim, state_dim)
80+
assert last_params.emissions.weights.shape == (num_timesteps, emission_dim, state_dim)
81+
assert last_params.dynamics.bias.shape == (num_timesteps - 1, state_dim)
82+
assert last_params.emissions.bias.shape == (num_timesteps, emission_dim)
83+
assert last_params.dynamics.cov.shape == (num_timesteps - 1, state_dim, state_dim)
84+
assert last_params.emissions.cov.shape == (num_timesteps, emission_dim, emission_dim)

0 commit comments

Comments
 (0)