Skip to content

Commit

Permalink
Update type annotations in hmm inference code.
Browse files Browse the repository at this point in the history
Major changes:
- Replace `jaxtyping.Int` with `dynamax.typing.IntScalar` or `int`
  - this reflects when integer scalar arrays are accepted
  - `jaxtyping.[Dtype]` cannot be used directly for type checking
    instead they must be used as part of an array.
- Fix the shape of `transition_matrix`:
  - if transition_matrix has a leading timestep axis it should be of
    length T-1 not of length T.
- Add annotation indicating that `transition_matrix` is an optional argument
- Raise ValueError when neither `transition_matrix` or `transition_fn`
  provided.
  • Loading branch information
gileshd committed Sep 24, 2024
1 parent da2395f commit 27008b7
Showing 1 changed file with 49 additions and 40 deletions.
89 changes: 49 additions & 40 deletions dynamax/hidden_markov_model/inference.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from functools import partial
from typing import Callable, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jax import jit, lax, vmap
from functools import partial

from typing import Callable, Optional, Tuple, Union, NamedTuple
from jaxtyping import Int, Float, Array

from dynamax.types import Scalar
from dynamax.types import IntScalar, Scalar

_get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x

def get_trans_mat(transition_matrix, transition_fn, t):
def get_trans_mat(
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]],
t: IntScalar
) -> Float[Array, "num_states num_states"]:
if transition_fn is not None:
return transition_fn(t)
else:
if transition_matrix.ndim == 3: # (T,K,K)
elif transition_matrix is not None:
if transition_matrix.ndim == 3: # (T-1,K,K)
return transition_matrix[t]
else:
return transition_matrix
else:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

class HMMPosteriorFiltered(NamedTuple):
r"""Simple wrapper for properties of an HMM filtering posterior.
Expand Down Expand Up @@ -49,8 +55,8 @@ class HMMPosterior(NamedTuple):
predicted_probs: Float[Array, "num_timesteps num_states"]
smoothed_probs: Float[Array, "num_timesteps num_states"]
initial_probs: Float[Array, " num_states"]
trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]] = None
trans_probs: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None


def _normalize(u: Array, axis=0, eps=1e-15):
Expand Down Expand Up @@ -96,10 +102,10 @@ def _predict(probs, A):
@partial(jit, static_argnames=["transition_fn"])
def hmm_filter(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> HMMPosteriorFiltered:
r"""Forwards filtering
Expand Down Expand Up @@ -143,8 +149,8 @@ def _step(carry, t):

@partial(jit, static_argnames=["transition_fn"])
def hmm_backward_filter(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None
) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]:
Expand Down Expand Up @@ -190,10 +196,10 @@ def _step(carry, t):
@partial(jit, static_argnames=["transition_fn"])
def hmm_two_filter_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using the two-filter
Expand Down Expand Up @@ -244,10 +250,10 @@ def hmm_two_filter_smoother(
@partial(jit, static_argnames=["transition_fn"])
def hmm_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None,
compute_trans_probs: bool = True
) -> HMMPosterior:
r"""Computed the smoothed state probabilities using a general
Expand Down Expand Up @@ -324,11 +330,11 @@ def _step(carry, args):
@partial(jit, static_argnames=["transition_fn", "window_size"])
def hmm_fixed_lag_smoother(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
window_size: Int,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
window_size: int,
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> HMMPosterior:
r"""Compute the smoothed state probabilities using the fixed-lag smoother.
Expand Down Expand Up @@ -438,10 +444,10 @@ def compute_posterior(filtered_probs, beta):
@partial(jit, static_argnames=["transition_fn"])
def hmm_posterior_mode(
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None
) -> Int[Array, " num_timesteps"]:
r"""Compute the most likely state sequence. This is called the Viterbi algorithm.
Expand Down Expand Up @@ -486,10 +492,10 @@ def _forward_pass(state, best_next_state):
def hmm_posterior_sample(
key: Array,
initial_distribution: Float[Array, " num_states"],
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
log_likelihoods: Float[Array, "num_timesteps num_states"],
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Tuple[Scalar, Int[Array, " num_timesteps"]]:
r"""Sample a latent sequence from the posterior.
Expand Down Expand Up @@ -542,6 +548,7 @@ def _compute_sum_transition_probs(
transition_matrix: Float[Array, "num_states num_states"],
hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.
Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand Down Expand Up @@ -578,11 +585,13 @@ def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]):


def _compute_all_transition_probs(
transition_matrix: Float[Array, "num_timesteps num_states num_states"],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Float[Array, "num_timesteps num_states num_states"]:
"""Compute the transition probabilities from the HMM posterior messages.
Args:
transition_matrix (_type_): _description_
hmm_posterior (_type_): _description_
Expand All @@ -600,14 +609,12 @@ def _compute_probs(t):
return transition_probs


# TODO: Consider alternative annotation for return type:
# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims.
# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd).
# TODO: This is a candidate for @overloading.
def compute_transition_probs(
transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]],
transition_matrix: Optional[Union[Float[Array, "num_states num_states"],
Float[Array, "num_timesteps_minus_1 num_states num_states"]]],
hmm_posterior: HMMPosterior,
transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None
transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None
) -> Union[Float[Array, "num_timesteps num_states num_states"],
Float[Array, "num_states num_states"]]:
r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$.
Expand All @@ -620,8 +627,10 @@ def compute_transition_probs(
Returns:
array of smoothed transition probabilities.
"""
reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2
if reduce_sum:
if transition_matrix is None and transition_fn is None:
raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.")

if transition_matrix is not None and transition_matrix.ndim == 2:
return _compute_sum_transition_probs(transition_matrix, hmm_posterior)
else:
return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn)

0 comments on commit 27008b7

Please sign in to comment.