From a0e811d17ebf3cbbedf07b9f202f05be8551af12 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 3 Jan 2024 12:22:23 +0530 Subject: [PATCH] adding changes final changes --- pymc/logprob/order.py | 12 +++++------- pymc/logprob/transforms.py | 9 ++++----- tests/logprob/test_transforms.py | 2 -- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 7c2b911d7e4..17d2c414279 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -166,11 +166,11 @@ class MeasurableMaxNeg(Max): MeasurableVariable.register(MeasurableMaxNeg) -class DiscreteMeasurableMaxNeg(Max): +class MeasurableDiscreteMaxNeg(Max): """A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables""" -MeasurableVariable.register(DiscreteMeasurableMaxNeg) +MeasurableVariable.register(MeasurableDiscreteMaxNeg) @node_rewriter(tracks=[Max]) @@ -215,13 +215,11 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[ # distinguish measurable discrete and continuous (because logprob is different) if base_rv.owner.op.dtype.startswith("int"): - measurable_min = DiscreteMeasurableMaxNeg(list(axis)) + measurable_min = MeasurableDiscreteMaxNeg(list(axis)) else: measurable_min = MeasurableMaxNeg(list(axis)) - min_rv_node = measurable_min.make_node(base_rv) - min_rv = min_rv_node.outputs - return min_rv + return measurable_min.make_node(base_rv).outputs measurable_ir_rewrites_db.register( @@ -250,7 +248,7 @@ def max_neg_logprob(op, values, base_rv, **kwargs): return logprob -@_logprob.register(DiscreteMeasurableMaxNeg) +@_logprob.register(MeasurableDiscreteMaxNeg) def discrete_max_neg_logprob(op, values, base_rv, **kwargs): r"""Compute the log-likelihood graph for the `Max` operation. diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index ef67cdeb6fe..df178930215 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -235,7 +235,7 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): - return NotImplementedError + return NotImplementedError("logcdf of transformed discrete variables not implemented") backward_value = op.transform_elemwise.backward(value, *other_inputs) @@ -283,7 +283,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs) # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): - return NotImplementedError + return NotImplementedError("icdf of transformed discrete variables not implemented") if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS): pass @@ -445,10 +445,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li # Do not apply rewrite to discrete variables except for their addition and negation if measurable_input.type.dtype.startswith("int"): - if not ( - find_negated_var(measurable_output) is not None or isinstance(node.op.scalar_op, Add) - ): + if not (find_negated_var(measurable_output) or isinstance(node.op.scalar_op, Add)): return None + # Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable if not measurable_output.type.dtype.startswith("int"): return None diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index ca5e31a9520..3a66f2ba28c 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -43,8 +43,6 @@ from pytensor.graph.basic import equal_computations -import pymc as pm - from pymc.distributions.continuous import Cauchy, ChiSquared from pymc.distributions.discrete import Bernoulli from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp