Skip to content

Commit

Permalink
Reject multivarate and nonrvs for logp of max
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Jul 17, 2023
1 parent 0c9f112 commit 9f9f16b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
7 changes: 2 additions & 5 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if not rv_map_feature.request_measurable(node.inputs):
return None

Check warning on line 79 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L78-L79

Added lines #L78 - L79 were not covered by tests

# Non-univariate non-RVs must be rejected
if (
not isinstance(base_var.owner.op, RandomVariable)
and base_var.owner.inputs[0].owner.op.ndim_supp != 0
):
# Non-univariate distributions and non-RVs must be rejected
if not (isinstance(base_var.owner.op, RandomVariable) and base_var.owner.op.ndim_supp == 0):
return None

Check warning on line 83 in pymc/logprob/order.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/order.py#L82-L83

Added lines #L82 - L83 were not covered by tests

# TODO: We are currently only supporting continuous rvs
Expand Down
11 changes: 11 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,17 @@ def test_max_non_rv_fails():
x_max_logprob = logp(x_max, x_max_value)


def test_max_multivariate_rv_fails():
_alpha = pt.scalar()
_k = pt.iscalar()
x = pm.StickBreakingWeights.dist(_alpha, _k)
x.name = "x"
x_max = pt.max(x, axis=-1)
x_max_value = pt.vector("x_max_value")
with pytest.raises(RuntimeError, match=re.escape("Logprob method not implemented")):
x_max_logprob = logp(x_max, x_max_value)


def test_max_categorical():
"""Test whether the logprob for ```pt.max``` for unsupported distributions is correctly rejected"""
x = pm.Categorical.dist([1, 1, 1, 1], shape=(5,))
Expand Down

0 comments on commit 9f9f16b

Please sign in to comment.