Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cast ZeroSumNormal shape operations to config.floatX #6889

Merged
merged 5 commits into from
Sep 5, 2023

Conversation

thomasjpfan
Copy link
Contributor

@thomasjpfan thomasjpfan commented Sep 2, 2023

What is this PR about?
This PR prevents ZeroSumNormal from upcasting to float64, when pytensor is configured with floatX = float32.

Checklist

Bugfixes


📚 Documentation preview 📚: https://pymc--6889.org.readthedocs.build/en/6889/

@codecov
Copy link

codecov bot commented Sep 2, 2023

Codecov Report

Merging #6889 (2d1546d) into main (dfb05b6) will decrease coverage by 0.37%.
Report is 2 commits behind head on main.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6889      +/-   ##
==========================================
- Coverage   92.05%   91.69%   -0.37%     
==========================================
  Files          96      100       +4     
  Lines       16446    16851     +405     
==========================================
+ Hits        15140    15451     +311     
- Misses       1306     1400      +94     
Files Changed Coverage Δ
pymc/distributions/multivariate.py 92.55% <100.00%> (ø)
pymc/distributions/transforms.py 99.37% <100.00%> (+<0.01%) ⬆️

... and 12 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, just a small suggestion to make the test faster

with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model():
pm.ZeroSumNormal("b", sigma=1, shape=(2,))
pm.sample(1, chains=1, tune=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be enough to call model.logp()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think it is enough because logp does not call ZeroSumTransform.backward. pm.sample is a way to test both ZeroSumTransform.backward and logp. I added a comment about this interaction in 7906292 (#6889)

Note that the new test is mostly a non-regression test for #6886.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see how m.logp works. Thanks for the suggestion.

@ricardoV94 ricardoV94 changed the title FIX Prevents ZeroSumNormal from upcasting float64 Cast ZeroSumNormal shape operations to config.floatX Sep 2, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Sep 4, 2023

I suggested just calling model.logp because it's enough to identify the original problem and it's much faster than compiling and sampling a model. This test is too heavy for the change being made.

@thomasjpfan
Copy link
Contributor Author

Okay, I updated the test to use logp.

with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
with pm.Model():
zsn = pm.ZeroSumNormal("b", sigma=1, shape=(2,))
pm.logp(zsn, value=np.zeros((2,)))
Copy link
Member

@ricardoV94 ricardoV94 Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I wasn't clear. I meant you should call model.logp().

import pymc as pm
import pytensor 

with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
    with pm.Model() as m:
        pm.ZeroSumNormal("b", sigma=1, shape=(2,))
    m.logp()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should call both backward and forward

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !!!

@ricardoV94 ricardoV94 merged commit 5bba69a into pymc-devs:main Sep 5, 2023
21 checks passed
@ricardoV94
Copy link
Member

Thanks @thomasjpfan!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: ZeroSumNormal unintentionally upcasts to float64
2 participants