Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Media transformation sampling & plotting methods #734

Merged
merged 12 commits into from
Jun 12, 2024
85 changes: 85 additions & 0 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

.. code-block:: python

from pymc_marketing.mmm import AdstockTransformation

class MyAdstock(AdstockTransformation):
def function(self, x, alpha):
return x * alpha
Expand All @@ -35,6 +37,9 @@ def function(self, x, alpha):

import warnings

import numpy as np
import xarray as xr

from pymc_marketing.mmm.components.base import Transformation
from pymc_marketing.mmm.transformers import (
ConvMode,
Expand Down Expand Up @@ -74,12 +79,62 @@ def __init__(

super().__init__(priors=priors, prefix=prefix)

def sample_curve(
self,
parameters: xr.Dataset,
amount: float = 1.0,
) -> xr.DataArray:
"""Sample the adstock transformation given parameters.

Parameters
----------
parameters : xr.Dataset
Dataset with parameter values.
amount : float, optional
Amount to apply the adstock transformation to, by default 1.0.

Returns
-------
xr.DataArray
Adstocked version of the amount.

"""

time_since = np.arange(0, self.l_max)
coords = {
"time since exposure": time_since,
}
x = np.zeros(self.l_max)
x[0] = amount

return self._sample_curve(
var_name="adstock",
parameters=parameters,
x=x,
coords=coords,
)


class GeometricAdstock(AdstockTransformation):
"""Wrapper around geometric adstock function.

For more information, see :func:`pymc_marketing.mmm.transformers.geometric_adstock`.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import GeometricAdstock

rng = np.random.default_rng(0)

adstock = GeometricAdstock(l_max=10)
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

"""

lookup_name = "geometric"
Expand All @@ -97,6 +152,21 @@ class DelayedAdstock(AdstockTransformation):

For more information, see :func:`pymc_marketing.mmm.transformers.delayed_adstock`.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import DelayedAdstock

rng = np.random.default_rng(0)

adstock = DelayedAdstock(l_max=10)
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

"""

lookup_name = "delayed"
Expand All @@ -122,6 +192,21 @@ class WeibullAdstock(AdstockTransformation):

For more information, see :func:`pymc_marketing.mmm.transformers.weibull_adstock`.

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import WeibullAdstock

rng = np.random.default_rng(0)

adstock = WeibullAdstock(l_max=10, kind="CDF")
prior = adstock.sample_prior(random_seed=rng)
curve = adstock.sample_curve(prior)
adstock.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()

"""

lookup_name = "weibull"
Expand Down
Loading
Loading