Skip to content

Commit

Permalink
Check output is still integer
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhruvanshu-Joshi committed Nov 7, 2023
1 parent c3a538e commit 4d2ac0d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import pytensor
import pytensor.tensor as pt

from pytensor import scan
Expand Down Expand Up @@ -670,6 +671,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
if isinstance(node.op, MeasurableVariable):
return None # pragma: no cover

pytensor.dprint(node.outputs)

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand All @@ -681,13 +684,16 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
return None

[measurable_input] = measurable_inputs
[measurable_output] = node.outputs

# Do not apply rewrite to discrete variables
if measurable_input.type.dtype.startswith("int"):
if not (

Check warning on line 691 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L691

Added line #L691 was not covered by tests
check_negation(node.op.scalar_op, node.inputs[0]) or isinstance(node.op.scalar_op, Add)
):
return None
if not measurable_output.type.dtype.startswith("int"):
return None

Check warning on line 696 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L694-L696

Added lines #L694 - L696 were not covered by tests

# Check that other inputs are not potentially measurable, in which case this rewrite
# would be invalid
Expand Down
2 changes: 2 additions & 0 deletions tests/logprob/test_order.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,8 @@ def test_min_discrete(mu, n, test_value, axis):
rtol=1e-06,
)

assert 0


def test_min_max_bernoulli():
p = 0.7
Expand Down

0 comments on commit 4d2ac0d

Please sign in to comment.