Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Format the codebase of LGSSM by black #337

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 106 additions & 88 deletions dynamax/linear_gaussian_ssm/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from tensorflow_probability.substrates.jax.distributions import (
MultivariateNormalDiagPlusLowRankCovariance as MVNLowRank,
MultivariateNormalFullCovariance as MVN)
MultivariateNormalFullCovariance as MVN,
)

from jax.tree_util import tree_map
from jaxtyping import Array, Float
Expand All @@ -16,6 +17,7 @@
from dynamax.parameters import ParameterProperties
from dynamax.types import PRNGKey, Scalar


class ParamsLGSSMInitial(NamedTuple):
r"""Parameters of the initial distribution

Expand Down Expand Up @@ -45,22 +47,30 @@ class ParamsLGSSMDynamics(NamedTuple):
:param cov: dynamics covariance $Q$

"""
weights: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "state_dim"],
Float[Array, "ntime state_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "state_dim input_dim"],
Float[Array, "ntime state_dim input_dim"]]

cov: Union[ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],
Float[Array, "state_dim_triu"]]
weights: Union[
ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],
]

bias: Union[
ParameterProperties,
Float[Array, "state_dim"],
Float[Array, "ntime state_dim"],
]

input_weights: Union[
ParameterProperties,
Float[Array, "state_dim input_dim"],
Float[Array, "ntime state_dim input_dim"],
]

cov: Union[
ParameterProperties,
Float[Array, "state_dim state_dim"],
Float[Array, "ntime state_dim state_dim"],
Float[Array, "state_dim_triu"],
]


class ParamsLGSSMEmissions(NamedTuple):
Expand All @@ -76,24 +86,32 @@ class ParamsLGSSMEmissions(NamedTuple):
:param cov: emission covariance $R$

