From 561104b99b5deb68c4826ca9e9e0c7bcd31ab26a Mon Sep 17 00:00:00 2001 From: Michal Raczycki <119355382+michaelraczycki@users.noreply.github.com> Date: Thu, 13 Jul 2023 09:53:57 +0200 Subject: [PATCH] MMM load updates (#317) * fixed dims format, added save_load test to delayed_saturated_mmm * removing temp file after test_save_load in DelayedSaturatedMMM tests * unifying save tests, adding model.id checks into load * pymc-experimental version bump * removing property @posterior_predictive from basic.py * extending tests to cover id preservation with load * _model_type update --- pymc_marketing/clv/models/basic.py | 11 ++-- pymc_marketing/mmm/base.py | 2 + pymc_marketing/mmm/delayed_saturated_mmm.py | 65 +++++++++++++++++++++ pyproject.toml | 2 +- tests/clv/models/test_basic.py | 31 ++++++++-- tests/clv/models/test_beta_geo.py | 32 ++++------ tests/clv/models/test_gamma_gamma.py | 51 ++++++---------- tests/clv/models/test_shifted_beta_geo.py | 30 ++++------ tests/mmm/test_delayed_saturated_mmm.py | 44 ++++++++++++++ 9 files changed, 184 insertions(+), 84 deletions(-) diff --git a/pymc_marketing/clv/models/basic.py b/pymc_marketing/clv/models/basic.py index 4ee2e641..5b35d186 100644 --- a/pymc_marketing/clv/models/basic.py +++ b/pymc_marketing/clv/models/basic.py @@ -180,6 +180,11 @@ def load(cls, fname: str): model.idata = idata model.build_model() + + if model.id != idata.attrs["id"]: + raise ValueError( + f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" + ) # All previously used data is in idata. return model @@ -269,12 +274,6 @@ def fit_result(self, res: az.InferenceData) -> None: else: self.idata.posterior = res - @property - def posterior_predictive(self) -> Dataset: - if self.idata is None or "posterior_predictive" not in self.idata: - raise RuntimeError("The model hasn't been fit yet, call .fit() first") - return self.idata["posterior_predictive"] - def fit_summary(self, **kwargs): res = self.fit_result # Map fitting only gives one value, so we return it. We use arviz diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index cf224f14..d0583f8c 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -30,6 +30,8 @@ class BaseMMM(ModelBuilder): model: pm.Model + _model_type = "BaseMMM" + version = "0.0.2" def __init__( self, diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 1edac8fb..40dfcec4 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -1,3 +1,5 @@ +import json +from pathlib import Path from typing import Any, Dict, List, Optional, Union import arviz as az @@ -19,6 +21,9 @@ class BaseDelayedSaturatedMMM(MMM): + _model_type = "DelayedSaturatedMMM" + version = "0.0.2" + def __init__( self, date_column: str, @@ -127,6 +132,15 @@ def generate_and_preprocess_model_data( self.X: pd.DataFrame = X_data self.y: pd.Series = y + def _save_input_params(self, idata) -> None: + """Saves input parameters to the attrs of idata.""" + idata.attrs["date_column"] = json.dumps(self.date_column) + idata.attrs["control_columns"] = json.dumps(self.control_columns) + idata.attrs["channel_columns"] = json.dumps(self.channel_columns) + idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag) + idata.attrs["validate_data"] = json.dumps(self.validate_data) + idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) + def build_model( self, X: pd.DataFrame, @@ -355,6 +369,57 @@ def _serializable_model_config(self) -> Dict[str, Any]: ]["sigma"].tolist() return serializable_config + @classmethod + def load(cls, fname: str): + """ + Creates a DelayedSaturatedMMM instance from a file, + instantiating the model with the saved original input parameters. + Loads inference data for the model. + + Parameters + ---------- + fname : string + This denotes the name with path from where idata should be loaded from. + + Returns + ------- + Returns an instance of DelayedSaturatedMMM. + + Raises + ------ + ValueError + If the inference data that is loaded doesn't match with the model. + """ + + filepath = Path(str(fname)) + idata = az.from_netcdf(filepath) + # needs to be converted, because json.loads was changing tuple to list + model_config = cls._convert_dims_to_tuple( + json.loads(idata.attrs["model_config"]) + ) + model = cls( + date_column=json.loads(idata.attrs["date_column"]), + control_columns=json.loads(idata.attrs["control_columns"]), + channel_columns=json.loads(idata.attrs["channel_columns"]), + adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), + validate_data=json.loads(idata.attrs["validate_data"]), + yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), + model_config=model_config, + sampler_config=json.loads(idata.attrs["sampler_config"]), + ) + model.idata = idata + dataset = idata.fit_data.to_dataframe() + X = dataset.drop(columns=[model.output_var]) + y = dataset[model.output_var].values + model.build_model(X, y) + # All previously used data is in idata. + if model.id != idata.attrs["id"]: + raise ValueError( + f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" + ) + + return model + def _data_setter( self, X: Union[np.ndarray, pd.DataFrame], diff --git a/pyproject.toml b/pyproject.toml index 87ce2c97..6fb26407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "seaborn>=0.12.2", "xarray", "xarray-einstats>=0.5.1", - "pymc-experimental>=0.0.7", + "pymc-experimental>=0.0.8", ] [project.optional-dependencies] diff --git a/tests/clv/models/test_basic.py b/tests/clv/models/test_basic.py index 01f4eb3b..1fb901a9 100644 --- a/tests/clv/models/test_basic.py +++ b/tests/clv/models/test_basic.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pandas as pd import pymc as pm @@ -140,10 +142,11 @@ def test_load(self): model = CLVModelTest() model.build_model() model.fit(target_accept=0.81, draws=100, chains=2, random_seed=1234) - model.save("test_model.pkl") - model.load("test_model.pkl") - assert model.fit_result is not None - assert model.model is not None + model.save("test_model") + model2 = model.load("test_model") + assert model2.fit_result is not None + assert model2.model is not None + os.remove("test_model") def test_default_sampler_config(self): model = CLVModelTest() @@ -205,3 +208,23 @@ def test_serializable_model_config(self): serializable_config = model._serializable_model_config assert isinstance(serializable_config, dict) assert serializable_config == model.model_config + + def test_fail_id_after_load(self, monkeypatch): + # This is the new behavior for the property + def mock_property(self): + return "for sure not correct id" + + # Now create an instance of MyClass + mock_basic = CLVModelTest() + + # Check that the property returns the new value + mock_basic.fit() + mock_basic.save("test_model") + # Apply the monkeypatch for the property + monkeypatch.setattr(CLVModelTest, "id", property(mock_property)) + with pytest.raises( + ValueError, + match="The file 'test_model' does not contain an inference data of the same model or configuration as 'CLVModelTest'", + ): + CLVModelTest.load("test_model") + os.remove("test_model") diff --git a/tests/clv/models/test_beta_geo.py b/tests/clv/models/test_beta_geo.py index 60e14a97..8a7d6c1d 100644 --- a/tests/clv/models/test_beta_geo.py +++ b/tests/clv/models/test_beta_geo.py @@ -1,6 +1,4 @@ -import json -import tempfile -from pathlib import Path +import os import arviz as az import numpy as np @@ -494,32 +492,26 @@ def test_distribution_new_customer(self, data) -> None: ) def test_save_load_beta_geo(self, data): - temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model = BetaGeoModel( data=data, ) model.build_model() model.fit("map") - model.save(temp) + model.save("test_model") # Testing the valid case. - model2 = BetaGeoModel.load(temp) + model2 = BetaGeoModel.load("test_model") # Check if the loaded model is indeed an instance of the class assert isinstance(model, BetaGeoModel) - - # Load data from the file to cross verify - filepath = Path(str(temp)) - idata = az.from_netcdf(filepath) - dataset = idata.fit_data.to_dataframe() # Check if the loaded data matches with the model data np.testing.assert_array_equal( - model2.customer_id.values, dataset.customer_id.values - ) - np.testing.assert_array_equal(model2.frequency.values, dataset.frequency.values) - np.testing.assert_array_equal(model2.T.values, dataset["T"]) - np.testing.assert_array_equal(model2.recency.values, dataset.recency.values) - assert model.model_config == json.loads(idata.attrs["model_config"]) - assert model.sampler_config == json.loads(idata.attrs["sampler_config"]) - assert model.idata == idata + model2.customer_id.values, model.customer_id.values + ) + np.testing.assert_array_equal(model2.frequency.values, model.frequency.values) + np.testing.assert_array_equal(model2.T.values, model.T.values) + np.testing.assert_array_equal(model2.recency.values, model.recency.values) + assert model.model_config == model2.model_config + assert model.sampler_config == model2.sampler_config + assert model.idata == model2.idata + os.remove("test_model") diff --git a/tests/clv/models/test_gamma_gamma.py b/tests/clv/models/test_gamma_gamma.py index 5fa0e49d..99295dc6 100644 --- a/tests/clv/models/test_gamma_gamma.py +++ b/tests/clv/models/test_gamma_gamma.py @@ -1,9 +1,6 @@ -import json -import tempfile -from pathlib import Path +import os from unittest.mock import patch -import arviz as az import numpy as np import pandas as pd import pymc as pm @@ -294,35 +291,29 @@ def test_model_repr(self, data, default_model_config): ) def test_save_load_beta_geo(self, data): - temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model = GammaGammaModel( data=data, ) model.build_model() model.fit("map") - model.save(temp) + model.save("test_model") # Testing the valid case. - model2 = GammaGammaModel.load(temp) + model2 = GammaGammaModel.load("test_model") # Check if the loaded model is indeed an instance of the class assert isinstance(model, GammaGammaModel) - - # Load data from the file to cross verify - filepath = Path(str(temp)) - idata = az.from_netcdf(filepath) - dataset = idata.fit_data.to_dataframe() # Check if the loaded data matches with the model data - assert np.array_equal(model2.customer_id.values, dataset.customer_id.values) - assert np.array_equal(model2.frequency, dataset.frequency) + assert np.array_equal(model2.customer_id.values, model.customer_id.values) + assert np.array_equal(model2.frequency, model.frequency) assert np.array_equal( - model2.mean_transaction_value, dataset.mean_transaction_value + model2.mean_transaction_value, model.mean_transaction_value ) - assert model.model_config == json.loads(idata.attrs["model_config"]) - assert model.sampler_config == json.loads(idata.attrs["sampler_config"]) - assert model.idata == idata + assert model.model_config == model2.model_config + assert model.sampler_config == model2.sampler_config + assert model.idata == model2.idata + os.remove("test_model") class TestGammaGammaModelIndividual(BaseTestGammaGammaModel): @@ -457,33 +448,27 @@ def test_model_repr(self, individual_data, default_model_config): ) def test_save_load_beta_geo(self, individual_data): - temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model = GammaGammaModelIndividual( data=individual_data, ) model.build_model() model.fit("map") - model.save(temp) + model.save("test_model") # Testing the valid case. - model2 = GammaGammaModelIndividual.load(temp) + model2 = GammaGammaModelIndividual.load("test_model") # Check if the loaded model is indeed an instance of the class assert isinstance(model, GammaGammaModelIndividual) - - # Load data from the file to cross verify - filepath = Path(str(temp)) - idata = az.from_netcdf(filepath) - dataset = idata.fit_data.to_dataframe() # Check if the loaded data matches with the model data np.testing.assert_array_equal( - model2.customer_id.values, dataset.customer_id.values + model2.customer_id.values, model.customer_id.values ) np.testing.assert_array_equal( - model2.individual_transaction_value, dataset.individual_transaction_value + model2.individual_transaction_value, model.individual_transaction_value ) - assert model.model_config == json.loads(idata.attrs["model_config"]) - assert model.sampler_config == json.loads(idata.attrs["sampler_config"]) - assert model.idata == idata + assert model.model_config == model2.model_config + assert model.sampler_config == model2.sampler_config + assert model.idata == model2.idata + os.remove("test_model") diff --git a/tests/clv/models/test_shifted_beta_geo.py b/tests/clv/models/test_shifted_beta_geo.py index abfb3f5c..dfb0bb0c 100644 --- a/tests/clv/models/test_shifted_beta_geo.py +++ b/tests/clv/models/test_shifted_beta_geo.py @@ -1,6 +1,4 @@ -import json -import tempfile -from pathlib import Path +import os import arviz as az import numpy as np @@ -249,31 +247,23 @@ def test_distribution_new_customer(self): ) def test_save_load_beta_geo(self, data): - temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model = ShiftedBetaGeoModelIndividual( data=data, ) model.build_model() model.fit("map") - model.save(temp) + model.save("test_model") # Testing the valid case. - - model2 = ShiftedBetaGeoModelIndividual.load(temp) - + model2 = ShiftedBetaGeoModelIndividual.load("test_model") # Check if the loaded model is indeed an instance of the class assert isinstance(model, ShiftedBetaGeoModelIndividual) - - # Load data from the file to cross verify - filepath = Path(str(temp)) - idata = az.from_netcdf(filepath) - dataset = idata.fit_data.to_dataframe() # Check if the loaded data matches with the model data np.testing.assert_array_equal( - model2.customer_id.values, dataset.customer_id.values + model2.customer_id.values, model.customer_id.values ) - np.testing.assert_array_equal(model2.t_churn, dataset.t_churn) - np.testing.assert_array_equal(model2.T, dataset["T"]) - assert model.model_config == json.loads(idata.attrs["model_config"]) - assert model.sampler_config == json.loads(idata.attrs["sampler_config"]) - assert model.idata == idata + np.testing.assert_array_equal(model2.t_churn, model.t_churn) + np.testing.assert_array_equal(model2.T, model.T) + assert model.model_config == model2.model_config + assert model.sampler_config == model2.sampler_config + assert model.idata == model2.idata + os.remove("test_model") diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index ab8d2fae..cfc36a60 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional import arviz as az @@ -174,6 +175,8 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None: adstock_max_lag=2, yearly_seasonality=2, ) + assert mmm.version == "0.0.2" + assert mmm._model_type == "DelayedSaturatedMMM" assert mmm.model_config is not None n_channel: int = len(mmm.channel_columns) n_control: int = len(mmm.control_columns) @@ -391,3 +394,44 @@ def test_data_setter(self, toy_X, toy_y): ) except Exception as e: pytest.fail(f"_data_setter failed with error {e}") + + def test_save_load(self, mmm_fitted): + model = mmm_fitted + + model.save("test_save_load") + model2 = BaseDelayedSaturatedMMM.load("test_save_load") + assert model.date_column == model2.date_column + assert model.control_columns == model2.control_columns + assert model.channel_columns == model2.channel_columns + assert model.adstock_max_lag == model2.adstock_max_lag + assert model.validate_data == model2.validate_data + assert model.yearly_seasonality == model2.yearly_seasonality + assert model.model_config == model2.model_config + assert model.sampler_config == model2.sampler_config + os.remove("test_save_load") + + def test_fail_id_after_load(self, monkeypatch, toy_X, toy_y): + # This is the new behavior for the property + def mock_property(self): + return "for sure not correct id" + + # Now create an instance of MyClass + DSMMM = DelayedSaturatedMMM( + date_column="date", + channel_columns=["channel_1", "channel_2"], + adstock_max_lag=4, + ) + + # Check that the property returns the new value + DSMMM.fit( + toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng + ) + DSMMM.save("test_model") + # Apply the monkeypatch for the property + monkeypatch.setattr(DelayedSaturatedMMM, "id", property(mock_property)) + with pytest.raises( + ValueError, + match="The file 'test_model' does not contain an inference data of the same model or configuration as 'DelayedSaturatedMMM'", + ): + DelayedSaturatedMMM.load("test_model") + os.remove("test_model")