Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast ZeroSumNormal shape operations to config.floatX #6889

Merged
merged 5 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,8 +2783,8 @@
n_zerosum_axes = op.ndim_supp

_deg_free_support_shape = pt.inc_subtensor(shape[-n_zerosum_axes:], -1)
_full_size = pt.prod(shape)
_degrees_of_freedom = pt.prod(_deg_free_support_shape)
_full_size = pm.floatX(pt.prod(shape))
_degrees_of_freedom = pm.floatX(pt.prod(_deg_free_support_shape))

Check warning on line 2787 in pymc/distributions/multivariate.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/multivariate.py#L2786-L2787

Added lines #L2786 - L2787 were not covered by tests

zerosums = [
pt.all(pt.isclose(pt.mean(value, axis=-axis - 1), 0, atol=1e-9))
Expand Down
6 changes: 4 additions & 2 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from pytensor.graph import Op
from pytensor.tensor import TensorVariable

import pymc as pm

from pymc.logprob.transforms import (
CircularTransform,
IntervalTransform,
Expand Down Expand Up @@ -330,7 +332,7 @@


def extend_axis(array, axis):
n = array.shape[axis] + 1
n = pm.floatX(array.shape[axis] + 1)

Check warning on line 335 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L335

Added line #L335 was not covered by tests
sum_vals = array.sum(axis, keepdims=True)
norm = sum_vals / (pt.sqrt(n) + n)
fill_val = norm - sum_vals / pt.sqrt(n)
Expand All @@ -342,7 +344,7 @@
def extend_axis_rev(array, axis):
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]

n = array.shape[normalized_axis]
n = pm.floatX(array.shape[normalized_axis])

Check warning on line 347 in pymc/distributions/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/distributions/transforms.py#L347

Added line #L347 was not covered by tests
last = pt.take(array, [-1], axis=normalized_axis)

sum_vals = -last * pt.sqrt(n)
Expand Down
6 changes: 6 additions & 0 deletions tests/distributions/test_multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,6 +1634,12 @@ def logp_norm(value, sigma, axes):

np.testing.assert_allclose(zsn_logp, mvn_logp)

def test_does_not_upcast_to_float64(self):
with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model():
pm.ZeroSumNormal("b", sigma=1, shape=(2,))
pm.sample(1, chains=1, tune=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be enough to call model.logp()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it is enough because logp does not call ZeroSumTransform.backward. pm.sample is a way to test both ZeroSumTransform.backward and logp. I added a comment about this interaction in 7906292 (#6889)

Note that the new test is mostly a non-regression test for #6886.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see how m.logp works. Thanks for the suggestion.



class TestMvStudentTCov(BaseTestDistributionRandom):
def mvstudentt_rng_fn(self, size, nu, mu, scale, rng):
Expand Down
Loading