Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use reverse=True keyword argument in lax.scan for smoothers #365

Merged
merged 2 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions dynamax/generalized_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
_jacfwd_2d = lambda f, x: jnp.atleast_2d(jacfwd(f)(x))



class EKFIntegrals(NamedTuple):
""" Lightweight container for EKF Gaussian integrals."""
gaussian_expectation: Callable = lambda f, m, P: jnp.atleast_1d(f(m))
Expand Down Expand Up @@ -85,7 +84,7 @@ def compute_weights_and_sigmas(self, m, P):

def _predict(m, P, f, Q, u, g_ev, g_cov):
"""Predict next mean and covariance under an additive-noise Gaussian filter

p(x_{t+1}) = N(x_{t+1} | mu_pred, Sigma_pred)
where
mu_pred = gev(f, m, P)
Expand Down Expand Up @@ -337,13 +336,17 @@ def _step(carry, args):
return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov)

# Run the smoother
init_carry = (filtered_means[-1], filtered_covs[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1])
_, (smoothed_means, smoothed_covs) = lax.scan(_step, init_carry, args)
_, (smoothed_means, smoothed_covs) = lax.scan(
_step,
(filtered_means[-1], filtered_covs[-1]),
(jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]),
reverse=True
)

# Concatenate the last smoothed mean and covariance
smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...]))

# Reverse the arrays and return
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand Down
37 changes: 20 additions & 17 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def _step(carry, t):
return post



@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Expand Down Expand Up @@ -184,9 +183,9 @@ def _step(carry, t):
next_backward_pred_probs = _predict(backward_filt_probs, A.T)
return (log_normalizer, next_backward_pred_probs), backward_pred_probs

carry = (0.0, jnp.ones(num_states))
(log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1])
backward_pred_probs = rev_backward_pred_probs[::-1]
(log_normalizer, _), backward_pred_probs = lax.scan(
_step, (0.0, jnp.ones(num_states)), jnp.arange(num_timesteps), reverse=True
)
return log_normalizer, backward_pred_probs


Expand Down Expand Up @@ -273,7 +272,7 @@ def hmm_smoother(
posterior distribution

"""
num_timesteps, num_states = log_likelihoods.shape
num_timesteps = log_likelihoods.shape[0]

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
Expand All @@ -298,12 +297,15 @@ def _step(carry, args):
return smoothed_probs, smoothed_probs

# Run the HMM smoother
carry = filtered_probs[-1]
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1])
_, rev_smoothed_probs = lax.scan(_step, carry, args)
_, smoothed_probs = lax.scan(
_step,
filtered_probs[-1],
(jnp.arange(num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:]),
reverse=True,
)

# Reverse the arrays and return
smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]])
# Concatenate the arrays and return
smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])

# Package into a posterior
posterior = HMMPosterior(
Expand Down Expand Up @@ -467,10 +469,9 @@ def _backward_pass(best_next_score, t):
return best_next_score, best_next_state

num_states = log_likelihoods.shape[1]
best_second_score, rev_best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1)
best_second_score, best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 1), reverse=True
)
best_next_states = rev_best_next_states[::-1]

# Run the forward pass
def _forward_pass(state, best_next_state):
Expand Down Expand Up @@ -530,11 +531,13 @@ def _step(carry, args):
# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1])
_, rev_states = lax.scan(_step, last_state, args)
_, states = lax.scan(
_step, last_state, (jnp.arange(1, num_timesteps), rngs[:-1], filtered_probs[:-1]),
reverse=True
)

# Reverse the arrays and return
states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])])
# Add the last state
states = jnp.concatenate([states, jnp.array([last_state])])
return log_normalizer, states

