diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index 6c35f9a5..db5f10ab 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -240,11 +240,11 @@ def build_model(self) -> None: # type: ignore[override] "purchase_covariate": self.purchase_covariate_cols, "dropout_covariate": self.dropout_covariate_cols, "obs_var": ["recency", "frequency"], + "customer_id": self.data["customer_id"], } - mutable_coords = {"customer_id": self.data["customer_id"]} - with pm.Model(coords=coords, coords_mutable=mutable_coords) as self.model: + with pm.Model(coords=coords) as self.model: if self.purchase_covariate_cols: - purchase_data = pm.MutableData( + purchase_data = pm.Data( "purchase_data", self.data[self.purchase_covariate_cols], dims=["customer_id", "purchase_covariate"], @@ -273,7 +273,7 @@ def build_model(self) -> None: # type: ignore[override] # churn priors if self.dropout_covariate_cols: - dropout_data = pm.MutableData( + dropout_data = pm.Data( "dropout_data", self.data[self.dropout_covariate_cols], dims=["customer_id", "dropout_covariate"], diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 148762d1..1f04d217 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -203,11 +203,9 @@ def _generate_and_preprocess_model_data( # type: ignore date_data = X[self.date_column] channel_data = X[self.channel_columns] - self.coords_mutable: dict[str, Any] = { - "date": date_data, - } coords: dict[str, Any] = { "channel": self.channel_columns, + "date": date_data, } new_X_dict = { @@ -250,6 +248,8 @@ def _save_input_params(self, idata) -> None: idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag) idata.attrs["validate_data"] = json.dumps(self.validate_data) idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality) + idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept) + idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media) def forward_pass( self, x: pt.TensorVariable | npt.NDArray[np.float64] @@ -347,20 +347,17 @@ def build_model( self._generate_and_preprocess_model_data(X, y) with pm.Model( coords=self.model_coords, - coords_mutable=self.coords_mutable, ) as self.model: channel_data_ = pm.Data( name="channel_data", value=self.preprocessed_data["X"][self.channel_columns], dims=("date", "channel"), - mutable=True, ) target_ = pm.Data( name="target", value=self.preprocessed_data["y"], dims="date", - mutable=True, ) if self.time_varying_intercept | self.time_varying_media: time_index = pm.Data( @@ -441,7 +438,6 @@ def build_model( name="control_data", value=self.preprocessed_data["X"][self.control_columns], dims=("date", "control"), - mutable=True, ) control_contributions = pm.Deterministic( @@ -459,7 +455,6 @@ def build_model( self.date_column ].dt.dayofyear.to_numpy(), dims="date", - mutable=True, ) def create_deterministic(x: pt.TensorVariable) -> None: @@ -544,7 +539,6 @@ def channel_contributions_forward_pass( """ coords = { **self.model_coords, - **self.coords_mutable, } with pm.Model(coords=coords): pm.Deterministic( @@ -602,19 +596,32 @@ def load(cls, fname: str): model_config = cls._model_config_formatting( json.loads(idata.attrs["model_config"]) ) - model = cls( - date_column=json.loads(idata.attrs["date_column"]), - control_columns=json.loads(idata.attrs["control_columns"]), - channel_columns=json.loads(idata.attrs["channel_columns"]), - adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), - adstock=json.loads(idata.attrs.get("adstock", "geometric")), - saturation=json.loads(idata.attrs.get("saturation", "logistic")), - adstock_first=json.loads(idata.attrs.get("adstock_first", True)), - validate_data=json.loads(idata.attrs["validate_data"]), - yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), - model_config=model_config, - sampler_config=json.loads(idata.attrs["sampler_config"]), - ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=DeprecationWarning) + model = cls( + date_column=json.loads(idata.attrs["date_column"]), + control_columns=json.loads(idata.attrs["control_columns"]), + # Media Transformations + channel_columns=json.loads(idata.attrs["channel_columns"]), + adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]), + adstock=json.loads(idata.attrs.get("adstock", "geometric")), + saturation=json.loads(idata.attrs.get("saturation", "logistic")), + adstock_first=json.loads(idata.attrs.get("adstock_first", True)), + # Seasonality + yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]), + # TVP + time_varying_intercept=json.loads( + idata.attrs.get("time_varying_intercept", False) + ), + time_varying_media=json.loads( + idata.attrs.get("time_varying_media", False) + ), + # Configurations + validate_data=json.loads(idata.attrs["validate_data"]), + model_config=model_config, + sampler_config=json.loads(idata.attrs["sampler_config"]), + ) + model.idata = idata dataset = idata.fit_data.to_dataframe() X = dataset.drop(columns=[model.output_var]) @@ -622,8 +629,11 @@ def load(cls, fname: str): model.build_model(X, y) # All previously used data is in idata. if model.id != idata.attrs["id"]: - error_msg = f"""The file '{fname}' does not contain an inference data of the same model - or configuration as '{cls._model_type}'""" + error_msg = ( + f"The file '{fname}' does not contain " + "an inference data of the same model or " + f"configuration as '{cls._model_type}'" + ) raise ValueError(error_msg) return model diff --git a/pyproject.toml b/pyproject.toml index eba88b28..0c899c71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "numpy>=1.17", "pandas", # NOTE: Used as minimum pymc version with ci.yml `OLDEST_PYMC_VERSION` - "pymc>=5.12.0,<5.16.0", + "pymc>=5.13.0,<5.16.0", "scikit-learn>=1.1.1", "seaborn>=0.12.2", "xarray", diff --git a/tests/clv/test_distributions.py b/tests/clv/test_distributions.py index c868d3c1..66356466 100644 --- a/tests/clv/test_distributions.py +++ b/tests/clv/test_distributions.py @@ -261,7 +261,7 @@ def test_pareto_nbd_sample_prior( s = pm.Gamma(name="s", alpha=5, beta=1, size=s_size) beta = pm.Gamma(name="beta", alpha=5, beta=1, size=beta_size) - T = pm.MutableData(name="T", value=np.array(10)) + T = pm.Data(name="T", value=np.array(10)) ParetoNBD( name="pareto_nbd", @@ -436,7 +436,7 @@ def test_beta_geo_beta_binom_sample_prior( gamma = pm.Normal(name="gamma", mu=gamma_true, sigma=1e-4, size=gamma_size) delta = pm.Normal(name="delta", mu=delta_true, sigma=1e-4, size=delta_size) - T = pm.MutableData(name="T", value=np.array(T_true)) + T = pm.Data(name="T", value=np.array(T_true)) BetaGeoBetaBinom( name="beta_geo_beta_binom", diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 010a609b..a6b7cc0c 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -637,9 +637,11 @@ def mock_property(self): # Apply the monkeypatch for the property monkeypatch.setattr(MMM, "id", property(mock_property)) - error_msg = """The file 'test_model' does not contain an inference data of the same model - or configuration as 'MMM'""" - + error_msg = ( + "The file 'test_model' does not " + "contain an inference data of the " + "same model or configuration as 'MMM'" + ) with pytest.raises(ValueError, match=error_msg): MMM.load("test_model") os.remove("test_model") @@ -1017,3 +1019,38 @@ def test_initialize_defaults_channel_media_dims() -> None: for transform in [mmm.adstock, mmm.saturation]: for config in transform.function_priors.values(): assert config.dims == ("channel",) + + +@pytest.mark.parametrize( + "time_varying_intercept, time_varying_media", + [ + (True, False), + (False, True), + (True, True), + ], +) +def test_save_load_with_tvp( + time_varying_intercept, time_varying_media, toy_X, toy_y +) -> None: + mmm = MMM( + channel_columns=["channel_1", "channel_2"], + date_column="date", + adstock="geometric", + saturation="logistic", + adstock_max_lag=5, + time_varying_intercept=time_varying_intercept, + time_varying_media=time_varying_media, + ) + mmm = mock_fit(mmm, toy_X, toy_y) + + file = "tmp-model" + mmm.save(file) + loaded_mmm = MMM.load(file) + + assert mmm.time_varying_intercept == loaded_mmm.time_varying_intercept + assert mmm.time_varying_intercept == time_varying_intercept + assert mmm.time_varying_media == loaded_mmm.time_varying_media + assert mmm.time_varying_media == time_varying_media + + # clean up + os.remove(file) diff --git a/tests/test_model_builder.py b/tests/test_model_builder.py index 73bac7af..130ebf45 100644 --- a/tests/test_model_builder.py +++ b/tests/test_model_builder.py @@ -110,8 +110,8 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None): with pm.Model(coords=coords) as self.model: if model_config is None: model_config = self.default_model_config - x = pm.MutableData("x", self.X["input"].values) - y_data = pm.MutableData("y_data", self.y) + x = pm.Data("x", self.X["input"].values) + y_data = pm.Data("y_data", self.y) # prior parameters a_loc = model_config["a"]["loc"]