Skip to content

Commit

Permalink
Adding support for Discrete distribution for max logprob
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Aug 6, 2023
1 parent ccad4c8 commit 813166a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
55 changes: 54 additions & 1 deletion pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ class MeasurableMax(Max):
MeasurableVariable.register(MeasurableMax)


class MeasurableMaxDiscrete(Max):
"""A placeholder used to specify a log-likelihood for a cmax sub-graph."""


MeasurableVariable.register(MeasurableMaxDiscrete)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -83,6 +90,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

<<<<<<< HEAD
# TODO: We are currently only supporting continuous rvs
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
return None
Expand All @@ -92,17 +100,39 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if params.type.ndim != 0:
return None

=======
# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
return None

>>>>>>> Adding support for Discrete distribution for max logprob
# Check whether axis covers all dimensions
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

<<<<<<< HEAD
measurable_max = MeasurableMax(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs
=======
# logprob for discrete distribution
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
# logprob for continuous distribution
else:
measurable_max = MeasurableMax(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs
>>>>>>> Adding support for Discrete distribution for max logprob

return max_rv
return max_rv


measurable_ir_rewrites_db.register(
Expand All @@ -126,3 +156,26 @@ def max_logprob(op, values, base_rv, **kwargs):
logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob
<<<<<<< HEAD
=======


@_logprob.register(MeasurableMaxDiscrete)
def max_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distrivution respectively.
"""
(value,) = values
logprob = _logprob_helper(base_rv, value)
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

n = base_rv.size

logprob = pt.log((pt.exp(logcdf)) ** n - (pt.exp(logcdf_prev)) ** n)

return logprob
>>>>>>> Adding support for Discrete distribution for max logprob
13 changes: 13 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,16 @@ def test_max_logprob(shape, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)
<<<<<<< HEAD
=======


def test_max_discrete():
x = pm.DiscreteUniform.dist(0, 1, size=(3,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.scalar("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

x_max_logprob.eval({x_max_value: 0.85})
>>>>>>> Adding support for Discrete distribution for max logprob

0 comments on commit 813166a

Please sign in to comment.