From b9003e36d5f8039f2ea8138252825c438b313d2f Mon Sep 17 00:00:00 2001 From: gileshd Date: Sat, 21 Sep 2024 12:50:20 +0100 Subject: [PATCH] [WIP] Add further type annotations to multinomialhmm See #TODOs for final steps. --- .../models/multinomial_hmm.py | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/dynamax/hidden_markov_model/models/multinomial_hmm.py b/dynamax/hidden_markov_model/models/multinomial_hmm.py index 6493f001..ce76114a 100644 --- a/dynamax/hidden_markov_model/models/multinomial_hmm.py +++ b/dynamax/hidden_markov_model/models/multinomial_hmm.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr @@ -6,6 +6,7 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState @@ -22,12 +23,13 @@ class ParamsMultinomialHMMEmissions(NamedTuple): class MultinomialHMMEmissions(HMMEmissions): + #TODO: [GHD] Should the emission_prior_concentration allow an array of length num_classes? def __init__(self, - num_states, - emission_dim, - num_classes, - num_trials, - emission_prior_concentration=1.1): + num_states: int, + emission_dim: int, + num_classes: int, + num_trials: int, + emission_prior_concentration: Scalar = 1.1): self.num_states = num_states self.emission_dim = emission_dim self.num_classes = num_classes @@ -38,7 +40,11 @@ def __init__(self, def emission_shape(self): return (self.emission_dim, self.num_classes) - def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): + def initialize(self, + key: Array = jr.PRNGKey(0), + method: str = "prior", + emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]] = None + ) -> Tuple[ParamsMultinomialHMMEmissions, ParamsMultinomialHMMEmissions]: # Initialize the emission probabilities if emission_probs is None: if method.lower() == "prior": @@ -58,22 +64,41 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None): props = ParamsMultinomialHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered())) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsMultinomialHMMEmissions, + state: int, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.Independent( tfd.Multinomial(self.num_trials, probs=params.probs[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): + def log_prior(self, params: ParamsMultinomialHMMEmissions) -> Float[Array, ""]: return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum() - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + # TODO: [GHD] Specify the shapes of emissions, inputs, and return array. + def collect_suff_stats( + self, + params: ParamsMultinomialHMMEmissions, + posterior: HMMPosterior, + emissions: Array, + inputs: Optional[Array] = None + ) -> Dict[str, Array]: expected_states = posterior.smoothed_probs return dict(sum_x=jnp.einsum("tk, tdi->kdi", expected_states, emissions)) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params, props) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + # TODO: [GHD] Specify the shapes of batch_stats + def m_step( + self, + params: ParamsMultinomialHMMEmissions, + props: ParamsMultinomialHMMEmissions, + batch_stats: Dict[str, Array], + m_step_state: Any + ) -> Tuple[ParamsMultinomialHMMEmissions, Any]: if props.probs.trainable: emission_stats = pytree_sum(batch_stats, axis=0) probs = tfd.Dirichlet( @@ -111,10 +136,10 @@ class MultinomialHMM(HMM): """ def __init__(self, - num_states, - emission_dim, - num_classes, - num_trials, + num_states: int, + emission_dim: int, + num_classes: int, + num_trials: int, initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1, transition_matrix_stickiness: Scalar=0.0,