"""
weights: Union[ParameterProperties,
Float[Array, "emission_dim state_dim"],
Float[Array, "ntime emission_dim state_dim"]]

bias: Union[ParameterProperties,
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"]]

input_weights: Union[ParameterProperties,
Float[Array, "emission_dim input_dim"],
Float[Array, "ntime emission_dim input_dim"]]

cov: Union[ParameterProperties,
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim_triu"]]
weights: Union[
ParameterProperties,
Float[Array, "emission_dim state_dim"],
Float[Array, "ntime emission_dim state_dim"],
]

bias: Union[
ParameterProperties,
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
]

input_weights: Union[
ParameterProperties,
Float[Array, "emission_dim input_dim"],
Float[Array, "ntime emission_dim input_dim"],
]

cov: Union[
ParameterProperties,
Float[Array, "emission_dim emission_dim"],
Float[Array, "ntime emission_dim emission_dim"],
Float[Array, "emission_dim"],
Float[Array, "ntime emission_dim"],
Float[Array, "emission_dim_triu"],
]


class ParamsLGSSM(NamedTuple):
Expand Down Expand Up @@ -145,6 +163,7 @@ class PosteriorGSSMSmoothed(NamedTuple):

# Helper functions


def _get_one_param(x, dim, t):
"""Helper function to get one parameter at time t."""
if callable(x):
Expand All @@ -154,6 +173,7 @@ def _get_one_param(x, dim, t):
else:
return x


def _get_params(params, num_timesteps, t):
"""Helper function to get parameters at time t."""
assert not callable(params.emissions.cov), "Emission covariance cannot be a callable."
Expand All @@ -166,9 +186,9 @@ def _get_params(params, num_timesteps, t):
D = _get_one_param(params.emissions.input_weights, 2, t)
d = _get_one_param(params.emissions.bias, 1, t)

if len(params.emissions.cov.shape) == 1:
if len(params.emissions.cov.shape) == 1:
R = _get_one_param(params.emissions.cov, 1, t)
elif len(params.emissions.cov.shape) > 2:
elif len(params.emissions.cov.shape) > 2:
R = _get_one_param(params.emissions.cov, 2, t)
elif params.emissions.cov.shape[0] != num_timesteps:
R = _get_one_param(params.emissions.cov, 2, t)
Expand All @@ -179,47 +199,49 @@ def _get_params(params, num_timesteps, t):
warnings.warn(
"Emission covariance has shape (N,N) where N is the number of timesteps. "
"The covariance will be interpreted as static and non-diagonal. To "
"specify a dynamic and diagonal covariance, pass it as a 3D array.")
"specify a dynamic and diagonal covariance, pass it as a 3D array."
)

return F, B, b, Q, H, D, d, R


_zeros_if_none = lambda x, shape: x if x is not None else jnp.zeros(shape)


def make_lgssm_params(initial_mean,
initial_cov,
dynamics_weights,
dynamics_cov,
emissions_weights,
emissions_cov,
dynamics_bias=None,
dynamics_input_weights=None,
emissions_bias=None,
emissions_input_weights=None):
def make_lgssm_params(
initial_mean,
initial_cov,
dynamics_weights,
dynamics_cov,
emissions_weights,
emissions_cov,
dynamics_bias=None,
dynamics_input_weights=None,
emissions_bias=None,
emissions_input_weights=None,
):
"""Helper function to construct a ParamsLGSSM object from arguments."""
state_dim = len(initial_mean)
emission_dim = emissions_cov.shape[-1]
input_dim = max(dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0,
emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0)
input_dim = max(
dynamics_input_weights.shape[-1] if dynamics_input_weights is not None else 0,
emissions_input_weights.shape[-1] if emissions_input_weights is not None else 0,
)

params = ParamsLGSSM(
initial=ParamsLGSSMInitial(
mean=initial_mean,
cov=initial_cov
),
initial=ParamsLGSSMInitial(mean=initial_mean, cov=initial_cov),
dynamics=ParamsLGSSMDynamics(
weights=dynamics_weights,
bias=_zeros_if_none(dynamics_bias,state_dim),
bias=_zeros_if_none(dynamics_bias, state_dim),
input_weights=_zeros_if_none(dynamics_input_weights, (state_dim, input_dim)),
cov=dynamics_cov
cov=dynamics_cov,
),
emissions=ParamsLGSSMEmissions(
weights=emissions_weights,
bias=_zeros_if_none(emissions_bias, emission_dim),
input_weights=_zeros_if_none(emissions_input_weights, (emission_dim, input_dim)),
cov=emissions_cov
)
cov=emissions_cov,
),
)
return params

Expand Down Expand Up @@ -278,20 +300,20 @@ def _condition_on(m, P, H, D, d, R, u, y):
if R.ndim == 2:
S = R + H @ P @ H.T
K = psd_solve(S, H @ P).T
else:
else:
# Optimization using Woodbury identity with A=R, U=H@chol(P), V=U.T, C=I
# (see https://en.wikipedia.org/wiki/Woodbury_matrix_identity)
I = jnp.eye(P.shape[0])
U = H @ jnp.linalg.cholesky(P)
X = U / R[:, None]
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
S_inv = jnp.diag(1.0 / R) - X @ psd_solve(I + U.T @ X, X.T)
"""
# Could alternatively use U=H and C=P
R_inv = jnp.diag(1.0 / R)
P_inv = psd_solve(P, jnp.eye(P.shape[0]))
S_inv = R_inv - R_inv @ H @ psd_solve(P_inv + H.T @ R_inv @ H, H.T @ R_inv)
"""
K = P @ H.T @ S_inv
K = P @ H.T @ S_inv
S = jnp.diag(R) + H @ P @ H.T

Sigma_cond = P - K @ S @ K.T
Expand Down Expand Up @@ -324,20 +346,20 @@ def preprocess_params_and_inputs(params, num_timesteps, inputs):
emissions_bias = _zeros_if_none(params.emissions.bias, (emission_dim,))

full_params = ParamsLGSSM(
initial=ParamsLGSSMInitial(
mean=params.initial.mean,
cov=params.initial.cov),
initial=ParamsLGSSMInitial(mean=params.initial.mean, cov=params.initial.cov),
dynamics=ParamsLGSSMDynamics(
weights=params.dynamics.weights,
bias=dynamics_bias,
input_weights=dynamics_input_weights,
cov=params.dynamics.cov),
cov=params.dynamics.cov,
),
emissions=ParamsLGSSMEmissions(
weights=params.emissions.weights,
bias=emissions_bias,
input_weights=emissions_input_weights,
cov=params.emissions.cov)
)
cov=params.emissions.cov,
),
)
return full_params, inputs


Expand All @@ -350,28 +372,26 @@ def wrapper(*args, **kwargs):
# Extract the arguments by name
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
params = bound_args.arguments['params']
emissions = bound_args.arguments['emissions']
inputs = bound_args.arguments['inputs']
params = bound_args.arguments["params"]
emissions = bound_args.arguments["emissions"]
inputs = bound_args.arguments["inputs"]

num_timesteps = len(emissions)
full_params, inputs = preprocess_params_and_inputs(params, num_timesteps, inputs)

return f(full_params, emissions, inputs=inputs)
return wrapper


return wrapper


def lgssm_joint_sample(
params: ParamsLGSSM,
key: PRNGKey,
num_timesteps: int,
inputs: Optional[Float[Array, "num_timesteps input_dim"]]=None
)-> Tuple[Float[Array, "num_timesteps state_dim"],
Float[Array, "num_timesteps emission_dim"]]:
inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None,
) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]:
r"""Sample from the joint distribution to produce state and emission trajectories.

