Skip to content

Commit 2b8626a

Browse files
committed
merge
Merge branch 'main' of https://github.com/probml/dynamax into main
2 parents 1fb965f + 58b9dbe commit 2b8626a

File tree

9 files changed

+34
-37
lines changed

9 files changed

+34
-37
lines changed

dynamax/generalized_gaussian_ssm/inference.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,12 @@
22
from numpy.polynomial.hermite_e import hermegauss
33
from jax import jacfwd, vmap, lax
44
import jax.numpy as jnp
5-
from jax import scipy as jsc
65
from jax import lax
76
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
87
from jaxtyping import Array, Float
98
from typing import NamedTuple, Optional, Union, Callable
109

11-
from dynamax.utils.utils import linear_solve
10+
from dynamax.utils.utils import psd_solve
1211
from dynamax.generalized_gaussian_ssm.models import ParamsGGSSM
1312
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
1413

@@ -162,7 +161,7 @@ def _step(carry, _):
162161
S = g_ev(Cov_Y, prior_mean, prior_cov) + g_cov(m_Y, m_Y, prior_mean, prior_cov)
163162
log_likelihood = emission_dist(yhat, S).log_prob(jnp.atleast_1d(y)).sum()
164163
C = g_cov(identity_fn, m_Y, prior_mean, prior_cov)
165-
K = linear_solve(S, C.T).T
164+
K = psd_solve(S, C.T).T
166165
posterior_mean = prior_mean + K @ (y - yhat)
167166
posterior_cov = prior_cov - K @ S @ K.T
168167
return (posterior_mean, posterior_cov), log_likelihood
@@ -195,7 +194,7 @@ def _statistical_linear_regression(mu, Sigma, m, S, C):
195194
b (D_obs):
196195
Omega (D_obs, D_obs):
197196
"""
198-
A = linear_solve(Sigma.T, C).T
197+
A = psd_solve(Sigma.T, C).T
199198
b = m - A @ mu
200199
Omega = S - A @ Sigma @ A.T
201200
return A, b, Omega
@@ -329,7 +328,7 @@ def _step(carry, args):
329328

330329
# Prediction step
331330
pred_mean, pred_cov, pred_cross = _predict(filtered_mean, filtered_cov, f, Q, u, g_ev, g_cov)
332-
G = linear_solve(pred_cov, pred_cross.T).T
331+
G = psd_solve(pred_cov, pred_cross.T).T
333332

334333
# Compute smoothed mean and covariance
335334
smoothed_mean = filtered_mean + G @ (smoothed_mean_next - pred_mean)

dynamax/linear_gaussian_ssm/inference.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import jax.numpy as jnp
22
import jax.random as jr
3-
from jax import scipy as jsc
43
from jax import lax
54
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
65
from functools import wraps
@@ -9,7 +8,7 @@
98
from jaxtyping import Array, Float
109
from typing import NamedTuple, Optional, Union
1110

12-
from dynamax.utils.utils import linear_solve
11+
from dynamax.utils.utils import psd_solve
1312
from dynamax.parameters import ParameterProperties
1413
from dynamax.types import PRNGKey, Scalar
1514

@@ -171,7 +170,7 @@ def _condition_on(m, P, H, D, d, R, u, y):
171170
"""
172171
# Compute the Kalman gain
173172
S = R + H @ P @ H.T
174-
K = linear_solve(S, H @ P).T
173+
K = psd_solve(S, H @ P).T
175174
Sigma_cond = P - K @ S @ K.T
176175
mu_cond = m + K @ (y - D @ u - d - H @ m)
177176
return mu_cond, Sigma_cond
@@ -324,7 +323,7 @@ def _step(carry, args):
324323

325324
# This is like the Kalman gain but in reverse
326325
# See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing"
327-
G = linear_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T
326+
G = psd_solve(Q + F @ filtered_cov @ F.T, F @ filtered_cov).T
328327

