Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MMM load updates #317

Merged
merged 8 commits into from
Jul 13, 2023
11 changes: 5 additions & 6 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,11 @@ def load(cls, fname: str):
model.idata = idata

model.build_model()

if model.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
)
# All previously used data is in idata.

return model
Expand Down Expand Up @@ -269,12 +274,6 @@ def fit_result(self, res: az.InferenceData) -> None:
else:
self.idata.posterior = res

@property
def posterior_predictive(self) -> Dataset:
if self.idata is None or "posterior_predictive" not in self.idata:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
return self.idata["posterior_predictive"]

def fit_summary(self, **kwargs):
res = self.fit_result
# Map fitting only gives one value, so we return it. We use arviz
Expand Down
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

class BaseMMM(ModelBuilder):
model: pm.Model
_model_type = "baseMMM"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_model_type = "baseMMM"
_model_type = "BaseMMM"

version = "0.0.2"

def __init__(
self,
Expand Down
65 changes: 65 additions & 0 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import arviz as az
Expand All @@ -19,6 +21,9 @@


class BaseDelayedSaturatedMMM(MMM):
_model_type = "DelayedSaturatedMMM"
version = "0.0.2"

def __init__(
self,
date_column: str,
Expand Down Expand Up @@ -127,6 +132,15 @@ def generate_and_preprocess_model_data(
self.X: pd.DataFrame = X_data
self.y: pd.Series = y

def _save_input_params(self, idata) -> None:
"""Saves input parameters to the attrs of idata."""
idata.attrs["date_column"] = json.dumps(self.date_column)
idata.attrs["control_columns"] = json.dumps(self.control_columns)
idata.attrs["channel_columns"] = json.dumps(self.channel_columns)
idata.attrs["adstock_max_lag"] = json.dumps(self.adstock_max_lag)
idata.attrs["validate_data"] = json.dumps(self.validate_data)
idata.attrs["yearly_seasonality"] = json.dumps(self.yearly_seasonality)

def build_model(
self,
X: pd.DataFrame,
Expand Down Expand Up @@ -355,6 +369,57 @@ def _serializable_model_config(self) -> Dict[str, Any]:
]["sigma"].tolist()
return serializable_config

@classmethod
def load(cls, fname: str):
"""
Creates a DelayedSaturatedMMM instance from a file,
instantiating the model with the saved original input parameters.
Loads inference data for the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.

Returns
-------
Returns an instance of DelayedSaturatedMMM.

Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.
"""

filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
# needs to be converted, because json.loads was changing tuple to list
model_config = cls._convert_dims_to_tuple(
json.loads(idata.attrs["model_config"])
)
model = cls(
date_column=json.loads(idata.attrs["date_column"]),
control_columns=json.loads(idata.attrs["control_columns"]),
channel_columns=json.loads(idata.attrs["channel_columns"]),
adstock_max_lag=json.loads(idata.attrs["adstock_max_lag"]),
validate_data=json.loads(idata.attrs["validate_data"]),
yearly_seasonality=json.loads(idata.attrs["yearly_seasonality"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
dataset = idata.fit_data.to_dataframe()
X = dataset.drop(columns=[model.output_var])
y = dataset[model.output_var].values
model.build_model(X, y)
# All previously used data is in idata.
if model.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
)

return model

def _data_setter(
self,
X: Union[np.ndarray, pd.DataFrame],
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies = [
"seaborn>=0.12.2",
"xarray",
"xarray-einstats>=0.5.1",
"pymc-experimental>=0.0.7",
"pymc-experimental>=0.0.8",
]

[project.optional-dependencies]
Expand Down
31 changes: 27 additions & 4 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import numpy as np
import pandas as pd
import pymc as pm
Expand Down Expand Up @@ -140,10 +142,11 @@ def test_load(self):
model = CLVModelTest()
model.build_model()
model.fit(target_accept=0.81, draws=100, chains=2, random_seed=1234)
model.save("test_model.pkl")
model.load("test_model.pkl")
assert model.fit_result is not None
assert model.model is not None
model.save("test_model")
model2 = model.load("test_model")
assert model2.fit_result is not None
assert model2.model is not None
os.remove("test_model")

def test_default_sampler_config(self):
model = CLVModelTest()
Expand Down Expand Up @@ -205,3 +208,23 @@ def test_serializable_model_config(self):
serializable_config = model._serializable_model_config
assert isinstance(serializable_config, dict)
assert serializable_config == model.model_config

def test_fail_id_after_load(self, monkeypatch):
# This is the new behavior for the property
def mock_property(self):
return "for sure not correct id"

# Now create an instance of MyClass
mock_basic = CLVModelTest()

# Check that the property returns the new value
mock_basic.fit()
mock_basic.save("test_model")
# Apply the monkeypatch for the property
monkeypatch.setattr(CLVModelTest, "id", property(mock_property))
with pytest.raises(
ValueError,
match="The file 'test_model' does not contain an inference data of the same model or configuration as 'CLVModelTest'",
):
CLVModelTest.load("test_model")
os.remove("test_model")
32 changes: 12 additions & 20 deletions tests/clv/models/test_beta_geo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import tempfile
from pathlib import Path
import os

import arviz as az
import numpy as np
Expand Down Expand Up @@ -494,32 +492,26 @@ def test_distribution_new_customer(self, data) -> None:
)

def test_save_load_beta_geo(self, data):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)

model = BetaGeoModel(
data=data,
)
model.build_model()
model.fit("map")
model.save(temp)
model.save("test_model")
# Testing the valid case.

model2 = BetaGeoModel.load(temp)
model2 = BetaGeoModel.load("test_model")

# Check if the loaded model is indeed an instance of the class
assert isinstance(model, BetaGeoModel)

# Load data from the file to cross verify
filepath = Path(str(temp))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
# Check if the loaded data matches with the model data
np.testing.assert_array_equal(
model2.customer_id.values, dataset.customer_id.values
)
np.testing.assert_array_equal(model2.frequency.values, dataset.frequency.values)
np.testing.assert_array_equal(model2.T.values, dataset["T"])
np.testing.assert_array_equal(model2.recency.values, dataset.recency.values)
assert model.model_config == json.loads(idata.attrs["model_config"])
assert model.sampler_config == json.loads(idata.attrs["sampler_config"])
assert model.idata == idata
model2.customer_id.values, model.customer_id.values
)
np.testing.assert_array_equal(model2.frequency.values, model.frequency.values)
np.testing.assert_array_equal(model2.T.values, model.T.values)
np.testing.assert_array_equal(model2.recency.values, model.recency.values)
assert model.model_config == model2.model_config
assert model.sampler_config == model2.sampler_config
assert model.idata == model2.idata
os.remove("test_model")
51 changes: 18 additions & 33 deletions tests/clv/models/test_gamma_gamma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
import tempfile
from pathlib import Path
import os
from unittest.mock import patch

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
Expand Down Expand Up @@ -294,35 +291,29 @@ def test_model_repr(self, data, default_model_config):
)

def test_save_load_beta_geo(self, data):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)

model = GammaGammaModel(
data=data,
)
model.build_model()
model.fit("map")
model.save(temp)
model.save("test_model")
# Testing the valid case.

model2 = GammaGammaModel.load(temp)
model2 = GammaGammaModel.load("test_model")

# Check if the loaded model is indeed an instance of the class
assert isinstance(model, GammaGammaModel)

# Load data from the file to cross verify
filepath = Path(str(temp))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
# Check if the loaded data matches with the model data
assert np.array_equal(model2.customer_id.values, dataset.customer_id.values)
assert np.array_equal(model2.frequency, dataset.frequency)
assert np.array_equal(model2.customer_id.values, model.customer_id.values)
assert np.array_equal(model2.frequency, model.frequency)
assert np.array_equal(
model2.mean_transaction_value, dataset.mean_transaction_value
model2.mean_transaction_value, model.mean_transaction_value
)

assert model.model_config == json.loads(idata.attrs["model_config"])
assert model.sampler_config == json.loads(idata.attrs["sampler_config"])
assert model.idata == idata
assert model.model_config == model2.model_config
assert model.sampler_config == model2.sampler_config
assert model.idata == model2.idata
os.remove("test_model")


class TestGammaGammaModelIndividual(BaseTestGammaGammaModel):
Expand Down Expand Up @@ -457,33 +448,27 @@ def test_model_repr(self, individual_data, default_model_config):
)

