Skip to content

Commit

Permalink
Merge pull request #385 from dominikstrb/parallel-inputs
Browse files Browse the repository at this point in the history
Parallel Kalman filter and smoother with inputs
  • Loading branch information
slinderman authored Nov 8, 2024
2 parents f8fca63 + 27c643d commit 51b7dc5
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 33 deletions.
81 changes: 50 additions & 31 deletions dynamax/linear_gaussian_ssm/parallel_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import jax.numpy as jnp
from jax import vmap, lax
from jaxtyping import Array, Float
from typing import NamedTuple
from typing import NamedTuple, Optional
from dynamax.types import PRNGKey
from functools import partial
import warnings
Expand All @@ -45,6 +45,7 @@
from jax.scipy.linalg import cho_solve, cho_factor
from dynamax.utils.utils import symmetrize, psd_solve
from dynamax.linear_gaussian_ssm import PosteriorGSSMFiltered, PosteriorGSSMSmoothed, ParamsLGSSM
from dynamax.linear_gaussian_ssm.inference import _zeros_if_none


def _get_one_param(x, dim, t):
Expand All @@ -56,14 +57,16 @@ def _get_one_param(x, dim, t):
else:
return x

def _get_params(params, num_timesteps, t):
def _get_params(params: ParamsLGSSM, num_timesteps, t):
"""Helper function to get parameters at time t."""
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable."

F = _get_one_param(params.dynamics.weights, 2, t)
B = _get_one_param(params.dynamics.input_weights, 2, t)
b = _get_one_param(params.dynamics.bias, 1, t)
Q = _get_one_param(params.dynamics.cov, 2, t)
H = _get_one_param(params.emissions.weights, 2, t+1)
D = _get_one_param(params.emissions.input_weights, 2, t+1)
d = _get_one_param(params.emissions.bias, 1, t+1)

if len(params.emissions.cov.shape) == 1:
Expand All @@ -81,7 +84,7 @@ def _get_params(params, num_timesteps, t):
"The covariance will be interpreted as static and non-diagonal. To "
"specify a dynamic and diagonal covariance, pass it as a 3D array.")

return F, b, Q, H, d, R
return F, B, b, Q, H, D, d, R


#---------------------------------------------------------------------------#
Expand Down Expand Up @@ -151,13 +154,18 @@ class FilterMessage(NamedTuple):
logZ: Float[Array, "ntime"]


def _initialize_filtering_messages(params, emissions):
def _initialize_filtering_messages(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
):
"""Preprocess observations to construct input for filtering assocative scan."""

num_timesteps = emissions.shape[0]
inputs = _zeros_if_none(inputs, (num_timesteps, 0))

def _first_message(params, y):
H, d, R = _get_params(params, num_timesteps, -1)[3:]
def _first_message(params, y, u):
H, D, d, R = _get_params(params, num_timesteps, -1)[4:]
m = params.initial.mean
P = params.initial.cov

Expand All @@ -166,34 +174,35 @@ def _first_message(params, y):
K = P @ H.T @ S_inv

A = jnp.zeros_like(P)
b = m + K @ (y - H @ m - d)
b = m + K @ (y - H @ m - D @ u - d)
C = symmetrize(P - K @ S @ K.T)
eta = jnp.zeros_like(b)
J = jnp.eye(len(b))

logZ = _marginal_loglik_elem(P, H, R, y)
logZ = _marginal_loglik_elem(P, H, R, y - H @ m - D @ u - d)
return A, b, C, J, eta, logZ


@partial(vmap, in_axes=(None, 0, 0))
def _generic_message(params, y, t):
F, b, Q, H, d, R = _get_params(params, num_timesteps, t)
@partial(vmap, in_axes=(None, 0, 0, 0))
def _generic_message(params, y, u, t):
F, B, b, Q, H, D, d, R = _get_params(params, num_timesteps, t)

S_inv = _emissions_scale(Q, H, R)
K = Q @ H.T @ S_inv

eta = F.T @ H.T @ S_inv @ (y - H @ b - d)
innov = (y - H @ b - D @ u - d)
eta = F.T @ H.T @ S_inv @ innov
J = symmetrize(F.T @ H.T @ S_inv @ H @ F)

A = F - K @ H @ F
b = b + K @ (y - H @ b - d)
b = b + B @ u + K @ innov
C = symmetrize(Q - K @ H @ Q)

logZ = _marginal_loglik_elem(Q, H, R, y)
logZ = _marginal_loglik_elem(Q, H, R, innov)
return A, b, C, J, eta, logZ

