Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inference with padded arrays to support vmap over variable-length time series #343

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 57 additions & 30 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def hmm_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None,
num_timesteps: Optional[Int] = None,
) -> HMMPosteriorFiltered:
r"""Forwards filtering

Expand All @@ -115,27 +116,32 @@ def hmm_filter(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.

Returns:
filtered posterior distribution

"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps = log_likelihoods.shape[0]
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

def _step(carry, t):
log_normalizer, predicted_probs = carry

A = get_trans_mat(transition_matrix, transition_fn, t)
ll = log_likelihoods[t]

# Ignore observations after specified number of timesteps
ll = jnp.where(t < num_timesteps, ll, 0.0)

filtered_probs, log_norm = _condition_on(predicted_probs, ll)
log_normalizer += log_norm
predicted_probs_next = _predict(filtered_probs, A)

return (log_normalizer, predicted_probs_next), (filtered_probs, predicted_probs)

carry = (0.0, initial_distribution)
(log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(num_timesteps))
(log_normalizer, _), (filtered_probs, predicted_probs) = lax.scan(_step, carry, jnp.arange(max_num_timesteps))

post = HMMPosteriorFiltered(marginal_loglik=log_normalizer,
filtered_probs=filtered_probs,
Expand All @@ -149,7 +155,8 @@ def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int] = None
) -> Tuple[Float, Float[Array, "num_timesteps num_states"]]:
r"""Run the filter backwards in time. This is the second step of the forward-backward algorithm.

