Skip to content

Commit

Permalink
Remove warnings during tests (#823)
Browse files Browse the repository at this point in the history
* address save and load tests

* catch warning on load

* remove warnings in budget optimizer

* remove plotting warnings

* remove validating warnings

* consolidate the loading function

* remove warnings in tests

* incorporate the docstring feedback

* only one deprecation warnings on DelayedSaturatedMMM

* dont have deprecation on test
  • Loading branch information
wd60622 authored and twiecki committed Sep 10, 2024
1 parent f17d990 commit 4d98e88
Show file tree
Hide file tree
Showing 15 changed files with 286 additions and 183 deletions.
18 changes: 12 additions & 6 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pymc_marketing.model_builder import ModelBuilder
from pymc_marketing.model_config import ModelConfig, parse_model_config
from pymc_marketing.utils import from_netcdf


class CLVModel(ModelBuilder):
Expand Down Expand Up @@ -186,17 +187,22 @@ def load(cls, fname: str):
>>> imported_model = MyModel.load(name)
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
idata = from_netcdf(filepath)
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
)
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
model.build_model() # type: ignore
if model.id != idata.attrs["id"]:
Expand Down
4 changes: 2 additions & 2 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _extract_predictive_variables(
purchase_coefficient = self.fit_result["purchase_coefficient"]
alpha = alpha_scale * np.exp(
-xarray.dot(
purchase_coefficient, purchase_xarray, dims="purchase_covariate"
purchase_coefficient, purchase_xarray, dim="purchase_covariate"
)
)
alpha.name = "alpha"
Expand All @@ -429,7 +429,7 @@ def _extract_predictive_variables(
dropout_coefficient = self.fit_result["dropout_coefficient"]
beta = beta_scale * np.exp(
-xarray.dot(
dropout_coefficient, dropout_xarray, dims="dropout_covariate"
dropout_coefficient, dropout_xarray, dim="dropout_covariate"
)
)
beta.name = "beta"
Expand Down
3 changes: 1 addition & 2 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
),
model_config: dict | None = Field(None, description="Model configuration."),
sampler_config: dict | None = Field(None, description="Sampler configuration."),
**kwargs,
) -> None:
self.date_column: str = date_column
self.channel_columns: list[str] | tuple[str] = channel_columns
Expand Down Expand Up @@ -392,7 +391,7 @@ def plot_posterior_predictive(

if original_scale:
likelihood_hdi = self.get_target_transformer().inverse_transform(
Xt=likelihood_hdi
likelihood_hdi
)

ax.fill_between(
Expand Down
12 changes: 5 additions & 7 deletions pymc_marketing/mmm/budget_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,19 @@ def allocate_budget(
"No budget bounds provided. Using default bounds (0, total_budget) for each channel.",
stacklevel=2,
)
else:
if not isinstance(budget_bounds, dict):
raise TypeError("`budget_bounds` should be a dictionary.")
elif not isinstance(budget_bounds, dict):
raise TypeError("`budget_bounds` should be a dictionary.")

if custom_constraints is None:
constraints = {"type": "eq", "fun": lambda x: np.sum(x) - total_budget}
warnings.warn(
"Using default equality constraint: The sum of all budgets should be equal to the total budget.",
stacklevel=2,
)
elif not isinstance(custom_constraints, dict):
raise TypeError("`custom_constraints` should be a dictionary.")
else:
if not isinstance(custom_constraints, dict):
raise TypeError("`custom_constraints` should be a dictionary.")
else:
constraints = custom_constraints
constraints = custom_constraints

num_channels = len(self.parameters.keys())
initial_guess = np.ones(num_channels) * total_budget / num_channels
Expand Down
97 changes: 63 additions & 34 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer
from pymc_marketing.mmm.components.adstock import (
AdstockTransformation,
GeometricAdstock,
_get_adstock_function,
)
from pymc_marketing.mmm.components.saturation import (
LogisticSaturation,
SaturationTransformation,
_get_saturation_function,
)
Expand All @@ -54,6 +56,7 @@
from pymc_marketing.mmm.validating import ValidateControlColumns
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.utils import from_netcdf

__all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"]

Expand All @@ -78,17 +81,20 @@ def __init__(
channel_columns: list[str] = Field(
min_length=1, description="Column names of the media channel variables."
),
adstock_max_lag: int = Field(
...,
gt=0,
description="Number of lags to consider in the adstock transformation.",
),
adstock: str | InstanceOf[AdstockTransformation] = Field(
..., description="Type of adstock transformation to apply."
),
saturation: str | InstanceOf[SaturationTransformation] = Field(
..., description="Type of saturation transformation to apply."
),
adstock_max_lag: int | None = Field(
None,
gt=0,
description=(
"Number of lags to consider in the adstock transformation. "
"Defaults to the max lag of the adstock transformation."
),
),
time_varying_intercept: bool = Field(
False, description="Whether to consider time-varying intercept."
),
Expand Down Expand Up @@ -118,7 +124,6 @@ def __init__(
adstock_first: bool = Field(
True, description="Whether to apply adstock first."
),
**kwargs,
) -> None:
"""Constructor method.
Expand All @@ -128,12 +133,13 @@ def __init__(
Column name of the date variable. Must be parsable using ~pandas.to_datetime.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int, optional
Number of lags to consider in the adstock transformation.
adstock : str | AdstockTransformation
Type of adstock transformation to apply.
saturation : str | SaturationTransformation
Type of saturation transformation to apply.
adstock_max_lag : int, optional
Number of lags to consider in the adstock transformation. Defaults to the
max lag of the adstock transformation.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
Because the `time-varying` variable is centered around 1 and acts as a multiplier,
Expand All @@ -159,14 +165,24 @@ def __init__(
Whether to apply adstock first, by default True.
"""
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.time_varying_media = time_varying_media
self.date_column = date_column
self.validate_data = validate_data

self.adstock_first = adstock_first
self.adstock = _get_adstock_function(function=adstock, l_max=adstock_max_lag)

if adstock_max_lag is not None:
warnings.warn(
"The `adstock_max_lag` parameter is deprecated. Use `adstock` directly",
DeprecationWarning,
stacklevel=1,
)
adstock_kwargs = {"l_max": adstock_max_lag}
else:
adstock_kwargs = {}

self.adstock = _get_adstock_function(function=adstock, **adstock_kwargs)
self.saturation = _get_saturation_function(function=saturation)

model_config = model_config or {}
Expand All @@ -184,7 +200,6 @@ def __init__(
channel_columns=channel_columns,
model_config=model_config,
sampler_config=sampler_config,
adstock_max_lag=adstock_max_lag,
)

self.yearly_seasonality = yearly_seasonality
Expand Down Expand Up @@ -287,7 +302,7 @@ def _save_input_params(self, idata) -> None:
idata.attrs["adstock_first"] = json.dumps(self.adstock_first)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
Expand Down Expand Up @@ -358,7 +373,11 @@ def build_model(
.. code-block:: python
from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)
from pymc_marketing.prior import Prior
custom_config = {
Expand All @@ -374,12 +393,13 @@ def build_model(
model = MMM(
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=custom_config,
)
Expand Down Expand Up @@ -634,7 +654,7 @@ def load(cls, fname: str):
"""

filepath = Path(fname)
idata = az.from_netcdf(filepath)
idata = from_netcdf(filepath)
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
Expand Down Expand Up @@ -844,22 +864,25 @@ class MMM(
import numpy as np
import pandas as pd
from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)
data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
data = pd.read_csv(data_url, parse_dates=["date_week"])
mmm = MMM(
date_column="date_week",
adstock="geometric",
saturation="logistic",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
)
Expand All @@ -880,27 +903,30 @@ class MMM(
import numpy as np
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation
MMM,
)
from pymc_marketing.prior import Prior
from pymc_marketing.mmm import MMM
my_model_config = {
"beta_channel": Prior("LogNormal", mu=np.array([2, 1]), sigma=1),
"likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=2)),
}
mmm = MMM(
adstock="geometric",
saturation="logistic",
model_config=my_model_config,
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
"t",
],
adstock_max_lag=8,
yearly_seasonality=2,
model_config=my_model_config,
)
As you can see, we can configure all prior and likelihood distributions via the `model_config`.
Expand Down Expand Up @@ -1765,18 +1791,21 @@ def add_lift_test_measurements(
import pandas as pd
import numpy as np
from pymc_marketing.mmm import MMM
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation,
MMM,
)
model = MMM(
adstock="geometric",
saturation="logistic",
date_column="date_week",
channel_columns=["x1", "x2"],
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
control_columns=[
"event_1",
"event_2",
],
adstock_max_lag=8,
yearly_seasonality=2,
)
Expand Down Expand Up @@ -2206,7 +2235,6 @@ def __init__(
control_columns: list[str] | None = None,
yearly_seasonality: int | None = None,
adstock_first: bool = True,
**kwargs,
) -> None:
"""
Wrapper function for DelayedSaturatedMMM class initializer.
Expand All @@ -2217,22 +2245,23 @@ def __init__(
warnings.warn(
"The DelayedSaturatedMMM class is deprecated. Please use the MMM class instead.",
DeprecationWarning,
stacklevel=2,
stacklevel=1,
)

adstock = GeometricAdstock(l_max=adstock_max_lag)
saturation = LogisticSaturation()

super().__init__(
date_column=date_column,
channel_columns=channel_columns,
adstock_max_lag=adstock_max_lag,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
model_config=model_config,
sampler_config=sampler_config,
validate_data=validate_data,
control_columns=control_columns,
yearly_seasonality=yearly_seasonality,
adstock="geometric",
saturation="logistic",
adstock=adstock,
saturation=saturation,
adstock_first=adstock_first,
**kwargs,
)
Loading

0 comments on commit 4d98e88

Please sign in to comment.