diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index bd332d9f75..c8d2a74a41 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -42,7 +42,9 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.scalar.basic import Mul +from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise +from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Max from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.var import TensorVariable @@ -130,14 +132,15 @@ def max_logprob(op, values, base_rv, **kwargs): class MeasurableMaxNeg(Max): - """A placeholder used to specify a log-likelihood for a min sub-graph.""" + """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. + This shows up in the graph of min, which is (neg(max(neg(x))).""" MeasurableVariable.register(MeasurableMaxNeg) @node_rewriter(tracks=[Max]) -def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: +def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: @@ -154,12 +157,19 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not rv_map_feature.request_measurable(node.inputs): return None - # Min is the Max of the negation of the same distribution. Hence, op must be Elemiwise + # Min is the Max of the negation of the same distribution. Hence, op must be Elemwise if not isinstance(base_var.owner.op, Elemwise): return None - # negation is -1*(rv). Hence the scalar_op must be Mul - if not isinstance(base_var.owner.op.scalar_op, Mul): + # negation is rv * (-1). Hence the scalar_op must be Mul + try: + if not ( + isinstance(base_var.owner.op.scalar_op, Mul) + and len(base_var.owner.inputs) == 2 + and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1 + ): + return None + except NotScalarConstantError: return None base_rv = base_var.owner.inputs[0] @@ -191,8 +201,8 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens measurable_ir_rewrites_db.register( - "find_measurable_min", - find_measurable_min, + "find_measurable_max_neg", + find_measurable_max_neg, "basic", "min", ) diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 27b700ba02..ff51199491 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -58,104 +58,87 @@ def test_argmax(): @pytest.mark.parametrize( - "if_max", + "pt_op", [ - True, - False, + pt.max, + pt.min, ], ) -def test_non_iid_fails(if_max): +def test_non_iid_fails(pt_op): """Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected""" x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,)) x.name = "x" - if if_max == True: - x_m = pt.max(x, axis=-1) - x_m_value = pt.vector("x_max_value") - else: - x_m = pt.min(x, axis=-1) - x_m_value = pt.vector("x_min_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @pytest.mark.parametrize( - "if_max", - [True, False], + "pt_op", + [ + pt.max, + pt.min, + ], ) -def test_non_rv_fails(if_max): +def test_non_rv_fails(pt_op): """Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected""" x = pt.exp(pt.random.beta(0, 1, size=(3,))) x.name = "x" - if if_max == True: - x_m = pt.max(x, axis=-1) - x_m_value = pt.vector("x_max_value") - else: - x_m = pt.min(x, axis=-1) - x_m_value = pt.vector("x_min_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @pytest.mark.parametrize( - "if_max", + "pt_op", [ - True, - False, + pt.max, + pt.min, ], ) -def test_multivariate_rv_fails(if_max): +def test_multivariate_rv_fails(pt_op): _alpha = pt.scalar() _k = pt.iscalar() x = pm.StickBreakingWeights.dist(_alpha, _k) x.name = "x" - if if_max == True: - x_m = pt.max(x, axis=-1) - x_m_value = pt.vector("x_max_value") - else: - x_m = pt.min(x, axis=-1) - x_m_value = pt.vector("x_min_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @pytest.mark.parametrize( - "if_max", + "pt_op", [ - True, - False, + pt.max, + pt.min, ], ) -def test_categorical(if_max): +def test_categorical(pt_op): """Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected""" x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,)) x.name = "x" - if if_max == True: - x_m = pt.max(x, axis=-1) - x_m_value = pt.vector("x_max_value") - else: - x_m = pt.min(x, axis=-1) - x_m_value = pt.vector("x_min_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value) @pytest.mark.parametrize( - "if_max", + "pt_op", [ - True, - False, + pt.max, + pt.min, ], ) -def test_non_supp_axis(if_max): +def test_non_supp_axis(pt_op): """Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected""" x = pt.random.normal(0, 1, size=(3, 3)) x.name = "x" - if if_max == True: - x_m = pt.max(x, axis=-1) - x_m_value = pt.vector("x_max_value") - else: - x_m = pt.min(x, axis=-1) - x_m_value = pt.vector("x_min_value") + x_m = pt_op(x, axis=-1) + x_m_value = pt.vector("x_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_max_logprob = logp(x_m, x_m_value)