Skip to content

Commit

Permalink
generalised expotential function added
Browse files Browse the repository at this point in the history
  • Loading branch information
Luke LB committed Jul 17, 2023
1 parent 48166a2 commit 85a1865
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,19 +556,29 @@ def measurable_special_log_to_log(fgraph, node):
return [pt.log(inp) / pt.log(10)]


@node_rewriter([exp2, expm1, sigmoid])
@node_rewriter([expm1, sigmoid])
def measurable_special_exp_to_exp(fgraph, node):
"""Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form."""
"""Convert expm1, sigmoid of `MeasurableVariable`s to xp form."""
[inp] = node.inputs

Check warning on line 562 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L553-L562

Added lines #L553 - L562 were not covered by tests

if isinstance(node.op.scalar_op, Exp2):
return [pt.power(2, inp)]
if isinstance(node.op.scalar_op, Expm1):
return [pt.exp(inp) - 1]
if isinstance(node.op.scalar_op, Sigmoid):
return [1 / (1 + pt.exp(-inp))]


@node_rewriter([exp2, pow])
def measurable_general_exp_to_exp(fgraph, node):
"""Convert exp2 and any const^x of `MeasurableVariable`s to exp form."""
if len(node.inputs) > 1:

Check warning on line 572 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L568-L572

Added lines #L568 - L572 were not covered by tests
[const, inp] = node.inputs
else:
[inp] = node.inputs
if isinstance(node.op.scalar_op, Exp2):
return [pt.exp(pt.log(2) * inp)]
if isinstance(node.op.scalar_op, Pow) and isinstance(inp, pt.TensorVariable):
return [pt.exp(pt.log(const) * inp)]

Check warning on line 579 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L578-L579

Added lines #L578 - L579 were not covered by tests


Check warning on line 581 in pymc/logprob/transforms.py

View check run for this annotation

Codecov / codecov/patch

pymc/logprob/transforms.py#L581

Added line #L581 was not covered by tests
@node_rewriter(
[
exp,
Expand Down Expand Up @@ -721,6 +731,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
"transform",
)

measurable_ir_rewrites_db.register(
"measurable_general_exp_to_exp",
measurable_general_exp_to_exp,
"basic",
"transform",
)

measurable_ir_rewrites_db.register(
"find_measurable_transforms",
find_measurable_transforms,
Expand Down

0 comments on commit 85a1865

Please sign in to comment.