Skip to content

Commit

Permalink
Suppport for discrete max/min
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Sep 10, 2023
1 parent 32b7534 commit efb6158
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
13 changes: 6 additions & 7 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ def find_measurable_max(fgraph: FunctionGraph, node: Node) -> Optional[List[Tens
if base_var.owner.op.dtype.startswith("int"):
if isinstance(base_var.owner.op, RandomVariable):
measurable_max = MeasurableMaxDiscrete(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
else:
return None
else:
measurable_max = MeasurableMax(list(axis))
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv
max_rv_node = measurable_max.make_node(base_var)
max_rv = max_rv_node.outputs

return max_rv


measurable_ir_rewrites_db.register(
Expand Down
6 changes: 3 additions & 3 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def test_min_non_mul_elemwise_fails():

@pytest.mark.parametrize(
"mu, size, value, axis",
[(2, 3, 0.85, -1), (2, 3, 0.01, 0), (1, 2, 0.2, None), (0, 4, 0, 0)],
[(2, 3, 0.85, -1), (2, 3, 1, 0), (1, 2, 2, None), (0, 4, 0, 0)],
)
def test_max_discrete(mu, size, value, axis):
x = pm.Poisson.dist(name="x", mu=mu, size=(size))
Expand All @@ -247,8 +247,8 @@ def test_max_discrete(mu, size, value, axis):
test_value = value

n = size
exp_rv = np.exp(sp.poisson(mu).logcdf(test_value)) ** n
exp_rv_prev = np.exp(sp.poisson(mu).logcdf(test_value - 1)) ** n
exp_rv = sp.poisson(mu).cdf(test_value) ** n
exp_rv_prev = sp.poisson(mu).cdf(test_value - 1) ** n

np.testing.assert_allclose(
np.log(exp_rv - exp_rv_prev),
Expand Down

0 comments on commit efb6158

Please sign in to comment.