diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index b786aa52d43..9b54a535b9c 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -90,34 +90,17 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): return None -<<<<<<< HEAD - # TODO: We are currently only supporting continuous rvs - if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): - return None - - # univariate i.i.d. test which also rules out other distributions - for params in base_var.owner.inputs[3:]: - if params.type.ndim != 0: - return None - -======= # univariate i.i.d. test which also rules out other distributions for params in base_var.owner.inputs[3:]: if params.type.ndim != 0: return None ->>>>>>> Adding support for Discrete distribution for max logprob # Check whether axis covers all dimensions axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) if axis != base_var_dims: return None -<<<<<<< HEAD - measurable_max = MeasurableMax(list(axis)) - max_rv_node = measurable_max.make_node(base_var) - max_rv = max_rv_node.outputs -======= # logprob for discrete distribution if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): measurable_max = MeasurableMaxDiscrete(list(axis)) @@ -130,7 +113,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs ->>>>>>> Adding support for Discrete distribution for max logprob return max_rv @@ -156,8 +138,6 @@ def max_logprob(op, values, base_rv, **kwargs): logprob = (n - 1) * logcdf + logprob + pt.math.log(n) return logprob -<<<<<<< HEAD -======= @_logprob.register(MeasurableMaxDiscrete) @@ -178,4 +158,3 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n) return logprob ->>>>>>> Adding support for Discrete distribution for max logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 6c577ebc26e..76ed6d44e7d 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -147,8 +147,6 @@ def test_max_logprob(shape, value, axis): (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) -<<<<<<< HEAD -======= def test_max_discrete(): @@ -159,4 +157,3 @@ def test_max_discrete(): x_max_logprob = logp(x_max, x_max_value) x_max_logprob.eval({x_max_value: 0.85}) ->>>>>>> Adding support for Discrete distribution for max logprob