Skip to content

Commit

Permalink
Support logp derivation of power(base, rv) (#6962)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>

Co-authored-by: Luke LB <[email protected]>
  • Loading branch information
LukeLB and Luke LB authored Oct 26, 2023
1 parent c53277b commit 419af06
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 2 deletions.
25 changes: 23 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.utils import check_potential_measurability
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability


class TransformedVariable(Op):
Expand Down Expand Up @@ -617,6 +617,21 @@ def measurable_special_exp_to_exp(fgraph, node):
return [1 / (1 + pt.exp(-inp))]


@node_rewriter([pow])
def measurable_power_exponent_to_exp(fgraph, node):
"""Convert power(base, rv) of `MeasurableVariable`s to exp(log(base) * rv) form."""
base, inp_exponent = node.inputs

# When the base is measurable we have `power(rv, exponent)`, which should be handled by `PowerTransform` and needs no further rewrite.
# Here we change only the cases where exponent is measurable `power(base, rv)` which is not supported by the `PowerTransform`
if check_potential_measurability([base], fgraph.preserve_rv_mappings.rv_values.keys()):
return None

base = CheckParameterValue("base >= 0")(base, pt.all(pt.ge(base, 0.0)))

return [pt.exp(pt.log(base) * inp_exponent)]


@node_rewriter(
[
exp,
Expand Down Expand Up @@ -693,7 +708,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
try:
(power,) = other_inputs
power = pt.get_underlying_scalar_constant_value(power).item()
# Power needs to be a constant
# Power needs to be a constant, if not then proceed to the other case power(base, rv)
except NotScalarConstantError:
return None
transform_inputs = (measurable_input, power)
Expand Down Expand Up @@ -769,6 +784,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_power_expotent_to_exp",
measurable_power_exponent_to_exp,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"find_measurable_transforms",
Expand Down
56 changes: 56 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
TransformValuesMapping,
TransformValuesRewrite,
)
from pymc.logprob.utils import ParameterValueError
from pymc.testing import Rplusbig, Vector, assert_no_rvs
from tests.distributions.test_transform import check_jacobian_det

Expand Down Expand Up @@ -1159,6 +1160,61 @@ def test_special_log_exp_transforms(transform):
assert equal_computations([logp_test], [logp_ref])


def test_measurable_power_exponent_with_constant_base():
# test power(2, rv) = exp2(rv)
# test negative base fails
x_rv_pow = pt.pow(2, pt.random.normal())
x_rv_exp2 = pt.exp2(pt.random.normal())

x_vv_pow = x_rv_pow.clone()
x_vv_exp2 = x_rv_exp2.clone()

x_logp_fn_pow = pytensor.function([x_vv_pow], pt.sum(logp(x_rv_pow, x_vv_pow)))
x_logp_fn_exp2 = pytensor.function([x_vv_exp2], pt.sum(logp(x_rv_exp2, x_vv_exp2)))

np.testing.assert_allclose(x_logp_fn_pow(0.1), x_logp_fn_exp2(0.1))

with pytest.raises(ParameterValueError, match="base >= 0"):
x_rv_neg = pt.pow(-2, pt.random.normal())
x_vv_neg = x_rv_neg.clone()
logp(x_rv_neg, x_vv_neg)


def test_measurable_power_exponent_with_variable_base():
# test with RV when logp(<0) we raise error
base_rv = pt.random.normal([2])
x_raw_rv = pt.random.normal()
x_rv = pt.power(base_rv, x_raw_rv)

x_rv.name = "x"
base_rv.name = "base"
base_vv = base_rv.clone()
x_vv = x_rv.clone()

res = conditional_logp({base_rv: base_vv, x_rv: x_vv})
x_logp = res[x_vv]
logp_vals_fn = pytensor.function([base_vv, x_vv], x_logp)

with pytest.raises(ParameterValueError, match="base >= 0"):
logp_vals_fn(np.array([-2]), np.array([2]))


def test_base_exponent_non_measurable():
# test dual sources of measuravility fails
base_rv = pt.random.normal([2])
x_raw_rv = pt.random.normal()
x_rv = pt.power(base_rv, x_raw_rv)
x_rv.name = "x"

x_vv = x_rv.clone()

with pytest.raises(
RuntimeError,
match="The logprob terms of the following value variables could not be derived: {x}",
):
conditional_logp({x_rv: x_vv})


@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
def test_multivariate_rv_transform(shift, scale):
Expand Down

0 comments on commit 419af06

Please sign in to comment.