Skip to content

Commit

Permalink
Logprob derivation for Min of continuous IID variables (#6846)
Browse files Browse the repository at this point in the history
Co-authored-by: Ricardo Vieira <[email protected]>
  • Loading branch information
Dhruvanshu-Joshi and ricardoV94 authored Sep 4, 2023
1 parent f249f12 commit 6d2a289
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 23 deletions.
101 changes: 100 additions & 1 deletion pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -122,7 +126,102 @@ def max_logprob(op, values, base_rv, **kwargs):
logcdf = _logcdf_helper(base_rv, value)

[n] = constant_fold([base_rv.size])

logprob = (n - 1) * logcdf + logprob + pt.math.log(n)

return logprob


class MeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


MeasurableVariable.register(MeasurableMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
return None

base_rv = base_var.owner.inputs[0]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

# TODO: We are currently only supporting continuous rvs
if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs

return min_rv


measurable_ir_rewrites_db.register(
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_var, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values
base_rv = base_var.owner.inputs[0]

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)

[n] = constant_fold([base_rv.size])
logprob = (n - 1) * pt.math.log(1 - pt.math.exp(logcdf)) + logprob + pt.math.log(n)

return logprob
127 changes: 105 additions & 22 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import pymc as pm

from pymc import logp
from pymc.logprob import conditional_logp
from pymc.testing import assert_no_rvs


Expand All @@ -58,55 +57,90 @@ def test_argmax():
x_max_logprob = logp(x_max, x_max_value)


def test_max_non_iid_fails():
"""Test whether the logprob for ```pt.max``` for non i.i.d is correctly rejected"""
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_iid_fails(pt_op):
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_non_rv_fails():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_rv_fails(pt_op):
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_multivariate_rv_fails():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_multivariate_rv_fails(pt_op):
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_max_categorical():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_categorical(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


def test_non_supp_axis_max():
@pytest.mark.parametrize(
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_supp_axis(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
x = pt.random.normal(0, 1, size=(3, 3))
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -147,3 +181,52 @@ def test_max_logprob(shape, value, axis):
(x_max_logprob.eval({x_max_value: test_value})),
rtol=1e-06,
)


@pytest.mark.parametrize(
"shape, value, axis",
[
(3, 0.85, -1),
(3, 0.01, 0),
(2, 0.2, None),
(4, 0.5, 0),
((3, 4), 0.9, None),
((3, 4), 0.75, (1, 0)),
],
)
def test_min_logprob(shape, value, axis):
"""Test whether the logprob for ```pt.mix``` produces the corrected
The fact that order statistics of i.i.d. uniform RVs ~ Beta is used here:
U_1, \\dots, U_n \\stackrel{\text{i.i.d.}}{\\sim} \text{Uniform}(0, 1) \\Rightarrow U_{(k)} \\sim \text{Beta}(k, n + 1- k)
for all 1<=k<=n
"""
x = pt.random.uniform(0, 1, size=shape)
x.name = "x"
x_min = pt.min(x, axis=axis)
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

assert_no_rvs(x_min_logprob)

test_value = value

n = np.prod(shape)
beta_rv = pt.random.beta(1, n, name="beta")
beta_vv = beta_rv.clone()
beta_rv_logprob = logp(beta_rv, beta_vv)

np.testing.assert_allclose(
beta_rv_logprob.eval({beta_vv: test_value}),
(x_min_logprob.eval({x_min_value: test_value})),
rtol=1e-06,
)


def test_min_non_mul_elemwise_fails():
"""Test whether the logprob for ```pt.min``` for non-mul elemwise RVs is rejected correctly"""
x = pt.log(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
x_min = pt.min(x, axis=-1)
x_min_value = pt.vector("x_min_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_min_logprob = logp(x_min, x_min_value)

0 comments on commit 6d2a289

Please sign in to comment.