Skip to content

Commit

Permalink
Save & load support for time varying parameters (#815)
Browse files Browse the repository at this point in the history
* add missing init for save and load

* get rid of warnings from JSON parsing

* new error message without line break
  • Loading branch information
wd60622 authored Jul 8, 2024
1 parent 9b691a9 commit 03e9215
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 18 deletions.
48 changes: 33 additions & 15 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def _save_input_params(self, idata) -> None:
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)
idata.attrs["time_varying_intercept"] = json.dumps(self.time_varying_intercept)
idata.attrs["time_varying_media"] = json.dumps(self.time_varying_media)

def forward_pass(
self, x: pt.TensorVariable | npt.NDArray[np.float64]
Expand Down Expand Up @@ -602,28 +604,44 @@ def load(cls, fname: str):
model_config = cls._model_config_formatting(
json.loads(idata.attrs["model_config"])
)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
validate_data=json.loads(idata.attrs["validate_data"]),
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=DeprecationWarning)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
# Media Transformations
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
adstock=json.loads(idata.attrs.get("adstock", "geometric")),
saturation=json.loads(idata.attrs.get("saturation", "logistic")),
adstock_first=json.loads(idata.attrs.get("adstock_first", True)),
# Seasonality
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
# TVP
time_varying_intercept=json.loads(
idata.attrs.get("time_varying_intercept", False)
),
time_varying_media=json.loads(
idata.attrs.get("time_varying_media", False)
),
# Configurations
validate_data=json.loads(idata.attrs["validate_data"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)

model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
error_msg = f"""The file '{fname}' does not contain an inference data of the same model
or configuration as '{cls._model_type}'"""
error_msg = (
f"The file '{fname}' does not contain "
"an inference data of the same model or "
f"configuration as '{cls._model_type}'"
)
raise ValueError(error_msg)

return model
Expand Down
43 changes: 40 additions & 3 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,11 @@ def mock_property(self):
# Apply the monkeypatch for the property
monkeypatch.setattr(MMM, "id", property(mock_property))

error_msg = """The file 'test_model' does not contain an inference data of the same model
or configuration as 'MMM'"""

error_msg = (
"The file 'test_model' does not "
"contain an inference data of the "
"same model or configuration as 'MMM'"
)
with pytest.raises(ValueError, match=error_msg):
MMM.load("test_model")
os.remove("test_model")
Expand Down Expand Up @@ -1017,3 +1019,38 @@ def test_initialize_defaults_channel_media_dims() -> None:
for transform in [mmm.adstock, mmm.saturation]:
for config in transform.function_priors.values():
assert config.dims == ("channel",)


@pytest.mark.parametrize(
"time_varying_intercept, time_varying_media",
[
(True, False),
(False, True),
(True, True),
],
)
def test_save_load_with_tvp(
time_varying_intercept, time_varying_media, toy_X, toy_y
) -> None:
mmm = MMM(
channel_columns=["channel_1", "channel_2"],
date_column="date",
adstock="geometric",
saturation="logistic",
adstock_max_lag=5,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
)
mmm = mock_fit(mmm, toy_X, toy_y)

file = "tmp-model"
mmm.save(file)
loaded_mmm = MMM.load(file)

assert mmm.time_varying_intercept == loaded_mmm.time_varying_intercept
assert mmm.time_varying_intercept == time_varying_intercept
assert mmm.time_varying_media == loaded_mmm.time_varying_media
assert mmm.time_varying_media == time_varying_media

# clean up
os.remove(file)

0 comments on commit 03e9215

Please sign in to comment.