329328
# Compute the smoothed mean and covariance
330329
smoothed_mean = filtered_mean + G @ (smoothed_mean_next - F @ filtered_mean - B @ u - b)

dynamax/linear_gaussian_ssm/info_inference.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import jax.numpy as jnp
2-
from jax import scipy as jsc
32
from jax import lax, vmap, value_and_grad
43
from jax.scipy.linalg import solve_triangular
54
from jaxtyping import Array, Float
65
from typing import NamedTuple, Optional
76

8-
from dynamax.utils.utils import linear_solve
7+
from dynamax.utils.utils import psd_solve
98

109

1110
class ParamsLGSSMInfo(NamedTuple):
@@ -60,7 +59,7 @@ def info_to_moment_form(etas, Lambdas):
6059
means (N,D)
6160
covs (N,D,D)
6261
"""
63-
means = vmap(lambda A, b:linear_solve(A, b))(Lambdas, etas)
62+
means = vmap(lambda A, b:psd_solve(A, b))(Lambdas, etas)
6463
covs = jnp.linalg.inv(Lambdas)
6564
return means, covs
6665

@@ -82,7 +81,7 @@ def _mvn_info_log_prob(eta, Lambda, x):
8281
"""
8382
D = len(Lambda)
8483
lp = x.T @ eta - 0.5 * x.T @ Lambda @ x
85-
lp += -0.5 * eta.T @ linear_solve(Lambda, eta)
84+
lp += -0.5 * eta.T @ psd_solve(Lambda, eta)
8685
sign, logdet = jnp.linalg.slogdet(Lambda)
8786
lp += -0.5 * (D * jnp.log(2 * jnp.pi) - sign * logdet)
8887
return lp
@@ -121,7 +120,7 @@ def _info_predict(eta, Lambda, F, Q_prec, B, u, b):
121120
eta_pred (D_hid,): predicted precision weighted mean.
122121
Lambda_pred (D_hid,D_hid): predicted precision.
123122
"""
124-
K = linear_solve(Lambda + F.T @ Q_prec @ F, F.T @ Q_prec).T
123+
K = psd_solve(Lambda + F.T @ Q_prec @ F, F.T @ Q_prec).T
125124
I = jnp.eye(F.shape[0])
126125
## This version should be more stable than:
127126
# Lambda_pred = (I - K @ F.T) @ Q_prec
@@ -263,7 +262,7 @@ def _smooth_step(carry, args):
263262

264263
# This is the information form version of the 'reverse' Kalman gain
265264
# See Eq 8.11 of Saarka's "Bayesian Filtering and Smoothing"
266-
G = linear_solve(Q_prec + smoothed_prec_next - pred_prec, Q_prec @ F)
265+
G = psd_solve(Q_prec + smoothed_prec_next - pred_prec, Q_prec @ F)
267266

268267
# Compute the smoothed parameter estimates
269268
smoothed_prec = filtered_prec + F.T @ Q_prec @ (F - G)
@@ -398,18 +397,18 @@ def lds_to_block_tridiag(lds, data, inputs):
398397
T = len(data)
399398

400399
# diagonal blocks of precision matrix
401-
J_diag = jnp.array([jnp.dot(C(t).T, linear_solve(R(t), C(t))) for t in range(T)])
400+
J_diag = jnp.array([jnp.dot(C(t).T, psd_solve(R(t), C(t))) for t in range(T)])
402401
J_diag = J_diag.at[0].add(jnp.linalg.inv(Q0))
403-
J_diag = J_diag.at[:-1].add(jnp.array([jnp.dot(A(t).T, linear_solve(Q(t), A(t))) for t in range(T - 1)]))
402+
J_diag = J_diag.at[:-1].add(jnp.array([jnp.dot(A(t).T, psd_solve(Q(t), A(t))) for t in range(T - 1)]))
404403
J_diag = J_diag.at[1:].add(jnp.array([jnp.linalg.inv(Q(t)) for t in range(0, T - 1)]))
405404

406405
# lower diagonal blocks of precision matrix
407-
J_lower_diag = jnp.array([-linear_solve(Q(t), A(t)) for t in range(T - 1)])
406+
J_lower_diag = jnp.array([-psd_solve(Q(t), A(t)) for t in range(T - 1)])
408407

409408
# linear potential
410-
h = jnp.array([jnp.dot(data[t] - D(t) @ inputs[t], linear_solve(R(t), C(t))) for t in range(T)])
411-
h = h.at[0].add(linear_solve(Q0, m0))
412-
h = h.at[:-1].add(jnp.array([-jnp.dot(A(t).T, linear_solve(Q(t), B(t) @ inputs[t])) for t in range(T - 1)]))
413-
h = h.at[1:].add(jnp.array([linear_solve(Q(t), B(t) @ inputs[t]) for t in range(T - 1)]))
409+
h = jnp.array([jnp.dot(data[t] - D(t) @ inputs[t], psd_solve(R(t), C(t))) for t in range(T)])
410+
h = h.at[0].add(psd_solve(Q0, m0))
411+
h = h.at[:-1].add(jnp.array([-jnp.dot(A(t).T, psd_solve(Q(t), B(t) @ inputs[t])) for t in range(T - 1)]))
412+
h = h.at[1:].add(jnp.array([psd_solve(Q(t), B(t) @ inputs[t]) for t in range(T - 1)]))
414413

415414
return J_diag, J_lower_diag, h

dynamax/linear_gaussian_ssm/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from dynamax.utils.distributions import MatrixNormalInverseWishart as MNIW
2121
from dynamax.utils.distributions import NormalInverseWishart as NIW
2222
from dynamax.utils.distributions import mniw_posterior_update, niw_posterior_update
23-
from dynamax.utils.utils import pytree_stack, linear_solve
23+
from dynamax.utils.utils import pytree_stack, psd_solve
2424

2525
class SuffStatsLGSSM(Protocol):
2626
"""A :class:`NamedTuple` with sufficient statistics for LGSSM parameter estimation."""
@@ -339,7 +339,7 @@ def m_step(
339339

340340
def fit_linear_regression(ExxT, ExyT, EyyT, N):
341341
# Solve a linear regression given sufficient statistics
342-
W = linear_solve(ExxT, ExyT).T
342+
W = psd_solve(ExxT, ExyT).T
343343
Sigma = (EyyT - W @ ExyT - ExyT.T @ W.T + W @ ExxT @ W.T) / N
344344
return W, Sigma
345345

dynamax/linear_gaussian_ssm/parallel_inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tensorflow_probability.substrates.jax.distributions import MultivariateNormalFullCovariance as MVN
88
from jaxtyping import Array, Float
99

10-
from dynamax.utils.utils import linear_solve
10+
from dynamax.utils.utils import psd_solve
1111
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM
1212

1313
def _make_associative_filtering_elements(params, emissions):
@@ -26,7 +26,7 @@ def _first_filtering_element(params, y):
2626
m1 = params.initial.mean
2727
P1 = params.initial.cov
2828
S1 = H @ P1 @ H.T + R
29-
K1 = linear_solve(S1, H @ P1).T
29+
K1 = psd_solve(S1, H @ P1).T
3030

3131
A = jnp.zeros_like(F)
3232
b = m1 + K1 @ (y - H @ m1)
@@ -131,7 +131,7 @@ def _generic_smoothing_element(params, m, P):
131131

132132
Pp = F @ P @ F.T + Q
133133

134-
E = linear_solve(Pp, F @ P).T
134+
E = psd_solve(Pp, F @ P).T
135135
g = m - E @ F @ m
136136
L = P - E @ Pp @ E.T
137137
return E, g, L

dynamax/nonlinear_gaussian_ssm/inference_ekf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jaxtyping import Array, Float
66
from typing import Optional
77

8-
from dynamax.utils.utils import linear_solve
8+
from dynamax.utils.utils import psd_solve
99
from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM
1010
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
1111

@@ -72,7 +72,7 @@ def _step(carry, _):
7272
prior_mean, prior_cov = carry
7373
H_x = H(prior_mean, u)
7474
S = R + H_x @ prior_cov @ H_x.T
75-
K = linear_solve(S, H_x @ prior_cov).T
75+
K = psd_solve(S, H_x @ prior_cov).T
7676
posterior_cov = prior_cov - K @ S @ K.T
7777
posterior_mean = prior_mean + K @ (y - h(prior_mean, u))
7878
return (posterior_mean, posterior_cov), None
@@ -204,7 +204,7 @@ def _step(carry, args):
204204
# Prediction step
205205
m_pred = f(filtered_mean, u)
206206
S_pred = Q + F_x @ filtered_cov @ F_x.T
207-
G = linear_solve(S_pred, F_x @ filtered_cov).T
207+
G = psd_solve(S_pred, F_x @ filtered_cov).T
208208

209209
# Compute smoothed mean and covariance
210210
smoothed_mean = filtered_mean + G @ (smoothed_mean_next - m_pred)

dynamax/nonlinear_gaussian_ssm/inference_ukf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jaxtyping import Array, Float
66
from typing import NamedTuple, Optional
77

8-
from dynamax.utils.utils import linear_solve
8+
from dynamax.utils.utils import psd_solve
99
from dynamax.nonlinear_gaussian_ssm.models import ParamsNLGSSM
1010
from dynamax.linear_gaussian_ssm.models import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
1111

@@ -130,7 +130,7 @@ def _condition_on(m, P, h, R, lamb, w_mean, w_cov, u, y):
130130
ll = MVN(pred_mean, pred_cov).log_prob(y)
131131

132132
# Compute filtered mean and covariace
133-
K = linear_solve(pred_cov, pred_cross.T).T # Filter gain
133+
K = psd_solve(pred_cov, pred_cross.T).T # Filter gain
134134
m_cond = m + K @ (y - pred_mean)
135135
P_cond = P - K @ pred_cov @ K.T
136136
return ll, m_cond, P_cond
@@ -244,7 +244,7 @@ def _step(carry, args):
244244

245245
# Prediction step
246246
m_pred, S_pred, S_cross = _predict(filtered_mean, filtered_cov, f, Q, lamb, w_mean, w_cov, u)
247-
G = linear_solve(S_pred, S_cross.T).T
247+
G = psd_solve(S_pred, S_cross.T).T
248248

249249
# Compute smoothed mean and covariance
250250
smoothed_mean = filtered_mean + G @ (smoothed_mean_next - m_pred)

dynamax/utils/distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
tfd = tfp.distributions
77
tfb = tfp.bijectors
88

9-
from dynamax.utils.utils import linear_solve
9+
from dynamax.utils.utils import psd_solve
1010

1111

1212
class InverseWishart(tfd.TransformedDistribution):
@@ -319,7 +319,7 @@ def mniw_posterior_update(mniw_prior, sufficient_stats):
319319
Sxx = V_pri + SxxT
320320
Sxy = SxyT + V_pri @ M_pri.T
321321
Syy = SyyT + M_pri @ V_pri @ M_pri.T
322-
M_pos = linear_solve(Sxx, Sxy).T
322+
M_pos = psd_solve(Sxx, Sxy).T
323323
V_pos = Sxx
324324
nu_pos = nu_pri + N
325325
Psi_pos = Psi_pri + Syy - M_pos @ Sxy

dynamax/utils/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,6 @@ def find_permutation(
198198
return perm
199199

200200

201-
def linear_solve(A,b):
202-
"""A wrapper for coordinating the linalg solvers used in the library."""
201+
def psd_solve(A,b):
202+
"""A wrapper for coordinating the linalg solvers used in the library for psd matrices."""
203203
return jnp.linalg.solve(A,b)

0 commit comments

Comments
 (0)