Skip to content

Commit

Permalink
Added suggested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Nov 7, 2023
1 parent df05946 commit e82253e
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 98 deletions.
110 changes: 18 additions & 92 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,7 @@
from pytensor.graph.basic import Node
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable
Expand All @@ -56,6 +53,7 @@
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import check_negation
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold

Expand Down Expand Up @@ -187,36 +185,28 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[

base_var = node.inputs[0]

if base_var.owner is None:
return None

if not rv_map_feature.request_measurable(node.inputs):
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if base_var.owner is None or not isinstance(base_var.owner.op, Elemwise):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
if len(base_var.owner.inputs) == 2:
if base_var.owner.inputs[0] is None:
base_rv = base_var.owner.inputs[1]
scalar_constant = base_var.owner.inputs[0]
else:
base_rv = base_var.owner.inputs[0]
scalar_constant = base_var.owner.inputs[1]
else:
return None

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
if check_negation(base_var.owner.op.scalar_op, scalar_constant) is False:
return None

base_rv = base_var.owner.inputs[0]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

if isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.dtype.startswith("int"):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
Expand All @@ -228,65 +218,9 @@ def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[
if axis != base_var_dims:
return None

measurable_min = MeasurableMaxNeg(list(axis))
min_rv_node = measurable_min.make_node(base_var)
min_rv = min_rv_node.outputs

return min_rv


@node_rewriter(tracks=[Max])
def find_measurable_max_neg_discrete(
fgraph: FunctionGraph, node: Node
) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
return None # pragma: no cover

if isinstance(node.op, MeasurableMaxNeg):
return None # pragma: no cover

base_var = node.inputs[0]

if base_var.owner is None:
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None

# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[0]) == -1
):
return None
except NotScalarConstantError:
return None

base_rv = base_var.owner.inputs[1]

# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_rv.owner.op, RandomVariable) and base_rv.owner.op.ndim_supp == 0):
return None

if not rv_map_feature.request_measurable([base_rv]):
return None

# univariate i.i.d. test which also rules out other distributions
for params in base_rv.owner.inputs[3:]:
if params.type.ndim != 0:
return None

# Check whether axis is supported or not
axis = set(node.op.axis)
base_var_dims = set(range(base_var.ndim))
if axis != base_var_dims:
return None

# distinguish measurable discrete and continuous (because logprob is different)
if base_rv.owner.op.dtype.startswith("int"):
if isinstance(base_rv.owner.op, RandomVariable):
Expand All @@ -296,7 +230,7 @@ def find_measurable_max_neg_discrete(
else:
measurable_min = MeasurableMaxNeg(list(axis))

min_rv_node = measurable_min.make_node(base_var)
min_rv_node = measurable_min.make_node(base_rv)
min_rv = min_rv_node.outputs
return min_rv

Expand All @@ -305,27 +239,18 @@ def find_measurable_max_neg_discrete(
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)


measurable_ir_rewrites_db.register(
"find_measurable_max_neg_discrete",
find_measurable_max_neg_discrete,
"basic",
"min_discrete",
)


@_logprob.register(MeasurableMaxNeg)
def max_neg_logprob(op, values, base_var, **kwargs):
def max_neg_logprob(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
\ln(f_{(n)}(x)) = \ln(n) + (n-1) \ln(1 - F(x)) + \ln(f(x))
where f(x) represents the p.d.f and F(x) represents the c.d.f of the distribution respectively.
"""
(value,) = values
base_rv = base_var.owner.inputs[0]

logprob = _logprob_helper(base_rv, -value)
logcdf = _logcdf_helper(base_rv, -value)
Expand All @@ -337,12 +262,12 @@ def max_neg_logprob(op, values, base_var, **kwargs):


@_logprob.register(MeasurableMaxNegDiscrete)
def maxneg_logprob_discrete(op, values, base_rv, **kwargs):
def max_neg_logprob_discrete(op, values, base_rv, **kwargs):
r"""Compute the log-likelihood graph for the `Max` operation.
The formula that we use here is :
.. math::
\ln(P_{(n)}(x)) = \ln(F(x)^n - F(x-1)^n)
\ln(P_{(n)}(x)) = \ln((1 - F(x - 1))^n - (1 - F(x))^n)
where $P_{(n)}(x)$ represents the p.m.f of the maximum statistic and $F(x)$ represents the c.d.f of the i.i.d. variables.
"""
(value,) = values
Expand All @@ -351,6 +276,7 @@ def maxneg_logprob_discrete(op, values, base_rv, **kwargs):

[n] = constant_fold([base_rv.size])

logprob = logdiffexp(n * logcdf_prev, n * logcdf)
# logprob = logdiffexp(1-n * logcdf_prev, n * logcdf)
logprob = pt.log((1 - pt.exp(logcdf_prev)) ** n - (1 - pt.exp(logcdf)) ** n)

return logprob
6 changes: 4 additions & 2 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
cleanup_ir_rewrites_db,
measurable_ir_rewrites_db,
)
from pymc.logprob.utils import CheckParameterValue, check_potential_measurability
from pymc.logprob.utils import CheckParameterValue, check_negation, check_potential_measurability


class TransformedVariable(Op):
Expand Down Expand Up @@ -672,7 +672,9 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
if str(node.op) != "Mul" and str(node.op) != "Add":
if check_negation(node.op.scalar_op, node.inputs[0]) is False and not isinstance(
node.op.scalar_op, Add
):
return None

# Check that other inputs are not potentially measurable, in which case this rewrite
Expand Down
18 changes: 18 additions & 0 deletions pymc/logprob/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
from pytensor.graph.op import HasInnerGraph
from pytensor.link.c.type import CType
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.variable import TensorVariable

Expand Down Expand Up @@ -311,3 +314,18 @@ def expand(r):
for node in walk(makeiter(vars), expand, False)
if node.owner and isinstance(node.owner.op, (RandomVariable, MeasurableVariable))
}


def check_negation(scalar_op, scalar_constant):
"""Make sure that the base variable invovles a multiplication with -1"""

try:
if not (
isinstance(scalar_op, Mul)
and get_underlying_scalar_constant_value(scalar_constant) == -1
):
return False
except NotScalarConstantError:
return False

return True
8 changes: 4 additions & 4 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,14 @@ def test_max_discrete(mu, size, value, axis):
def test_min_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
x_min = pt.min(x, axis=axis)
x_min_value = pt.vector("x_min_value")
x_min_value = pt.scalar("x_min_value")
x_min_logprob = logp(x_min, x_min_value)

test_value = [value]
test_value = value

n = size
exp_rv = sp.poisson(mu).cdf(test_value[0]) ** n
exp_rv_prev = sp.poisson(mu).cdf(test_value[0] - 1) ** n
exp_rv = (1 - sp.poisson(mu).cdf(test_value)) ** n
exp_rv_prev = (1 - sp.poisson(mu).cdf(test_value - 1)) ** n

np.testing.assert_allclose(
(np.log(exp_rv_prev - exp_rv)),
Expand Down

0 comments on commit e82253e

Please sign in to comment.