From 419af0688353292c0a356cddbb9271737e89a723 Mon Sep 17 00:00:00 2001 From: Luke Lewis-Borrell <35955390+LukeLB@users.noreply.github.com> Date: Thu, 26 Oct 2023 06:24:39 +0100 Subject: [PATCH] Support logp derivation of `power(base, rv)` (#6962) Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-authored-by: Luke LB --- pymc/logprob/transforms.py | 25 ++++++++++++-- tests/logprob/test_transforms.py | 56 ++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 588f350fb7..c2f038ad59 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -127,7 +127,7 @@ cleanup_ir_rewrites_db, measurable_ir_rewrites_db, ) -from pymc.logprob.utils import check_potential_measurability +from pymc.logprob.utils import CheckParameterValue, check_potential_measurability class TransformedVariable(Op): @@ -617,6 +617,21 @@ def measurable_special_exp_to_exp(fgraph, node): return [1 / (1 + pt.exp(-inp))] +@node_rewriter([pow]) +def measurable_power_exponent_to_exp(fgraph, node): + """Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form.""" + base, inp_exponent = node.inputs + + # When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite. + # Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform` + if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()): + return None + + base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0))) + + return [pt.exp(pt.log(base) * inp_exponent)] + + @node_rewriter( [ exp, @@ -693,7 +708,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li try: (power,) = other_inputs power = pt.get_underlying_scalar_constant_value(power).item() - # Power needs to be a constant + # Power needs to be a constant, if not then proceed to the other case power(base, rv) except NotScalarConstantError: return None transform_inputs = (measurable_input, power) @@ -769,6 +784,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_power_expotent_to_exp", + measurable_power_exponent_to_exp, + "basic", + "transform", +) measurable_ir_rewrites_db.register( "find_measurable_transforms", diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index dfb9bc8770..32924a37d2 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -72,6 +72,7 @@ TransformValuesMapping, TransformValuesRewrite, ) +from pymc.logprob.utils import ParameterValueError from pymc.testing import Rplusbig, Vector, assert_no_rvs from tests.distributions.test_transform import check_jacobian_det @@ -1159,6 +1160,61 @@ def test_special_log_exp_transforms(transform): assert equal_computations([logp_test], [logp_ref]) +def test_measurable_power_exponent_with_constant_base(): + # test power(2, rv) = exp2(rv) + # test negative base fails + x_rv_pow = pt.pow(2, pt.random.normal()) + x_rv_exp2 = pt.exp2(pt.random.normal()) + + x_vv_pow = x_rv_pow.clone() + x_vv_exp2 = x_rv_exp2.clone() + + x_logp_fn_pow = pytensor.function([x_vv_pow], pt.sum(logp(x_rv_pow, x_vv_pow))) + x_logp_fn_exp2 = pytensor.function([x_vv_exp2], pt.sum(logp(x_rv_exp2, x_vv_exp2))) + + np.testing.assert_allclose(x_logp_fn_pow(0.1), x_logp_fn_exp2(0.1)) + + with pytest.raises(ParameterValueError, match="base >= 0"): + x_rv_neg = pt.pow(-2, pt.random.normal()) + x_vv_neg = x_rv_neg.clone() + logp(x_rv_neg, x_vv_neg) + + +def test_measurable_power_exponent_with_variable_base(): + # test with RV when logp(<0) we raise error + base_rv = pt.random.normal([2]) + x_raw_rv = pt.random.normal() + x_rv = pt.power(base_rv, x_raw_rv) + + x_rv.name = "x" + base_rv.name = "base" + base_vv = base_rv.clone() + x_vv = x_rv.clone() + + res = conditional_logp({base_rv: base_vv, x_rv: x_vv}) + x_logp = res[x_vv] + logp_vals_fn = pytensor.function([base_vv, x_vv], x_logp) + + with pytest.raises(ParameterValueError, match="base >= 0"): + logp_vals_fn(np.array([-2]), np.array([2])) + + +def test_base_exponent_non_measurable(): + # test dual sources of measuravility fails + base_rv = pt.random.normal([2]) + x_raw_rv = pt.random.normal() + x_rv = pt.power(base_rv, x_raw_rv) + x_rv.name = "x" + + x_vv = x_rv.clone() + + with pytest.raises( + RuntimeError, + match="The logprob terms of the following value variables could not be derived: {x}", + ): + conditional_logp({x_rv: x_vv}) + + @pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])]) @pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])]) def test_multivariate_rv_transform(shift, scale):