From c9d469780d0f447454d2418c3feb07d2637a67a8 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 13 Aug 2023 13:09:41 +0530 Subject: [PATCH] Test for discrete --- tests/logprob/test_order.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 76ed6d44e7d..1b600205bb2 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -43,7 +43,7 @@ import pymc as pm from pymc import logp -from pymc.logprob import conditional_logp +from pymc.logprob.abstract import _logcdf_helper, _logprob_helper from pymc.testing import assert_no_rvs @@ -156,4 +156,16 @@ def test_max_discrete(): x_max_value = pt.scalar("x_max_value") x_max_logprob = logp(x_max, x_max_value) - x_max_logprob.eval({x_max_value: 0.85}) + discrete_logprob = _logprob_helper(x, x_max_value) + discrete_logcdf = _logcdf_helper(x, x_max_value) + discrete_logcdf_prev = _logcdf_helper(x, x_max_value - 1) + n = x.size + discrete_logprob = pt.log((pt.exp(discrete_logcdf)) ** n - (pt.exp(discrete_logcdf_prev)) ** n) + + test_value = 0.85 + + np.testing.assert_allclose( + discrete_logprob.eval({x_max_value: test_value}), + (x_max_logprob.eval({x_max_value: test_value})), + rtol=1e-06, + )