diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index df175af4b3..864fa2ad9c 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -2783,8 +2783,8 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): n_zerosum_axes = op.ndim_supp _deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1) - _full_size = pt.prod(shape) - _degrees_of_freedom = pt.prod(_deg_free_support_shape) + _full_size = pm.floatX(pt.prod(shape)) + _degrees_of_freedom = pm.floatX(pt.prod(_deg_free_support_shape)) zerosums = [ pt.all(pt.isclose(pt.mean(value, axis=-axis - 1), 0, atol=1e-9)) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index b873eba235..f3aff75fed 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -22,6 +22,8 @@ from pytensor.graph import Op from pytensor.tensor import TensorVariable +import pymc as pm + from pymc.logprob.transforms import ( CircularTransform, IntervalTransform, @@ -330,7 +332,7 @@ def log_jac_det(self, value, *rv_inputs): def extend_axis(array, axis): - n = array.shape[axis] + 1 + n = pm.floatX(array.shape[axis] + 1) sum_vals = array.sum(axis, keepdims=True) norm = sum_vals / (pt.sqrt(n) + n) fill_val = norm - sum_vals / pt.sqrt(n) @@ -342,7 +344,7 @@ def extend_axis(array, axis): def extend_axis_rev(array, axis): normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] - n = array.shape[normalized_axis] + n = pm.floatX(array.shape[normalized_axis]) last = pt.take(array, [-1], axis=normalized_axis) sum_vals = -last * pt.sqrt(n) diff --git a/tests/distributions/test_multivariate.py b/tests/distributions/test_multivariate.py index fbe5c09fe5..e7a205c715 100644 --- a/tests/distributions/test_multivariate.py +++ b/tests/distributions/test_multivariate.py @@ -1634,6 +1634,12 @@ def logp_norm(value, sigma, axes): np.testing.assert_allclose(zsn_logp, mvn_logp) + def test_does_not_upcast_to_float64(self): + with pytensor.config.change_flags(floatX="float32", warn_float64="raise"): + with pm.Model(): + pm.ZeroSumNormal("b", sigma=1, shape=(2,)) + pm.sample(1, chains=1, tune=1) + class TestMvStudentTCov(BaseTestDistributionRandom): def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):