diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index b53c224e..83394b90 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -28,6 +28,7 @@ from pymc_marketing.model_builder import ModelBuilder from pymc_marketing.model_config import ModelConfig, parse_model_config +from pymc_marketing.utils import from_netcdf class CLVModel(ModelBuilder): @@ -186,17 +187,22 @@ def load(cls, fname: str): >>> imported_model = MyModel.load(name) """ filepath = Path(str(fname)) - idata = az.from_netcdf(filepath) + idata = from_netcdf(filepath) return cls._build_with_idata(idata) @classmethod def _build_with_idata(cls, idata: az.InferenceData): dataset = idata.fit_data.to_dataframe() - model = cls( - dataset, - model_config=json.loads(idata.attrs["model_config"]), # type: ignore - sampler_config=json.loads(idata.attrs["sampler_config"]), - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=DeprecationWarning, + ) + model = cls( + dataset, + model_config=json.loads(idata.attrs["model_config"]), # type: ignore + sampler_config=json.loads(idata.attrs["sampler_config"]), + ) model.idata = idata model.build_model() # type: ignore if model.id != idata.attrs["id"]: diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index db5f10ab..81f8cac7 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -412,7 +412,7 @@ def _extract_predictive_variables( purchase_coefficient = self.fit_result["purchase_coefficient"] alpha = alpha_scale * np.exp( -xarray.dot( - purchase_coefficient, purchase_xarray, dims="purchase_covariate" + purchase_coefficient, purchase_xarray, dim="purchase_covariate" ) ) alpha.name = "alpha" @@ -429,7 +429,7 @@ def _extract_predictive_variables( dropout_coefficient = self.fit_result["dropout_coefficient"] beta = beta_scale * np.exp( -xarray.dot( - dropout_coefficient, dropout_xarray, dims="dropout_covariate" + dropout_coefficient, dropout_xarray, dim="dropout_covariate" ) ) beta.name = "beta" diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index 34c1e429..d5f2497e 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -63,7 +63,6 @@ def __init__( ), model_config: dict | None = Field(None, description="Model configuration."), sampler_config: dict | None = Field(None, description="Sampler configuration."), - **kwargs, ) -> None: self.date_column: str = date_column self.channel_columns: list[str] | tuple[str] = channel_columns @@ -392,7 +391,7 @@ def plot_posterior_predictive( if original_scale: likelihood_hdi = self.get_target_transformer().inverse_transform( - Xt=likelihood_hdi + likelihood_hdi ) ax.fill_between( diff --git a/pymc_marketing/mmm/budget_optimizer.py b/pymc_marketing/mmm/budget_optimizer.py index 3b6ba111..8f3719dc 100644 --- a/pymc_marketing/mmm/budget_optimizer.py +++ b/pymc_marketing/mmm/budget_optimizer.py @@ -167,9 +167,8 @@ def allocate_budget( "No budget bounds provided. Using default bounds (0, total_budget) for each channel.", stacklevel=2, ) - else: - if not isinstance(budget_bounds, dict): - raise TypeError("`budget_bounds` should be a dictionary.") + elif not isinstance(budget_bounds, dict): + raise TypeError("`budget_bounds` should be a dictionary.") if custom_constraints is None: constraints = {"type": "eq", "fun": lambda x: np.sum(x) - total_budget} @@ -177,11 +176,10 @@ def allocate_budget( "Using default equality constraint: The sum of all budgets should be equal to the total budget.", stacklevel=2, ) + elif not isinstance(custom_constraints, dict): + raise TypeError("`custom_constraints` should be a dictionary.") else: - if not isinstance(custom_constraints, dict): - raise TypeError("`custom_constraints` should be a dictionary.") - else: - constraints = custom_constraints + constraints = custom_constraints num_channels = len(self.parameters.keys()) initial_guess = np.ones(num_channels) * total_budget / num_channels diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 36c9a729..74366e88 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -34,9 +34,11 @@ from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer from pymc_marketing.mmm.components.adstock import ( AdstockTransformation, + GeometricAdstock, _get_adstock_function, ) from pymc_marketing.mmm.components.saturation import ( + LogisticSaturation, SaturationTransformation, _get_saturation_function, ) @@ -54,6 +56,7 @@ from pymc_marketing.mmm.validating import ValidateControlColumns from pymc_marketing.model_config import parse_model_config from pymc_marketing.prior import Prior +from pymc_marketing.utils import from_netcdf __all__ = ["BaseMMM", "MMM", "DelayedSaturatedMMM"] @@ -78,17 +81,20 @@ def __init__( channel_columns: list[str] = Field( min_length=1, description="Column names of the media channel variables." ), - adstock_max_lag: int = Field( - ..., - gt=0, - description="Number of lags to consider in the adstock transformation.", - ), adstock: str | InstanceOf[AdstockTransformation] = Field( ..., description="Type of adstock transformation to apply." ), saturation: str | InstanceOf[SaturationTransformation] = Field( ..., description="Type of saturation transformation to apply." ), + adstock_max_lag: int | None = Field( + None, + gt=0, + description=( + "Number of lags to consider in the adstock transformation. " + "Defaults to the max lag of the adstock transformation." + ), + ), time_varying_intercept: bool = Field( False, description="Whether to consider time-varying intercept." ), @@ -118,7 +124,6 @@ def __init__( adstock_first: bool = Field( True, description="Whether to apply adstock first." ), - **kwargs, ) -> None: """Constructor method. @@ -128,12 +133,13 @@ def __init__( Column name of the date variable. Must be parsable using ~pandas.to_datetime. channel_columns : List[str] Column names of the media channel variables. - adstock_max_lag : int, optional - Number of lags to consider in the adstock transformation. adstock : str | AdstockTransformation Type of adstock transformation to apply. saturation : str | SaturationTransformation Type of saturation transformation to apply. + adstock_max_lag : int, optional + Number of lags to consider in the adstock transformation. Defaults to the + max lag of the adstock transformation. time_varying_intercept : bool, optional Whether to consider time-varying intercept, by default False. Because the `time-varying` variable is centered around 1 and acts as a multiplier, @@ -159,14 +165,24 @@ def __init__( Whether to apply adstock first, by default True. """ self.control_columns = control_columns - self.adstock_max_lag = adstock_max_lag self.time_varying_intercept = time_varying_intercept self.time_varying_media = time_varying_media self.date_column = date_column self.validate_data = validate_data self.adstock_first = adstock_first - self.adstock = _get_adstock_function(function=adstock, l_max=adstock_max_lag) + + if adstock_max_lag is not None: + warnings.warn( + "The `adstock_max_lag` parameter is deprecated. Use `adstock` directly", + DeprecationWarning, + stacklevel=1, + ) + adstock_kwargs = {"l_max": adstock_max_lag} + else: + adstock_kwargs = {} + + self.adstock = _get_adstock_function(function=adstock, **adstock_kwargs) self.saturation = _get_saturation_function(function=saturation) model_config = model_config or {} @@ -184,7 +200,6 @@ def __init__( channel_columns=channel_columns, model_config=model_config, sampler_config=sampler_config, - adstock_max_lag=adstock_max_lag, ) self.yearly_seasonality = yearly_seasonality @@ -287,7 +302,7 @@ def _save_input_params(self, idata) -> None: idata.attrs["adstock_first"] = json.dumps(self.adstock_first) idata.attrs["control_columns"] = json.dumps(self.control_columns) idata.attrs["channel_columns"] = json.dumps(self.channel_columns) - idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag) + idata.attrs["adstock_max_lag"] = json.dumps(self.adstock.l_max) idata.attrs["validate_data"] = json.dumps(self.validate_data) idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept) @@ -358,7 +373,11 @@ def build_model( .. code-block:: python - from pymc_marketing.mmm import MMM + from pymc_marketing.mmm import ( + GeometricAdstock, + LogisticSaturation + MMM, + ) from pymc_marketing.prior import Prior custom_config = { @@ -374,12 +393,13 @@ def build_model( model = MMM( date_column="date_week", channel_columns=["x1", "x2"], + adstock=GeometricAdstock(l_max=8), + saturation=LogisticSaturation(), control_columns=[ "event_1", "event_2", "t", ], - adstock_max_lag=8, yearly_seasonality=2, model_config=custom_config, ) @@ -634,7 +654,7 @@ def load(cls, fname: str): """ filepath = Path(fname) - idata = az.from_netcdf(filepath) + idata = from_netcdf(filepath) model_config = cls._model_config_formatting( json.loads(idata.attrs["model_config"]) ) @@ -844,22 +864,25 @@ class MMM( import numpy as np import pandas as pd - from pymc_marketing.mmm import MMM + from pymc_marketing.mmm import ( + GeometricAdstock, + LogisticSaturation + MMM, + ) data_url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv" data = pd.read_csv(data_url, parse_dates=["date_week"]) mmm = MMM( date_column="date_week", - adstock="geometric", - saturation="logistic", channel_columns=["x1", "x2"], + adstock=GeometricAdstock(l_max=8), + saturation=LogisticSaturation(), control_columns=[ "event_1", "event_2", "t", ], - adstock_max_lag=8, yearly_seasonality=2, ) @@ -880,8 +903,12 @@ class MMM( import numpy as np + from pymc_marketing.mmm import ( + GeometricAdstock, + LogisticSaturation + MMM, + ) from pymc_marketing.prior import Prior - from pymc_marketing.mmm import MMM my_model_config = { "beta_channel": Prior("LogNormal", mu=np.array([2, 1]), sigma=1), @@ -889,18 +916,17 @@ class MMM( } mmm = MMM( - adstock="geometric", - saturation="logistic", - model_config=my_model_config, date_column="date_week", channel_columns=["x1", "x2"], + adstock=GeometricAdstock(l_max=8), + saturation=LogisticSaturation(), control_columns=[ "event_1", "event_2", "t", ], - adstock_max_lag=8, yearly_seasonality=2, + model_config=my_model_config, ) As you can see, we can configure all prior and likelihood distributions via the `model_config`. @@ -1765,18 +1791,21 @@ def add_lift_test_measurements( import pandas as pd import numpy as np - from pymc_marketing.mmm import MMM + from pymc_marketing.mmm import ( + GeometricAdstock, + LogisticSaturation, + MMM, + ) model = MMM( - adstock="geometric", - saturation="logistic", date_column="date_week", channel_columns=["x1", "x2"], + adstock=GeometricAdstock(l_max=8), + saturation=LogisticSaturation(), control_columns=[ "event_1", "event_2", ], - adstock_max_lag=8, yearly_seasonality=2, ) @@ -2206,7 +2235,6 @@ def __init__( control_columns: list[str] | None = None, yearly_seasonality: int | None = None, adstock_first: bool = True, - **kwargs, ) -> None: """ Wrapper function for DelayedSaturatedMMM class initializer. @@ -2217,13 +2245,15 @@ def __init__( warnings.warn( "The DelayedSaturatedMMM class is deprecated. Please use the MMM class instead.", DeprecationWarning, - stacklevel=2, + stacklevel=1, ) + adstock = GeometricAdstock(l_max=adstock_max_lag) + saturation = LogisticSaturation() + super().__init__( date_column=date_column, channel_columns=channel_columns, - adstock_max_lag=adstock_max_lag, time_varying_intercept=time_varying_intercept, time_varying_media=time_varying_media, model_config=model_config, @@ -2231,8 +2261,7 @@ def __init__( validate_data=validate_data, control_columns=control_columns, yearly_seasonality=yearly_seasonality, - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, adstock_first=adstock_first, - **kwargs, ) diff --git a/pymc_marketing/mmm/transformers.py b/pymc_marketing/mmm/transformers.py index 76524d5e..5c036445 100644 --- a/pymc_marketing/mmm/transformers.py +++ b/pymc_marketing/mmm/transformers.py @@ -437,10 +437,10 @@ def weibull_adstock( def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): - """Logistic saturation transformation. + r"""Logistic saturation transformation. .. math:: - f(x) = \\frac{1 - e^{-\lambda x}}{1 + e^{-\lambda x}} + f(x) = \frac{1 - e^{-\lambda x}}{1 + e^{-\lambda x}} .. plot:: :context: close-figs @@ -474,7 +474,7 @@ def logistic_saturation(x, lam: npt.NDArray[np.float64] | float = 0.5): ------- tensor Transformed tensor. - """ # noqa: W605 + """ return (1 - pt.exp(-lam * x)) / (1 + pt.exp(-lam * x)) diff --git a/pymc_marketing/model_builder.py b/pymc_marketing/model_builder.py index 02128549..be38010c 100644 --- a/pymc_marketing/model_builder.py +++ b/pymc_marketing/model_builder.py @@ -29,6 +29,7 @@ from pymc_marketing.hsgp_kwargs import HSGPKwargs from pymc_marketing.prior import Prior +from pymc_marketing.utils import from_netcdf # If scikit-learn is available, use its data validator try: @@ -398,7 +399,8 @@ def load(cls, fname: str): >>> imported_model = MyModel.load(name) """ filepath = Path(str(fname)) - idata = az.from_netcdf(filepath) + idata = from_netcdf(filepath) + # needs to be converted, because json.loads was changing tuple to list model_config = cls._model_config_formatting( json.loads(idata.attrs["model_config"]) diff --git a/pymc_marketing/utils.py b/pymc_marketing/utils.py new file mode 100644 index 00000000..1efd5eff --- /dev/null +++ b/pymc_marketing/utils.py @@ -0,0 +1,27 @@ +# Copyright 2024 The PyMC Labs Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from pathlib import Path + +import arviz as az + + +def from_netcdf(filepath: str | Path) -> az.InferenceData: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message=r"fit_data group is not defined in the InferenceData scheme", + ) + return az.from_netcdf(filepath) diff --git a/tests/conftest.py b/tests/conftest.py index 0a6b87ee..696f6235 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import pandas as pd import pytest from arviz import InferenceData @@ -78,5 +80,12 @@ def set_model_fit(model: CLVModel, fit: InferenceData | Dataset): if not hasattr(model, "model"): model.build_model() model.idata = fit - model.idata.add_groups(fit_data=model.data.to_xarray()) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + model.idata.add_groups(fit_data=model.data.to_xarray()) model.set_idata_attrs(fit) diff --git a/tests/mmm/components/test_adstock.py b/tests/mmm/components/test_adstock.py index 096c40f4..d68b08fe 100644 --- a/tests/mmm/components/test_adstock.py +++ b/tests/mmm/components/test_adstock.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import numpy as np import pymc as pm import pytensor.tensor as pt @@ -30,14 +32,16 @@ def adstocks() -> list[AdstockTransformation]: - return [ - DelayedAdstock(l_max=10), - GeometricAdstock(l_max=10), - WeibullAdstock(l_max=10, kind="PDF"), - WeibullAdstock(l_max=10, kind="CDF"), - WeibullPDFAdstock(l_max=10), - WeibullCDFAdstock(l_max=10), - ] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return [ + DelayedAdstock(l_max=10), + GeometricAdstock(l_max=10), + WeibullAdstock(l_max=10, kind="PDF"), + WeibullAdstock(l_max=10, kind="CDF"), + WeibullPDFAdstock(l_max=10), + WeibullCDFAdstock(l_max=10), + ] @pytest.fixture diff --git a/tests/mmm/test_budget_optimizer.py b/tests/mmm/test_budget_optimizer.py index 244c1b63..4e2f4a47 100644 --- a/tests/mmm/test_budget_optimizer.py +++ b/tests/mmm/test_budget_optimizer.py @@ -17,8 +17,8 @@ import pytest from pymc_marketing.mmm.budget_optimizer import BudgetOptimizer, MinimizeException -from pymc_marketing.mmm.components.adstock import _get_adstock_function -from pymc_marketing.mmm.components.saturation import _get_saturation_function +from pymc_marketing.mmm.components.adstock import GeometricAdstock +from pymc_marketing.mmm.components.saturation import LogisticSaturation @pytest.mark.parametrize( @@ -74,8 +74,8 @@ def test_allocate_budget( expected_response, ): # Initialize Adstock and Saturation Transformations - adstock = _get_adstock_function(function="geometric", l_max=4) - saturation = _get_saturation_function(function="logistic") + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() # Create BudgetOptimizer Instance optimizer = BudgetOptimizer( @@ -87,11 +87,13 @@ def test_allocate_budget( ) # Allocate Budget - optimal_budgets, total_response = optimizer.allocate_budget( - total_budget=total_budget, - budget_bounds=budget_bounds, - minimize_kwargs=minimize_kwargs, - ) + match = "Using default equality constraint" + with pytest.warns(UserWarning, match=match): + optimal_budgets, total_response = optimizer.allocate_budget( + total_budget=total_budget, + budget_bounds=budget_bounds, + minimize_kwargs=minimize_kwargs, + ) # Assert Results assert optimal_budgets == expected_optimal @@ -122,8 +124,9 @@ def test_allocate_budget( def test_allocate_budget_zero_total( total_budget, budget_bounds, parameters, expected_optimal, expected_response ): - adstock = _get_adstock_function(function="geometric", l_max=4) - saturation = _get_saturation_function(function="logistic") + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() + optimizer = BudgetOptimizer( adstock=adstock, saturation=saturation, @@ -131,9 +134,11 @@ def test_allocate_budget_zero_total( parameters=parameters, adstock_first=True, ) - optimal_budgets, total_response = optimizer.allocate_budget( - total_budget, budget_bounds - ) + match = "Using default equality constraint" + with pytest.warns(UserWarning, match=match): + optimal_budgets, total_response = optimizer.allocate_budget( + total_budget, budget_bounds + ) assert optimal_budgets == pytest.approx(expected_optimal, rel=1e-2) assert total_response == pytest.approx(expected_response, abs=1e-1) @@ -157,8 +162,9 @@ def test_allocate_budget_custom_minimize_args(minimize_mock) -> None: "options": {"ftol": 1e-8, "maxiter": 1_002}, } - adstock = _get_adstock_function(function="geometric", l_max=4) - saturation = _get_saturation_function(function="logistic") + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() + optimizer = optimizer = BudgetOptimizer( adstock=adstock, saturation=saturation, @@ -166,9 +172,11 @@ def test_allocate_budget_custom_minimize_args(minimize_mock) -> None: parameters=parameters, adstock_first=True, ) - optimizer.allocate_budget( - total_budget, budget_bounds, minimize_kwargs=minimize_kwargs - ) + match = "Using default equality constraint" + with pytest.warns(UserWarning, match=match): + optimizer.allocate_budget( + total_budget, budget_bounds, minimize_kwargs=minimize_kwargs + ) kwargs = minimize_mock.call_args_list[0].kwargs @@ -212,8 +220,9 @@ def test_allocate_budget_custom_minimize_args(minimize_mock) -> None: def test_allocate_budget_infeasible_constraints( total_budget, budget_bounds, parameters, custom_constraints ): - adstock = _get_adstock_function(function="geometric", l_max=4) - saturation = _get_saturation_function(function="logistic") + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() + optimizer = optimizer = BudgetOptimizer( adstock=adstock, saturation=saturation, diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index d060c617..25823976 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings import arviz as az import numpy as np @@ -21,8 +22,11 @@ import xarray as xr from matplotlib import pyplot as plt -from pymc_marketing.mmm.components.adstock import DelayedAdstock -from pymc_marketing.mmm.components.saturation import MichaelisMentenSaturation +from pymc_marketing.mmm.components.adstock import DelayedAdstock, GeometricAdstock +from pymc_marketing.mmm.components.saturation import ( + LogisticSaturation, + MichaelisMentenSaturation, +) from pymc_marketing.mmm.delayed_saturated_mmm import MMM, BaseMMM, DelayedSaturatedMMM from pymc_marketing.prior import Prior @@ -39,14 +43,20 @@ def mock_fit(model, X: pd.DataFrame, y: np.ndarray, **kwargs): model.preprocess("X", X) model.preprocess("y", y) - idata.add_groups( - { - "posterior": idata.prior, - "fit_data": pd.concat( - [X, pd.Series(y, index=X.index, name="y")], axis=1 - ).to_xarray(), - } - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + idata.add_groups( + { + "posterior": idata.prior, + "fit_data": pd.concat( + [X, pd.Series(y, index=X.index, name="y")], axis=1 + ).to_xarray(), + } + ) model.idata = idata model.set_idata_attrs(idata=idata) @@ -127,10 +137,9 @@ def mmm() -> MMM: return MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) @@ -139,10 +148,9 @@ def mmm_with_fourier_features() -> MMM: return MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), yearly_seasonality=2, ) @@ -201,13 +209,15 @@ def deep_equal(dict1, dict2): return False return True + l_max = 4 + adstock = GeometricAdstock(l_max=l_max) + saturation = LogisticSaturation() model = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, model_config=model_config_requiring_serialization, - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, ) model = mock_fit(model, toy_X, toy_y) model.save("test_save_load") @@ -215,7 +225,7 @@ def deep_equal(dict1, dict2): assert model.date_column == model2.date_column assert model.control_columns == model2.control_columns assert model.channel_columns == model2.channel_columns - assert model.adstock_max_lag == model2.adstock_max_lag + assert model.adstock.l_max == model2.adstock.l_max assert model.validate_data == model2.validate_data assert model.yearly_seasonality == model2.yearly_seasonality assert deep_equal(model.model_config, model2.model_config) @@ -231,10 +241,9 @@ def test_bad_date_column(self, toy_X_with_bad_dates) -> None: my_mmm = MMM( date_column="bad_date_column", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", ) y = np.ones(toy_X_with_bad_dates.shape[0]) my_mmm.build_model(X=toy_X_with_bad_dates, y=y) @@ -290,12 +299,11 @@ def test_init( date_column="date", channel_columns=channel_columns, control_columns=control_columns, - adstock_max_lag=adstock_max_lag, yearly_seasonality=yearly_seasonality, time_varying_intercept=time_varying_intercept, time_varying_media=time_varying_media, - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=adstock_max_lag), + saturation=LogisticSaturation(), ) mmm.build_model(X=toy_X, y=toy_y) n_channel: int = len(mmm.channel_columns) @@ -368,10 +376,9 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None: date_column="date", channel_columns=["channel_1", "channel_2"], control_columns=["control_1", "control_2"], - adstock_max_lag=2, yearly_seasonality=2, - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=2), + saturation=LogisticSaturation(), ) assert mmm.version == "0.0.3" assert mmm._model_type == "BaseValidateMMM" @@ -483,10 +490,9 @@ def test_get_errors_raises_not_fitted(self) -> None: my_mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) with pytest.raises( RuntimeError, @@ -498,10 +504,9 @@ def test_posterior_predictive_raises_not_fitted(self) -> None: my_mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) with pytest.raises( RuntimeError, @@ -601,9 +606,8 @@ def test_data_setter(self, toy_X, toy_y): base_delayed_saturated_mmm = BaseMMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) base_delayed_saturated_mmm = mock_fit(base_delayed_saturated_mmm, toy_X, toy_y) @@ -643,7 +647,7 @@ def test_save_load(self, mmm_fitted: MMM): assert model.date_column == model2.date_column assert model.control_columns == model2.control_columns assert model.channel_columns == model2.channel_columns - assert model.adstock_max_lag == model2.adstock_max_lag + assert model.adstock.l_max == model2.adstock.l_max assert model.validate_data == model2.validate_data assert model.yearly_seasonality == model2.yearly_seasonality assert model.model_config == model2.model_config @@ -659,9 +663,8 @@ def mock_property(self): DSMMM = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) # Check that the property returns the new value @@ -712,11 +715,10 @@ def test_model_config( model = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=2, yearly_seasonality=2, model_config=model_config, - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=2), + saturation=LogisticSaturation(), ) model.build_model(X=toy_X, y=toy_y.to_numpy()) @@ -857,7 +859,7 @@ def new_contributions_property_checks(new_contributions, X, model): assert coords["channel"].values.tolist() == model.channel_columns np.testing.assert_allclose( coords["time_since_spend"].values, - np.arange(-model.adstock_max_lag, model.adstock_max_lag + 1), + np.arange(-model.adstock.l_max, model.adstock.l_max + 1), ) # Channel contributions are non-negative @@ -875,10 +877,9 @@ def test_new_spend_contributions_prior_error() -> None: mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=GeometricAdstock(l_max=4), + saturation=LogisticSaturation(), ) new_spend = np.ones(len(mmm.channel_columns)) match = "sample_prior_predictive" @@ -983,13 +984,14 @@ def test_add_lift_test_measurements(mmm, toy_X, toy_y, df_lift_test) -> None: def test_add_lift_test_measurements_no_model() -> None: + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, ) with pytest.raises(RuntimeError, match="The model has not been built yet."): mmm.add_lift_test_measurements( @@ -1010,14 +1012,15 @@ def test_delayed_saturated_mmm_raises_deprecation_warning() -> None: def test_initialize_alternative_with_strings() -> None: - mmm = MMM( - date_column="date", - channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, - control_columns=["control_1", "control_2"], - adstock="delayed", - saturation="michaelis_menten", - ) + with pytest.warns(DeprecationWarning): + mmm = MMM( + date_column="date", + channel_columns=["channel_1", "channel_2"], + adstock_max_lag=4, + control_columns=["control_1", "control_2"], + adstock="delayed", + saturation="michaelis_menten", + ) assert isinstance(mmm.adstock, DelayedAdstock) assert mmm.adstock.l_max == 4 @@ -1028,7 +1031,6 @@ def test_initialize_alternative_with_classes() -> None: mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], adstock=DelayedAdstock(l_max=10), saturation=MichaelisMentenSaturation(), @@ -1043,7 +1045,6 @@ def test_initialize_defaults_channel_media_dims() -> None: mmm = MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], adstock=DelayedAdstock(l_max=10), saturation=MichaelisMentenSaturation(), @@ -1065,12 +1066,13 @@ def test_initialize_defaults_channel_media_dims() -> None: def test_save_load_with_tvp( time_varying_intercept, time_varying_media, toy_X, toy_y ) -> None: + adstock = GeometricAdstock(l_max=5) + saturation = LogisticSaturation() mmm = MMM( channel_columns=["channel_1", "channel_2"], date_column="date", - adstock="geometric", - saturation="logistic", - adstock_max_lag=5, + adstock=adstock, + saturation=saturation, time_varying_intercept=time_varying_intercept, time_varying_media=time_varying_media, ) diff --git a/tests/mmm/test_plotting.py b/tests/mmm/test_plotting.py index 11fc4f13..a8b189fa 100644 --- a/tests/mmm/test_plotting.py +++ b/tests/mmm/test_plotting.py @@ -11,12 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import numpy as np import pandas as pd import pymc as pm import pytest from matplotlib import pyplot as plt +from pymc_marketing.mmm.components.adstock import GeometricAdstock +from pymc_marketing.mmm.components.saturation import LogisticSaturation from pymc_marketing.mmm.delayed_saturated_mmm import MMM, BaseMMM from pymc_marketing.mmm.preprocessing import MaxAbsScaleTarget @@ -55,14 +59,20 @@ def mock_fit_base(model, X: pd.DataFrame, y: np.ndarray, **kwargs): with model.model: idata = pm.sample_prior_predictive(random_seed=rng, **kwargs) - idata.add_groups( - { - "posterior": idata.prior, - "fit_data": pd.concat( - [X, pd.Series(y, index=X.index, name="y")], axis=1 - ).to_xarray(), - } - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + idata.add_groups( + { + "posterior": idata.prior, + "fit_data": pd.concat( + [X, pd.Series(y, index=X.index, name="y")], axis=1 + ).to_xarray(), + } + ) model.idata = idata model.set_idata_attrs(idata=idata) @@ -91,22 +101,23 @@ class ToyMMM(BaseMMM): class ToyMMM(BaseMMM, MaxAbsScaleTarget): pass + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() + if control == "without_controls": mmm = ToyMMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, ) elif control == "with_controls": mmm = ToyMMM( date_column="date", - adstock_max_lag=4, control_columns=["control_1", "control_2"], channel_columns=["channel_1", "channel_2"], - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, ) for transform in [mmm.adstock, mmm.saturation]: @@ -155,13 +166,14 @@ def test_plots(self, plotting_mmm, func_plot_name, kwargs_plot) -> None: @pytest.fixture(scope="module") def mock_mmm() -> MMM: + adstock = GeometricAdstock(l_max=4) + saturation = LogisticSaturation() return MMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, control_columns=["control_1", "control_2"], - adstock="geometric", - saturation="logistic", + adstock=adstock, + saturation=saturation, ) @@ -174,14 +186,20 @@ def mock_fit(model: MMM, X: pd.DataFrame, y: np.ndarray, **kwargs): model.preprocess("X", X) model.preprocess("y", y) - idata.add_groups( - { - "posterior": idata.prior, - "fit_data": pd.concat( - [X, pd.Series(y, index=X.index, name="y")], axis=1 - ).to_xarray(), - } - ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + category=UserWarning, + message="The group fit_data is not defined in the InferenceData scheme", + ) + idata.add_groups( + { + "posterior": idata.prior, + "fit_data": pd.concat( + [X, pd.Series(y, index=X.index, name="y")], axis=1 + ).to_xarray(), + } + ) model.idata = idata model.set_idata_attrs(idata=idata) diff --git a/tests/mmm/test_validating.py b/tests/mmm/test_validating.py index b4f9e466..b3f5a44a 100644 --- a/tests/mmm/test_validating.py +++ b/tests/mmm/test_validating.py @@ -104,19 +104,19 @@ def test_channel_columns(): obj.validate_channel_columns(toy_X) with pytest.raises( ValueError, - match="channel_columns \['out_of_columns'\] not in data", # noqa: W605 + match=r"channel_columns \['out_of_columns'\] not in data", ): obj.channel_columns = ["out_of_columns"] obj.validate_channel_columns(toy_X) with pytest.raises( ValueError, - match="channel_columns \['channel_1', 'channel_1'\] contains duplicates", # noqa: W605 + match=r"channel_columns \['channel_1', 'channel_1'\] contains duplicates", ): obj.channel_columns = ["channel_1", "channel_1"] obj.validate_channel_columns(toy_X) with pytest.raises( ValueError, - match="channel_columns \['channel_1'\] contains negative values", # noqa: W605 + match=r"channel_columns \['channel_1'\] contains negative values", ): new_toy_X = toy_X.copy() new_toy_X["channel_1"] -= 1e4 @@ -143,13 +143,13 @@ def test_control_columns(): obj.validate_control_columns(toy_X) with pytest.raises( ValueError, - match="control_columns \['out_of_columns'\] not in data", # noqa: W605 + match=r"control_columns \['out_of_columns'\] not in data", ): obj.control_columns = ["out_of_columns"] obj.validate_control_columns(toy_X) with pytest.raises( ValueError, - match="control_columns \['control_1', 'control_1'\] contains duplicates", # noqa: W605 + match=r"control_columns \['control_1', 'control_1'\] contains duplicates", ): obj.control_columns = ["control_1", "control_1"] obj.validate_control_columns(toy_X) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 130ebf45..0403263b 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -227,7 +227,7 @@ def test_fit(fitted_model_instance): rng = np.random.default_rng(42) assert fitted_model_instance.idata is not None assert "posterior" in fitted_model_instance.idata.groups() - assert fitted_model_instance.idata.posterior.dims["draw"] == 100 + assert fitted_model_instance.idata.posterior.sizes["draw"] == 100 prediction_data = pd.DataFrame({"input": rng.uniform(low=0, high=1, size=100)}) fitted_model_instance.predict(prediction_data) @@ -271,8 +271,8 @@ def test_sample_posterior_predictive(fitted_model_instance, combined): pred = fitted_model_instance.sample_posterior_predictive( prediction_data, combined=combined, extend_idata=True ) - chains = fitted_model_instance.idata.sample_stats.dims["chain"] - draws = fitted_model_instance.idata.sample_stats.dims["draw"] + chains = fitted_model_instance.idata.sample_stats.sizes["chain"] + draws = fitted_model_instance.idata.sample_stats.sizes["draw"] expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred) assert pred[fitted_model_instance.output_var].shape == expected_shape assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)