Skip to content

Commit

Permalink
Add option to thin fit result in CLV models
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2023
1 parent 96da752 commit 4d61373
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 49 deletions.
46 changes: 39 additions & 7 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, Optional, Tuple, cast

import arviz as az
import pandas as pd
import pymc as pm
from pymc import Model, str_for_dist
from pymc.backends import NDArray
Expand All @@ -20,10 +21,13 @@ class CLVModel(ModelBuilder):

def __init__(
self,
data: Optional[pd.DataFrame] = None,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data

def __repr__(self):
return f"{self._model_type}\n{self.model.str_repr()}"
Expand Down Expand Up @@ -138,24 +142,52 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
dataset = idata.fit_data.to_dataframe()
return cls._build_with_idata(idata)

@classmethod
def _build_with_idata(cls, idata: az.InferenceData):
dataset = idata.fit_data.to_dataframe()
model = cls(
dataset,
model_config=json.loads(idata.attrs["model_config"]), # type: ignore
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata

model.build_model() # type: ignore

if model.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
raise ValueError(f"Inference data not compatible with {cls._model_type}")
return model

def thin_fit_result(self, keep_every: int):
"""Return a copy of the model with a thinned fit result.
This is useful when computing summary statistics that may require too much memory per posterior draw.
Examples
--------
.. code-block:: python
fitted_gg = ...
fitted bg = ...
fitted_gg_thinned = fitted_gg.thin_fit_result(keep_every=10)
fitted_bg_thinned = fitted_bg.thin_fit_result(keep_every=10)
clv_thinned = fitted_gg_thinned.expected_customer_lifetime_value(
transaction_model=fitted_bg_thinned,
customer_id=t.index,
frequency=t["frequency"],
recency=t["recency"],
T=t["T"],
mean_transaction_value=t["monetary_value"],
)
# All previously used data is in idata.
return model
"""
self.fit_result # Raise Error if fit didn't happen yet
assert self.idata is not None
new_idata = self.idata.isel(draw=slice(None, None, keep_every)).copy()
return type(self)._build_with_idata(new_idata)

@staticmethod
def _check_prior_ndim(prior, ndim: int = 0):
Expand Down
2 changes: 1 addition & 1 deletion pymc_marketing/clv/models/beta_geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ def __init__(
except KeyError:
raise KeyError("T column is missing from data")
super().__init__(
data=data,
model_config=model_config,
sampler_config=sampler_config,
)
self.data = data
self.a_prior = self._create_distribution(self.model_config["a_prior"])
self.b_prior = self._create_distribution(self.model_config["b_prior"])
self.alpha_prior = self._create_distribution(self.model_config["alpha_prior"])
Expand Down
6 changes: 4 additions & 2 deletions pymc_marketing/clv/models/gamma_gamma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ class BaseGammaGammaModel(CLVModel):
def __init__(
self,
data: pd.DataFrame,
*,
model_config: Optional[Dict] = None,
sampler_config: Optional[Dict] = None,
):
super().__init__(model_config, sampler_config)
self.data = data
super().__init__(
data=data, model_config=model_config, sampler_config=sampler_config
)
self.p_prior = self._create_distribution(self.model_config["p_prior"])
self.q_prior = self._create_distribution(self.model_config["q_prior"])
self.v_prior = self._create_distribution(self.model_config["v_prior"])
Expand Down
83 changes: 44 additions & 39 deletions tests/clv/models/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,38 @@
import pandas as pd
import pymc as pm
import pytest
from arviz import InferenceData
from arviz import InferenceData, from_dict

from pymc_marketing.clv.models.basic import CLVModel


class CLVModelTest(CLVModel):
_model_type = "CLVModelTest"

def __init__(self, dataset=None, model_config=None, sampler_config=None):
super().__init__()
self.data = pd.DataFrame({"y": np.random.randn(100)})
self.a = self._create_distribution(self.model_config["a"])
self._process_priors(self.a)
def __init__(self, data=None, **kwargs):
if data is None:
data = pd.DataFrame({"y": np.random.randn(10)})
super().__init__(data=data, **kwargs)
self.x_prior = self._create_distribution(self.model_config["x"])
self._process_priors(self.x_prior)

@property
def default_model_config(self):
return {
"a": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}},
"b": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}},
"x": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}},
}

