Skip to content

Commit

Permalink
Logprob derivation of Max for Discrete IID distributions (#6790)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi authored Oct 24, 2023
1 parent c3f93ba commit efa0d34
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 5 deletions.
39 changes: 34 additions & 5 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold


Expand All @@ -66,6 +67,13 @@ class MeasurableMax(Max):
MeasurableVariable.register(MeasurableMax)


class MeasurableMaxDiscrete(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables"""


MeasurableVariable.register(MeasurableMaxDiscrete)


@node_rewriter([Max])
def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -87,10 +95,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_var.owner.inputs[3:]:
if params.type.ndim != 0:
Expand All @@ -102,7 +106,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if axis != base_var_dims:
return None

measurable_max = MeasurableMax(list(axis))
# distinguish measurable discrete and continuous (because logprob is different)
if base_var.owner.op.dtype.startswith("int"):
measurable_max = MeasurableMaxDiscrete(list(axis))
else:
measurable_max = MeasurableMax(list(axis))

max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

Expand Down Expand Up @@ -131,6 +140,26 @@ def max_logprob(op, values, base_rv, **kwargs):
return logprob


@_logprob.register(MeasurableMaxDiscrete)
def max_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

[n] = constant_fold([base_rv.size])

logprob = logdiffexp(n * logcdf, n * logcdf_prev)

return logprob


class MeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""
Expand Down
24 changes: 24 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import numpy as np
import pytensor.tensor as pt
import pytest
import scipy.stats as sp

import pymc as pm

Expand Down Expand Up @@ -230,3 +231,26 @@ def test_min_non_mul_elemwise_fails():
x_min_value = pt.vector("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)


@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x_max = pt.max(x, axis=axis)
x_max_value = pt.scalar("x_max_value")
x_max_logprob = logp(x_max, x_max_value)

test_value = value

n = size
exp_rv = sp.poisson(mu).cdf(test_value) ** n
exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n

np.testing.assert_allclose(
np.log(exp_rv - exp_rv_prev),
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)

0 comments on commit efa0d34

Please sign in to comment.