Skip to content

Commit

Permalink
Add further type annotations to poisson hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
gileshd committed Sep 23, 2024
1 parent b0877ed commit c11f50f
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions dynamax/hidden_markov_model/models/poisson_hmm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import NamedTuple, Optional, Tuple, Union
from typing import Dict, NamedTuple, Optional, Tuple, Union

import jax.numpy as jnp
import jax.random as jr
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd
from jaxtyping import Array, Float

from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions
from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState
from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState
Expand All @@ -23,10 +24,10 @@ class ParamsPoissonHMMEmissions(NamedTuple):
class PoissonHMMEmissions(HMMEmissions):

def __init__(self,
num_states,
emission_dim,
emission_prior_concentration=1.1,
emission_prior_rate=0.1):
num_states: int,
emission_dim: int,
emission_prior_concentration: Scalar = 1.1,
emission_prior_rate: Scalar = 0.1):
"""_summary_
Args:
Expand All @@ -40,12 +41,13 @@ def __init__(self,
self.emission_prior_rate = emission_prior_rate

@property
def emission_shape(self):
def emission_shape(self) -> Tuple[int]:
return (self.emission_dim,)

def initialize(self, key=jr.PRNGKey(0),
method="prior",
emission_rates=None):
def initialize(self, key: Array=jr.PRNGKey(0),
method: str = "prior",
emission_rates: Optional[Float[Array, "num_states emission_dim"]] = None
) -> Tuple[ParamsPoissonHMMEmissions, ParamsPoissonHMMEmissions]:
# Initialize the emission probabilities
if emission_rates is None:
if method.lower() == "prior":
Expand All @@ -64,24 +66,41 @@ def initialize(self, key=jr.PRNGKey(0),
props = ParamsPoissonHMMEmissions(rates=ParameterProperties(constrainer=tfb.Softplus()))
return params, props

def distribution(self, params, state, inputs=None):
def distribution(
self,
params: ParamsPoissonHMMEmissions,
state: int,
inputs: Optional[Array] = None
) -> tfd.Distribution:
return tfd.Independent(tfd.Poisson(rate=params.rates[state]),
reinterpreted_batch_ndims=1)

def log_prior(self, params):
def log_prior(self, params: ParamsPoissonHMMEmissions) -> Float[Array, ""]:
prior = tfd.Gamma(self.emission_prior_concentration, self.emission_prior_rate)
return prior.log_prob(params.rates).sum()

def collect_suff_stats(self, params, posterior, emissions, inputs=None):
def collect_suff_stats(
self,
params: ParamsPoissonHMMEmissions,
posterior: HMMPosterior,
emissions: Float[Array, "num_timesteps emission_dim"],
inputs: Optional[Array] = None
) -> Dict[str, Float[Array, "..."]]:
expected_states = posterior.smoothed_probs
sum_w = jnp.einsum("tk->k", expected_states)[:, None]
sum_x = jnp.einsum("tk, ti->ki", expected_states, emissions)
return dict(sum_w=sum_w, sum_x=sum_x)

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params: ParamsPoissonHMMEmissions, props: ParamsPoissonHMMEmissions) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsPoissonHMMEmissions,
props: ParamsPoissonHMMEmissions,
batch_stats: Dict[str, Float[Array, "..."]],
m_step_state: Any
) -> Tuple[ParamsPoissonHMMEmissions, Any]:
if props.rates.trainable:
emission_stats = pytree_sum(batch_stats, axis=0)
post_concentration = self.emission_prior_concentration + emission_stats['sum_x']
Expand Down Expand Up @@ -132,7 +151,7 @@ def __init__(self,
emission_component = PoissonHMMEmissions(num_states, emission_dim, emission_prior_concentration=emission_prior_concentration, emission_prior_rate=emission_prior_rate)
super().__init__(num_states, initial_component, transition_component, emission_component)

def initialize(self, key=jr.PRNGKey(0),
def initialize(self, key: Array=jr.PRNGKey(0),
method="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None,
transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,
Expand Down

0 comments on commit c11f50f

Please sign in to comment.