def build_model(self):
with pm.Model() as self.model:
self.a = pm.Normal("a", mu=0, sigma=1)
self.b = pm.Normal("b", mu=0, sigma=1)
self.y = pm.Normal(
"y", mu=self.a + self.b, sigma=1, observed=self.data["y"]
)
x = self.model.register_rv(self.x_prior, name="x")
pm.Normal("y", mu=x, sigma=1, observed=self.data["y"])


class TestCLVModel:
def test_repr(self):
model = CLVModelTest()
model.build_model()
assert (
model.__repr__()
== "CLVModelTest\na ~ Normal(0, 1)\nb ~ Normal(0, 1)\ny ~ Normal(f(a, b), 1)"
)
assert model.__repr__() == "CLVModelTest\nx ~ Normal(0, 1)\ny ~ Normal(x, 1)"

def test_check_prior_ndim(self):
prior = pm.Normal.dist(shape=(5,)) # ndim = 1
Expand Down Expand Up @@ -123,14 +117,12 @@ def test_fit_result_error(self):
def test_load(self):
model = CLVModelTest()
model.build_model()
model.fit(draws=100, chains=2, random_seed=1234)
model.fit(tune=0, chains=2, draws=5)
model.save("test_model")
try:
model2 = model.load("test_model")
assert model2.fit_result is not None
assert model2.model is not None
finally:
os.remove("test_model")
model2 = model.load("test_model")
assert model2.fit_result is not None
assert model2.model is not None
os.remove("test_model")

def test_default_sampler_config(self):
model = CLVModelTest()
Expand All @@ -148,13 +140,12 @@ def test_set_fit_result(self):
with pytest.warns(UserWarning, match="Overriding pre-existing fit_result"):
model.fit_result = fake_fit
model.idata = None
model.sample_prior_predictive(samples=50, extend_idata=True)
model.fit_result = fake_fit

def test_fit_summary_for_mcmc(self):
model = CLVModelTest()
model.build_model()
model.fit()
model.fit(tune=0, chains=2, draws=5)
summ = model.fit_summary()
assert isinstance(summ, pd.DataFrame)

Expand All @@ -171,17 +162,31 @@ def mock_property(self):

# Now create an instance of MyClass
mock_basic = CLVModelTest()

# Check that the property returns the new value
mock_basic.fit()
mock_basic.fit(tune=0, chains=2, draws=5)
mock_basic.save("test_model")
try:
# Apply the monkeypatch for the property
monkeypatch.setattr(CLVModelTest, "id", property(mock_property))
with pytest.raises(
ValueError,
match="The file 'test_model' does not contain an inference data of the same model or configuration as 'CLVModelTest'",
):
CLVModelTest.load("test_model")
finally:
os.remove("test_model")

# Apply the monkeypatch for the property
monkeypatch.setattr(CLVModelTest, "id", property(mock_property))
with pytest.raises(
ValueError,
match="Inference data not compatible with CLVModelTest",
):
CLVModelTest.load("test_model")
os.remove("test_model")

def test_thin_fit_result(self):
data = pd.DataFrame(dict(y=[-3, -2, -1]))
model = CLVModelTest(data=data)
model.build_model()
fake_idata = from_dict(dict(x=np.random.normal(size=(4, 1000))))
fake_idata.add_groups(dict(fit_data=data.to_xarray()))
model.set_idata_attrs(fake_idata)
model.idata = fake_idata

thin_model = model.thin_fit_result(keep_every=20)
assert thin_model is not model
assert thin_model.idata is not model.idata
assert len(thin_model.idata.posterior["x"].chain) == 4
assert len(thin_model.idata.posterior["x"].draw) == 50
assert thin_model.data is not model.data
assert np.all(thin_model.data == model.data)

0 comments on commit 4d61373

Please sign in to comment.