diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 5c88a8ff0e0..1db159d2520 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -688,7 +688,8 @@ def test_discrete_rv_unary_transform_fails(): conditional_logp({y_rv: y_rv.clone()}) -def test_discrete_rv_multinary_transform_fails(): +# add 2 tests. One fir supported and one for unsupported +def test_discrete_rv_multinary_transform(): y_rv = 5 + pt.random.poisson(1) with pytest.raises(RuntimeError, match="could not be derived"): conditional_logp({y_rv: y_rv.clone()})