Skip to content

Commit

Permalink
Save off media transformations (#882)
Browse files Browse the repository at this point in the history
* to_dict via lookup_name

* parse to and from dict for attrs

* improve the codecov

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change test with change in default behavior

* increase the MMM model version

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and twiecki committed Sep 10, 2024
1 parent 64af0ce commit 552d9c7
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 16 deletions.
40 changes: 39 additions & 1 deletion pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from pymc_marketing.prior import Prior
class MyAdstock(AdstockTransformation):
lookup_name: str = "my_adstock"
def function(self, x, alpha):
return x * alpha
Expand Down Expand Up @@ -92,7 +94,7 @@ def __init__(
True, description="Whether to normalize the adstock values."
),
mode: ConvMode = Field(ConvMode.After, description="Convolution mode."),
priors: dict[str, str | InstanceOf[Prior]] | None = Field(
priors: dict[str, InstanceOf[Prior]] | None = Field(
default=None, description="Priors for the parameters."
),
prefix: str | None = Field(None, description="Prefix for the parameters."),
Expand All @@ -103,6 +105,27 @@ def __init__(

super().__init__(priors=priors, prefix=prefix)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"prefix={self.prefix!r}, "
f"l_max={self.l_max}, "
f"normalize={self.normalize}, "
f"mode={self.mode.name!r}, "
f"priors={self.function_priors}"
")"
)

def to_dict(self) -> dict:
"""Convert the adstock transformation to a dictionary."""
data = super().to_dict()

data["l_max"] = self.l_max
data["normalize"] = self.normalize
data["mode"] = self.mode.name

return data

def sample_curve(
self,
parameters: xr.Dataset,
Expand Down Expand Up @@ -371,6 +394,21 @@ def function(self, x, lam, k):
}


def register_adstock_transformation(cls: type[AdstockTransformation]) -> None:
"""Register a new adstock transformation."""
ADSTOCK_TRANSFORMATIONS[cls.lookup_name] = cls


def adstock_from_dict(data: dict) -> AdstockTransformation:
"""Create an adstock transformation from a dictionary."""
data = data.copy()
lookup_name = data.pop("lookup_name")
cls = ADSTOCK_TRANSFORMATIONS[lookup_name]

data["priors"] = {k: Prior.from_json(v) for k, v in data["priors"].items()}
return cls(**data)


def _get_adstock_function(
function: str | AdstockTransformation,
**kwargs,
Expand Down
44 changes: 40 additions & 4 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,10 @@ class Transformation:
Parameters
----------
priors : dict, optional
priors : dict[str, Prior], optional
Dictionary with the priors for the parameters of the function. The keys should be the
parameter names and the values should be dictionaries with the distribution and kwargs.
parameter names and the values the priors. If not provided, it will use the default
priors from the subclass.
prefix : str, optional
The prefix for the variables that will be created. If not provided, it will use the prefix
from the subclass.
Expand All @@ -112,12 +113,43 @@ class Transformation:
lookup_name: str

def __init__(
self, priors: dict[str, Any | Prior] | None = None, prefix: str | None = None
self, priors: dict[str, Prior] | None = None, prefix: str | None = None
) -> None:
self._checks()
self.function_priors = priors # type: ignore
self.prefix = prefix or self.prefix

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"prefix={self.prefix!r}, "
f"priors={self.function_priors}"
")"
)

def to_dict(self) -> dict[str, Any]:
"""Convert the transformation to a dictionary.
Returns
-------
dict
The dictionary defining the transformation.
"""
return {
"lookup_name": self.lookup_name,
"prefix": self.prefix,
"priors": {
key: value.to_json() for key, value in self.function_priors.items()
},
}

def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False

return self.to_dict() == other.to_dict()

@property
def function_priors(self) -> dict[str, Prior]:
return self._function_priors
Expand All @@ -137,7 +169,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
Parameters
----------
priors : dict
priors : dict[str, Prior]
Dictionary with the new priors for the parameters of the function.
Examples
Expand All @@ -150,6 +182,7 @@ def update_priors(self, priors: dict[str, Prior]) -> None:
from pymc_marketing.prior import Prior
class MyTransformation(Transformation):
lookup_name: str = "my_transformation"
prefix: str = "transformation"
function = lambda x, lam: x * lam
default_priors = {"lam": Prior("Gamma", alpha=3, beta=1)}
Expand Down Expand Up @@ -200,6 +233,9 @@ def _has_all_attributes(self) -> None:
if not hasattr(self, "function"):
raise NotImplementedError("function must be implemented in the subclass")

if not hasattr(self, "lookup_name"):
raise NotImplementedError("lookup_name must be implemented in the subclass")

def _has_defaults_for_all_arguments(self) -> None:
function_signature = signature(self.function)

Expand Down
18 changes: 18 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from pymc_marketing.prior import Prior
class InfiniteReturns(SaturationTransformation):
lookup_name: str = "infinite_returns"
def function(self, x, b):
return b * x
Expand Down Expand Up @@ -109,6 +111,7 @@ def infinite_returns(x, b):
return b * x
class InfiniteReturns(SaturationTransformation):
lookup_name = "infinite_returns"
function = infinite_returns
default_priors = {"b": Prior("HalfNormal")}
Expand Down Expand Up @@ -417,6 +420,21 @@ def function(self, x, alpha, beta):
}


def register_saturation_transformation(cls: type[SaturationTransformation]) -> None:
"""Register a new saturation transformation."""
SATURATION_TRANSFORMATIONS[cls.lookup_name] = cls


def saturation_from_dict(data: dict) -> SaturationTransformation:
data = data.copy()
cls = SATURATION_TRANSFORMATIONS[data.pop("lookup_name")]

data["priors"] = {
key: Prior.from_json(value) for key, value in data["priors"].items()
}
return cls(**data)


def _get_saturation_function(
function: str | SaturationTransformation,
) -> SaturationTransformation:
Expand Down
12 changes: 7 additions & 5 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@
AdstockTransformation,
GeometricAdstock,
_get_adstock_function,
adstock_from_dict,
)
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
SaturationTransformation,
_get_saturation_function,
saturation_from_dict,
)
from pymc_marketing.mmm.fourier import YearlyFourier
from pymc_marketing.mmm.lift_test import (
Expand Down Expand Up @@ -299,8 +301,8 @@ def _generate_and_preprocess_model_data( # type: ignore
def create_idata_attrs(self) -> dict[str, str]:
attrs = super().create_idata_attrs()
attrs["date_column"] = json.dumps(self.date_column)
attrs["adstock"] = json.dumps(self.adstock.lookup_name)
attrs["saturation"] = json.dumps(self.saturation.lookup_name)
attrs["adstock"] = json.dumps(self.adstock.to_dict())
attrs["saturation"] = json.dumps(self.saturation.to_dict())
attrs["adstock_first"] = json.dumps(self.adstock_first)
attrs["control_columns"] = json.dumps(self.control_columns)
attrs["channel_columns"] = json.dumps(self.channel_columns)
Expand Down Expand Up @@ -645,8 +647,8 @@ def attrs_to_init_kwargs(cls, attrs) -> dict[str, Any]:
"control_columns": json.loads(attrs["control_columns"]),
"channel_columns": json.loads(attrs["channel_columns"]),
"adstock_max_lag": json.loads(attrs["adstock_max_lag"]),
"adstock": json.loads(attrs.get("adstock", '"geometric"')),
"saturation": json.loads(attrs.get("saturation", '"logistic"')),
"adstock": adstock_from_dict(json.loads(attrs["adstock"])),
"saturation": saturation_from_dict(json.loads(attrs["saturation"])),
"adstock_first": json.loads(attrs.get("adstock_first", "true")),
"yearly_seasonality": json.loads(attrs["yearly_seasonality"]),
"time_varying_intercept": json.loads(
Expand Down Expand Up @@ -909,7 +911,7 @@ class MMM(
""" # noqa: E501

_model_type: str = "MMM"
version: str = "0.0.1"
version: str = "0.0.2"

def channel_contributions_forward_pass(
self, channel_data: npt.NDArray[np.float64]
Expand Down
67 changes: 67 additions & 0 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from pydantic import ValidationError

from pymc_marketing.mmm.components.adstock import (
ADSTOCK_TRANSFORMATIONS,
AdstockTransformation,
DelayedAdstock,
GeometricAdstock,
WeibullAdstock,
WeibullCDFAdstock,
WeibullPDFAdstock,
_get_adstock_function,
adstock_from_dict,
register_adstock_transformation,
)
from pymc_marketing.mmm.transformers import ConvMode
from pymc_marketing.prior import Prior


def adstocks() -> list[AdstockTransformation]:
Expand Down Expand Up @@ -141,3 +146,65 @@ def test_adstock_sample_curve(adstock) -> None:
assert isinstance(curve, xr.DataArray)
assert curve.name == "adstock"
assert curve.shape == (1, 500, adstock.l_max)


def test_adstock_from_dict() -> None:
data = {
"lookup_name": "geometric",
"l_max": 10,
"prefix": "test",
"mode": "Before",
"priors": {
"alpha": {
"dist": "Beta",
"kwargs": {
"alpha": 1,
"beta": 2,
},
},
},
}

adstock = adstock_from_dict(data)
assert adstock == GeometricAdstock(
l_max=10,
prefix="test",
priors={
"alpha": Prior("Beta", alpha=1, beta=2),
},
mode=ConvMode.Before,
)


def test_register_adstock_transformation() -> None:
class NewTransformation(AdstockTransformation):
lookup_name: str = "new_transformation"
default_priors = {}

def function(self, x):
return x

register_adstock_transformation(NewTransformation)
assert "new_transformation" in ADSTOCK_TRANSFORMATIONS

data = {
"lookup_name": "new_transformation",
"l_max": 10,
"normalize": False,
"mode": "Before",
"priors": {},
}
adstock = adstock_from_dict(data)
assert adstock == NewTransformation(
l_max=10, mode=ConvMode.Before, normalize=False, priors={}
)


def test_repr() -> None:
assert repr(GeometricAdstock(l_max=10)) == (
"GeometricAdstock(prefix='adstock', l_max=10, "
"normalize=True, "
"mode='After', "
"priors={'alpha': Prior(\"Beta\", alpha=1, beta=3)}"
")"
)
Loading

0 comments on commit 552d9c7

Please sign in to comment.