A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0])
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], jnp.arange(len(emissions)-1))
A0, b0, C0, J0, eta0, logZ0 = _first_message(params, emissions[0], inputs[0])
At, bt, Ct, Jt, etat, logZt = _generic_message(params, emissions[1:], inputs[1:], jnp.arange(len(emissions)-1))

return FilterMessage(
A=jnp.concatenate([A0[None], At]),
Expand All @@ -208,7 +217,8 @@ def _generic_message(params, y, t):

def lgssm_filter(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> PosteriorGSSMFiltered:
"""A parallel version of the lgssm filtering algorithm.
Expand Down Expand Up @@ -238,7 +248,7 @@ def _operator(elem1, elem2):
logZ = (logZ1 + logZ2 + 0.5 * jnp.linalg.slogdet(I_C1J2)[1] + 0.5 * t1)
return FilterMessage(A, b, C, J, eta, logZ)

initial_messages = _initialize_filtering_messages(params, emissions)
initial_messages = _initialize_filtering_messages(params, emissions, inputs)
final_messages = lax.associative_scan(_operator, initial_messages)

return PosteriorGSSMFiltered(
Expand All @@ -265,25 +275,30 @@ class SmoothMessage(NamedTuple):
L: Float[Array, "ntime state_dim state_dim"]


def _initialize_smoothing_messages(params, filtered_means, filtered_covariances):
def _initialize_smoothing_messages(params: ParamsLGSSM,
filtered_means: Float[Array, "ntime state_dim"],
filtered_covariances: Float[Array, "ntime state_dim state_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> SmoothMessage:
"""Preprocess filtering output to construct input for smoothing assocative scan."""

def _last_message(m, P):
return jnp.zeros_like(P), m, P

num_timesteps = filtered_means.shape[0]
inputs = _zeros_if_none(inputs, (num_timesteps, 0))

@partial(vmap, in_axes=(None, 0, 0, 0))
def _generic_message(params, m, P, t):
F, b, Q = _get_params(params, num_timesteps, t)[:3]
@partial(vmap, in_axes=(None, 0, 0, 0, 0))
def _generic_message(params, m, P, u, t):
F, B, b, Q = _get_params(params, num_timesteps, t)[:4]
CF, low = cho_factor(F @ P @ F.T + Q)
E = cho_solve((CF, low), F @ P).T
g = m - E @ (F @ m + b)
g = m - E @ (F @ m + b + B @ u)
L = symmetrize(P - E @ F @ P)
return E, g, L

En, gn, Ln = _last_message(filtered_means[-1], filtered_covariances[-1])
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], jnp.arange(len(filtered_means)-1))
Et, gt, Lt = _generic_message(params, filtered_means[:-1], filtered_covariances[:-1], inputs[:-1], jnp.arange(len(filtered_means)-1))

return SmoothMessage(
E=jnp.concatenate([Et, En[None]]),
Expand All @@ -294,15 +309,16 @@ def _generic_message(params, m, P, t):

def lgssm_smoother(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> PosteriorGSSMSmoothed:
"""A parallel version of the lgssm smoothing algorithm.
See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
Note: This function does not yet handle `inputs` to the system.
"""
filtered_posterior = lgssm_filter(params, emissions)
filtered_posterior = lgssm_filter(params, emissions, inputs)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances

Expand All @@ -315,7 +331,7 @@ def _operator(elem1, elem2):
L = symmetrize(E2 @ L1 @ E2.T + L2)
return E, g, L

initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs)
initial_messages = _initialize_smoothing_messages(params, filtered_means, filtered_covs, inputs)
final_messages = lax.associative_scan(_operator, initial_messages, reverse=True)

return PosteriorGSSMSmoothed(
Expand Down Expand Up @@ -343,7 +359,9 @@ class SampleMessage(NamedTuple):
h: Float[Array, "ntime state_dim"]


def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances):
def _initialize_sampling_messages(key, params, filtered_means, filtered_covariances,
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> SampleMessage:
"""A parallel version of the lgssm sampling algorithm.
Given parallel smoothing messages `z_i ~ N(E_i z_{i+1} + g_i, L_i)`,
Expand All @@ -356,15 +374,16 @@ def _initialize_sampling_messages(key, params, filtered_means, filtered_covarian
def lgssm_posterior_sample(
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"]
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
) -> Float[Array, "ntime state_dim"]:
"""A parallel version of the lgssm sampling algorithm.
See S. Särkkä and Á. F. García-Fernández (2021) - https://arxiv.org/abs/1905.13002.
Note: This function does not yet handle `inputs` to the system.
"""
filtered_posterior = lgssm_filter(params, emissions)
filtered_posterior = lgssm_filter(params, emissions, inputs)
filtered_means = filtered_posterior.filtered_means
filtered_covs = filtered_posterior.filtered_covariances

Expand All @@ -377,6 +396,6 @@ def _operator(elem1, elem2):
h = E2 @ h1 + h2
return E, h

initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs)
initial_messages = _initialize_sampling_messages(key, params, filtered_means, filtered_covs, inputs)
_, samples = lax.associative_scan(_operator, initial_messages, reverse=True)
return samples
70 changes: 68 additions & 2 deletions dynamax/linear_gaussian_ssm/parallel_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from dynamax.linear_gaussian_ssm import LinearGaussianSSM
from dynamax.linear_gaussian_ssm import lgssm_joint_sample
from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother
from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother
from dynamax.linear_gaussian_ssm import lgssm_smoother as serial_lgssm_smoother, lgssm_filter as serial_lgssm_filter
from dynamax.linear_gaussian_ssm import parallel_lgssm_smoother, parallel_lgssm_filter
from dynamax.linear_gaussian_ssm import lgssm_posterior_sample as serial_lgssm_posterior_sample
from dynamax.linear_gaussian_ssm import parallel_lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference_test import flatten_diagonal_emission_cov
Expand Down Expand Up @@ -45,6 +45,37 @@ def make_static_lgssm_params():
emission_weights=H,
emission_covariance=R)
return params, lgssm


def make_lgssm_params_with_inputs():
dt = 0.1
F = jnp.eye(4) + dt * jnp.eye(4, k=2)
B = jnp.array([[0., 0.], [1., 0.], [0., 0.], [0., 1.]]) * dt
Q = 1. * jnp.kron(jnp.array([[dt**3/3, dt**2/2],
[dt**2/2, dt]]),
jnp.eye(2))

H = jnp.eye(2, 4)
D = jnp.ones((2, 2))
R = 0.5 ** 2 * jnp.eye(2)
μ0 = jnp.array([0.,0.,1.,-1.])
Σ0 = jnp.eye(4)

latent_dim = 4
observation_dim = 2
input_dim = 2

lgssm = LinearGaussianSSM(latent_dim, observation_dim, input_dim)
params, _ = lgssm.initialize(jr.PRNGKey(0),
initial_mean=μ0,
initial_covariance= Σ0,
dynamics_weights=F,
dynamics_input_weights=B,
dynamics_covariance=Q,
emission_weights=H,
emission_input_weights=D,
emission_covariance=R)
return params, lgssm


def make_dynamic_lgssm_params(num_timesteps, latent_dim=4, observation_dim=2, seed=0):
Expand Down Expand Up @@ -114,6 +145,41 @@ def test_marginal_loglik(self):
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1)


