Skip to content

Commit

Permalink
Test customer_lifetime_value after thinning
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 5, 2023
1 parent 4d61373 commit 372ec23
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 9 deletions.
22 changes: 13 additions & 9 deletions pymc_marketing/clv/models/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def __init__(
def __repr__(self):
return f"{self._model_type}\n{self.model.str_repr()}"

def _add_fit_data_group(self, data: pd.DataFrame) -> None:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
assert self.idata is not None
self.idata.add_groups(fit_data=data.to_xarray())

def fit( # type: ignore
self,
fit_method: str = "mcmc",
Expand Down Expand Up @@ -60,16 +70,10 @@ def fit( # type: ignore
f"Fit method options are ['mcmc', 'map'], got: {fit_method}"
)

self.set_idata_attrs(idata)
self.idata = idata

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=UserWarning,
message="The group fit_data is not defined in the InferenceData scheme",
)
self.idata.add_groups(fit_data=self.data.to_xarray()) # type: ignore
self.set_idata_attrs(self.idata)
if self.data is not None:
self._add_fit_data_group(self.data)

return self.idata

Expand Down
35 changes: 35 additions & 0 deletions tests/clv/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def fitted_bg(test_summary_data) -> BetaGeoModel:
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.idata = fake_fit
model.set_idata_attrs(model.idata)
model._add_fit_data_group(model.data)

return model

Expand Down Expand Up @@ -106,6 +108,8 @@ def fitted_gg(test_summary_data) -> GammaGammaModel:
)
fake_fit.add_groups(dict(posterior=fake_fit.prior))
model.idata = fake_fit
model.set_idata_attrs(model.idata)
model._add_fit_data_group(model.data)

return model

Expand Down Expand Up @@ -271,6 +275,37 @@ def test_map_posterior_mix_fit_types(

assert res.dims == ("chain", "draw", "customer_id")

def test_clv_after_thinning(self, test_summary_data, fitted_gg, fitted_bg):
t = test_summary_data.head()

ggf_clv = fitted_gg.expected_customer_lifetime_value(
transaction_model=fitted_bg,
customer_id=t.index,
frequency=t["frequency"],
recency=t["recency"],
T=t["T"],
mean_transaction_value=t["monetary_value"],
)

fitted_gg_thinned = fitted_gg.thin_fit_result(keep_every=10)
fitted_bg_thinned = fitted_bg.thin_fit_result(keep_every=10)
ggf_clv_thinned = fitted_gg_thinned.expected_customer_lifetime_value(
transaction_model=fitted_bg_thinned,
customer_id=t.index,
frequency=t["frequency"],
recency=t["recency"],
T=t["T"],
mean_transaction_value=t["monetary_value"],
)

assert ggf_clv.shape == (1, 50, 5)
assert ggf_clv_thinned.shape == (1, 5, 5)

np.testing.assert_equal(
ggf_clv.isel(draw=slice(None, None, 10)).values,
ggf_clv_thinned.values,
)


def test_find_first_transactions_observation_period_end_none(transaction_data):
max_date = transaction_data["date"].max()
Expand Down

0 comments on commit 372ec23

Please sign in to comment.