From efa0d3494267463606a6fe89ebaee507dbf785b9 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi <104030847+Dhruvanshu-Joshi@users.noreply.github.com> Date: Tue, 24 Oct 2023 18:53:37 +0530 Subject: [PATCH] Logprob derivation of Max for Discrete IID distributions (#6790) --- pymc/logprob/order.py | 39 ++++++++++++++++++++++++++++++++----- tests/logprob/test_order.py | 24 +++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index f76428f83c..35b84542db 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -56,6 +56,7 @@ _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db +from pymc.math import logdiffexp from pymc.pytensorf import constant_fold @@ -66,6 +67,13 @@ class MeasurableMax(Max): MeasurableVariable.register(MeasurableMax) +class MeasurableMaxDiscrete(Max): + """A placeholder used to specify a log-likelihood for sub-graphs of maxima of discrete variables""" + + +MeasurableVariable.register(MeasurableMaxDiscrete) + + @node_rewriter([Max]) def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]: rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None) @@ -87,10 +95,6 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0): return None - # TODO: We are currently only supporting continuous rvs - if isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.dtype.startswith("int"): - return None - # univariate i.i.d. test which also rules out other distributions for params in base_var.owner.inputs[3:]: if params.type.ndim != 0: @@ -102,7 +106,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens if axis != base_var_dims: return None - measurable_max = MeasurableMax(list(axis)) + # distinguish measurable discrete and continuous (because logprob is different) + if base_var.owner.op.dtype.startswith("int"): + measurable_max = MeasurableMaxDiscrete(list(axis)) + else: + measurable_max = MeasurableMax(list(axis)) + max_rv_node = measurable_max.make_node(base_var) max_rv = max_rv_node.outputs @@ -131,6 +140,26 @@ def max_logprob(op, values, base_rv, **kwargs): return logprob +@_logprob.register(MeasurableMaxDiscrete) +def max_logprob_discrete(op, values, base_rv, **kwargs): + r"""Compute the log-likelihood graph for the `Max` operation. + + The formula that we use here is : + .. math:: + \ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n) + where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables. + """ + (value,) = values + logcdf = _logcdf_helper(base_rv, value) + logcdf_prev = _logcdf_helper(base_rv, value - 1) + + [n] = constant_fold([base_rv.size]) + + logprob = logdiffexp(n * logcdf, n * logcdf_prev) + + return logprob + + class MeasurableMaxNeg(Max): """A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph. This shows up in the graph of min, which is (neg(max(neg(x))).""" diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index ff51199491..8eae026c0b 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -39,6 +39,7 @@ import numpy as np import pytensor.tensor as pt import pytest +import scipy.stats as sp import pymc as pm @@ -230,3 +231,26 @@ def test_min_non_mul_elemwise_fails(): x_min_value = pt.vector("x_min_value") with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")): x_min_logprob = logp(x_min, x_min_value) + + +@pytest.mark.parametrize( + "mu, size, value, axis", + [(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)], +) +def test_max_discrete(mu, size, value, axis): + x = pm.Poisson.dist(name="x", mu=mu, size=(size)) + x_max = pt.max(x, axis=axis) + x_max_value = pt.scalar("x_max_value") + x_max_logprob = logp(x_max, x_max_value) + + test_value = value + + n = size + exp_rv = sp.poisson(mu).cdf(test_value) ** n + exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n + + np.testing.assert_allclose( + np.log(exp_rv - exp_rv_prev), + (x_max_logprob.eval({x_max_value: test_value})), + rtol=1e-06, + )