Skip to content

Commit

Permalink
reverted back to just allowing exp2
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke LB committed Jul 17, 2023
1 parent 85a1865 commit 7e4909a
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def measurable_sub_to_neg(fgraph, node):

@node_rewriter([log1p, softplus, log1mexp, log2, log10])
def measurable_special_log_to_log(fgraph, node):
"""Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form."""
"""Convert log1p, log1mexp, softplus, log2, log10 of `MeasurableVariable`s to log form."""
[inp] = node.inputs

if isinstance(node.op.scalar_op, Log1p):
Expand All @@ -556,29 +556,18 @@ def measurable_special_log_to_log(fgraph, node):
return [pt.log(inp) / pt.log(10)]


@node_rewriter([expm1, sigmoid])
@node_rewriter([expm1, sigmoid, exp2])
def measurable_special_exp_to_exp(fgraph, node):
"""Convert expm1, sigmoid of `MeasurableVariable`s to xp form."""
"""Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form."""
[inp] = node.inputs
if isinstance(node.op.scalar_op, Exp2):
return [pt.exp(pt.log(2) * inp)]
if isinstance(node.op.scalar_op, Expm1):
return [pt.exp(inp) - 1]
if isinstance(node.op.scalar_op, Sigmoid):
return [1 / (1 + pt.exp(-inp))]


@node_rewriter([exp2, pow])
def measurable_general_exp_to_exp(fgraph, node):
"""Convert exp2 and any const^x of `MeasurableVariable`s to exp form."""
if len(node.inputs) > 1:
[const, inp] = node.inputs
else:
[inp] = node.inputs
if isinstance(node.op.scalar_op, Exp2):
return [pt.exp(pt.log(2) * inp)]
if isinstance(node.op.scalar_op, Pow) and isinstance(inp, pt.TensorVariable):
return [pt.exp(pt.log(const) * inp)]


@node_rewriter(
[
exp,
Expand Down

0 comments on commit 7e4909a

Please sign in to comment.