Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement different convolution modes #454

Merged
merged 9 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Union

import numpy as np
Expand All @@ -6,7 +7,13 @@
from pytensor.tensor.random.utils import params_broadcast_shapes


def batched_convolution(x, w, axis: int = 0):
class ConvMode(Enum):
After = "After"
Before = "Before"
Overlap = "Overlap"


def batched_convolution(x, w, axis: int = 0, mode: ConvMode = ConvMode.Before):
"""Apply a 1D convolution in a vectorized way across multiple batch dimensions.

Parameters
Expand All @@ -18,6 +25,12 @@ def batched_convolution(x, w, axis: int = 0):
to use in the convolution.
axis : int
The axis of ``x`` along witch to apply the convolution
mode : ConvMode, optional
The convolution mode determines how the convolution is applied at the boundaries of the input signal, denoted as "x." The default mode is ConvMode.Before.

- ConvMode.After: Applies the convolution with the "Adstock" effect, resulting in a trailing decay effect.
- ConvMode.Before: Applies the convolution with the "Excitement" effect, creating a leading effect similar to the wow factor.
- ConvMode.Overlap: Applies the convolution with both "Pull-Forward" and "Pull-Backward" effects, where the effect overlaps with both preceding and succeeding elements.

Returns
-------
Expand All @@ -43,22 +56,40 @@ def batched_convolution(x, w, axis: int = 0):
# The last dimension of x is the "time" axis, which doesn't get broadcast
# The last dimension of w is the number of time steps that go into the convolution
x_shape, w_shape = params_broadcast_shapes([x.shape, w.shape], [1, 1])

x = pt.broadcast_to(x, x_shape)
w = pt.broadcast_to(w, w_shape)
x_time = x.shape[-1]
shape = (*x.shape, w.shape[-1])
# Make a tensor with x at the different time lags needed for the convolution
x_shape = list(x.shape)
# Add the size of the kernel to the time axis
x_shape[-1] = x_shape[-1] + w.shape[-1] - 1
shape = [*x_shape, w.shape[-1]]
abdalazizrashid marked this conversation as resolved.
Show resolved Hide resolved
padded_x = pt.zeros(shape, dtype=x.dtype)
if l_max is not None:
for i in range(l_max):
padded_x = pt.set_subtensor(
padded_x[..., i:x_time, i], x[..., : x_time - i]
)
else: # pragma: no cover

if l_max is None: # pragma: no cover
raise NotImplementedError(
"At the moment, convolving with weight arrays that don't have a concrete shape "
"at compile time is not supported."
)
# The window is the slice of the padded array that corresponds to the original x
if l_max <= 1:
window = slice(None)
elif mode == ConvMode.After:
window = slice(l_max - 1, None)
elif mode == ConvMode.Before:
window = slice(None, -l_max + 1)
elif mode == ConvMode.Overlap:
# Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end
window = slice((l_max // 2) - (1 if l_max % 2 == 0 else 0), -(l_max // 2))
else:
raise ValueError(f"Wrong Mode: {mode}, expected of ConvMode")

for i in range(l_max):
padded_x = pt.set_subtensor(padded_x[..., i : x_time + i, i], x)

padded_x = padded_x[..., window, :]

# The convolution is treated as an element-wise product, that then gets reduced
# along the dimension that represents the convolution time lags
conv = pt.sum(padded_x * w[..., None, :], axis=-1)
Expand Down
31 changes: 27 additions & 4 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.tensor.var import TensorVariable

from pymc_marketing.mmm.transformers import (
ConvMode,
batched_convolution,
delayed_adstock,
geometric_adstock,
Expand Down Expand Up @@ -51,9 +52,10 @@ def convolution_axis(request):
return request.param


def test_batched_convolution(convolution_inputs, convolution_axis):
@pytest.mark.parametrize("mode", [ConvMode.After, ConvMode.Before, ConvMode.Overlap])
def test_batched_convolution(convolution_inputs, convolution_axis, mode):
x, w, x_val, w_val = convolution_inputs
y = batched_convolution(x, w, convolution_axis)
y = batched_convolution(x, w, convolution_axis, mode)
if x_val is None:
y_val = y.eval()
expected_shape = getattr(x, "value", x).shape
Expand All @@ -65,8 +67,29 @@ def test_batched_convolution(convolution_inputs, convolution_axis):
x_val = np.moveaxis(
x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0
)
assert np.allclose(y_val[0], x_val[0])
assert np.allclose(y_val[1:], x_val[1:] + x_val[:-1])
mode_assertions = {
ConvMode.Before: lambda: (
np.allclose(y_val[0], x_val[0]),
np.allclose(y_val[1:], x_val[1:] + x_val[:-1]),
),
ConvMode.After: lambda: (
np.allclose(y_val[-1], x_val[-1]),
np.allclose(y_val[:-1], x_val[1:] + x_val[:-1]),
),
ConvMode.Overlap: lambda: (
np.allclose(y_val[0], x_val[0]),
np.allclose(y_val[1:-1], x_val[1:-1] + x_val[:-2]),
),
}

assert all(mode_assertions[mode]())
abdalazizrashid marked this conversation as resolved.
Show resolved Hide resolved


def test_batched_convolution_invalid_mode(convolution_inputs, convolution_axis):
x, w, x_val, w_val = convolution_inputs
invalid_mode = "InvalidMode"
with pytest.raises(ValueError):
batched_convolution(x, w, convolution_axis, invalid_mode)


def test_batched_convolution_broadcasting():
Expand Down