Skip to content

Commit

Permalink
Implement different convolution modes (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
abdalazizrashid authored Dec 7, 2023
1 parent c1d8909 commit c87446d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 12 deletions.
46 changes: 38 additions & 8 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Union

import numpy as np
Expand All @@ -6,7 +7,13 @@
from pytensor.tensor.random.utils import params_broadcast_shapes


def batched_convolution(x, w, axis: int = 0):
class ConvMode(Enum):
After = "After"
Before = "Before"
Overlap = "Overlap"


def batched_convolution(x, w, axis: int = 0, mode: ConvMode = ConvMode.Before):
"""Apply a 1D convolution in a vectorized way across multiple batch dimensions.
Parameters
Expand All @@ -18,6 +25,12 @@ def batched_convolution(x, w, axis: int = 0):
to use in the convolution.
axis : int
The axis of ``x`` along witch to apply the convolution
mode : ConvMode, optional
The convolution mode determines how the convolution is applied at the boundaries of the input signal, denoted as "x." The default mode is ConvMode.Before.
- ConvMode.After: Applies the convolution with the "Adstock" effect, resulting in a trailing decay effect.
- ConvMode.Before: Applies the convolution with the "Excitement" effect, creating a leading effect similar to the wow factor.
- ConvMode.Overlap: Applies the convolution with both "Pull-Forward" and "Pull-Backward" effects, where the effect overlaps with both preceding and succeeding elements.
Returns
-------
Expand All @@ -43,22 +56,39 @@ def batched_convolution(x, w, axis: int = 0):
# The last dimension of x is the "time" axis, which doesn't get broadcast
# The last dimension of w is the number of time steps that go into the convolution
x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1])

x = pt.broadcast_to(x, x_shape)
w = pt.broadcast_to(w, w_shape)
x_time = x.shape[-1]
shape = (*x.shape, w.shape[-1])
# Make a tensor with x at the different time lags needed for the convolution
x_shape = x.shape
# Add the size of the kernel to the time axis
shape = (*x_shape[:-1], x_shape[-1] + w.shape[-1] - 1, w.shape[-1])
padded_x = pt.zeros(shape, dtype=x.dtype)
if l_max is not None:
for i in range(l_max):
padded_x = pt.set_subtensor(
padded_x[..., i:x_time, i], x[..., : x_time - i]
)
else: # pragma: no cover

if l_max is None: # pragma: no cover
raise NotImplementedError(
"At the moment, convolving with weight arrays that don't have a concrete shape "
"at compile time is not supported."
)
# The window is the slice of the padded array that corresponds to the original x
if l_max <= 1:
window = slice(None)
elif mode == ConvMode.After:
window = slice(l_max - 1, None)
elif mode == ConvMode.Before:
window = slice(None, -l_max + 1)
elif mode == ConvMode.Overlap:
# Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end
window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2))
else:
raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode")

for i in range(l_max):
padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x)

padded_x = padded_x[..., window, :]

# The convolution is treated as an element-wise product, that then gets reduced
# along the dimension that represents the convolution time lags
conv = pt.sum(padded_x * w[..., None, :], axis=-1)
Expand Down
24 changes: 20 additions & 4 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.tensor.var import TensorVariable

from pymc_marketing.mmm.transformers import (
ConvMode,
batched_convolution,
delayed_adstock,
geometric_adstock,
Expand Down Expand Up @@ -51,9 +52,10 @@ def convolution_axis(request):
return request.param


def test_batched_convolution(convolution_inputs, convolution_axis):
@pytest.mark.parametrize("mode", [ConvMode.After, ConvMode.Before, ConvMode.Overlap])
def test_batched_convolution(convolution_inputs, convolution_axis, mode):
x, w, x_val, w_val = convolution_inputs
y = batched_convolution(x, w, convolution_axis)
y = batched_convolution(x, w, convolution_axis, mode)
if x_val is None:
y_val = y.eval()
expected_shape = getattr(x, "value", x).shape
Expand All @@ -65,8 +67,22 @@ def test_batched_convolution(convolution_inputs, convolution_axis):
x_val = np.moveaxis(
x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0
)
assert np.allclose(y_val[0], x_val[0])
assert np.allclose(y_val[1:], x_val[1:] + x_val[:-1])
if mode == ConvMode.Before:
np.testing.assert_allclose(y_val[0], x_val[0])
np.testing.assert_allclose(y_val[1:], x_val[1:] + x_val[:-1])
elif mode == ConvMode.After:
np.testing.assert_allclose(y_val[-1], x_val[-1])
np.testing.assert_allclose(y_val[:-1], x_val[1:] + x_val[:-1])
elif mode == ConvMode.Overlap:
np.testing.assert_allclose(y_val[0], x_val[0])
np.testing.assert_allclose(y_val[1:-1], x_val[1:-1] + x_val[:-2])


def test_batched_convolution_invalid_mode(convolution_inputs, convolution_axis):
x, w, x_val, w_val = convolution_inputs
invalid_mode = "InvalidMode"
with pytest.raises(ValueError):
batched_convolution(x, w, convolution_axis, invalid_mode)


def test_batched_convolution_broadcasting():
Expand Down

0 comments on commit c87446d

Please sign in to comment.