Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logprob derivation of Min for Discrete IID distributions #6968

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand All @@ -56,6 +53,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import find_negated_var
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand Down Expand Up @@ -168,6 +166,13 @@
MeasurableVariable.register(MeasurableMaxNeg)


class MeasurableDiscreteMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for sub-graphs of negative maxima of discrete variables"""


MeasurableVariable.register(MeasurableDiscreteMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
Expand All @@ -180,37 +185,20 @@

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
if not (base_var.owner is not None and isinstance(base_var.owner.op, Elemwise)):
return None

base_rv = find_negated_var(base_var)

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
if base_rv is None:
return None

base_rv = base_var.owner.inputs[0]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
Expand All @@ -222,11 +210,16 @@
if axis != base_var_dims:
return None

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs
if not rv_map_feature.request_measurable([base_rv]):
return None

Check warning on line 214 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L214

Added line #L214 was not covered by tests

return min_rv
# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
measurable_min = MeasurableDiscreteMaxNeg(list(axis))
else:
measurable_min = MeasurableMaxNeg(list(axis))

return measurable_min.make_node(base_rv).outputs


measurable_ir_rewrites_db.register(
Expand All @@ -238,14 +231,13 @@


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_var, **kwargs):
def max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values
base_rv = base_var.owner.inputs[0]

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)
Expand All @@ -254,3 +246,31 @@
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob


@_logprob.register(MeasurableDiscreteMaxNeg)
def discrete_max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.

The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""

(value,) = values

# The cdf of a negative variable is the survival at the negated value
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))

[n] = constant_fold([base_rv.size])

# Now we can use the same expression as the discrete max
logprob = pt.where(
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
-pt.inf,
logdiffexp(n * logcdf_prev, n * logcdf),
)

return logprob
23 changes: 20 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,11 @@
_logprob_helper,
)
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
find_negated_var,
)


class Transform(abc.ABC):
Expand Down Expand Up @@ -229,6 +233,10 @@
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
raise NotImplementedError("logcdf of transformed discrete variables not implemented")

backward_value = op.transform_elemwise.backward(value, *other_inputs)

# Fail if transformation is not injective
Expand Down Expand Up @@ -273,6 +281,10 @@
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
raise NotImplementedError("icdf of transformed discrete variables not implemented")

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
Expand Down Expand Up @@ -429,10 +441,15 @@
return None

[measurable_input] = measurable_inputs
[measurable_output] = node.outputs

# Do not apply rewrite to discrete variables
# Do not apply rewrite to discrete variables except for their addition and negation
if measurable_input.type.dtype.startswith("int"):
return None
if not (find_negated_var(measurable_output) or isinstance(node.op.scalar_op, Add)):
return None
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
# Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
if not measurable_output.type.dtype.startswith("int"):
Dhruvanshu-Joshi marked this conversation as resolved.
Show resolved Hide resolved
return None

Check warning on line 452 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L452

Added line #L452 was not covered by tests

# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
Expand Down
27 changes: 27 additions & 0 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -295,3 +299,26 @@
(const_value,) = inputs
values, const_value = pt.broadcast_arrays(values, const_value)
return pt.switch(pt.isclose(values, const_value, rtol=op.rtol, atol=op.atol), 0.0, -np.inf)


def find_negated_var(var):
"""Return a variable that is being multiplied by -1 or None otherwise."""

if (
not (var.owner)
and isinstance(var.owner.op, Elemwise)
and isinstance(var.owner.op.scalar_op, Mul)
):
return None

Check warning on line 312 in pymc/logprob/utils.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/utils.py#L312

Added line #L312 was not covered by tests
if len(var.owner.inputs) != 2:
return None

inputs = var.owner.inputs
for mul_var, mul_const in (inputs, reversed(inputs)):
try:
if get_underlying_scalar_constant_value(mul_const) == -1:
return mul_var
except NotScalarConstantError:
continue

return None
39 changes: 39 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import re

import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.stats as sp
Expand Down Expand Up @@ -254,3 +255,41 @@ def test_max_discrete(mu, size, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)


@pytest.mark.parametrize(
"mu, n, test_value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_min_discrete(mu, n, test_value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
sf = 1 - sp.poisson(mu).cdf(test_value)

expected_logp = np.log(sf_before**n - sf**n)

np.testing.assert_allclose(
x_min_logprob.eval({x_min_value: test_value}),
expected_logp,
rtol=1e-06,
)
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved


def test_min_max_bernoulli():
p = 0.7
q = 1 - p
n = 3
x = pm.Bernoulli.dist(name="x", p=p, shape=(n,))
value = pt.scalar("value", dtype=int)

max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value))
np.testing.assert_allclose(max_logp_fn(0), np.log(q**n))
np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n))

min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))
42 changes: 38 additions & 4 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytensor.graph.basic import equal_computations

from pymc.distributions.continuous import Cauchy, ChiSquared
from pymc.distributions.discrete import Bernoulli
from pymc.logprob.basic import conditional_logp, icdf, logcdf, logp
from pymc.logprob.transforms import (
ArccoshTransform,
Expand Down Expand Up @@ -680,18 +681,51 @@ def test_multivariate_rv_transform(shift, scale):
)


def test_discrete_rv_unary_transform_fails():
def test_not_implemented_discrete_rv_transform():
y_rv = pt.exp(pt.random.poisson(1))
with pytest.raises(RuntimeError, match="could not be derived"):
conditional_logp({y_rv: y_rv.clone()})


def test_discrete_rv_multinary_transform_fails():
y_rv = 5 + pt.random.poisson(1)
y_rv = 5 * pt.random.poisson(1)
with pytest.raises(RuntimeError, match="could not be derived"):
conditional_logp({y_rv: y_rv.clone()})


def test_negated_discrete_rv_transform():
p = 0.7
rv = -Bernoulli.dist(p=p)
vv = rv.type()
logp_fn = pytensor.function([vv], logp(rv, vv))

# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
assert logp_fn(-2) == -np.inf
np.testing.assert_allclose(logp_fn(-1), np.log(p))
np.testing.assert_allclose(logp_fn(0), np.log(1 - p))
assert logp_fn(1) == -np.inf

# Logcdf and icdf not supported yet
for func in (logcdf, icdf):
with pytest.raises(NotImplementedError):
func(rv, 0)


def test_shifted_discrete_rv_transform():
p = 0.7
rv = Bernoulli.dist(p=p) + 5
vv = rv.type()
rv_logp_fn = pytensor.function([vv], logp(rv, vv))

assert rv_logp_fn(4) == -np.inf
np.testing.assert_allclose(rv_logp_fn(5), np.log(1 - p))
np.testing.assert_allclose(rv_logp_fn(6), np.log(p))
assert rv_logp_fn(7) == -np.inf

# Logcdf and icdf not supported yet
for func in (logcdf, icdf):
with pytest.raises(NotImplementedError):
func(rv, 0)


@pytest.mark.xfail(reason="Check not implemented yet")
def test_invalid_broadcasted_transform_rv_fails():
loc = pt.vector("loc")
Expand Down
Loading