Skip to content

Commit

Permalink
Fix minimum discrete formula and discrete cdf/icdf transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Nov 7, 2023
1 parent e82253e commit c3a538e
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 16 deletions.
15 changes: 11 additions & 4 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,20 @@ def max_neg_logprob_discrete(op, values, base_rv, **kwargs):
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""

(value,) = values
logcdf = _logcdf_helper(base_rv, value)
logcdf_prev = _logcdf_helper(base_rv, value - 1)

# The cdf of a negative variable is the survival at the negated value
logcdf = pt.log1mexp(_logcdf_helper(base_rv, -value))
logcdf_prev = pt.log1mexp(_logcdf_helper(base_rv, -(value + 1)))

[n] = constant_fold([base_rv.size])

# logprob = logdiffexp(1-n * logcdf_prev, n * logcdf)
logprob = pt.log((1 - pt.exp(logcdf_prev)) ** n - (1 - pt.exp(logcdf)) ** n)
# Now we can use the same expression as the discrete max
logprob = pt.where(
pt.and_(pt.eq(logcdf, -pt.inf), pt.eq(logcdf_prev, -pt.inf)),
-pt.inf,
logdiffexp(n * logcdf_prev, n * logcdf),
)

return logprob
18 changes: 15 additions & 3 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.utils import CheckParameterValue, check_negation, check_potential_measurability
from pymc.logprob.utils import (
CheckParameterValue,
check_negation,
check_potential_measurability,
)


class TransformedVariable(Op):
Expand Down Expand Up @@ -469,6 +473,10 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
return NotImplementedError

backward_value = op.transform_elemwise.backward(value, *other_inputs)

# Fail if transformation is not injective
Expand Down Expand Up @@ -513,6 +521,10 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
other_inputs = list(inputs)
measurable_input = other_inputs.pop(op.measurable_input_idx)

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
return NotImplementedError

if isinstance(op.scalar_op, MONOTONICALLY_INCREASING_OPS):
pass
elif isinstance(op.scalar_op, MONOTONICALLY_DECREASING_OPS):
Expand Down Expand Up @@ -672,8 +684,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
if check_negation(node.op.scalar_op, node.inputs[0]) is False and not isinstance(
node.op.scalar_op, Add
if not (
check_negation(node.op.scalar_op, node.inputs[0]) or isinstance(node.op.scalar_op, Add)
):
return None

Expand Down
34 changes: 25 additions & 9 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import re

import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest
import scipy.stats as sp
Expand Down Expand Up @@ -257,23 +258,38 @@ def test_max_discrete(mu, size, value, axis):


@pytest.mark.parametrize(
"mu, size, value, axis",
"mu, n, test_value, axis",
[(2, 3, 1, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_min_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
def test_min_discrete(mu, n, test_value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(n,))
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

test_value = value
sf_before = 1 - sp.poisson(mu).cdf(test_value - 1)
sf = 1 - sp.poisson(mu).cdf(test_value)

n = size
exp_rv = (1 - sp.poisson(mu).cdf(test_value)) ** n
exp_rv_prev = (1 - sp.poisson(mu).cdf(test_value - 1)) ** n
expected_logp = np.log(sf_before**n - sf**n)

np.testing.assert_allclose(
(np.log(exp_rv_prev - exp_rv)),
(x_min_logprob.eval({x_min_value: (test_value)})),
x_min_logprob.eval({x_min_value: test_value}),
expected_logp,
rtol=1e-06,
)


def test_min_max_bernoulli():
p = 0.7
q = 1 - p
n = 3
x = pm.Bernoulli.dist(p=p, shape=(n,))
value = pt.scalar("value", dtype=int)

max_logp_fn = pytensor.function([value], pm.logp(pt.max(x), value))
np.testing.assert_allclose(max_logp_fn(0), np.log(q**n))
np.testing.assert_allclose(max_logp_fn(1), np.log(1 - q**n))

min_logp_fn = pytensor.function([value], pm.logp(pt.min(x), value))
np.testing.assert_allclose(min_logp_fn(1), np.log(p**n))
np.testing.assert_allclose(min_logp_fn(0), np.log(1 - p**n))
13 changes: 13 additions & 0 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.scan import scan

import pymc as pm

from pymc.distributions.continuous import Cauchy
from pymc.distributions.transforms import _default_transform, log, logodds
from pymc.logprob.abstract import MeasurableVariable, _logprob
Expand Down Expand Up @@ -1262,3 +1264,14 @@ def test_invalid_broadcasted_transform_rv_fails():
# This logp derivation should fail or count only once the values that are broadcasted
logprob = logp(y_rv, y_vv)
assert logprob.eval({y_vv: [0, 0, 0, 0], loc: [0, 0, 0, 0]}).shape == ()


def test_discrete_measurable_cdf_icdf():
p = 0.7
rv = -pm.Bernoulli.dist(p=p)

# A negated Bernoulli has pmf {p if x == -1; 1-p if x == 0; 0 otherwise}
assert pm.logp(rv, -2).eval() == -np.inf # Correct
assert pm.logp(rv, -1).eval() == np.log(p) # Correct
assert pm.logp(rv, 0).eval() == np.log(1 - p) # Correct
assert pm.logp(rv, 1).eval() == -np.inf # Correct

0 comments on commit c3a538e

Please sign in to comment.