def test_save_load_beta_geo(self, individual_data):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)

model = GammaGammaModelIndividual(
data=individual_data,
)
model.build_model()
model.fit("map")
model.save(temp)
model.save("test_model")
# Testing the valid case.

model2 = GammaGammaModelIndividual.load(temp)
model2 = GammaGammaModelIndividual.load("test_model")

# Check if the loaded model is indeed an instance of the class
assert isinstance(model, GammaGammaModelIndividual)

# Load data from the file to cross verify
filepath = Path(str(temp))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
# Check if the loaded data matches with the model data
np.testing.assert_array_equal(
model2.customer_id.values, dataset.customer_id.values
model2.customer_id.values, model.customer_id.values
)
np.testing.assert_array_equal(
model2.individual_transaction_value, dataset.individual_transaction_value
model2.individual_transaction_value, model.individual_transaction_value
)

assert model.model_config == json.loads(idata.attrs["model_config"])
assert model.sampler_config == json.loads(idata.attrs["sampler_config"])
assert model.idata == idata
assert model.model_config == model2.model_config
assert model.sampler_config == model2.sampler_config
assert model.idata == model2.idata
os.remove("test_model")
30 changes: 10 additions & 20 deletions tests/clv/models/test_shifted_beta_geo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import tempfile
from pathlib import Path
import os

import arviz as az
import numpy as np
Expand Down Expand Up @@ -249,31 +247,23 @@ def test_distribution_new_customer(self):
)

def test_save_load_beta_geo(self, data):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)

model = ShiftedBetaGeoModelIndividual(
data=data,
)
model.build_model()
model.fit("map")
model.save(temp)
model.save("test_model")
# Testing the valid case.

model2 = ShiftedBetaGeoModelIndividual.load(temp)

model2 = ShiftedBetaGeoModelIndividual.load("test_model")
# Check if the loaded model is indeed an instance of the class
assert isinstance(model, ShiftedBetaGeoModelIndividual)

# Load data from the file to cross verify
filepath = Path(str(temp))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
# Check if the loaded data matches with the model data
np.testing.assert_array_equal(
model2.customer_id.values, dataset.customer_id.values
model2.customer_id.values, model.customer_id.values
)
np.testing.assert_array_equal(model2.t_churn, dataset.t_churn)
np.testing.assert_array_equal(model2.T, dataset["T"])
assert model.model_config == json.loads(idata.attrs["model_config"])
assert model.sampler_config == json.loads(idata.attrs["sampler_config"])
assert model.idata == idata
np.testing.assert_array_equal(model2.t_churn, model.t_churn)
np.testing.assert_array_equal(model2.T, model.T)
assert model.model_config == model2.model_config
assert model.sampler_config == model2.sampler_config
assert model.idata == model2.idata
os.remove("test_model")
Loading