Skip to content

Commit

Permalink
Swap before and after and make mode explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
abdalazizrashid committed Jan 26, 2024
1 parent 40d77ba commit e7ac756
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ConvMode(str, Enum):
Overlap = "Overlap"


def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.Before):
def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.After):
R"""Apply a 1D convolution in a vectorized way across multiple batch dimensions.
.. plot::
Expand Down Expand Up @@ -99,9 +99,9 @@ def batched_convolution(x, w, axis: int = 0, mode: ConvMode | str = ConvMode.Bef
# The window is the slice of the padded array that corresponds to the original x
if l_max <= 1:
window = slice(None)
elif mode == ConvMode.After:
window = slice(l_max - 1, None)
elif mode == ConvMode.Before:
window = slice(l_max - 1, None)
elif mode == ConvMode.After:
window = slice(None, -l_max + 1)
elif mode == ConvMode.Overlap:
# Handle even and odd l_max differently if l_max is odd then we can split evenly otherwise we drop from the end
Expand Down Expand Up @@ -186,7 +186,7 @@ def geometric_adstock(

w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype))
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
return batched_convolution(x, w, axis=axis)
return batched_convolution(x, w, axis=axis, mode=ConvMode.After)


def delayed_adstock(
Expand Down Expand Up @@ -258,7 +258,7 @@ def delayed_adstock(
(pt.arange(l_max, dtype=x.dtype) - pt.as_tensor(theta)[..., None]) ** 2,
)
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
return batched_convolution(x, w, axis=axis)
return batched_convolution(x, w, axis=axis, mode=ConvMode.After)


def logistic_saturation(x, lam: Union[npt.NDArray[np.float_], float] = 0.5):
Expand Down
4 changes: 2 additions & 2 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def test_batched_convolution(convolution_inputs, convolution_axis, mode):
x_val = np.moveaxis(
x_val if x_val is not None else getattr(x, "value", x), convolution_axis, 0
)
if mode == ConvMode.Before:
if mode == ConvMode.After:
np.testing.assert_allclose(y_val[0], x_val[0])
np.testing.assert_allclose(y_val[1:], x_val[1:] + x_val[:-1])
elif mode == ConvMode.After:
elif mode == ConvMode.Before:
np.testing.assert_allclose(y_val[-1], x_val[-1])
np.testing.assert_allclose(y_val[:-1], x_val[1:] + x_val[:-1])
elif mode == ConvMode.Overlap:
Expand Down

0 comments on commit e7ac756

Please sign in to comment.