diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index be8d688d809..b786aa52d43 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -62,6 +62,13 @@ class MeasurableMax(Max): MeasurableVariable.register(MeasurableMax) +class MeasurableMaxDiscrete(Max): + """A placeholder used to specify a log-likelihood for a cmax sub-graph.""" + + +MeasurableVariable.register(MeasurableMaxDiscrete) + + @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) @@ -83,6 +90,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): return None +<<<<<<< HEAD # TODO: We are currently only supporting continuous rvs if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): return None @@ -92,17 +100,39 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if params.type.ndim != 0: return None +======= + # univariate i.i.d. test which also rules out other distributions + for params in base_var.owner.inputs[3:]: + if params.type.ndim != 0: + return None + +>>>>>>> Adding support for Discrete distribution for max logprob # Check whether axis covers all dimensions axis = set(node.op.axis) base_var_dims = set(range(base_var.ndim)) if axis != base_var_dims: return None +<<<<<<< HEAD measurable_max = MeasurableMax(list(axis)) max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs +======= + # logprob for discrete distribution + if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): + measurable_max = MeasurableMaxDiscrete(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs + + return max_rv + # logprob for continuous distribution + else: + measurable_max = MeasurableMax(list(axis)) + max_rv_node = measurable_max.make_node(base_var) + max_rv = max_rv_node.outputs +>>>>>>> Adding support for Discrete distribution for max logprob - return max_rv + return max_rv measurable_ir_rewrites_db.register( @@ -126,3 +156,26 @@ def max_logprob(op, values, base_rv, **kwargs): logprob = (n - 1) * logcdf + logprob + pt.math.log(n) return logprob +<<<<<<< HEAD +======= + + +@_logprob.register(MeasurableMaxDiscrete) +def max_logprob_discrete(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + + The formula that we use here is : + \ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively. + """ + (value,) = values + logprob = _logprob_helper(base_rv, value) + logcdf = _logcdf_helper(base_rv, value) + logcdf_prev = _logcdf_helper(base_rv, value - 1) + + n = base_rv.size + + logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n) + + return logprob +>>>>>>> Adding support for Discrete distribution for max logprob diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index 5a3818716dc..6c577ebc26e 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -147,3 +147,16 @@ def test_max_logprob(shape, value, axis): (x_max_logprob.eval({x_max_value: test_value})), rtol=1e-06, ) +<<<<<<< HEAD +======= + + +def test_max_discrete(): + x = pm.DiscreteUniform.dist(0, 1, size=(3,)) + x.name = "x" + x_max = pt.max(x, axis=-1) + x_max_value = pt.scalar("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + x_max_logprob.eval({x_max_value: 0.85}) +>>>>>>> Adding support for Discrete distribution for max logprob