From 4d2ac0d8f57620e00788f2f3c0d1184109b184fc Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Tue, 7 Nov 2023 18:33:38 +0530 Subject: [PATCH] Check output is still integer --- pymc/logprob/transforms.py | 6 ++++++ tests/logprob/test_order.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index c73095fa2ba..262374c2482 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -40,6 +40,7 @@ from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +import pytensor import pytensor.tensor as pt from pytensor import scan @@ -670,6 +671,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li if isinstance(node.op, MeasurableVariable): return None # pragma: no cover + pytensor.dprint(node.outputs) + rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None) if rv_map_feature is None: return None # pragma: no cover @@ -681,6 +684,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li return None [measurable_input] = measurable_inputs + [measurable_output] = node.outputs # Do not apply rewrite to discrete variables if measurable_input.type.dtype.startswith("int"): @@ -688,6 +692,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li check_negation(node.op.scalar_op, node.inputs[0]) or isinstance(node.op.scalar_op, Add) ): return None + if not measurable_output.type.dtype.startswith("int"): + return None # Check that other inputs are not potentially measurable, in which case this rewrite # would be invalid diff --git a/tests/logprob/test_order.py b/tests/logprob/test_order.py index a7bbc97f053..b1b59666d82 100644 --- a/tests/logprob/test_order.py +++ b/tests/logprob/test_order.py @@ -278,6 +278,8 @@ def test_min_discrete(mu, n, test_value, axis): rtol=1e-06, ) + assert 0 + def test_min_max_bernoulli(): p = 0.7