From 85a186597e995dffabac8fb6690fb4a374fa5a1a Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 17 Jul 2023 18:39:09 +0100 Subject: [PATCH] generalised expotential function added --- pymc/logprob/transforms.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 535b554e409..93dc4050c96 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -556,19 +556,29 @@ def measurable_special_log_to_log(fgraph, node): return [pt.log(inp) / pt.log(10)] -@node_rewriter([exp2, expm1, sigmoid]) +@node_rewriter([expm1, sigmoid]) def measurable_special_exp_to_exp(fgraph, node): - """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" + """Convert expm1, sigmoid of `MeasurableVariable`s to xp form.""" [inp] = node.inputs - - if isinstance(node.op.scalar_op, Exp2): - return [pt.power(2, inp)] if isinstance(node.op.scalar_op, Expm1): return [pt.exp(inp) - 1] if isinstance(node.op.scalar_op, Sigmoid): return [1 / (1 + pt.exp(-inp))] +@node_rewriter([exp2, pow]) +def measurable_general_exp_to_exp(fgraph, node): + """Convert exp2 and any const^x of `MeasurableVariable`s to exp form.""" + if len(node.inputs) > 1: + [const, inp] = node.inputs + else: + [inp] = node.inputs + if isinstance(node.op.scalar_op, Exp2): + return [pt.exp(pt.log(2) * inp)] + if isinstance(node.op.scalar_op, Pow) and isinstance(inp, pt.TensorVariable): + return [pt.exp(pt.log(const) * inp)] + + @node_rewriter( [ exp, @@ -721,6 +731,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_general_exp_to_exp", + measurable_general_exp_to_exp, + "basic", + "transform", +) + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms,