Skip to content

Commit

Permalink
logprob for maximum derived
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 26, 2023
1 parent 7e79bfa commit 93c5051
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions tests/logprob/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import re

import numpy as np
import pytensor
Expand All @@ -43,7 +44,7 @@
from scipy import stats

from pymc.distributions import Dirichlet
from pymc.logprob.joint_logprob import factorized_joint_logprob
from pymc.logprob.basic import conditional_logp
from tests.distributions.test_multivariate import dirichlet_logpdf


Expand All @@ -58,7 +59,7 @@ def test_specify_shape_logprob():

# 2. Request logp
x_vv = x_rv.clone()
[x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()
[x_logp] = conditional_logp({x_rv: x_vv}).values()

# 3. Test logp
x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)
Expand All @@ -80,17 +81,19 @@ def test_assert_logprob():
rv = pt.random.normal()
assert_op = Assert("Test assert")
# Example: Add assert that rv must be positive
assert_rv = assert_op(rv > 0, rv)
assert_rv = assert_op(rv, rv > 0)
assert_rv.name = "assert_rv"

assert_vv = assert_rv.clone()
assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
assert_logp = conditional_logp({assert_rv: assert_vv})[assert_vv]

# Check valid value is correct and doesn't raise
# Since here the value to the rv satisfies the condition, no error is raised.
valid_value = 3.0
with pytest.raises(AssertionError, match="Test assert"):
assert_logp.eval({assert_vv: valid_value})
np.testing.assert_allclose(
assert_logp.eval({assert_vv: valid_value}),
stats.norm.logpdf(valid_value),
)

# Check invalid value
# Since here the value to the rv is negative, an exception is raised as the condition is not met
Expand Down

0 comments on commit 93c5051

Please sign in to comment.