7
7
8
8
from fastprogress .fastprogress import progress_bar
9
9
from functools import partial
10
- from jax import jit , vmap
10
+ from jax import jit , tree , vmap
11
11
from jax .tree_util import tree_map
12
12
from jaxtyping import Array , Float
13
13
from tensorflow_probability .substrates .jax .distributions import MultivariateNormalFullCovariance as MVN
14
14
from typing import Any , Optional , Tuple , Union , runtime_checkable
15
- from typing_extensions import Protocol
15
+ from typing_extensions import Protocol
16
16
17
17
from dynamax .ssm import SSM
18
18
from dynamax .linear_gaussian_ssm .inference import lgssm_joint_sample , lgssm_filter , lgssm_smoother , lgssm_posterior_sample
@@ -206,7 +206,7 @@ def sample(self,
206
206
key : PRNGKeyT ,
207
207
num_timesteps : int ,
208
208
inputs : Optional [Float [Array , "num_timesteps input_dim" ]] = None ) \
209
- -> Tuple [Float [Array , "num_timesteps state_dim" ],
209
+ -> Tuple [Float [Array , "num_timesteps state_dim" ],
210
210
Float [Array , "num_timesteps emission_dim" ]]:
211
211
"""Sample from the model.
212
212
@@ -357,7 +357,7 @@ def forecast(self,
357
357
input_weights = params .emissions .input_weights ,
358
358
cov = 1e8 * jnp .ones (self .emission_dim )) # ignore dummy observatiosn
359
359
)
360
-
360
+
361
361
dummy_emissions = jnp .zeros ((num_forecast_timesteps , self .emission_dim ))
362
362
forecast_inputs = forecast_inputs if forecast_inputs is not None else \
363
363
jnp .zeros ((num_forecast_timesteps , 0 ))
@@ -367,7 +367,7 @@ def forecast(self,
367
367
H = params .emissions .weights
368
368
b = params .emissions .bias
369
369
R = params .emissions .cov if params .emissions .cov .ndim == 2 else jnp .diag (params .emissions .cov )
370
-
370
+
371
371
forecast_emissions = forecast_states .filtered_means @ H .T + b
372
372
forecast_emissions_cov = H @ forecast_states .filtered_covariances @ H .T + R
373
373
return forecast_states .filtered_means , \
@@ -643,6 +643,47 @@ def m_step(self,
643
643
)
644
644
return params , m_step_state
645
645
646
+ def _check_params (self , params : ParamsLGSSM , num_timesteps : int ) -> ParamsLGSSM :
647
+ """Replace None parameters with zeros."""
648
+ dynamics , emissions = params .dynamics , params .emissions
649
+ is_inhomogeneous = dynamics .weights .ndim == 3
650
+
651
+ def _zeros_if_none (x , shape ):
652
+ if x is None :
653
+ return jnp .zeros (shape )
654
+ return x
655
+
656
+ shape_prefix = ()
657
+ if is_inhomogeneous :
658
+ shape_prefix = (num_timesteps - 1 ,)
659
+
660
+ clean_dynamics = ParamsLGSSMDynamics (
661
+ weights = dynamics .weights ,
662
+ bias = _zeros_if_none (dynamics .bias , shape = shape_prefix + (self .state_dim ,)),
663
+ input_weights = _zeros_if_none (
664
+ dynamics .input_weights , shape = shape_prefix + (self .state_dim , self .input_dim )
665
+ ),
666
+ cov = dynamics .cov
667
+ )
668
+ shape_prefix = ()
669
+ if is_inhomogeneous :
670
+ shape_prefix = (num_timesteps ,)
671
+
672
+ clean_emissions = ParamsLGSSMEmissions (
673
+ weights = emissions .weights ,
674
+ bias = _zeros_if_none (emissions .bias , shape = shape_prefix + (self .emission_dim ,)),
675
+ input_weights = _zeros_if_none (
676
+ emissions .input_weights , shape = shape_prefix + (self .emission_dim , self .input_dim )
677
+ ),
678
+ cov = emissions .cov
679
+ )
680
+ return ParamsLGSSM (
681
+ initial = params .initial ,
682
+ dynamics = clean_dynamics ,
683
+ emissions = clean_emissions ,
684
+ )
685
+
686
+
646
687
def fit_blocked_gibbs (self ,
647
688
key : PRNGKeyT ,
648
689
initial_params : ParamsLGSSM ,
@@ -654,7 +695,8 @@ def fit_blocked_gibbs(self,
654
695
655
696
Args:
656
697
key: random number key.
657
- initial_params: starting parameters.
698
+ initial_params: starting parameters. Include a leading time axis for
699
+ the dynamics and emissions parameters in inhomogeneous models.
658
700
sample_size: how many samples to draw.
659
701
emissions: set of observation sequences.
660
702
inputs: optional set of input sequences.
@@ -664,67 +706,97 @@ def fit_blocked_gibbs(self,
664
706
"""
665
707
num_timesteps = len (emissions )
666
708
709
+ # Inhomogeneous models have a leading time dimension.
710
+ is_inhomogeneous = initial_params .dynamics .weights .ndim == 3
711
+
667
712
if inputs is None :
668
713
inputs = jnp .zeros ((num_timesteps , 0 ))
669
714
715
+ initial_params = self ._check_params (initial_params , num_timesteps )
716
+
670
717
def sufficient_stats_from_sample (states ):
671
718
"""Convert samples of states to sufficient statistics."""
672
719
inputs_joint = jnp .concatenate ((inputs , jnp .ones ((num_timesteps , 1 ))), axis = 1 )
673
720
# Let xn[t] = x[t+1] for t = 0...T-2
674
- x , xp , xn = states , states [:- 1 ], states [1 :]
675
- u , up = inputs_joint , inputs_joint [:- 1 ]
721
+ x , xn = states , states [1 :]
722
+ u = inputs_joint
723
+ # Let z[t] = [x[t], u[t]] for t = 0...T-1
724
+ z = jnp .concatenate ([x , u ], axis = - 1 )
725
+ # Let zp[t] = [x[t], u[t]] for t = 0...T-2
726
+ zp = z [:- 1 ]
676
727
y = emissions
677
728
678
729
init_stats = (x [0 ], jnp .outer (x [0 ], x [0 ]), 1 )
679
730
680
731
# Quantities for the dynamics distribution
681
- # Let zp[t] = [x[t], u[t]] for t = 0...T-2
682
- sum_zpzpT = jnp .block ([[xp .T @ xp , xp .T @ up ], [up .T @ xp , up .T @ up ]])
683
- sum_zpxnT = jnp .block ([[xp .T @ xn ], [up .T @ xn ]])
684
- sum_xnxnT = xn .T @ xn
685
- dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , num_timesteps - 1 )
732
+ sum_zpzpT = jnp .einsum ('ti,tj->tij' , zp , zp )
733
+ sum_zpxnT = jnp .einsum ('ti,tj->tij' , zp , xn )
734
+ sum_xnxnT = jnp .einsum ('ti,tj->tij' , xn , xn )
735
+ z_is_observed = jnp .ones (num_timesteps - 1 )
736
+ # The dynamics stats have a leading time dimension.
737
+ dynamics_stats = (sum_zpzpT , sum_zpxnT , sum_xnxnT , z_is_observed )
686
738
if not self .has_dynamics_bias :
687
- dynamics_stats = (sum_zpzpT [:- 1 , :- 1 ], sum_zpxnT [:- 1 , :], sum_xnxnT ,
688
- num_timesteps - 1 )
739
+ dynamics_stats = (sum_zpzpT [:, : - 1 , :- 1 ], sum_zpxnT [:, :- 1 , :], sum_xnxnT ,
740
+ z_is_observed )
689
741
690
742
# Quantities for the emissions
691
- # Let z[t] = [x[t], u[t]] for t = 0...T-1
692
- sum_zzT = jnp .block ([[x .T @ x , x .T @ u ], [u .T @ x , u .T @ u ]])
693
- sum_zyT = jnp .block ([[x .T @ y ], [u .T @ y ]])
694
- sum_yyT = y .T @ y
695
- emission_stats = (sum_zzT , sum_zyT , sum_yyT , num_timesteps )
743
+ sum_zzT = jnp .einsum ('ti,tj->tij' , z , z )
744
+ sum_zyT = jnp .einsum ('ti,tj->tij' , z , y )
745
+ sum_yyT = jnp .einsum ('ti,tj->tij' , y , y )
746
+ y_is_observed = jnp .ones (num_timesteps )
747
+ # The emissions stats have a leading time dimension.
748
+ emission_stats = (sum_zzT , sum_zyT , sum_yyT , y_is_observed )
696
749
if not self .has_emissions_bias :
697
- emission_stats = (sum_zzT [:- 1 , :- 1 ], sum_zyT [:- 1 , :], sum_yyT , num_timesteps )
750
+ emission_stats = (sum_zzT [:, : - 1 , :- 1 ], sum_zyT [:, : - 1 , :], sum_yyT , y_is_observed )
698
751
699
752
return init_stats , dynamics_stats , emission_stats
700
753
701
- def lgssm_params_sample (rng , stats ):
702
- """Sample parameters of the model given sufficient statistics from observed states and emissions."""
703
- init_stats , dynamics_stats , emission_stats = stats
704
- rngs = iter (jr .split (rng , 3 ))
705
-
706
- # Sample the initial params
754
+ def _sample_initial_params (rng , init_stats ):
707
755
initial_posterior = niw_posterior_update (self .initial_prior , init_stats )
708
- S , m = initial_posterior .sample (seed = next (rngs ))
756
+ S , m = initial_posterior .sample (seed = rng )
757
+ return ParamsLGSSMInitial (mean = m , cov = S )
709
758
710
- # Sample the dynamics params
759
+ def _sample_dynamics_params ( rng , dynamics_stats ):
711
760
dynamics_posterior = mniw_posterior_update (self .dynamics_prior , dynamics_stats )
712
- Q , FB = dynamics_posterior .sample (seed = next ( rngs ) )
761
+ Q , FB = dynamics_posterior .sample (seed = rng )
713
762
F = FB [:, :self .state_dim ]
714
763
B , b = (FB [:, self .state_dim :- 1 ], FB [:, - 1 ]) if self .has_dynamics_bias \
715
764
else (FB [:, self .state_dim :], jnp .zeros (self .state_dim ))
765
+ return ParamsLGSSMDynamics (weights = F , bias = b , input_weights = B , cov = Q )
716
766
717
- # Sample the emission params
767
+ def _sample_emission_params ( rng , emission_stats ):
718
768
emission_posterior = mniw_posterior_update (self .emission_prior , emission_stats )
719
- R , HD = emission_posterior .sample (seed = next ( rngs ) )
769
+ R , HD = emission_posterior .sample (seed = rng )
720
770
H = HD [:, :self .state_dim ]
721
771
D , d = (HD [:, self .state_dim :- 1 ], HD [:, - 1 ]) if self .has_emissions_bias \
722
772
else (HD [:, self .state_dim :], jnp .zeros (self .emission_dim ))
773
+ return ParamsLGSSMEmissions (weights = H , bias = d , input_weights = D , cov = R )
774
+
775
+ def lgssm_params_sample (rng , stats ):
776
+ """Sample parameters of the model given sufficient statistics from observed states and emissions."""
777
+ init_stats , dynamics_stats , emission_stats = stats
778
+ rngs = iter (jr .split (rng , 3 ))
779
+
780
+ # Sample the initial params
781
+ initial_params = _sample_initial_params (next (rngs ), init_stats )
782
+
783
+ # Sample the dynamics and emission params.
784
+ if not is_inhomogeneous :
785
+ # Aggregate summary statistics across time for homogeneous model.
786
+ dynamics_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), dynamics_stats )
787
+ emission_stats = tree .map (lambda x : jnp .sum (x , axis = 0 ), emission_stats )
788
+ dynamics_params = _sample_dynamics_params (next (rngs ), dynamics_stats )
789
+ emission_params = _sample_emission_params (next (rngs ), emission_stats )
790
+ else :
791
+ keys_dynamics = jr .split (next (rngs ), num_timesteps - 1 )
792
+ keys_emission = jr .split (next (rngs ), num_timesteps )
793
+ dynamics_params = vmap (_sample_dynamics_params )(keys_dynamics , dynamics_stats )
794
+ emission_params = vmap (_sample_emission_params )(keys_emission , emission_stats )
723
795
724
796
params = ParamsLGSSM (
725
- initial = ParamsLGSSMInitial ( mean = m , cov = S ) ,
726
- dynamics = ParamsLGSSMDynamics ( weights = F , bias = b , input_weights = B , cov = Q ) ,
727
- emissions = ParamsLGSSMEmissions ( weights = H , bias = d , input_weights = D , cov = R )
797
+ initial = initial_params ,
798
+ dynamics = dynamics_params ,
799
+ emissions = emission_params ,
728
800
)
729
801
return params
730
802
0 commit comments