Skip to content

Commit

Permalink
Incorporating suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Aug 18, 2023
1 parent fea004c commit 086f7fe
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 56 deletions.
24 changes: 17 additions & 7 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Mul
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.var import TensorVariable
Expand Down Expand Up @@ -130,14 +132,15 @@ def max_logprob(op, values, base_rv, **kwargs):


class MeasurableMaxNeg(Max):
"""A placeholder used to specify a log-likelihood for a min sub-graph."""
"""A placeholder used to specify a log-likelihood for a max(neg(x)) sub-graph.
This shows up in the graph of min, which is (neg(max(neg(x)))."""


MeasurableVariable.register(MeasurableMaxNeg)


@node_rewriter(tracks=[Max])
def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
def find_measurable_max_neg(fgraph: FunctionGraph, node: Node) -> Optional[List[TensorVariable]]:
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)

if rv_map_feature is None:
Expand All @@ -154,12 +157,19 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if not rv_map_feature.request_measurable(node.inputs):
return None

# Min is the Max of the negation of the same distribution. Hence, op must be Elemiwise
# Min is the Max of the negation of the same distribution. Hence, op must be Elemwise
if not isinstance(base_var.owner.op, Elemwise):
return None

# negation is -1*(rv). Hence the scalar_op must be Mul
if not isinstance(base_var.owner.op.scalar_op, Mul):
# negation is rv * (-1). Hence the scalar_op must be Mul
try:
if not (
isinstance(base_var.owner.op.scalar_op, Mul)
and len(base_var.owner.inputs) == 2
and get_underlying_scalar_constant_value(base_var.owner.inputs[1]) == -1
):
return None
except NotScalarConstantError:
return None

Check warning on line 173 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L172-L173

Added lines #L172 - L173 were not covered by tests

base_rv = base_var.owner.inputs[0]
Expand Down Expand Up @@ -191,8 +201,8 @@ def find_measurable_min(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens


measurable_ir_rewrites_db.register(
"find_measurable_min",
find_measurable_min,
"find_measurable_max_neg",
find_measurable_max_neg,
"basic",
"min",
)
Expand Down
81 changes: 32 additions & 49 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,104 +58,87 @@ def test_argmax():


@pytest.mark.parametrize(
"if_max",
"pt_op",
[
True,
False,
pt.max,
pt.min,
],
)
def test_non_iid_fails(if_max):
def test_non_iid_fails(pt_op):
"""Test whether the logprob for ```pt.max``` or ```pt.min``` for non i.i.d is correctly rejected"""
x = pm.Normal.dist([0, 1, 2, 3, 4], 1, shape=(5,))
x.name = "x"
if if_max == True:
x_m = pt.max(x, axis=-1)
x_m_value = pt.vector("x_max_value")
else:
x_m = pt.min(x, axis=-1)
x_m_value = pt.vector("x_min_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"if_max",
[True, False],
"pt_op",
[
pt.max,
pt.min,
],
)
def test_non_rv_fails(if_max):
def test_non_rv_fails(pt_op):
"""Test whether the logprob for ```pt.max``` for non-RVs is correctly rejected"""
x = pt.exp(pt.random.beta(0, 1, size=(3,)))
x.name = "x"
if if_max == True:
x_m = pt.max(x, axis=-1)
x_m_value = pt.vector("x_max_value")
else:
x_m = pt.min(x, axis=-1)
x_m_value = pt.vector("x_min_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"if_max",
"pt_op",
[
True,
False,
pt.max,
pt.min,
],
)
def test_multivariate_rv_fails(if_max):
def test_multivariate_rv_fails(pt_op):
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
if if_max == True:
x_m = pt.max(x, axis=-1)
x_m_value = pt.vector("x_max_value")
else:
x_m = pt.min(x, axis=-1)
x_m_value = pt.vector("x_min_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"if_max",
"pt_op",
[
True,
False,
pt.max,
pt.min,
],
)
def test_categorical(if_max):
def test_categorical(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
x.name = "x"
if if_max == True:
x_m = pt.max(x, axis=-1)
x_m_value = pt.vector("x_max_value")
else:
x_m = pt.min(x, axis=-1)
x_m_value = pt.vector("x_min_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)


@pytest.mark.parametrize(
"if_max",
"pt_op",
[
True,
False,
pt.max,
pt.min,
],
)
def test_non_supp_axis(if_max):
def test_non_supp_axis(pt_op):
"""Test whether the logprob for ```pt.max``` for unsupported axis is correctly rejected"""
x = pt.random.normal(0, 1, size=(3, 3))
x.name = "x"
if if_max == True:
x_m = pt.max(x, axis=-1)
x_m_value = pt.vector("x_max_value")
else:
x_m = pt.min(x, axis=-1)
x_m_value = pt.vector("x_min_value")
x_m = pt_op(x, axis=-1)
x_m_value = pt.vector("x_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_m, x_m_value)

Expand Down

0 comments on commit 086f7fe

Please sign in to comment.