diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index caa47a4d..d469769f 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Union import numpy as np @@ -6,7 +7,13 @@ from pytensor.tensor.random.utils import params_broadcast_shapes -def batched_convolution(x, w, axis: int = 0): +class ConvMode(Enum): + After = "After" + Before = "Before" + Overlap = "Overlap" + + +def batched_convolution(x, w, axis: int = 0, mode: ConvMode = ConvMode.Before): """Apply a 1D convolution in a vectorized way across multiple batch dimensions. Parameters @@ -18,6 +25,12 @@ def batched_convolution(x, w, axis: int = 0): to use in the convolution. axis : int The axis of ``x`` along witch to apply the convolution + mode : ConvMode, optional + The convolution mode determines how the convolution is applied at the boundaries of the input signal, denoted as "x." The default mode is ConvMode.Before. + + - ConvMode.After: Applies the convolution with the "Adstock" effect, resulting in a trailing decay effect. + - ConvMode.Before: Applies the convolution with the "Excitement" effect, creating a leading effect similar to the wow factor. + - ConvMode.Overlap: Applies the convolution with both "Pull-Forward" and "Pull-Backward" effects, where the effect overlaps with both preceding and succeeding elements. Returns ------- @@ -43,22 +56,39 @@ def batched_convolution(x, w, axis: int = 0): # The last dimension of x is the "time" axis, which doesn't get broadcast # The last dimension of w is the number of time steps that go into the convolution x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1]) + x = pt.broadcast_to(x, x_shape) w = pt.broadcast_to(w, w_shape) x_time = x.shape[-1] - shape = (*x.shape, w.shape[-1]) # Make a tensor with x at the different time lags needed for the convolution + x_shape = x.shape + # Add the size of the kernel to the time axis + shape = (*x_shape[:-1], x_shape[-1] + w.shape[-1] - 1, w.shape[-1]) padded_x = pt.zeros(shape, dtype=x.dtype) - if l_max is not None: - for i in range(l_max): - padded_x = pt.set_subtensor( - padded_x[..., i:x_time, i], x[..., : x_time - i] - ) - else: # pragma: no cover + + if l_max is None: # pragma: no cover raise NotImplementedError( "At the moment, convolving with weight arrays that don't have a concrete shape " "at compile time is not supported." ) + # The window is the slice of the padded array that corresponds to the original x + if l_max <= 1: + window = slice(None) + elif mode == ConvMode.After: + window = slice(l_max - 1, None) + elif mode == ConvMode.Before: + window = slice(None, -l_max + 1) + elif mode == ConvMode.Overlap: + # Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end + window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2)) + else: + raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode") + + for i in range(l_max): + padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x) + + padded_x = padded_x[..., window, :] + # The convolution is treated as an element-wise product, that then gets reduced # along the dimension that represents the convolution time lags conv = pt.sum(padded_x * w[..., None, :], axis=-1) diff --git a/tests/mmm/test_transformers.py b/tests/mmm/test_transformers.py index 2870e715..cf2336e2 100644 --- a/tests/mmm/test_transformers.py +++ b/tests/mmm/test_transformers.py @@ -5,6 +5,7 @@ from pytensor.tensor.var import TensorVariable from pymc_marketing.mmm.transformers import ( + ConvMode, batched_convolution, delayed_adstock, geometric_adstock, @@ -51,9 +52,10 @@ def convolution_axis(request): return request.param -def test_batched_convolution(convolution_inputs, convolution_axis): +@pytest.mark.parametrize("mode", [ConvMode.After, ConvMode.Before, ConvMode.Overlap]) +def test_batched_convolution(convolution_inputs, convolution_axis, mode): x, w, x_val, w_val = convolution_inputs - y = batched_convolution(x, w, convolution_axis) + y = batched_convolution(x, w, convolution_axis, mode) if x_val is None: y_val = y.eval() expected_shape = getattr(x, "value", x).shape @@ -65,8 +67,22 @@ def test_batched_convolution(convolution_inputs, convolution_axis): x_val = np.moveaxis( x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0 ) - assert np.allclose(y_val[0], x_val[0]) - assert np.allclose(y_val[1:], x_val[1:] + x_val[:-1]) + if mode == ConvMode.Before: + np.testing.assert_allclose(y_val[0], x_val[0]) + np.testing.assert_allclose(y_val[1:], x_val[1:] + x_val[:-1]) + elif mode == ConvMode.After: + np.testing.assert_allclose(y_val[-1], x_val[-1]) + np.testing.assert_allclose(y_val[:-1], x_val[1:] + x_val[:-1]) + elif mode == ConvMode.Overlap: + np.testing.assert_allclose(y_val[0], x_val[0]) + np.testing.assert_allclose(y_val[1:-1], x_val[1:-1] + x_val[:-2]) + + +def test_batched_convolution_invalid_mode(convolution_inputs, convolution_axis): + x, w, x_val, w_val = convolution_inputs + invalid_mode = "InvalidMode" + with pytest.raises(ValueError): + batched_convolution(x, w, convolution_axis, invalid_mode) def test_batched_convolution_broadcasting():