class TestParallelLGSSMSmootherWithInputs:
""" Compare parallel and serial lgssm smoothing implementations."""

num_timesteps = 50
key = jr.PRNGKey(1)

params, lgssm = make_lgssm_params_with_inputs()
params_diag = flatten_diagonal_emission_cov(params)
inputs = jnp.ones((num_timesteps, 2))
_, emissions = lgssm_joint_sample(params, key, num_timesteps, inputs)


serial_posterior = serial_lgssm_smoother(params, emissions, inputs)
parallel_posterior = parallel_lgssm_smoother(params, emissions, inputs)
parallel_posterior_diag = parallel_lgssm_smoother(params_diag, emissions, inputs)

def test_filtered_means(self):
assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior.filtered_means)
assert allclose(self.serial_posterior.filtered_means, self.parallel_posterior_diag.filtered_means)

def test_filtered_covariances(self):
assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior.filtered_covariances)
assert allclose(self.serial_posterior.filtered_covariances, self.parallel_posterior_diag.filtered_covariances)

def test_smoothed_means(self):
assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior.smoothed_means)
assert allclose(self.serial_posterior.smoothed_means, self.parallel_posterior_diag.smoothed_means)

def test_smoothed_covariances(self):
assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior.smoothed_covariances)
assert allclose(self.serial_posterior.smoothed_covariances, self.parallel_posterior_diag.smoothed_covariances)

def test_marginal_loglik(self):
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior.marginal_loglik, atol=2e-1)
assert jnp.allclose(self.serial_posterior.marginal_loglik, self.parallel_posterior_diag.marginal_loglik, atol=2e-1)


class TestTimeVaryingParallelLGSSMSmoother:
Expand Down

0 comments on commit 51b7dc5

Please sign in to comment.