Skip to content

Commit

Permalink
Default zero mu for MvNormal and MvStudentT
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 4, 2023
1 parent ce543da commit 602234b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class MvNormal(Continuous):
rv_op = multivariate_normal

@classmethod
def dist(cls, mu, cov=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
mu = pt.as_tensor_variable(mu)
cov = quaddist_matrix(cov, chol, tau, lower)
# PyTensor is stricter about the shape of mu, than PyMC used to be
Expand Down Expand Up @@ -358,7 +358,7 @@ class MvStudentT(Continuous):
rv_op = mv_studentt

@classmethod
def dist(cls, nu, *, Sigma=None, mu, scale=None, tau=None, chol=None, lower=True, **kwargs):
def dist(cls, nu, *, Sigma=None, mu=0, scale=None, tau=None, chol=None, lower=True, **kwargs):
cov = kwargs.pop("cov", None)
if cov is not None:
warnings.warn(
Expand Down
12 changes: 10 additions & 2 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,7 +2300,11 @@ def test_mvnormal_no_cholesky_in_model_logp():


def test_mvnormal_mu_convenience():
"""Test that mu is broadcasted to the length of cov"""
"""Test that mu is broadcasted to the length of cov and provided a default of zero"""
x = pm.MvNormal.dist(cov=np.eye(3))
mu = x.owner.inputs[3]
np.testing.assert_allclose(mu.eval(), np.zeros((3,)))

x = pm.MvNormal.dist(mu=1, cov=np.eye(3))
mu = x.owner.inputs[3]
np.testing.assert_allclose(mu.eval(), np.ones((3,)))
Expand All @@ -2325,7 +2329,11 @@ def test_mvnormal_mu_convenience():


def test_mvstudentt_mu_convenience():
"""Test that mu is broadcasted to the length of scale"""
"""Test that mu is broadcasted to the length of scale and provided a default of zero"""
x = pm.MvStudentT.dist(nu=4, scale=np.eye(3))
mu = x.owner.inputs[4]
np.testing.assert_allclose(mu.eval(), np.zeros((3,)))

x = pm.MvStudentT.dist(nu=4, mu=1, scale=np.eye(3))
mu = x.owner.inputs[4]
np.testing.assert_allclose(mu.eval(), np.ones((3,)))
Expand Down

0 comments on commit 602234b

Please sign in to comment.