Skip to content

Commit

Permalink
Allow batched parameters in MvNormal and MvStudentT distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 2, 2023
1 parent 434f55c commit 5f1fcc4
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 133 deletions.
103 changes: 35 additions & 68 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,40 +115,37 @@ def simplex_cont_transform(op, rv):


def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
if chol is not None and not lower:
chol = chol.T

if len([i for i in [tau, cov, chol] if i is not None]) != 1:
raise ValueError("Incompatible parameterization. Specify exactly one of tau, cov, or chol.")

if cov is not None:
cov = pt.as_tensor_variable(cov)
if cov.ndim != 2:
raise ValueError("cov must be two dimensional.")
if cov.ndim < 2:
raise ValueError("cov must be at least two dimensional.")

Check warning on line 124 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L124

Added line #L124 was not covered by tests
elif tau is not None:
tau = pt.as_tensor_variable(tau)
if tau.ndim != 2:
raise ValueError("tau must be two dimensional.")
# TODO: What's the correct order/approach (in the non-square case)?
# `pytensor.tensor.nlinalg.tensorinv`?
if tau.ndim < 2:
raise ValueError("tau must be at least two dimensional.")

Check warning on line 128 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L127-L128

Added lines #L127 - L128 were not covered by tests
cov = matrix_inverse(tau)
else:
# TODO: What's the correct order/approach (in the non-square case)?
chol = pt.as_tensor_variable(chol)
if chol.ndim != 2:
raise ValueError("chol must be two dimensional.")
if chol.ndim < 2:
raise ValueError("chol must be at least two dimensional.")

Check warning on line 133 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L133

Added line #L133 was not covered by tests

if not lower:
chol = pt.swapaxes(chol, -1, -2)

Check warning on line 136 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L136

Added line #L136 was not covered by tests

# tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l
chol.tag.lower_triangular = True
cov = chol.dot(chol.T)
cov = pt.matmul(chol, pt.swapaxes(chol, -1, -2))

return cov


def quaddist_parse(value, mu, cov, mat_type="cov"):
def quaddist_chol(value, mu, cov):
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
if value.ndim > 2 or value.ndim == 0:
raise ValueError("Invalid dimension for value: %s" % value.ndim)
if value.ndim == 0:
raise ValueError("Value can't be a scalar")

Check warning on line 148 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L148

Added line #L148 was not covered by tests
if value.ndim == 1:
onedim = True
value = value[None, :]
Expand All @@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):

delta = value - mu
chol_cov = nan_lower_cholesky(cov)
if mat_type != "tau":
dist, logdet, ok = quaddist_chol(delta, chol_cov)
else:
dist, logdet, ok = quaddist_tau(delta, chol_cov)
if onedim:
return dist[0], logdet, ok

return dist, logdet, ok


def quaddist_chol(delta, chol_mat):
diag = pt.diag(chol_mat)
diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
# Check if the covariance matrix is positive definite.
ok = pt.all(diag > 0)
ok = pt.all(diag > 0, axis=-1)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_cov = pt.switch(ok, chol_mat, 1)

delta_trans = solve_lower(chol_cov, delta.T).T
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = pt.sum(pt.log(diag))
return quaddist, logdet, ok


def quaddist_tau(delta, chol_mat):
diag = pt.nlinalg.diag(chol_mat)
# Check if the precision matrix is positive definite.
ok = pt.all(diag > 0)
# If not, replace the diagonal. We return -inf later, but
# need to prevent solve_lower from throwing an exception.
chol_tau = pt.switch(ok, chol_mat, 1)
logdet = pt.log(diag).sum(axis=-1)

delta_trans = pt.dot(delta, chol_tau)
quaddist = (delta_trans**2).sum(axis=-1)
logdet = -pt.sum(pt.log(diag))
return quaddist, logdet, ok
if onedim:
return quaddist[0], logdet, ok
else:
return quaddist, logdet, ok


class MvNormal(Continuous):
Expand Down Expand Up @@ -290,7 +266,7 @@ def logp(value, mu, cov):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, cov)
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
k = floatX(value.shape[-1])
norm = -0.5 * k * pm.floatX(np.log(2 * np.pi))
return check_parameters(
Expand All @@ -307,22 +283,6 @@ class MvStudentTRV(RandomVariable):
dtype = "floatX"
_print_name = ("MvStudentT", "\\operatorname{MvStudentT}")

def make_node(self, rng, size, dtype, nu, mu, cov):
nu = pt.as_tensor_variable(nu)
if not nu.ndim == 0:
raise ValueError("nu must be a scalar (ndim=0).")

return super().make_node(rng, size, dtype, nu, mu, cov)

def __call__(self, nu, mu=None, cov=None, size=None, **kwargs):
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype

if mu is None:
mu = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(nu, mu, cov, size=size, **kwargs)

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
Expand All @@ -333,14 +293,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):

@classmethod
def rng_fn(cls, rng, nu, mu, cov, size):
if size is None:

Check warning on line 296 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L296

Added line #L296 was not covered by tests
# When size is implicit, we need to broadcast parameters correctly,
# so that the MvNormal draws and the chisquare draws have the same number of batch dimensions.
# nu broadcasts mu and cov
if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2):
_, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)

Check warning on line 301 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L300-L301

Added lines #L300 - L301 were not covered by tests
# nu is broadcasted by either mu or cov
elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2):
nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params)

Check warning on line 304 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L303-L304

Added lines #L303 - L304 were not covered by tests

mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size)

# Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below
chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None]

if size:
mu = np.broadcast_to(mu, size + (mu.shape[-1],))

return (mv_samples / chi2_samples) + mu


Expand Down Expand Up @@ -390,7 +357,7 @@ class MvStudentT(Continuous):
rv_op = mv_studentt

@classmethod
def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs):
cov = kwargs.pop("cov", None)
if cov is not None:
warnings.warn(
Expand Down Expand Up @@ -432,7 +399,7 @@ def logp(value, nu, mu, scale):
-------
TensorVariable
"""
quaddist, logdet, ok = quaddist_parse(value, mu, scale)
quaddist, logdet, ok = quaddist_chol(value, mu, scale)
k = floatX(value.shape[-1])

norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * pt.log(nu * np.pi)
Expand Down
27 changes: 12 additions & 15 deletions tests/distributions/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,21 +333,18 @@ def test_list_multivariate_components_deterministic_weights(self, weights, compo
assert not repetitions

# Test logp
# MvNormal logp is currently limited to 2d values
expectation = pytest.raises(ValueError) if mix_eval.ndim > 2 else does_not_raise()
with expectation:
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)
mix_logp_eval = logp(mix, mix_eval).eval()
assert mix_logp_eval.shape == expected_shape[:-1]
bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2))
expected_logp = np.stack(
(
logp(components[0], mix_eval).eval(),
logp(components[1], mix_eval).eval(),
),
axis=-1,
)[bcast_weights == 1]
expected_logp = expected_logp.reshape(expected_shape[:-1])
assert np.allclose(mix_logp_eval, expected_logp)

def test_component_choice_random(self):
"""Test that mixture choices change over evaluations"""
Expand Down
Loading

0 comments on commit 5f1fcc4

Please sign in to comment.