Skip to content

Commit

Permalink
[WIP] Add further type annotations to multinomialhmm
Browse files Browse the repository at this point in the history
See #TODOs for final steps.
  • Loading branch information
gileshd committed Sep 21, 2024
1 parent 6197671 commit b9003e3
Showing 1 changed file with 41 additions and 16 deletions.
57 changes: 41 additions & 16 deletions dynamax/hidden_markov_model/models/multinomial_hmm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union

import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Array, Float

from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions
from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState
Expand All @@ -22,12 +23,13 @@ class ParamsMultinomialHMMEmissions(NamedTuple):

class MultinomialHMMEmissions(HMMEmissions):

#TODO: [GHD] Should the emission_prior_concentration allow an array of length num_classes?
def __init__(self,
num_states,
emission_dim,
num_classes,
num_trials,
emission_prior_concentration=1.1):
num_states: int,
emission_dim: int,
num_classes: int,
num_trials: int,
emission_prior_concentration: Scalar = 1.1):
self.num_states = num_states
self.emission_dim = emission_dim
self.num_classes = num_classes
Expand All @@ -38,7 +40,11 @@ def __init__(self,
def emission_shape(self):
return (self.emission_dim, self.num_classes)

def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
def initialize(self,
key: Array = jr.PRNGKey(0),
method: str = "prior",
emission_probs: Optional[Float[Array, "num_states emission_dim num_classes"]] = None
) -> Tuple[ParamsMultinomialHMMEmissions, ParamsMultinomialHMMEmissions]:
# Initialize the emission probabilities
if emission_probs is None:
if method.lower() == "prior":
Expand All @@ -58,22 +64,41 @@ def initialize(self, key=jr.PRNGKey(0), method="prior", emission_probs=None):
props = ParamsMultinomialHMMEmissions(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props

def distribution(self, params, state, inputs=None):
def distribution(
self,
params: ParamsMultinomialHMMEmissions,
state: int,
inputs: Optional[Array] = None
) -> tfd.Distribution:
return tfd.Independent(
tfd.Multinomial(self.num_trials, probs=params.probs[state]),
reinterpreted_batch_ndims=1)

def log_prior(self, params):
def log_prior(self, params: ParamsMultinomialHMMEmissions) -> Float[Array, ""]:
return tfd.Dirichlet(self.emission_prior_concentration).log_prob(params.probs).sum()

def collect_suff_stats(self, params, posterior, emissions, inputs=None):
# TODO: [GHD] Specify the shapes of emissions, inputs, and return array.
def collect_suff_stats(
self,
params: ParamsMultinomialHMMEmissions,
posterior: HMMPosterior,
emissions: Array,
inputs: Optional[Array] = None
) -> Dict[str, Array]:
expected_states = posterior.smoothed_probs
return dict(sum_x=jnp.einsum("tk, tdi->kdi", expected_states, emissions))

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params, props) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
# TODO: [GHD] Specify the shapes of batch_stats
def m_step(
self,
params: ParamsMultinomialHMMEmissions,
props: ParamsMultinomialHMMEmissions,
batch_stats: Dict[str, Array],
m_step_state: Any
) -> Tuple[ParamsMultinomialHMMEmissions, Any]:
if props.probs.trainable:
emission_stats = pytree_sum(batch_stats, axis=0)
probs = tfd.Dirichlet(
Expand Down Expand Up @@ -111,10 +136,10 @@ class MultinomialHMM(HMM):
"""
def __init__(self,
num_states,
emission_dim,
num_classes,
num_trials,
num_states: int,
emission_dim: int,
num_classes: int,
num_trials: int,
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1,
transition_matrix_stickiness: Scalar=0.0,
Expand Down

0 comments on commit b9003e3

Please sign in to comment.