def _compute_sum_transition_probs(
Expand Down
101 changes: 53 additions & 48 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,21 +45,21 @@ class ParamsLGSSMDynamics(NamedTuple):
:param cov: dynamics covariance $Q$

"""
weights: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
weights: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "state_dim"],
Float[Array, "state_dim"],
Float[Array, "ntime state_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "state_dim input_dim"],
Float[Array, "state_dim input_dim"],
Float[Array, "ntime state_dim input_dim"]]
cov: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],

cov: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],
Float[Array, "state_dim_triu"]]


Expand All @@ -77,22 +77,22 @@ class ParamsLGSSMEmissions(NamedTuple):

"""
weights: Union[ParameterProperties,
Float[Array, "emission_dim state_dim"],
Float[Array, "emission_dim state_dim"],
Float[Array, "ntime emission_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "emission_dim input_dim"],
Float[Array, "emission_dim input_dim"],
Float[Array, "ntime emission_dim input_dim"]]

cov: Union[ParameterProperties,
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim_triu"]]


Expand Down Expand Up @@ -166,9 +166,9 @@ def _get_params(params, num_timesteps, t):
D = _get_one_param(params.emissions.input_weights, 2, t)
d = _get_one_param(params.emissions.bias, 1, t)

if len(params.emissions.cov.shape) == 1:
if len(params.emissions.cov.shape) == 1:
R = _get_one_param(params.emissions.cov, 1, t)
elif len(params.emissions.cov.shape) > 2:
elif len(params.emissions.cov.shape) > 2:
R = _get_one_param(params.emissions.cov, 2, t)
elif params.emissions.cov.shape[0] != num_timesteps:
R = _get_one_param(params.emissions.cov, 2, t)
Expand Down Expand Up @@ -278,20 +278,20 @@ def _condition_on(m, P, H, D, d, R, u, y):
if R.ndim == 2:
S = R + H @ P @ H.T
K = psd_solve(S, H @ P).T
else:
else:
# Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I
# (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity)
I = jnp.eye(P.shape[0])
U = H @ jnp.linalg.cholesky(P)
X = U / R[:, None]
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
"""
# Could alternatively use U=H and C=P
R_inv = jnp.diag(1.0 / R)
P_inv = psd_solve(P, jnp.eye(P.shape[0]))
S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv)
"""
K = P @ H.T @ S_inv
K = P @ H.T @ S_inv
S = jnp.diag(R) + H @ P @ H.T

Sigma_cond = P - K @ S @ K.T
Expand Down Expand Up @@ -361,8 +361,6 @@ def wrapper(*args, **kwargs):
return wrapper




def lgssm_joint_sample(
params: ParamsLGSSM,
key: PRNGKey,
Expand All @@ -371,7 +369,7 @@ def lgssm_joint_sample(
)-> Tuple[Float[Array, "num_timesteps state_dim"],
Float[Array, "num_timesteps emission_dim"]]:
r"""Sample from the joint distribution to produce state and emission trajectories.

Args:
params: model parameters
inputs: optional array of inputs.
Expand All @@ -390,7 +388,7 @@ def _sample_emission(key, H, D, d, R, x, u):
mean = H @ x + D @ u + d
R = jnp.diag(R) if R.ndim==1 else R
return MVN(mean, R).sample(seed=key)

def _sample_initial(key, params, inputs):
key1, key2 = jr.split(key)

Expand All @@ -417,7 +415,7 @@ def _step(prev_state, args):

# Sample the initial state
key1, key2 = jr.split(key)

initial_state, initial_emission = _sample_initial(key1, params, inputs)

# Sample the remaining emissions and states
Expand Down Expand Up @@ -462,7 +460,7 @@ def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y):
else:
L = H @ jnp.linalg.cholesky(pred_cov)
return MVNLowRank(m, R, L).log_prob(y)


def _step(carry, t):
ll, pred_mean, pred_cov = carry
Expand Down Expand Up @@ -539,14 +537,17 @@ def _step(carry, args):
return (smoothed_mean, smoothed_cov), (smoothed_mean, smoothed_cov, smoothed_cross)

# Run the Kalman smoother
init_carry = (filtered_means[-1], filtered_covs[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_means[:-1][::-1], filtered_covs[:-1][::-1])
_, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(_step, init_carry, args)

# Reverse the arrays and return
smoothed_means = jnp.vstack((smoothed_means[::-1], filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs[::-1], filtered_covs[-1][None, ...]))
smoothed_cross = smoothed_cross[::-1]
_, (smoothed_means, smoothed_covs, smoothed_cross) = lax.scan(
_step,
(filtered_means[-1], filtered_covs[-1]),
(jnp.arange(num_timesteps - 1), filtered_means[:-1], filtered_covs[:-1]),
reverse=True,
)

# Concatenate the arrays and return
smoothed_means = jnp.vstack((smoothed_means, filtered_means[-1][None, ...]))
smoothed_covs = jnp.vstack((smoothed_covs, filtered_covs[-1][None, ...]))

return PosteriorGSSMSmoothed(
marginal_loglik=ll,
filtered_means=filtered_means,
Expand All @@ -563,7 +564,7 @@ def lgssm_posterior_sample(
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
jitter: Optional[Scalar]=0

) -> Float[Array, "ntime state_dim"]:
r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$.

Expand Down Expand Up @@ -603,12 +604,16 @@ def _step(carry, args):
key, this_key = jr.split(key, 2)
last_state = MVN(filtered_means[-1], filtered_covs[-1]).sample(seed=this_key)

args = (
jr.split(key, num_timesteps - 1),
filtered_means[:-1][::-1],
filtered_covs[:-1][::-1],
jnp.arange(num_timesteps - 2, -1, -1),
_, states = lax.scan(
_step,
last_state,
(
jr.split(key, num_timesteps - 1),
filtered_means[:-1],
filtered_covs[:-1],
jnp.arange(num_timesteps - 1),
),
reverse=True,
)
_, reversed_states = lax.scan(_step, last_state, args)
states = jnp.vstack([reversed_states[::-1], last_state])
return states

return jnp.vstack([states, last_state])
16 changes: 10 additions & 6 deletions dynamax/linear_gaussian_ssm/info_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,13 +271,17 @@ def _smooth_step(carry, args):
return (smoothed_eta, smoothed_prec), (smoothed_eta, smoothed_prec)

# Run the Kalman smoother
init_carry = (filtered_etas[-1], filtered_precisions[-1])
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_etas[:-1][::-1], filtered_precisions[:-1][::-1])
_, (smoothed_etas, smoothed_precisions) = lax.scan(_smooth_step, init_carry, args)
_, (smoothed_etas, smoothed_precisions) = lax.scan(
_smooth_step,
(filtered_etas[-1], filtered_precisions[-1]),
(jnp.arange(num_timesteps - 1), filtered_etas[:-1], filtered_precisions[:-1]),
reverse=True
)

# Concatenate the arrays and return
smoothed_etas = jnp.vstack((smoothed_etas, filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.vstack((smoothed_precisions, filtered_precisions[-1][None, ...]))

# Reverse the arrays and return
smoothed_etas = jnp.vstack((smoothed_etas[::-1], filtered_etas[-1][None, ...]))
smoothed_precisions = jnp.vstack((smoothed_precisions[::-1], filtered_precisions[-1][None, ...]))
return PosteriorGSSMInfoSmoothed(
marginal_loglik=ll,
filtered_etas=filtered_etas,
Expand Down
Loading
Loading