diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 87079ee9..453de651 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -11,7 +11,7 @@ from typing_extensions import Protocol from dynamax.ssm import SSM -from dynamax.linear_gaussian_ssm.inference import lgssm_filter, lgssm_smoother, lgssm_posterior_sample +from dynamax.linear_gaussian_ssm.inference import lgssm_joint_sample, lgssm_filter, lgssm_smoother, lgssm_posterior_sample from dynamax.linear_gaussian_ssm.inference import ParamsLGSSM, ParamsLGSSMInitial, ParamsLGSSMDynamics, ParamsLGSSMEmissions from dynamax.linear_gaussian_ssm.inference import PosteriorGSSMFiltered, PosteriorGSSMSmoothed from dynamax.parameters import ParameterProperties, ParameterSet @@ -198,6 +198,15 @@ def emission_distribution( if self.has_emissions_bias: mean += params.emissions.bias return MVN(mean, params.emissions.cov) + + def sample( + self, + params: ParamsLGSSM, + key: PRNGKey, + num_timesteps: int, + inputs: Optional[Float[Array, "ntime input_dim"]] = None + ) -> PosteriorGSSMFiltered: + return lgssm_joint_sample(params, key, num_timesteps, inputs) def marginal_log_prob( self,