Skip to content

Commit

Permalink
Add logprob for discrete minimum of IID variables
Browse files Browse the repository at this point in the history
Co-authored-by: Dhruvanshu-Joshi <[email protected]>
  • Loading branch information
2 people authored and ricardoV94 committed Feb 5, 2024
1 parent dfc4788 commit 6554683
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 30 deletions.
80 changes: 50 additions & 30 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand All @@ -56,6 +53,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import find_negated_var
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand Down Expand Up @@ -168,6 +166,13 @@ class MeasurableMaxNeg(Max):
MeasurableVariable.register(MeasurableMaxNeg)


class MeasurableDiscreteMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(MeasurableDiscreteMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -180,37 +185,20 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
return None

base_rv = find_negated_var(base_var)

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
if base_rv is None:
return None

base_rv = base_var.owner.inputs[0]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
Expand All @@ -222,11 +210,16 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[
if axis != base_var_dims:
return None

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs
if not rv_map_feature.request_measurable([base_rv]):
return None

return min_rv
# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
else:
measurable_min = MeasurableMaxNeg(list(axis))

return measurable_min.make_node(base_rv).outputs


measurable_ir_rewrites_db.register(
Expand All @@ -238,14 +231,13 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[list[


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_var, **kwargs):
def max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values
base_rv = base_var.owner.inputs[0]

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)
Expand All @@ -254,3 +246,31 @@ def max_neg_logprob(op, values, base_var, **kwargs):
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob


@_logprob.register(MeasurableDiscreteMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""

(value,) = values

# The cdf of a negative variable is the survival at the negated value
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))

[n] = constant_fold([base_rv.size])

# Now we can use the same expression as the discrete max
logprob = pt.where(
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
-pt.inf,
logdiffexp(n * logcdf_prev, n * logcdf),
)

return logprob
27 changes: 27 additions & 0 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -296,3 +300,26 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):
(const_value,) = inputs
values, const_value = pt.broadcast_arrays(values, const_value)
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)


def find_negated_var(var):
"""Return a variable that is being multiplied by -1 or None otherwise."""

if (
not (var.owner)
and isinstance(var.owner.op, Elemwise)
and isinstance(var.owner.op.scalar_op, Mul)
):
return None
if len(var.owner.inputs) != 2:
return None

inputs = var.owner.inputs
for mul_var, mul_const in (inputs, reversed(inputs)):
try:
if get_underlying_scalar_constant_value(mul_const) == -1:
return mul_var
except NotScalarConstantError:
continue

return None
39 changes: 39 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import re

import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.stats as sp
Expand Down Expand Up @@ -254,3 +255,41 @@ def test_max_discrete(mu, size, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)


@pytest.mark.parametrize(
"mu, n, test_value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_min_discrete(mu, n, test_value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
sf = 1 - sp.poisson(mu).cdf(test_value)

expected_logp = np.log(sf_before**n - sf**n)

np.testing.assert_allclose(
x_min_logprob.eval({x_min_value: test_value}),
expected_logp,
rtol=1e-06,
)


def test_min_max_bernoulli():
p = 0.7
q = 1 - p
n = 3
x = pm.Bernoulli.dist(name="x", p=p, shape=(n,))
value = pt.scalar("value", dtype=int)

max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value))
np.testing.assert_allclose(max_logp_fn(0), np.log(q**n))
np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n))

min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))

0 comments on commit 6554683

Please sign in to comment.