From 662edf0ba5d7073bccad331229595d0cc2b182dd Mon Sep 17 00:00:00 2001 From: Gokul D Date: Fri, 31 Mar 2023 23:44:15 +0530 Subject: [PATCH] Added ICDF for the continuous exponential distribution. --- pymc/distributions/continuous.py | 9 +++++++++ tests/distributions/test_continuous.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index e2368db2ee..87d7e61c29 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1363,6 +1363,15 @@ def logcdf(value, mu): msg="lam >= 0", ) + def icdf(value, mu): + res = -mu * pt.log(1 - value) + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + mu >= 0, + msg="mu >= 0", + ) + class Laplace(Continuous): r""" diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 5c0ad13de3..b23b39d366 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -438,6 +438,20 @@ def test_exponential(self): {"lam": Rplus}, lambda value, lam: st.expon.logcdf(value, 0, 1 / lam), ) + check_icdf( + pm.Exponential, + {"lam": Rplus}, + lambda q, lam: st.expon.ppf(q, loc=0, scale=1 / lam), + ) + # Custom logp / logcdf / icdf check for invalid parameters + invalid_dist = pm.Exponential.dist(lam=-1) + with pytensor.config.change_flags(mode=Mode("py")): + with pytest.raises(ParameterValueError): + logp(invalid_dist, np.array(0.5)).eval() + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, np.array(0.5)).eval() + with pytest.raises(ParameterValueError): + icdf(invalid_dist, np.array(0.5)).eval() def test_laplace(self): check_logp(