From 738d6de21fd1409736a991319f35d9558817760c Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 25 Jun 2023 13:39:00 +0100 Subject: [PATCH 01/11] LogTransform with arb base n now works --- pymc/logprob/transforms.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index cbd613abdcc..d48a25ae3d7 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -83,6 +83,8 @@ erfcx, exp, log, + log2, + log10, mul, neg, pow, @@ -523,6 +525,23 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] +@node_rewriter([log, log2, log10]) +def measurable_logs_to_logn(fgraph, node): + """Convert logrithm funtions involving `MeasurableVariable`s to logarithm base n""" + [inp] = node.inputs + scalar_op = node.op.scalar_op + + def logn(input, base): + return pt.log(input) / pt.log(base) + + if isinstance(scalar_op, log): + return [logn(inp, np.exp)] + if isinstance(scalar_op, log2): + return [logn(inp, 2)] + if isinstance(scalar_op, log10): + return [logn(inp, 10)] + + @node_rewriter( [exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx] ) @@ -808,11 +827,15 @@ def log_jac_det(self, value, *inputs): class LogTransform(RVTransform): name = "log" + def __init__(self, base=pt.exp(1)): + self.base = base + super().__init__() + def forward(self, value, *inputs): - return pt.log(value) + return pt.log(value) / pt.log(self.base) def backward(self, value, *inputs): - return pt.exp(value) + return pt.power(self.base, value) def log_jac_det(self, value, *inputs): return value From 892c8d169a22c4618d030596d543720d9db2fefd Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 25 Jun 2023 20:47:51 +0100 Subject: [PATCH 02/11] log2 and log10 now working --- pymc/logprob/transforms.py | 48 +++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index d48a25ae3d7..210daec0c8b 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -62,6 +62,8 @@ Erfcx, Exp, Log, + Log2, + Log10, Mul, Pow, Sinh, @@ -380,6 +382,8 @@ class MeasurableTransform(MeasurableElemwise): valid_scalar_types = ( Exp, Log, + Log2, + Log10, Add, Mul, Pow, @@ -525,25 +529,26 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([log, log2, log10]) -def measurable_logs_to_logn(fgraph, node): - """Convert logrithm funtions involving `MeasurableVariable`s to logarithm base n""" - [inp] = node.inputs - scalar_op = node.op.scalar_op - - def logn(input, base): - return pt.log(input) / pt.log(base) - - if isinstance(scalar_op, log): - return [logn(inp, np.exp)] - if isinstance(scalar_op, log2): - return [logn(inp, 2)] - if isinstance(scalar_op, log10): - return [logn(inp, 10)] - - @node_rewriter( - [exp, log, add, mul, pow, abs, sinh, cosh, tanh, arcsinh, arccosh, arctanh, erf, erfc, erfcx] + [ + exp, + log, + log2, + log10, + add, + mul, + pow, + abs, + sinh, + cosh, + tanh, + arcsinh, + arccosh, + arctanh, + erf, + erfc, + erfcx, + ] ) def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[List[Node]]: """Find measurable transformations from Elemwise operators.""" @@ -582,7 +587,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform_dict = { Exp: ExpTransform(), - Log: LogTransform(), Abs: AbsTransform(), Sinh: SinhTransform(), Cosh: CoshTransform(), @@ -612,6 +616,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) + elif isinstance(scalar_op, Log): + transform = LogTransform() + elif isinstance(scalar_op, Log2): + transform = LogTransform(base=2) + elif isinstance(scalar_op, Log10): + transform = LogTransform(base=10) elif transform is None: transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( From 3556150a854cdf64026eb2d56ceb4862c47f7da2 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 25 Jun 2023 22:01:22 +0100 Subject: [PATCH 03/11] log1p, softplus, and log1mexp now working --- pymc/logprob/transforms.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 210daec0c8b..3e99e9a234b 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -62,11 +62,14 @@ Erfcx, Exp, Log, + Log1mexp, + Log1p, Log2, Log10, Mul, Pow, Sinh, + Softplus, Sqr, Sqrt, Tanh, @@ -85,6 +88,8 @@ erfcx, exp, log, + log1mexp, + log1p, log2, log10, mul, @@ -92,6 +97,7 @@ pow, reciprocal, sinh, + softplus, sqr, sqrt, sub, @@ -384,6 +390,9 @@ class MeasurableTransform(MeasurableElemwise): Log, Log2, Log10, + Log1p, + Softplus, + Log1mexp, Add, Mul, Pow, @@ -529,6 +538,19 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] +@node_rewriter([log1p, softplus, log1mexp]) +def measurable_special_log_to_log(fgraph, node): + """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" + [inp] = node.inputs + + if isinstance(node.op.scalar_op, Log1p): + return [pt.log(1 + inp)] + if isinstance(node.op.scalar_op, Softplus): + return [pt.log(1 + pt.exp(inp))] + if isinstance(node.op.scalar_op, Log1p): + return [pt.log(1 - pt.exp(pt.neg(inp)))] + + @node_rewriter( [ exp, @@ -675,6 +697,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_special_log_to_log", + measurable_special_log_to_log, + "basic", + "transform", +) + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms, From a96b6c45f4104db553a2bea36c1da345c18f01fc Mon Sep 17 00:00:00 2001 From: Luke LB Date: Tue, 11 Jul 2023 18:33:23 +0100 Subject: [PATCH 04/11] updated tests but failing --- pymc/logprob/transforms.py | 58 ++++++++++++++++++++++++++------ tests/logprob/test_transforms.py | 13 ++++++- 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3e99e9a234b..42f3ade1670 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -61,6 +61,8 @@ Erfc, Erfcx, Exp, + Exp2, + Expm1, Log, Log1mexp, Log1p, @@ -68,6 +70,7 @@ Log10, Mul, Pow, + Sigmoid, Sinh, Softplus, Sqr, @@ -87,6 +90,8 @@ erfc, erfcx, exp, + exp2, + expm1, log, log1mexp, log1p, @@ -96,6 +101,7 @@ neg, pow, reciprocal, + sigmoid, sinh, softplus, sqr, @@ -387,6 +393,9 @@ class MeasurableTransform(MeasurableElemwise): valid_scalar_types = ( Exp, + Exp2, + Expm1, + Sigmoid, Log, Log2, Log10, @@ -547,13 +556,16 @@ def measurable_special_log_to_log(fgraph, node): return [pt.log(1 + inp)] if isinstance(node.op.scalar_op, Softplus): return [pt.log(1 + pt.exp(inp))] - if isinstance(node.op.scalar_op, Log1p): + if isinstance(node.op.scalar_op, Log1mexp): return [pt.log(1 - pt.exp(pt.neg(inp)))] @node_rewriter( [ exp, + exp2, + expm1, + sigmoid, log, log2, log10, @@ -609,6 +621,12 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform_dict = { Exp: ExpTransform(), + Exp2: ExpTransform(base=2), + Expm1: ExpTransform(m=True), + Log: LogTransform(), + Log2: LogTransform(base=2), + Log10: LogTransform(base=10), + Sigmoid: SigmoidTransform(), Abs: AbsTransform(), Sinh: SinhTransform(), Cosh: CoshTransform(), @@ -638,18 +656,11 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform = LocTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - elif isinstance(scalar_op, Log): - transform = LogTransform() - elif isinstance(scalar_op, Log2): - transform = LogTransform(base=2) - elif isinstance(scalar_op, Log10): - transform = LogTransform(base=10) elif transform is None: transform_inputs = (measurable_input, pt.mul(*other_inputs)) transform = ScaleTransform( transform_args_fn=lambda *inputs: inputs[-1], ) - transform_op = MeasurableTransform( scalar_op=scalar_op, transform=transform, @@ -883,14 +894,25 @@ def log_jac_det(self, value, *inputs): class ExpTransform(RVTransform): name = "exp" + def __init__(self, base=pt.exp(1), m=False): + self.base = base + self.m = m + super().__init__() + def forward(self, value, *inputs): - return pt.exp(value) + if self.m: + return pt.power(self.base, value) - 1 + else: + return pt.power(self.base, value) def backward(self, value, *inputs): - return pt.log(value) + if self.m: + return pt.log(value + 1) + else: + return pt.log(value) / pt.log(self.base) def log_jac_det(self, value, *inputs): - return -pt.log(value) + return -pt.log(value) / pt.log(self.base) class AbsTransform(RVTransform): @@ -1002,6 +1024,20 @@ def log_jac_det(self, value, *inputs): return value +class SigmoidTransform(RVTransform): + name = "Sigmoid" + + def forward(self, value, *inputs): + return pt.expit(value) + + def backward(self, value, *inputs): + return pt.log(value / (1 - value)) + + def log_jac_det(self, value, *inputs): + sigmoid_value = pt.sigmoid(value) + return pt.log(sigmoid_value * (1 - sigmoid_value)) + + class LogOddsTransform(RVTransform): name = "logodds" diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 785e2599fbb..723688c5c30 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -1034,9 +1034,16 @@ def test_multivariate_transform(shift, scale): (pt.arcsinh, ArcsinhTransform()), (pt.arccosh, ArccoshTransform()), (pt.arctanh, ArctanhTransform()), + (pt.exp2, ExpTransform(base=2)), + (pt.expm1, ExpTransform(m=True)), + (pt.log2, LogTransform(base=2)), + (pt.log10, LogTransform(base=10)), + (pt.log1p, LogTransform()), + (pt.log1mexp, LogTransform()), + (pt.log1pexp, LogTransform()), ], ) -def test_erf_logp(pt_transform, transform): +def test_transform_logp(pt_transform, transform): base_rv = pt.random.normal( 0.5, 1, name="base_rv" ) # Something not centered around 0 is usually better @@ -1069,6 +1076,10 @@ def test_erf_logp(pt_transform, transform): ArcsinhTransform(), ArccoshTransform(), ArctanhTransform(), + ExpTransform(base=2), + ExpTransform(m=True), + LogTransform(base=2), + LogTransform(base=10), ], ) def test_check_jac_det(transform): From 5513d8cd5f6cb9ab6bc13d490a2de673caf883b5 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Thu, 13 Jul 2023 18:58:26 +0100 Subject: [PATCH 05/11] 3 tests failing --- pymc/logprob/transforms.py | 10 ++++++++-- tests/logprob/test_transforms.py | 11 +++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 42f3ade1670..ef42e4fadb6 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -888,7 +888,10 @@ def backward(self, value, *inputs): return pt.power(self.base, value) def log_jac_det(self, value, *inputs): - return value + if self.base == pt.exp(1): + return value + else: + return pt.log(pt.power(self.base, value) * pt.log(self.base)) class ExpTransform(RVTransform): @@ -912,7 +915,10 @@ def backward(self, value, *inputs): return pt.log(value) / pt.log(self.base) def log_jac_det(self, value, *inputs): - return -pt.log(value) / pt.log(self.base) + if self.m: + return -pt.log(value + 1) + else: + return -pt.log(value) - pt.log(pt.log(self.base)) class AbsTransform(RVTransform): diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 723688c5c30..741f21cb503 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -72,7 +72,8 @@ TransformValuesRewrite, transformed_variable, ) -from pymc.testing import assert_no_rvs +from pymc.testing import Rplusbig, Vector, assert_no_rvs +from tests.distributions.test_transform import check_jacobian_det class DirichletScipyDist: @@ -1038,9 +1039,6 @@ def test_multivariate_transform(shift, scale): (pt.expm1, ExpTransform(m=True)), (pt.log2, LogTransform(base=2)), (pt.log10, LogTransform(base=10)), - (pt.log1p, LogTransform()), - (pt.log1mexp, LogTransform()), - (pt.log1pexp, LogTransform()), ], ) def test_transform_logp(pt_transform, transform): @@ -1060,10 +1058,6 @@ def test_transform_logp(pt_transform, transform): ) -from pymc.testing import Rplusbig, Vector -from tests.distributions.test_transform import check_jacobian_det - - @pytest.mark.parametrize( "transform", [ @@ -1080,6 +1074,7 @@ def test_transform_logp(pt_transform, transform): ExpTransform(m=True), LogTransform(base=2), LogTransform(base=10), + ExpTransform(), ], ) def test_check_jac_det(transform): From e51b14660974bf60470dd06c617e0c198b11545a Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 16 Jul 2023 23:27:56 +0100 Subject: [PATCH 06/11] implemented rewrites and reverted Transform classes --- pymc/logprob/transforms.py | 83 ++++++++++++-------------------- tests/logprob/test_transforms.py | 8 --- 2 files changed, 31 insertions(+), 60 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index ef42e4fadb6..240dd86eb1f 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -547,7 +547,7 @@ def measurable_sub_to_neg(fgraph, node): return [pt.add(minuend, pt.neg(subtrahend))] -@node_rewriter([log1p, softplus, log1mexp]) +@node_rewriter([log1p, softplus, log1mexp, log2, log10]) def measurable_special_log_to_log(fgraph, node): """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" [inp] = node.inputs @@ -558,17 +558,29 @@ def measurable_special_log_to_log(fgraph, node): return [pt.log(1 + pt.exp(inp))] if isinstance(node.op.scalar_op, Log1mexp): return [pt.log(1 - pt.exp(pt.neg(inp)))] + if isinstance(node.op.scalar_op, Log2): + return [pt.log(inp) / pt.log(2)] + if isinstance(node.op.scalar_op, Log10): + return [pt.log(inp) / pt.log(10)] + + +@node_rewriter([exp2, expm1, sigmoid]) +def measurable_special_exp_to_exp(fgraph, node): + """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" + [inp] = node.inputs + + if isinstance(node.op.scalar_op, Exp2): + return [pt.power(2, inp)] + if isinstance(node.op.scalar_op, Expm1): + return [pt.exp(inp + 1)] + if isinstance(node.op.scalar_op, Sigmoid): + return [1 / (1 + pt.exp(-inp))] @node_rewriter( [ exp, - exp2, - expm1, - sigmoid, log, - log2, - log10, add, mul, pow, @@ -621,12 +633,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li transform_dict = { Exp: ExpTransform(), - Exp2: ExpTransform(base=2), - Expm1: ExpTransform(m=True), Log: LogTransform(), - Log2: LogTransform(base=2), - Log10: LogTransform(base=10), - Sigmoid: SigmoidTransform(), Abs: AbsTransform(), Sinh: SinhTransform(), Cosh: CoshTransform(), @@ -715,6 +722,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_special_exp_to_exp", + measurable_special_exp_to_exp, + "basic", + "transform", +) + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms, @@ -877,48 +891,27 @@ def log_jac_det(self, value, *inputs): class LogTransform(RVTransform): name = "log" - def __init__(self, base=pt.exp(1)): - self.base = base - super().__init__() - def forward(self, value, *inputs): - return pt.log(value) / pt.log(self.base) + return pt.log(value) def backward(self, value, *inputs): - return pt.power(self.base, value) + return pt.exp(value) def log_jac_det(self, value, *inputs): - if self.base == pt.exp(1): - return value - else: - return pt.log(pt.power(self.base, value) * pt.log(self.base)) + return value class ExpTransform(RVTransform): name = "exp" - def __init__(self, base=pt.exp(1), m=False): - self.base = base - self.m = m - super().__init__() - def forward(self, value, *inputs): - if self.m: - return pt.power(self.base, value) - 1 - else: - return pt.power(self.base, value) + return pt.exp(value) def backward(self, value, *inputs): - if self.m: - return pt.log(value + 1) - else: - return pt.log(value) / pt.log(self.base) + return pt.log(value) def log_jac_det(self, value, *inputs): - if self.m: - return -pt.log(value + 1) - else: - return -pt.log(value) - pt.log(pt.log(self.base)) + return -pt.log(value) class AbsTransform(RVTransform): @@ -1030,20 +1023,6 @@ def log_jac_det(self, value, *inputs): return value -class SigmoidTransform(RVTransform): - name = "Sigmoid" - - def forward(self, value, *inputs): - return pt.expit(value) - - def backward(self, value, *inputs): - return pt.log(value / (1 - value)) - - def log_jac_det(self, value, *inputs): - sigmoid_value = pt.sigmoid(value) - return pt.log(sigmoid_value * (1 - sigmoid_value)) - - class LogOddsTransform(RVTransform): name = "logodds" diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 741f21cb503..caa1ac60839 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -1035,10 +1035,6 @@ def test_multivariate_transform(shift, scale): (pt.arcsinh, ArcsinhTransform()), (pt.arccosh, ArccoshTransform()), (pt.arctanh, ArctanhTransform()), - (pt.exp2, ExpTransform(base=2)), - (pt.expm1, ExpTransform(m=True)), - (pt.log2, LogTransform(base=2)), - (pt.log10, LogTransform(base=10)), ], ) def test_transform_logp(pt_transform, transform): @@ -1070,10 +1066,6 @@ def test_transform_logp(pt_transform, transform): ArcsinhTransform(), ArccoshTransform(), ArctanhTransform(), - ExpTransform(base=2), - ExpTransform(m=True), - LogTransform(base=2), - LogTransform(base=10), ExpTransform(), ], ) From c2a71410c5c884f9ed18799a2fbcbba1bbb54e53 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Sun, 16 Jul 2023 23:35:28 +0100 Subject: [PATCH 07/11] removed additions to node rewriter --- pymc/logprob/transforms.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 240dd86eb1f..b4561e7cf7a 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -393,15 +393,7 @@ class MeasurableTransform(MeasurableElemwise): valid_scalar_types = ( Exp, - Exp2, - Expm1, - Sigmoid, Log, - Log2, - Log10, - Log1p, - Softplus, - Log1mexp, Add, Mul, Pow, From 48166a2565562c3d3a20ac010d4b3cb84de82c10 Mon Sep 17 00:00:00 2001 From: Luke Lewis-Borrell <35955390+LukeLB@users.noreply.github.com> Date: Mon, 17 Jul 2023 17:08:38 +0100 Subject: [PATCH 08/11] Update pymc/logprob/transforms.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pymc/logprob/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index b4561e7cf7a..535b554e409 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -564,7 +564,7 @@ def measurable_special_exp_to_exp(fgraph, node): if isinstance(node.op.scalar_op, Exp2): return [pt.power(2, inp)] if isinstance(node.op.scalar_op, Expm1): - return [pt.exp(inp + 1)] + return [pt.exp(inp) - 1] if isinstance(node.op.scalar_op, Sigmoid): return [1 / (1 + pt.exp(-inp))] From 85a186597e995dffabac8fb6690fb4a374fa5a1a Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 17 Jul 2023 18:39:09 +0100 Subject: [PATCH 09/11] generalised expotential function added --- pymc/logprob/transforms.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 535b554e409..93dc4050c96 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -556,19 +556,29 @@ def measurable_special_log_to_log(fgraph, node): return [pt.log(inp) / pt.log(10)] -@node_rewriter([exp2, expm1, sigmoid]) +@node_rewriter([expm1, sigmoid]) def measurable_special_exp_to_exp(fgraph, node): - """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" + """Convert expm1, sigmoid of `MeasurableVariable`s to xp form.""" [inp] = node.inputs - - if isinstance(node.op.scalar_op, Exp2): - return [pt.power(2, inp)] if isinstance(node.op.scalar_op, Expm1): return [pt.exp(inp) - 1] if isinstance(node.op.scalar_op, Sigmoid): return [1 / (1 + pt.exp(-inp))] +@node_rewriter([exp2, pow]) +def measurable_general_exp_to_exp(fgraph, node): + """Convert exp2 and any const^x of `MeasurableVariable`s to exp form.""" + if len(node.inputs) > 1: + [const, inp] = node.inputs + else: + [inp] = node.inputs + if isinstance(node.op.scalar_op, Exp2): + return [pt.exp(pt.log(2) * inp)] + if isinstance(node.op.scalar_op, Pow) and isinstance(inp, pt.TensorVariable): + return [pt.exp(pt.log(const) * inp)] + + @node_rewriter( [ exp, @@ -721,6 +731,13 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) +measurable_ir_rewrites_db.register( + "measurable_general_exp_to_exp", + measurable_general_exp_to_exp, + "basic", + "transform", +) + measurable_ir_rewrites_db.register( "find_measurable_transforms", find_measurable_transforms, From 7e4909af8f62c72977a9f62d3d89feec11d28a32 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 17 Jul 2023 19:03:04 +0100 Subject: [PATCH 10/11] reverted back to just allowing exp2 --- pymc/logprob/transforms.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 93dc4050c96..aa057669e91 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -541,7 +541,7 @@ def measurable_sub_to_neg(fgraph, node): @node_rewriter([log1p, softplus, log1mexp, log2, log10]) def measurable_special_log_to_log(fgraph, node): - """Convert log1p, log1mexp, softplus of `MeasurableVariable`s to log form.""" + """Convert log1p, log1mexp, softplus, log2, log10 of `MeasurableVariable`s to log form.""" [inp] = node.inputs if isinstance(node.op.scalar_op, Log1p): @@ -556,29 +556,18 @@ def measurable_special_log_to_log(fgraph, node): return [pt.log(inp) / pt.log(10)] -@node_rewriter([expm1, sigmoid]) +@node_rewriter([expm1, sigmoid, exp2]) def measurable_special_exp_to_exp(fgraph, node): - """Convert expm1, sigmoid of `MeasurableVariable`s to xp form.""" + """Convert expm1, sigmoid, and exp2 of `MeasurableVariable`s to xp form.""" [inp] = node.inputs + if isinstance(node.op.scalar_op, Exp2): + return [pt.exp(pt.log(2) * inp)] if isinstance(node.op.scalar_op, Expm1): return [pt.exp(inp) - 1] if isinstance(node.op.scalar_op, Sigmoid): return [1 / (1 + pt.exp(-inp))] -@node_rewriter([exp2, pow]) -def measurable_general_exp_to_exp(fgraph, node): - """Convert exp2 and any const^x of `MeasurableVariable`s to exp form.""" - if len(node.inputs) > 1: - [const, inp] = node.inputs - else: - [inp] = node.inputs - if isinstance(node.op.scalar_op, Exp2): - return [pt.exp(pt.log(2) * inp)] - if isinstance(node.op.scalar_op, Pow) and isinstance(inp, pt.TensorVariable): - return [pt.exp(pt.log(const) * inp)] - - @node_rewriter( [ exp, From 4bc507799a7225bd4396289f54d6dc69e72c6042 Mon Sep 17 00:00:00 2001 From: Luke LB Date: Mon, 17 Jul 2023 19:04:27 +0100 Subject: [PATCH 11/11] removed measurable rewrite --- pymc/logprob/transforms.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index aa057669e91..25b0e0b1dbd 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -720,12 +720,6 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li "transform", ) -measurable_ir_rewrites_db.register( - "measurable_general_exp_to_exp", - measurable_general_exp_to_exp, - "basic", - "transform", -) measurable_ir_rewrites_db.register( "find_measurable_transforms",