From 4598cd4d582d34dab8e87c62f37807a483ad5b27 Mon Sep 17 00:00:00 2001 From: Gokul D Date: Sat, 1 Apr 2023 02:05:08 +0530 Subject: [PATCH] Added ICDF for the Kumaraswamy distribution. --- pymc/distributions/continuous.py | 10 ++++++++++ tests/distributions/test_continuous.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index e2368db2ee8..9d5df670747 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -1287,6 +1287,16 @@ def logcdf(value, a, b): msg="a > 0, b > 0", ) + def icdf(value, a, b): + res = (1 - (1 - value) ** pt.reciprocal(b)) ** pt.reciprocal(a) + res = check_icdf_value(res, value) + return check_icdf_parameters( + res, + a > 0, + b > 0, + msg="a > 0, b > 0", + ) + class Exponential(PositiveContinuous): r""" diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 5c0ad13de38..42a618fc0fa 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -424,6 +424,22 @@ def scipy_log_cdf(value, a, b): {"a": Rplus, "b": Rplus}, scipy_log_cdf, ) + check_icdf( + pm.Kumaraswamy, + {"a": Rplus, "b": Rplus}, + lambda q, a, b: (1 - (1 - q) ** (1 / b)) ** (1 / a), + ) + + # Custom logp / logcdf / icdf check for invalid parameters + for a, b in ((-2, 0.5), (0.5, -2), (-2, -2)): + invalid_dist = pm.Kumaraswamy.dist(a=a, b=b) + with pytensor.config.change_flags(mode=Mode("py")): + with pytest.raises(ParameterValueError): + logp(invalid_dist, np.array(0.5)).eval() + with pytest.raises(ParameterValueError): + logcdf(invalid_dist, np.array(0.5)).eval() + with pytest.raises(ParameterValueError): + icdf(invalid_dist, np.array(0.5)).eval() def test_exponential(self): check_logp(