diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index b17199aa74a..2d32ac4ce1c 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -166,11 +166,11 @@ class MeasurableMaxNeg(Max): MeasurableVariable.register(MeasurableMaxNeg) -class MeasurableMaxNegDiscrete(Max): +class DiscreteMeasurableMaxNeg(Max): """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" -MeasurableVariable.register(MeasurableMaxNegDiscrete) +MeasurableVariable.register(DiscreteMeasurableMaxNeg) @node_rewriter(tracks=[Max]) @@ -215,10 +215,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - if isinstance(base_rv.owner.op, RandomVariable): - measurable_min = MeasurableMaxNegDiscrete(list(axis)) - else: - return None + measurable_min = DiscreteMeasurableMaxNeg(list(axis)) else: measurable_min = MeasurableMaxNeg(list(axis)) @@ -253,8 +250,8 @@ def max_neg_logprob(op, values, base_rv, **kwargs): return logprob -@_logprob.register(MeasurableMaxNegDiscrete) -def max_neg_logprob_discrete(op, values, base_rv, **kwargs): +@_logprob.register(DiscreteMeasurableMaxNeg) +def discrete_max_neg_logprob(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. The formula that we use here is : diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 6cd61f58180..5e17ab0df9e 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -443,7 +443,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li [measurable_input] = measurable_inputs [measurable_output] = node.outputs - # Do not apply rewrite to discrete variables + # Do not apply rewrite to discrete variables except for their addition and negation if measurable_input.type.dtype.startswith("int"): if not ( find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add) diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index 0db767b3090..4c95a0b7d30 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -50,6 +50,7 @@ from pytensor.raise_op import CheckAndRaise from pytensor.scalar.basic import Mul from pytensor.tensor.basic import get_underlying_scalar_constant_value +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.random.op import RandomVariable from pytensor.tensor.variable import TensorVariable @@ -308,29 +309,24 @@ def expand(r): } -def find_negated_var(base_var): - """Make sure that the base variable involves a multiplication with -1""" +def find_negated_var(var): + """Return a variable that is being multiplied by -1 or None otherwise.""" - scalar_op = base_var.owner.op.scalar_op - - if len(base_var.owner.inputs) == 2: - if base_var.owner.inputs[0] is None: - base_rv = base_var.owner.inputs[0] - scalar_constant = base_var.owner.inputs[1] - else: - base_rv = base_var.owner.inputs[1] - scalar_constant = base_var.owner.inputs[0] - else: + if ( + not (var.owner) + and isinstance(var.owner.op, Elemwise) + and isinstance(var.owner.op.scalar_op, Mul) + ): return None - - try: - if not ( - isinstance(scalar_op, Mul) - and get_underlying_scalar_constant_value(scalar_constant) == -1 - ): - return None - - except NotScalarConstantError: + if len(var.owner.inputs) != 2: return None - return base_rv + inputs = var.owner.inputs + for mul_var, mul_const in (inputs, reversed(inputs)): + try: + if get_underlying_scalar_constant_value(mul_const) == -1: + return mul_var + except NotScalarConstantError: + continue + + return None