Skip to content

Commit

Permalink
Don't use check_parameters in get_tau_sigma.
Browse files Browse the repository at this point in the history
The default use of `check_parameters` indicates that an expression can be replaced by -inf, if the constraints aren't met. Instead, if `can_be_replaced_by_ninf=False`, sampling would fail for negative tau/sigma.

To avoid this, the conversion now returns the right value for positive tau or sigma, but negative images if the inputs were negative. The methods that then validate the paramters (such as logp, logcdf, random), can later catch this.
  • Loading branch information
ricardoV94 committed Mar 31, 2023
1 parent 3f2a1da commit 4c64eb9
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
10 changes: 6 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,26 @@ def get_tau_sigma(tau=None, sigma=None):
tau = 1.0
else:
if isinstance(sigma, Variable):
sigma_ = check_parameters(sigma, sigma > 0, msg="sigma > 0")
# Keep tau negative, if sigma was negative, so that it will fail when used
tau = (sigma**-2.0) * pt.sgn(sigma)
else:
sigma_ = np.asarray(sigma)
if np.any(sigma_ <= 0):
raise ValueError("sigma must be positive")
tau = sigma_**-2.0
tau = sigma_**-2.0

else:
if sigma is not None:
raise ValueError("Can't pass both tau and sigma")
else:
if isinstance(tau, Variable):
tau_ = check_parameters(tau, tau > 0, msg="tau > 0")
# Keep sigma negative, if tau was negative, so that it will fail when used
sigma = pt.abs(tau) ** (-0.5) * pt.sgn(tau)
else:
tau_ = np.asarray(tau)
if np.any(tau_ <= 0):
raise ValueError("tau must be positive")
sigma = tau_**-0.5
sigma = tau_**-0.5

return floatX(tau), floatX(sigma)

Expand Down
31 changes: 17 additions & 14 deletions tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import functools as ft
import warnings

import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -890,24 +891,26 @@ def scipy_logp(value, mu, sigma, lower, upper):
assert np.isinf(logp[2])

def test_get_tau_sigma(self):
sigma = np.array(2)
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])
# Fail on warnings
with warnings.catch_warnings():
warnings.simplefilter("error")

tau = np.array(2)
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])
sigma = np.array(2)
npt.assert_almost_equal(get_tau_sigma(sigma=sigma), [1.0 / sigma**2, sigma])

tau, _ = get_tau_sigma(sigma=pt.constant(-2))
with pytest.raises(ParameterValueError):
tau.eval()
tau = np.array(2)
npt.assert_almost_equal(get_tau_sigma(tau=tau), [tau, tau**-0.5])

_, sigma = get_tau_sigma(tau=pt.constant(-2))
with pytest.raises(ParameterValueError):
sigma.eval()
tau, _ = get_tau_sigma(sigma=pt.constant(-2))
npt.assert_almost_equal(tau.eval(), -0.25)

sigma = [1, 2]
npt.assert_almost_equal(
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
)
_, sigma = get_tau_sigma(tau=pt.constant(-2))
npt.assert_almost_equal(sigma.eval(), -np.sqrt(1 / 2))

sigma = [1, 2]
npt.assert_almost_equal(
get_tau_sigma(sigma=sigma), [1.0 / np.array(sigma) ** 2, np.array(sigma)]
)

@pytest.mark.parametrize(
"value,mu,sigma,nu,logp",
Expand Down

0 comments on commit 4c64eb9

Please sign in to comment.