diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index f77f3bf1..1e163699 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -14,7 +14,7 @@ class ConvMode(str, Enum): Overlap = "Overlap" -def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.Before): +def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.After): R"""Apply a 1D convolution in a vectorized way across multiple batch dimensions. .. plot:: @@ -99,9 +99,9 @@ def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.Bef # The window is the slice of the padded array that corresponds to the original x if l_max <= 1: window = slice(None) - elif mode == ConvMode.After: - window = slice(l_max - 1, None) elif mode == ConvMode.Before: + window = slice(l_max - 1, None) + elif mode == ConvMode.After: window = slice(None, -l_max + 1) elif mode == ConvMode.Overlap: # Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end @@ -186,7 +186,7 @@ def geometric_adstock( w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype)) w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w - return batched_convolution(x, w, axis=axis) + return batched_convolution(x, w, axis=axis, mode=ConvMode.After) def delayed_adstock( @@ -258,7 +258,7 @@ def delayed_adstock( (pt.arange(l_max, dtype=x.dtype) - pt.as_tensor(theta)[..., None]) ** 2, ) w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w - return batched_convolution(x, w, axis=axis) + return batched_convolution(x, w, axis=axis, mode=ConvMode.After) def logistic_saturation(x, lam: Union[npt.NDArray[np.float_], float] = 0.5): diff --git a/tests/mmm/test_transformers.py b/tests/mmm/test_transformers.py index cf2336e2..0d4c5f06 100644 --- a/tests/mmm/test_transformers.py +++ b/tests/mmm/test_transformers.py @@ -67,10 +67,10 @@ def test_batched_convolution(convolution_inputs, convolution_axis, mode): x_val = np.moveaxis( x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0 ) - if mode == ConvMode.Before: + if mode == ConvMode.After: np.testing.assert_allclose(y_val[0], x_val[0]) np.testing.assert_allclose(y_val[1:], x_val[1:] + x_val[:-1]) - elif mode == ConvMode.After: + elif mode == ConvMode.Before: np.testing.assert_allclose(y_val[-1], x_val[-1]) np.testing.assert_allclose(y_val[:-1], x_val[1:] + x_val[:-1]) elif mode == ConvMode.Overlap: