Skip to content

Commit

Permalink
adding changes final changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jan 3, 2024
1 parent 81dfd8c commit a0e811d
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 14 deletions.
12 changes: 5 additions & 7 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ class MeasurableMaxNeg(Max):
MeasurableVariable.register(MeasurableMaxNeg)


class DiscreteMeasurableMaxNeg(Max):
class MeasurableDiscreteMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(DiscreteMeasurableMaxNeg)
MeasurableVariable.register(MeasurableDiscreteMaxNeg)


@node_rewriter(tracks=[Max])
Expand Down Expand Up @@ -215,13 +215,11 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[

# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
measurable_min = DiscreteMeasurableMaxNeg(list(axis))
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
else:
measurable_min = MeasurableMaxNeg(list(axis))

min_rv_node = measurable_min.make_node(base_rv)
min_rv = min_rv_node.outputs
return min_rv
return measurable_min.make_node(base_rv).outputs


measurable_ir_rewrites_db.register(
Expand Down Expand Up @@ -250,7 +248,7 @@ def max_neg_logprob(op, values, base_rv, **kwargs):
return logprob


@_logprob.register(DiscreteMeasurableMaxNeg)
@_logprob.register(MeasurableDiscreteMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
Expand Down
9 changes: 4 additions & 5 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
return NotImplementedError
return NotImplementedError("logcdf of transformed discrete variables not implemented")

Check warning on line 238 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L238

Added line #L238 was not covered by tests

backward_value = op.transform_elemwise.backward(value, *other_inputs)

Expand Down Expand Up @@ -283,7 +283,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
return NotImplementedError
return NotImplementedError("icdf of transformed discrete variables not implemented")

Check warning on line 286 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L286

Added line #L286 was not covered by tests

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
Expand Down Expand Up @@ -445,10 +445,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

# Do not apply rewrite to discrete variables except for their addition and negation
if measurable_input.type.dtype.startswith("int"):
if not (
find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add)
):
if not (find_negated_var(measurable_output) or isinstance(node.op.scalar_op, Add)):
return None
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
if not measurable_output.type.dtype.startswith("int"):
return None

Check warning on line 452 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L452

Added line #L452 was not covered by tests

Expand Down
2 changes: 0 additions & 2 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@

from pytensor.graph.basic import equal_computations

import pymc as pm

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
Expand Down

0 comments on commit a0e811d

Please sign in to comment.