Skip to content

Commit

Permalink
Added Logprob for discrete minimum
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Dec 5, 2023
1 parent 69f1b32 commit 4962e85
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 31 deletions.
13 changes: 5 additions & 8 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ class MeasurableMaxNeg(Max):
MeasurableVariable.register(MeasurableMaxNeg)


class MeasurableMaxNegDiscrete(Max):
class DiscreteMeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxNegDiscrete)
MeasurableVariable.register(DiscreteMeasurableMaxNeg)


@node_rewriter(tracks=[Max])
Expand Down Expand Up @@ -215,10 +215,7 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[

# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
if isinstance(base_rv.owner.op, RandomVariable):
measurable_min = MeasurableMaxNegDiscrete(list(axis))
else:
return None
measurable_min = DiscreteMeasurableMaxNeg(list(axis))
else:
measurable_min = MeasurableMaxNeg(list(axis))

Expand Down Expand Up @@ -253,8 +250,8 @@ def max_neg_logprob(op, values, base_rv, **kwargs):
return logprob


@_logprob.register(MeasurableMaxNegDiscrete)
def max_neg_logprob_discrete(op, values, base_rv, **kwargs):
@_logprob.register(DiscreteMeasurableMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
Expand Down
2 changes: 1 addition & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
[measurable_input] = measurable_inputs
[measurable_output] = node.outputs

# Do not apply rewrite to discrete variables
# Do not apply rewrite to discrete variables except for their addition and negation
if measurable_input.type.dtype.startswith("int"):
if not (
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
Expand Down
40 changes: 18 additions & 22 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand Down Expand Up @@ -308,29 +309,24 @@ def expand(r):
}


def find_negated_var(base_var):
"""Make sure that the base variable involves a multiplication with -1"""
def find_negated_var(var):
"""Return a variable that is being multiplied by -1 or None otherwise."""

scalar_op = base_var.owner.op.scalar_op

if len(base_var.owner.inputs) == 2:
if base_var.owner.inputs[0] is None:
base_rv = base_var.owner.inputs[0]
scalar_constant = base_var.owner.inputs[1]
else:
base_rv = base_var.owner.inputs[1]
scalar_constant = base_var.owner.inputs[0]
else:
if (
not (var.owner)
and isinstance(var.owner.op, Elemwise)
and isinstance(var.owner.op.scalar_op, Mul)
):
return None

Check warning on line 320 in pymc/logprob/utils.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/utils.py#L320

Added line #L320 was not covered by tests

try:
if not (
isinstance(scalar_op, Mul)
and get_underlying_scalar_constant_value(scalar_constant) == -1
):
return None

except NotScalarConstantError:
if len(var.owner.inputs) != 2:
return None

return base_rv
inputs = var.owner.inputs
for mul_var, mul_const in (inputs, reversed(inputs)):
try:
if get_underlying_scalar_constant_value(mul_const) == -1:
return mul_var
except NotScalarConstantError:
continue

return None

0 comments on commit 4962e85

Please sign in to comment.