diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 148762d1..eaa9d382 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -250,6 +250,8 @@ def _save_input_params(self, idata) -> None: idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag) idata.attrs["validate_data"] = json.dumps(self.validate_data) idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) + idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept) + idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media) def forward_pass( self, x: pt.TensorVariable | npt.NDArray[np.float64] @@ -602,19 +604,32 @@ def load(cls, fname: str): model_config = cls._model_config_formatting( json.loads(idata.attrs["model_config"]) ) - model = cls( - date_column=json.loads(idata.attrs["date_column"]), - control_columns=json.loads(idata.attrs["control_columns"]), - channel_columns=json.loads(idata.attrs["channel_columns"]), - adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), - adstock=json.loads(idata.attrs.get("adstock", "geometric")), - saturation=json.loads(idata.attrs.get("saturation", "logistic")), - adstock_first=json.loads(idata.attrs.get("adstock_first", True)), - validate_data=json.loads(idata.attrs["validate_data"]), - yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), - model_config=model_config, - sampler_config=json.loads(idata.attrs["sampler_config"]), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + model = cls( + date_column=json.loads(idata.attrs["date_column"]), + control_columns=json.loads(idata.attrs["control_columns"]), + # Media Transformations + channel_columns=json.loads(idata.attrs["channel_columns"]), + adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), + adstock=json.loads(idata.attrs.get("adstock", "geometric")), + saturation=json.loads(idata.attrs.get("saturation", "logistic")), + adstock_first=json.loads(idata.attrs.get("adstock_first", True)), + # Seasonality + yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), + # TVP + time_varying_intercept=json.loads( + idata.attrs.get("time_varying_intercept", False) + ), + time_varying_media=json.loads( + idata.attrs.get("time_varying_media", False) + ), + # Configurations + validate_data=json.loads(idata.attrs["validate_data"]), + model_config=model_config, + sampler_config=json.loads(idata.attrs["sampler_config"]), + ) + model.idata = idata dataset = idata.fit_data.to_dataframe() X = dataset.drop(columns=[model.output_var]) @@ -622,8 +637,11 @@ def load(cls, fname: str): model.build_model(X, y) # All previously used data is in idata. if model.id != idata.attrs["id"]: - error_msg = f"""The file '{fname}' does not contain an inference data of the same model - or configuration as '{cls._model_type}'""" + error_msg = ( + f"The file '{fname}' does not contain " + "an inference data of the same model or " + f"configuration as '{cls._model_type}'" + ) raise ValueError(error_msg) return model diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 010a609b..a6b7cc0c 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -637,9 +637,11 @@ def mock_property(self): # Apply the monkeypatch for the property monkeypatch.setattr(MMM, "id", property(mock_property)) - error_msg = """The file 'test_model' does not contain an inference data of the same model - or configuration as 'MMM'""" - + error_msg = ( + "The file 'test_model' does not " + "contain an inference data of the " + "same model or configuration as 'MMM'" + ) with pytest.raises(ValueError, match=error_msg): MMM.load("test_model") os.remove("test_model") @@ -1017,3 +1019,38 @@ def test_initialize_defaults_channel_media_dims() -> None: for transform in [mmm.adstock, mmm.saturation]: for config in transform.function_priors.values(): assert config.dims == ("channel",) + + +@pytest.mark.parametrize( + "time_varying_intercept, time_varying_media", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_save_load_with_tvp( + time_varying_intercept, time_varying_media, toy_X, toy_y +) -> None: + mmm = MMM( + channel_columns=["channel_1", "channel_2"], + date_column="date", + adstock="geometric", + saturation="logistic", + adstock_max_lag=5, + time_varying_intercept=time_varying_intercept, + time_varying_media=time_varying_media, + ) + mmm = mock_fit(mmm, toy_X, toy_y) + + file = "tmp-model" + mmm.save(file) + loaded_mmm = MMM.load(file) + + assert mmm.time_varying_intercept == loaded_mmm.time_varying_intercept + assert mmm.time_varying_intercept == time_varying_intercept + assert mmm.time_varying_media == loaded_mmm.time_varying_media + assert mmm.time_varying_media == time_varying_media + + # clean up + os.remove(file)