Skip to content

Commit

Permalink
Cast ZeroSumNormal shape operations to config.floatX (#6889)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored Sep 5, 2023
1 parent e6e0fed commit 5bba69a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,8 +2783,8 @@ def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
n_zerosum_axes = op.ndim_supp

_deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1)
_full_size = pt.prod(shape)
_degrees_of_freedom = pt.prod(_deg_free_support_shape)
_full_size = pm.floatX(pt.prod(shape))
_degrees_of_freedom = pm.floatX(pt.prod(_deg_free_support_shape))

zerosums = [
pt.all(pt.isclose(pt.mean(value, axis=-axis - 1), 0, atol=1e-9))
Expand Down
6 changes: 4 additions & 2 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pytensor.graph import Op
from pytensor.tensor import TensorVariable

import pymc as pm

from pymc.logprob.transforms import (
CircularTransform,
IntervalTransform,
Expand Down Expand Up @@ -330,7 +332,7 @@ def log_jac_det(self, value, *rv_inputs):


def extend_axis(array, axis):
n = array.shape[axis] + 1
n = pm.floatX(array.shape[axis] + 1)
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (pt.sqrt(n) + n)
fill_val = norm - sum_vals / pt.sqrt(n)
Expand All @@ -342,7 +344,7 @@ def extend_axis(array, axis):
def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = array.shape[normalized_axis]
n = pm.floatX(array.shape[normalized_axis])
last = pt.take(array, [-1], axis=normalized_axis)

sum_vals = -last * pt.sqrt(n)
Expand Down
6 changes: 6 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,12 @@ def logp_norm(value, sigma, axes):

np.testing.assert_allclose(zsn_logp, mvn_logp)

def test_does_not_upcast_to_float64(self):
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model() as m:
pm.ZeroSumNormal("b", sigma=1, shape=(2,))
m.logp()


class TestMvStudentTCov(BaseTestDistributionRandom):
def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):
Expand Down

0 comments on commit 5bba69a

Please sign in to comment.