Args:
params: model parameters
inputs: optional array of inputs.
Expand All @@ -388,9 +408,9 @@ def _sample_transition(key, F, B, b, Q, x_tm1, u):

def _sample_emission(key, H, D, d, R, x, u):
mean = H @ x + D @ u + d
R = jnp.diag(R) if R.ndim==1 else R
R = jnp.diag(R) if R.ndim == 1 else R
return MVN(mean, R).sample(seed=key)

def _sample_initial(key, params, inputs):
key1, key2 = jr.split(key)

Expand All @@ -417,7 +437,7 @@ def _step(prev_state, args):

# Sample the initial state
key1, key2 = jr.split(key)

initial_state, initial_emission = _sample_initial(key1, params, inputs)

# Sample the remaining emissions and states
Expand All @@ -437,8 +457,8 @@ def _step(prev_state, args):
@preprocess_args
def lgssm_filter(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
) -> PosteriorGSSMFiltered:
r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates.

Expand All @@ -456,13 +476,12 @@ def lgssm_filter(

def _log_likelihood(pred_mean, pred_cov, H, D, d, R, u, y):
m = H @ pred_mean + D @ u + d
if R.ndim==2:
if R.ndim == 2:
S = R + H @ pred_cov @ H.T
return MVN(m, S).log_prob(y)
else:
L = H @ jnp.linalg.cholesky(pred_cov)
return MVNLowRank(m, R, L).log_prob(y)


def _step(carry, t):
ll, pred_mean, pred_cov = carry
Expand Down Expand Up @@ -493,7 +512,7 @@ def _step(carry, t):
def lgssm_smoother(
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
) -> PosteriorGSSMSmoothed:
r"""Run forward-filtering, backward-smoother to compute expectations
under the posterior distribution on latent states. Technically, this
Expand Down Expand Up @@ -560,10 +579,9 @@ def _step(carry, args):
def lgssm_posterior_sample(
key: PRNGKey,
params: ParamsLGSSM,
emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]]=None,
jitter: Optional[Scalar]=0

emissions: Float[Array, "ntime emission_dim"],
inputs: Optional[Float[Array, "ntime input_dim"]] = None,
jitter: Optional[Scalar] = 0,
) -> Float[Array, "ntime state_dim"]:
r"""Run forward-filtering, backward-sampling to draw samples from $p(z_{1:T} \mid y_{1:T}, u_{1:T})$.

Expand Down
Loading