Skip to content

Commit

Permalink
Enforce check_parameters for alpha in geometric_adstock (pymc-l…
Browse files Browse the repository at this point in the history
…abs#960)

* Enforce check_parameters in geometric_adstock

* Remove check_parameters on l_max. Revert to original test

* Add test_geometric_adstock_bad_alpha

* use ge, le instead of gt,lt

* Update tests/mmm/test_transformers.py

Co-authored-by: Will Dean <[email protected]>

* Simplify test parameters

* Update tests/mmm/test_transformers.py

Co-authored-by: Will Dean <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: Will Dean <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 22, 2024
1 parent 1536dad commit b478d9d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
5 changes: 5 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
from pymc.distributions.dist_math import check_parameters
from pytensor.tensor.random.utils import params_broadcast_shapes


Expand Down Expand Up @@ -235,6 +236,10 @@ def geometric_adstock(
with carryover and shape effects." (2017).
"""
alpha = check_parameters(
alpha, [pt.ge(alpha, 0), pt.le(alpha, 1)], msg="0 <= alpha <= 1"
)

w = pt.power(pt.as_tensor(alpha)[..., None], pt.arange(l_max, dtype=x.dtype))
w = w / pt.sum(w, axis=-1, keepdims=True) if normalize else w
return batched_convolution(x, w, axis=axis, mode=mode)
Expand Down
18 changes: 18 additions & 0 deletions tests/mmm/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytensor.tensor as pt
import pytest
import scipy as sp
from pymc.logprob.utils import ParameterValueError
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.mmm.transformers import (
Expand Down Expand Up @@ -148,6 +149,23 @@ def test_geometric_adstock_good_alpha(self, x, alpha, l_max):
assert y_np[1] == x[1] + alpha * x[0]
assert y_np[2] == x[2] + alpha * x[1] + (alpha**2) * x[0]

@pytest.mark.parametrize(
"alpha",
[-0.3, -2, 22.5, 2],
ids=[
"less_than_zero_0",
"less_than_zero_1",
"greater_than_one_0",
"greater_than_one_1",
],
)
def test_geometric_adstock_bad_alpha(self, alpha):
l_max = 10
x = np.ones(shape=100)
y = geometric_adstock(x=x, alpha=alpha, l_max=l_max)
with pytest.raises(ParameterValueError):
y.eval()

@pytest.mark.parametrize(
argnames="mode",
argvalues=[ConvMode.After, ConvMode.Before, ConvMode.Overlap],
Expand Down

0 comments on commit b478d9d

Please sign in to comment.