Skip to content

Commit

Permalink
fixed duplicat test and removed negation from log1mexp logprob conver…
Browse files Browse the repository at this point in the history
…sion
  • Loading branch information
Luke LB committed Sep 7, 2023
1 parent 2882a0e commit 47f08ea
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ def measurable_special_log_to_log(fgraph, node):
if isinstance(node.op.scalar_op, Softplus):
return [pt.log(1 + pt.exp(inp))]
if isinstance(node.op.scalar_op, Log1mexp):
return [pt.log(1 - pt.exp(pt.neg(inp)))]
return [pt.log(1 - pt.exp(inp))]
if isinstance(node.op.scalar_op, Log2):
return [pt.log(inp) / pt.log(2)]
if isinstance(node.op.scalar_op, Log10):
Expand Down
30 changes: 3 additions & 27 deletions tests/logprob/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,8 @@ def test_chained_transform(self):
ArcsinhTransform(),
ArccoshTransform(),
ArctanhTransform(),
LogTransform(),
ExpTransform(),
],
)
def test_check_jac_det(self, transform):
Expand Down Expand Up @@ -1052,32 +1054,6 @@ def test_negative_value_frac_power_transform_logp(self, power):
assert np.isfinite(x_logp_fn(2.5))
assert np.isneginf(x_logp_fn(-2.5))

@pytest.mark.parametrize(
"transform",
[
ErfTransform(),
ErfcTransform(),
ErfcxTransform(),
SinhTransform(),
CoshTransform(),
TanhTransform(),
ArcsinhTransform(),
ArccoshTransform(),
ArctanhTransform(),
LogTransform(),
ExpTransform(),
],
)
def test_check_jac_det(self, transform):
check_jacobian_det(
transform,
Vector(Rplusbig, 2),
pt.dvector,
[0.1, 0.1],
elemwise=True,
rv_var=pt.random.normal(0.5, 1, name="base_rv"),
)


@pytest.mark.parametrize("test_val", (2.5, -2.5))
def test_absolute_rv_transform(test_val):
Expand Down Expand Up @@ -1153,7 +1129,7 @@ def test_cosh_rv_transform():
TRANSFORMATIONS = {
"log1p": (pt.log1p, lambda x: pt.log(1 + x)),
"softplus": (pt.softplus, lambda x: pt.log(1 + pt.exp(x))),
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(pt.neg(x)))),
"log1mexp": (pt.log1mexp, lambda x: pt.log(1 - pt.exp(x))),
"log2": (pt.log2, lambda x: pt.log(x) / pt.log(2)),
"log10": (pt.log10, lambda x: pt.log(x) / pt.log(10)),
"exp2": (pt.exp2, lambda x: pt.exp(pt.log(2) * x)),
Expand Down

0 comments on commit 47f08ea

Please sign in to comment.