Skip to content

Commit

Permalink
Merge pull request #356 from probml/timevarying_emission_weights_bug
Browse files Browse the repository at this point in the history
Addressing issue 347: time varying weights in LGSSM
  • Loading branch information
slinderman committed Feb 1, 2024
2 parents a1cb385 + b351698 commit 08d1381
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion dynamax/linear_gaussian_ssm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Protocol

from dynamax.ssm import SSM
from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions
from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed
from dynamax.parameters import ParameterProperties, ParameterSet
Expand Down Expand Up @@ -198,6 +198,15 @@ def emission_distribution(
if self.has_emissions_bias:
mean += params.emissions.bias
return MVN(mean, params.emissions.cov)

def sample(
self,
params: ParamsLGSSM,
key: PRNGKey,
num_timesteps: int,
inputs: Optional[Float[Array, "ntime input_dim"]] = None
) -> PosteriorGSSMFiltered:
return lgssm_joint_sample(params, key, num_timesteps, inputs)

def marginal_log_prob(
self,
Expand Down

0 comments on commit 08d1381

Please sign in to comment.