Expand All @@ -163,30 +170,33 @@ def hmm_backward_filter(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.

Returns:
marginal log likelihood and backward messages.

"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, num_states = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

def _step(carry, t):
log_normalizer, backward_pred_probs = carry

A = get_trans_mat(transition_matrix, transition_fn, t)
ll = log_likelihoods[t]

# Ignore observations after specified number of timesteps
ll = jnp.where(t < num_timesteps, ll, 0.0)

# Condition on emission at time t, being careful not to overflow.
backward_filt_probs, log_norm = _condition_on(backward_pred_probs, ll)
# Update the log normalizer.
log_normalizer += log_norm

# Predict the next state (going backward in time).
next_backward_pred_probs = _predict(backward_filt_probs, A.T)
return (log_normalizer, next_backward_pred_probs), backward_pred_probs

carry = (0.0, jnp.ones(num_states))
(log_normalizer, _), rev_backward_pred_probs = lax.scan(_step, carry, jnp.arange(num_timesteps)[::-1])
backward_pred_probs = rev_backward_pred_probs[::-1]
(log_normalizer, _), backward_pred_probs = lax.scan(_step, carry, jnp.arange(max_num_timesteps), reverse=True)
return log_normalizer, backward_pred_probs


Expand All @@ -197,6 +207,7 @@ def hmm_two_filter_smoother(
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int] = None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand All @@ -212,16 +223,19 @@ def hmm_two_filter_smoother(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.

Returns:
posterior distribution

"""
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
# Forward
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
ll = post.marginal_loglik
filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs

_, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn)
# Backward
_, backward_pred_probs = hmm_backward_filter(transition_matrix, log_likelihoods, transition_fn, num_timesteps)

# Compute smoothed probabilities
smoothed_probs = filtered_probs * backward_pred_probs
Expand Down Expand Up @@ -251,6 +265,7 @@ def hmm_smoother(
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int]=None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand All @@ -268,15 +283,17 @@ def hmm_smoother(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.

Returns:
posterior distribution

"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, _ = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
ll = post.marginal_loglik
filtered_probs, predicted_probs = post.filtered_probs, post.predicted_probs

Expand All @@ -294,16 +311,15 @@ def _step(carry, args):
smoothed_probs_next / predicted_probs_next)
smoothed_probs = filtered_probs * (A @ relative_probs_next)
smoothed_probs /= smoothed_probs.sum()

return smoothed_probs, smoothed_probs

# Run the HMM smoother
carry = filtered_probs[-1]
args = (jnp.arange(num_timesteps - 2, -1, -1), filtered_probs[:-1][::-1], predicted_probs[1:][::-1])
_, rev_smoothed_probs = lax.scan(_step, carry, args)
args = (jnp.arange(max_num_timesteps - 1), filtered_probs[:-1], predicted_probs[1:])
_, smoothed_probs = lax.scan(_step, carry, args, reverse=True)

# Reverse the arrays and return
smoothed_probs = jnp.vstack([rev_smoothed_probs[::-1], filtered_probs[-1]])
smoothed_probs = jnp.vstack([smoothed_probs, filtered_probs[-1]])

# Package into a posterior
posterior = HMMPosterior(
Expand Down Expand Up @@ -352,6 +368,7 @@ def hmm_fixed_lag_smoother(
posterior distribution

"""
# TODO: Update to allow variable length time series
num_timesteps, num_states = log_likelihoods.shape

def _step(carry, t):
Expand Down Expand Up @@ -441,7 +458,8 @@ def hmm_posterior_mode(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
num_timesteps: Optional[Int]=None,
) -> Int[Array, "num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.

Expand All @@ -450,12 +468,14 @@ def hmm_posterior_mode(
transition_matrix: $p(z_{t+1} \mid z_t, u_t, \theta)$
log_likelihoods: $p(y_t \mid z_t, u_t, \theta)$ for $t=1,\ldots, T$.
transition_fn: function that takes in an integer time index and returns a $K \times K$ transition matrix.
num_timesteps: number of "valid" timesteps, to support vmapping with padded arrays.

Returns:
most likely state sequence

"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, _ = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the backward pass
def _backward_pass(best_next_score, t):
Expand All @@ -464,14 +484,19 @@ def _backward_pass(best_next_score, t):
scores = jnp.log(A) + best_next_score + log_likelihoods[t + 1]
best_next_state = jnp.argmax(scores, axis=1)
best_next_score = jnp.max(scores, axis=1)

# Only update if log_likelihoods[t+1] is valid
best_next_score = jnp.where(t + 1 < num_timesteps, best_next_score, jnp.zeros(num_states))
best_next_state = jnp.where(t + 1 < num_timesteps, best_next_state, jnp.zeros(num_states, dtype=int))

return best_next_score, best_next_state

num_states = log_likelihoods.shape[1]
best_second_score, rev_best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(num_timesteps - 2, -1, -1)
best_second_score, best_next_states = lax.scan(
_backward_pass, jnp.zeros(num_states), jnp.arange(max_num_timesteps - 1),
reverse=True
)
best_next_states = rev_best_next_states[::-1]


# Run the forward pass
def _forward_pass(state, best_next_state):
next_state = best_next_state[state]
Expand All @@ -490,7 +515,8 @@ def hmm_posterior_sample(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None,
num_timesteps: Optional[Int] = None,
) -> Int[Array, "num_timesteps"]:
r"""Sample a latent sequence from the posterior.

Expand All @@ -505,10 +531,11 @@ def hmm_posterior_sample(
:sample of the latent states, $z_{1:T}$

"""
num_timesteps, num_states = log_likelihoods.shape
max_num_timesteps, num_states = log_likelihoods.shape
num_timesteps = num_timesteps if num_timesteps is not None else max_num_timesteps

# Run the HMM filter
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn)
post = hmm_filter(initial_distribution, transition_matrix, log_likelihoods, transition_fn, num_timesteps)
log_normalizer, filtered_probs = post.marginal_loglik, post.filtered_probs

# Run the sampler backward in time
Expand All @@ -528,13 +555,13 @@ def _step(carry, args):
return state, state

# Run the HMM smoother
rngs = jr.split(rng, num_timesteps)
rngs = jr.split(rng, max_num_timesteps)
last_state = jr.choice(rngs[-1], a=num_states, p=filtered_probs[-1])
args = (jnp.arange(num_timesteps - 1, 0, -1), rngs[:-1][::-1], filtered_probs[:-1][::-1])
_, rev_states = lax.scan(_step, last_state, args)
args = (jnp.arange(max_num_timesteps - 1), rngs[:-1], filtered_probs[:-1])
_, states = lax.scan(_step, last_state, args, reverse=True)

# Reverse the arrays and return
states = jnp.concatenate([rev_states[::-1], jnp.array([last_state])])
states = jnp.concatenate([states, jnp.array([last_state])])
return log_normalizer, states

def _compute_sum_transition_probs(
Expand Down
38 changes: 38 additions & 0 deletions dynamax/hidden_markov_model/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import dynamax.hidden_markov_model.inference as core
import dynamax.hidden_markov_model.parallel_inference as parallel

from jax import vmap
from jax.scipy.special import logsumexp

def big_log_joint(initial_probs, transition_matrix, log_likelihoods):
Expand Down Expand Up @@ -259,6 +260,43 @@ def trans_mat_callable(t):
assert jnp.allclose(sample, sample2)


def test_hmm_padding(key=0, num_timesteps=10, num_states=5, padding=3):
if isinstance(key, int):
key = jr.PRNGKey(key)

initial_probs, transition_matrix, log_lkhds = random_hmm_args(key, num_timesteps + padding, num_states)

# Run the HMM filter with a 3d list of transition matrices and a callable
post = core.hmm_filter(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
post2 = core.hmm_filter(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(post.marginal_loglik, post2.marginal_loglik, atol=1e-4)
assert jnp.allclose(post.filtered_probs, post2.filtered_probs[:num_timesteps], atol=1e-4)

# Run the HMM smoother with a 3d list of transition matrices and a callable
post = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
post2 = core.hmm_smoother(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(post.smoothed_probs, post2.smoothed_probs[:num_timesteps], atol=1e-4)

# Run Viterbi
mode = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds[:num_timesteps])
mode2 = core.hmm_posterior_mode(initial_probs, transition_matrix, log_lkhds, num_timesteps=num_timesteps)
assert jnp.allclose(mode, mode2[:num_timesteps])


# Test vmap
def test_hmm_variable_length_vmap(key=0, max_num_timesteps=10, num_states=5, num_seqs=10):
if isinstance(key, int):
key = jr.PRNGKey(key)

all_args = vmap(random_hmm_args, in_axes=(0, None, None))(
jr.split(key, num_seqs), max_num_timesteps, num_states)

all_num_timesteps = jr.randint(key, (num_seqs,), 1, max_num_timesteps)

# Just make sure vmap runs without throwing a concretization error
posteriors = vmap(core.hmm_filter)(*all_args, num_timesteps=all_num_timesteps)


def test_parallel_filter(key=0, num_timesteps=100, num_states=3):
if isinstance(key, int):
key = jr.PRNGKey(key)
Expand Down
1 change: 1 addition & 0 deletions dynamax/nonlinear_gaussian_ssm/inference_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,5 +341,6 @@ def _step(carry, _):
smoothed_posterior = extended_kalman_smoother(params, emissions, smoothed_prior, inputs)
return smoothed_posterior, None

# TODO: Does this even work with None as initial carry?
smoothed_posterior, _ = lax.scan(_step, None, jnp.arange(num_iter))
return smoothed_posterior
Loading