From ad38ee09485211f9a8d32a80a039e4e86e1afe87 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 6 Sep 2023 17:30:10 +0200 Subject: [PATCH] Allow batched parameters in MvNormal and MvStudentT distributions --- pymc/distributions/multivariate.py | 95 +++++++------------ tests/distributions/test_mixture.py | 27 +++--- tests/distributions/test_multivariate.py | 113 +++++++++++++++++++---- 3 files changed, 140 insertions(+), 95 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index f104f084d34..f697ce5893e 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -123,20 +123,17 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): if cov is not None: cov = pt.as_tensor_variable(cov) - if cov.ndim != 2: - raise ValueError("cov must be two dimensional.") + if cov.ndim < 2: + raise ValueError("cov must be at least two dimensional.") elif tau is not None: tau = pt.as_tensor_variable(tau) - if tau.ndim != 2: - raise ValueError("tau must be two dimensional.") - # TODO: What's the correct order/approach (in the non-square case)? - # `pytensor.tensor.nlinalg.tensorinv`? + if tau.ndim < 2: + raise ValueError("tau must be at least two dimensional.") cov = matrix_inverse(tau) else: - # TODO: What's the correct order/approach (in the non-square case)? chol = pt.as_tensor_variable(chol) - if chol.ndim != 2: - raise ValueError("chol must be two dimensional.") + if chol.ndim < 2: + raise ValueError("chol must be at least two dimensional.") # tag as lower triangular to enable pytensor rewrites of chol(l.l') -> l chol.tag.lower_triangular = True @@ -145,10 +142,10 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs): return cov -def quaddist_parse(value, mu, cov, mat_type="cov"): +def quaddist_chol(value, mu, cov): """Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma.""" - if value.ndim > 2 or value.ndim == 0: - raise ValueError("Invalid dimension for value: %s" % value.ndim) + if value.ndim == 0: + raise ValueError(f"Invalid dimension for value: {value.ndim}") if value.ndim == 1: onedim = True value = value[None, :] @@ -157,42 +154,21 @@ def quaddist_parse(value, mu, cov, mat_type="cov"): delta = value - mu chol_cov = nan_lower_cholesky(cov) - if mat_type != "tau": - dist, logdet, ok = quaddist_chol(delta, chol_cov) - else: - dist, logdet, ok = quaddist_tau(delta, chol_cov) - if onedim: - return dist[0], logdet, ok - - return dist, logdet, ok - -def quaddist_chol(delta, chol_mat): - diag = pt.diag(chol_mat) + diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1) # Check if the covariance matrix is positive definite. - ok = pt.all(diag > 0) + ok = pt.all(diag > 0, axis=-1) # If not, replace the diagonal. We return -inf later, but # need to prevent solve_lower from throwing an exception. - chol_cov = pt.switch(ok, chol_mat, 1) - - delta_trans = solve_lower(chol_cov, delta.T).T + chol_cov = pt.switch(ok[..., None, None], chol_cov, 1) + delta_trans = solve_lower(chol_cov, delta, b_ndim=1) quaddist = (delta_trans**2).sum(axis=-1) - logdet = pt.sum(pt.log(diag)) - return quaddist, logdet, ok - + logdet = pt.log(diag).sum(axis=-1) -def quaddist_tau(delta, chol_mat): - diag = pt.nlinalg.diag(chol_mat) - # Check if the precision matrix is positive definite. - ok = pt.all(diag > 0) - # If not, replace the diagonal. We return -inf later, but - # need to prevent solve_lower from throwing an exception. - chol_tau = pt.switch(ok, chol_mat, 1) - - delta_trans = pt.dot(delta, chol_tau) - quaddist = (delta_trans**2).sum(axis=-1) - logdet = -pt.sum(pt.log(diag)) - return quaddist, logdet, ok + if onedim: + return quaddist[0], logdet, ok + else: + return quaddist, logdet, ok class MvNormal(Continuous): @@ -290,7 +266,7 @@ def logp(value, mu, cov): ------- TensorVariable """ - quaddist, logdet, ok = quaddist_parse(value, mu, cov) + quaddist, logdet, ok = quaddist_chol(value, mu, cov) k = floatX(value.shape[-1]) norm = -0.5 * k * pm.floatX(np.log(2 * np.pi)) return check_parameters( @@ -307,22 +283,6 @@ class MvStudentTRV(RandomVariable): dtype = "floatX" _print_name = ("MvStudentT", "\\operatorname{MvStudentT}") - def make_node(self, rng, size, dtype, nu, mu, cov): - nu = pt.as_tensor_variable(nu) - if not nu.ndim == 0: - raise ValueError("nu must be a scalar (ndim=0).") - - return super().make_node(rng, size, dtype, nu, mu, cov) - - def __call__(self, nu, mu=None, cov=None, size=None, **kwargs): - dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype - - if mu is None: - mu = np.array([0.0], dtype=dtype) - if cov is None: - cov = np.array([[1.0]], dtype=dtype) - return super().__call__(nu, mu, cov, size=size, **kwargs) - def _supp_shape_from_params(self, dist_params, param_shapes=None): return supp_shape_from_ref_param_shape( ndim_supp=self.ndim_supp, @@ -333,14 +293,21 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None): @classmethod def rng_fn(cls, rng, nu, mu, cov, size): + if size is None: + # When size is implicit, we need to broadcast parameters correctly, + # so that the MvNormal draws and the chisquare draws have the same number of batch dimensions. + # nu broadcasts mu and cov + if np.ndim(nu) > max(mu.ndim - 1, cov.ndim - 2): + _, mu, cov = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) + # nu is broadcasted by either mu or cov + elif np.ndim(nu) < max(mu.ndim - 1, cov.ndim - 2): + nu, _, _ = broadcast_params((nu, mu, cov), ndims_params=cls.ndims_params) + mv_samples = multivariate_normal.rng_fn(rng=rng, mean=np.zeros_like(mu), cov=cov, size=size) # Take chi2 draws and add an axis of length 1 to the right for correct broadcasting below chi2_samples = np.sqrt(rng.chisquare(nu, size=size) / nu)[..., None] - if size: - mu = np.broadcast_to(mu, size + (mu.shape[-1],)) - return (mv_samples / chi2_samples) + mu @@ -390,7 +357,7 @@ class MvStudentT(Continuous): rv_op = mv_studentt @classmethod - def dist(cls, nu, Sigma=None, mu=None, scale=None, tau=None, chol=None, lower=True, **kwargs): + def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs): cov = kwargs.pop("cov", None) if cov is not None: warnings.warn( @@ -432,7 +399,7 @@ def logp(value, nu, mu, scale): ------- TensorVariable """ - quaddist, logdet, ok = quaddist_parse(value, mu, scale) + quaddist, logdet, ok = quaddist_chol(value, mu, scale) k = floatX(value.shape[-1]) norm = gammaln((nu + k) / 2.0) - gammaln(nu / 2.0) - 0.5 * k * pt.log(nu * np.pi) diff --git a/tests/distributions/test_mixture.py b/tests/distributions/test_mixture.py index c90c4809775..9632efd859f 100644 --- a/tests/distributions/test_mixture.py +++ b/tests/distributions/test_mixture.py @@ -333,21 +333,18 @@ def test_list_multivariate_components_deterministic_weights(self, weights, compo assert not repetitions # Test logp - # MvNormal logp is currently limited to 2d values - expectation = pytest.raises(ValueError) if mix_eval.ndim > 2 else does_not_raise() - with expectation: - mix_logp_eval = logp(mix, mix_eval).eval() - assert mix_logp_eval.shape == expected_shape[:-1] - bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2)) - expected_logp = np.stack( - ( - logp(components[0], mix_eval).eval(), - logp(components[1], mix_eval).eval(), - ), - axis=-1, - )[bcast_weights == 1] - expected_logp = expected_logp.reshape(expected_shape[:-1]) - assert np.allclose(mix_logp_eval, expected_logp) + mix_logp_eval = logp(mix, mix_eval).eval() + assert mix_logp_eval.shape == expected_shape[:-1] + bcast_weights = np.broadcast_to(weights, (*expected_shape[:-1], 2)) + expected_logp = np.stack( + ( + logp(components[0], mix_eval).eval(), + logp(components[1], mix_eval).eval(), + ), + axis=-1, + )[bcast_weights == 1] + expected_logp = expected_logp.reshape(expected_shape[:-1]) + assert np.allclose(mix_logp_eval, expected_logp) def test_component_choice_random(self): """Test that mixture choices change over evaluations""" diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index 4f34b0c1037..6a75ba29bd5 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -330,6 +330,36 @@ def test_mvnormal_init_fail(self): with pytest.raises(ValueError): x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), tau=np.eye(3), size=3) + @pytest.mark.parametrize("batch_mu", (False, True)) + @pytest.mark.parametrize("batch_cov", (False, True)) + @pytest.mark.parametrize("use_tau", (False, True)) + def test_mvnormal_batched_dims(self, batch_mu, batch_cov, use_tau): + def ref_logp_core(value, mu, cov): + return st.multivariate_normal.logpdf(value, mu, cov) + + ref_logp = np.vectorize(ref_logp_core, signature="(a),(a),(a,a)->()") + + mu = np.arange(5 * 3 * 2).reshape(5, 3, 2) + 1 + cov = np.eye(2) * mu[..., None] + value = mu - np.mean(mu) + + if not batch_mu: + mu = mu[0, 0] + assert mu.ndim == 1 + if not batch_cov: + cov = cov[0, 0] + assert cov.ndim == 2 + + if use_tau: + dist = pm.MvNormal.dist(mu=mu, tau=np.linalg.inv(cov)) + else: + dist = pm.MvNormal.dist(mu=mu, cov=cov) + + np.testing.assert_allclose( + pm.logp(dist, value).eval(), + ref_logp(value, mu, cov), + ) + @pytest.mark.parametrize("n", [1, 2, 3]) def test_matrixnormal(self, n): mat_scale = 1e3 # To reduce logp magnitude @@ -472,6 +502,40 @@ def test_mvt(self, n): extra_args={"size": 2}, ) + @pytest.mark.parametrize("batch_nu", (False, True)) + @pytest.mark.parametrize("batch_mu", (False, True)) + @pytest.mark.parametrize("batch_cov", (False, True)) + @pytest.mark.parametrize("use_tau", (False, True)) + def test_mvt_batched_dims(self, batch_nu, batch_mu, batch_cov, use_tau): + def ref_logp_core(value, nu, mu, cov): + return st.multivariate_t.logpdf(value, mu, cov, df=nu) + + ref_logp = np.vectorize(ref_logp_core, signature="(a),(),(a),(a,a)->()") + + nu = np.arange(5 * 3).reshape(5, 3) + 2 + mu = np.arange(5 * 3 * 2).reshape(5, 3, 2) + 1 + cov = np.eye(2) * mu[..., None] + value = mu - np.mean(mu) + + if not batch_nu: + nu = nu[0, 0] + if not batch_mu: + mu = mu[0, 0] + assert mu.ndim == 1 + if not batch_cov: + cov = cov[0, 0] + assert cov.ndim == 2 + + if use_tau: + dist = pm.MvStudentT.dist(nu=nu, mu=mu, tau=np.linalg.inv(cov)) + else: + dist = pm.MvStudentT.dist(nu=nu, mu=mu, cov=cov) + + np.testing.assert_allclose( + pm.logp(dist, value).eval(), + ref_logp(value, nu, mu, cov), + ) + @pytest.mark.parametrize("n", [2, 3]) def test_wishart(self, n): with pytest.warns(UserWarning, match="Wishart distribution can currently not be used"): @@ -1038,8 +1102,7 @@ def test_mv_normal_moment(self, mu, cov, size, expected): with pm.Model() as model: x = pm.MvNormal("x", mu=mu, cov=cov, size=size) - # MvNormal logp is only implemented for up to 2D variables - assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) + assert_moment_is_expected(model, expected) @pytest.mark.parametrize( "shape, n_zerosum_axes, expected", @@ -1109,8 +1172,7 @@ def test_mvstudentt_moment(self, nu, mu, cov, size, expected): with pm.Model() as model: x = pm.MvStudentT("x", nu=nu, mu=mu, scale=cov, size=size) - # MvStudentT logp is only implemented for up to 2D variables - assert_moment_is_expected(model, expected, check_finite_logp=x.ndim < 3) + assert_moment_is_expected(model, expected) @pytest.mark.parametrize( "mu, rowchol, colchol, size, expected", @@ -1671,21 +1733,10 @@ def mvstudentt_rng_fn(self, size, nu, mu, scale, rng): "check_pymc_params_match_rv_op", "check_pymc_draws_match_reference", "check_rv_size", - "check_errors", "check_mu_broadcast_helper", + "check_batched_nu", ] - def check_errors(self): - msg = "nu must be a scalar (ndim=0)." - with pm.Model(): - with pytest.raises(ValueError, match=re.escape(msg)): - mvstudentt = pm.MvStudentT( - "mvstudentt", - nu=np.array([1, 2]), - mu=np.ones(2), - scale=np.full((2, 2), np.ones(2)), - ) - def check_mu_broadcast_helper(self): """Test that mu is broadcasted to the shape of cov""" x = pm.MvStudentT.dist(nu=4, mu=1, scale=np.eye(3)) @@ -1709,6 +1760,36 @@ def check_mu_broadcast_helper(self): # mu = x.owner.inputs[4] # assert mu.eval().shape == (10, 2, 3) + def check_batched_nu(self): + rng = np.random.default_rng(sum(map(ord, "batched_nu"))) + a = ( + pm.draw( + pm.MvStudentT.dist(nu=2, mu=[1, 2, 3], cov=np.eye(3), size=(5_000,)), + random_seed=rng, + ) + .std(-1) + .mean() + ) + b = ( + pm.draw( + pm.MvStudentT.dist(nu=30, mu=[1, 2, 3], cov=np.eye(3), size=(5_000,)), + random_seed=rng, + ) + .std(-1) + .mean() + ) + ab = ( + pm.draw( + pm.MvStudentT.dist(nu=[2, 30], mu=[1, 2, 3], cov=np.eye(3), size=(5_000, 2)), + random_seed=rng, + ) + .std(-1) + .mean(0) + ) + + assert not np.isclose(ab[0], ab[1], rtol=0.3), "Test is not informative" + np.testing.assert_allclose([a, b], ab, rtol=0.1) + class TestMvStudentTChol(BaseTestDistributionRandom): pymc_dist = pm.MvStudentT