From 47f08ea1734da6a3c558a6756d60b850139eb532 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Thu, 7 Sep 2023 13:27:44 +0100 Subject: [PATCH] fixed duplicat test and removed negation from log1mexp logprob conversion --- pymc/logprob/transforms.py | 2 +- tests/logprob/test_transforms.py | 30 +++--------------------------- 2 files changed, 4 insertions(+), 28 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index adb0ef0e79..b5b528ea44 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -595,7 +595,7 @@ def measurable_special_log_to_log(fgraph, node): if isinstance(node.op.scalar_op, Softplus): return [pt.log(1 + pt.exp(inp))] if isinstance(node.op.scalar_op, Log1mexp): - return [pt.log(1 - pt.exp(pt.neg(inp)))] + return [pt.log(1 - pt.exp(inp))] if isinstance(node.op.scalar_op, Log2): return [pt.log(inp) / pt.log(2)] if isinstance(node.op.scalar_op, Log10): diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index f73b5709ab..dfb9bc8770 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -746,6 +746,8 @@ def test_chained_transform(self): ArcsinhTransform(), ArccoshTransform(), ArctanhTransform(), + LogTransform(), + ExpTransform(), ], ) def test_check_jac_det(self, transform): @@ -1052,32 +1054,6 @@ def test_negative_value_frac_power_transform_logp(self, power): assert np.isfinite(x_logp_fn(2.5)) assert np.isneginf(x_logp_fn(-2.5)) - @pytest.mark.parametrize( - "transform", - [ - ErfTransform(), - ErfcTransform(), - ErfcxTransform(), - SinhTransform(), - CoshTransform(), - TanhTransform(), - ArcsinhTransform(), - ArccoshTransform(), - ArctanhTransform(), - LogTransform(), - ExpTransform(), - ], - ) - def test_check_jac_det(self, transform): - check_jacobian_det( - transform, - Vector(Rplusbig, 2), - pt.dvector, - [0.1, 0.1], - elemwise=True, - rv_var=pt.random.normal(0.5, 1, name="base_rv"), - ) - @pytest.mark.parametrize("test_val", (2.5, -2.5)) def test_absolute_rv_transform(test_val): @@ -1153,7 +1129,7 @@ def test_cosh_rv_transform(): TRANSFORMATIONS = { "log1p": (pt.log1p, lambda x: pt.log(1 + x)), "softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))), - "log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(pt.neg(x)))), + "log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))), "log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)), "log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)), "exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),