From 4c64eb9737dc460ae2abda8648eb3c85451b6ca0 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 30 Mar 2023 15:34:54 +0200 Subject: [PATCH] Don't use `check_parameters` in `get_tau_sigma`. The default use of `check_parameters` indicates that an expression can be replaced by -inf, if the constraints aren't met. Instead, if `can_be_replaced_by_ninf=False`, sampling would fail for negative tau/sigma. To avoid this, the conversion now returns the right value for positive tau or sigma, but negative images if the inputs were negative. The methods that then validate the paramters (such as logp, logcdf, random), can later catch this. --- pymc/distributions/continuous.py | 10 +++++---- tests/distributions/test_continuous.py | 31 ++++++++++++++------------ 2 files changed, 23 insertions(+), 18 deletions(-) diff --git a/pymc/distributions/continuous.py b/pymc/distributions/continuous.py index 8688cb04c2..e2368db2ee 100644 --- a/pymc/distributions/continuous.py +++ b/pymc/distributions/continuous.py @@ -234,24 +234,26 @@ def get_tau_sigma(tau=None, sigma=None): tau = 1.0 else: if isinstance(sigma, Variable): - sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0") + # Keep tau negative, if sigma was negative, so that it will fail when used + tau = (sigma**-2.0) * pt.sgn(sigma) else: sigma_ = np.asarray(sigma) if np.any(sigma_ <= 0): raise ValueError("sigma must be positive") - tau = sigma_**-2.0 + tau = sigma_**-2.0 else: if sigma is not None: raise ValueError("Can't pass both tau and sigma") else: if isinstance(tau, Variable): - tau_ = check_parameters(tau, tau > 0, msg="tau > 0") + # Keep sigma negative, if tau was negative, so that it will fail when used + sigma = pt.abs(tau) ** (-0.5) * pt.sgn(tau) else: tau_ = np.asarray(tau) if np.any(tau_ <= 0): raise ValueError("tau must be positive") - sigma = tau_**-0.5 + sigma = tau_**-0.5 return floatX(tau), floatX(sigma) diff --git a/tests/distributions/test_continuous.py b/tests/distributions/test_continuous.py index 8857a569fa..5c0ad13de3 100644 --- a/tests/distributions/test_continuous.py +++ b/tests/distributions/test_continuous.py @@ -13,6 +13,7 @@ # limitations under the License. import functools as ft +import warnings import numpy as np import numpy.testing as npt @@ -890,24 +891,26 @@ def scipy_logp(value, mu, sigma, lower, upper): assert np.isinf(logp[2]) def test_get_tau_sigma(self): - sigma = np.array(2) - npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) + # Fail on warnings + with warnings.catch_warnings(): + warnings.simplefilter("error") - tau = np.array(2) - npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) + sigma = np.array(2) + npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma]) - tau, _ = get_tau_sigma(sigma=pt.constant(-2)) - with pytest.raises(ParameterValueError): - tau.eval() + tau = np.array(2) + npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5]) - _, sigma = get_tau_sigma(tau=pt.constant(-2)) - with pytest.raises(ParameterValueError): - sigma.eval() + tau, _ = get_tau_sigma(sigma=pt.constant(-2)) + npt.assert_almost_equal(tau.eval(), -0.25) - sigma = [1, 2] - npt.assert_almost_equal( - get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] - ) + _, sigma = get_tau_sigma(tau=pt.constant(-2)) + npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2)) + + sigma = [1, 2] + npt.assert_almost_equal( + get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)] + ) @pytest.mark.parametrize( "value,mu,sigma,nu,logp",