Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified lgssm_filter to allow it to processes data with NaNs #306

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,8 @@ def _step(prev_state, args):
def lgssm_filter(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
nan_fill_multiplier: float=1e8
) -> PosteriorGSSMFiltered:
r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates.

Expand All @@ -386,6 +387,18 @@ def lgssm_filter(
num_timesteps = len(emissions)
inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs

# Create a vector to replace nans in the emissions
# if an entire time trace is nan, replace with the mean across all values
# if the entire emission is nan, replace with 0s
nan_fill_mean = jnp.nanmean(emissions, axis=0)
nan_fill_mean = jnp.where(jnp.isnan(nan_fill_mean), jnp.nanmean(nan_fill_mean), nan_fill_mean)
nan_fill_mean = jnp.where(jnp.isnan(nan_fill_mean), 0, nan_fill_mean)

# Create a vector to set the diagonal of the covariance of the emissions to a large number
# wherever there are NaNs in the emissions
nan_fill_cov = jnp.nanmax(jnp.nanvar(emissions)) * nan_fill_multiplier
nan_fill_cov = jnp.where(jnp.isnan(nan_fill_cov), nan_fill_multiplier, nan_fill_cov)

def _step(carry, t):
ll, pred_mean, pred_cov = carry

Expand All @@ -401,6 +414,12 @@ def _step(carry, t):
u = inputs[t]
y = emissions[t]

# Find NaNs in the emissions and replace them with nan_fill_mean
# Then, set the emission covariance to nan_fill_cov to push the filter to ignore these emissions
nan_loc = jnp.isnan(y)
y = jnp.where(nan_loc, nan_fill_mean, y)
R = jnp.where(jnp.diag(nan_loc), nan_fill_cov, R)

# Update the log likelihood
ll += MVN(H @ pred_mean + D @ u + d, H @ pred_cov @ H.T + R).log_prob(y)

Expand Down
20 changes: 16 additions & 4 deletions dynamax/linear_gaussian_ssm/inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from dynamax.utils.utils import has_tpu

if has_tpu():
def allclose(x, y):
return jnp.allclose(x, y, atol=1e-1)
def allclose(x, y, atol=1e-1):
return jnp.allclose(x, y, atol=atol)
else:
def allclose(x,y):
return jnp.allclose(x, y, atol=1e-1)
def allclose(x, y, atol=1e-1):
return jnp.allclose(x, y, atol=atol)

def joint_posterior_mvn(params, emissions):
"""Construct the joint posterior MVN of a LGSSM, by inverting the joint precision matrix which
Expand Down Expand Up @@ -165,6 +165,12 @@ class TestFilteringAndSmoothing():
print(ssm_posterior.filtered_means.shape)
print(ssm_posterior.smoothed_means.shape)

# repeat sampling with NaNs in the emissions
nan_x = (0, emissions.shape[0], 0, emissions.shape[0])
nan_y = (0, emissions.shape[0], emissions.shape[1], emissions.shape[1])
emissions_nan = emissions.at[nan_x, nan_y].set(jnp.nan)
ssm_posterior_nan = lgssm.smoother(params, emissions_nan)

# TensorFlow Probability posteriors
tfp_lgssm = lgssm_dynamax_to_tfp(num_timesteps, params)
tfp_lls, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(emissions)
Expand Down Expand Up @@ -200,6 +206,12 @@ def test_kalman_tfp(self):
assert allclose(self.ssm_posterior.smoothed_covariances, self.tfp_smoothed_covs)
assert allclose(self.ssm_posterior.marginal_loglik, self.tfp_lls.sum())

def test_kalman_tfp_nan(self):
assert allclose(self.ssm_posterior_nan.filtered_means, self.tfp_filtered_means, atol=1e0)
assert allclose(self.ssm_posterior_nan.filtered_covariances, self.tfp_filtered_covs, atol=1e0)
assert allclose(self.ssm_posterior_nan.smoothed_means, self.tfp_smoothed_means, atol=1e0)
assert allclose(self.ssm_posterior_nan.smoothed_covariances, self.tfp_smoothed_covs, atol=1e0)

def test_kalman_vs_joint(self):
assert allclose(self.ssm_posterior.smoothed_means, self.joint_means)
assert allclose(self.ssm_posterior.smoothed_covariances, self.joint_covs)
Expand Down