Skip to content

Commit

Permalink
Revert sample_kwargs, build_model, and fit
Browse files Browse the repository at this point in the history
  • Loading branch information
ColtAllen committed Oct 23, 2023
1 parent 6c5af0b commit ef45da2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
19 changes: 4 additions & 15 deletions pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __init__(
self.r_prior, self.alpha_prior, self.s_prior, self.beta_prior
)

self.build_model()
# TODO: Add self.build_model() call here

@property
def default_model_config(self) -> Dict[str, Dict]:
Expand Down Expand Up @@ -323,7 +323,7 @@ def fit(self, fit_method: str = "map", **kwargs): # type: ignore
)
super().fit(fit_method, **kwargs)

return self
# TODO: return self or None?

def expected_purchases(
self,
Expand Down Expand Up @@ -646,13 +646,8 @@ def _distribution_new_customers(
s = pm.HalfFlat("s")
beta = pm.HalfFlat("beta")

if shape_kwargs is None:
shape_kwargs = {"shape": 1000}

pm.Gamma(
"population_purchase_rate", alpha=r, beta=1 / alpha, **shape_kwargs
)
pm.Gamma("population_dropout", alpha=s, beta=1 / beta, **shape_kwargs)
pm.Gamma("population_purchase_rate", alpha=r, beta=1 / alpha, shape=1000)
pm.Gamma("population_dropout", alpha=s, beta=1 / beta, shape=1000)

ParetoNBD(
name="customer_population",
Expand All @@ -672,7 +667,6 @@ def _distribution_new_customers(
def distribution_new_customer_dropout(
self,
random_seed: Optional[RandomState] = None,
shape_kwargs: Optional[Dict] = None,
) -> xarray.Dataset:
"""Sample from the Gamma distribution representing dropout times for new customers.
Expand All @@ -691,13 +685,11 @@ def distribution_new_customer_dropout(
return self._distribution_new_customers(
random_seed=random_seed,
var_names=["population_dropout"],
shape_kwargs=shape_kwargs,
)["population_dropout"]

def distribution_new_customer_purchase_rate(
self,
random_seed: Optional[RandomState] = None,
shape_kwargs: Optional[Dict] = None,
) -> xarray.Dataset:
"""Sample from the Gamma distribution representing purchase rates for new customers.
Expand All @@ -717,13 +709,11 @@ def distribution_new_customer_purchase_rate(
return self._distribution_new_customers(
random_seed=random_seed,
var_names=["population_purchase_rate"],
shape_kwargs=shape_kwargs,
)["population_purchase_rate"]

def distribution_customer_population(
self,
random_seed: Optional[RandomState] = None,
shape_kwargs: Optional[Dict] = None,
) -> xarray.Dataset:
"""Pareto/NBD process representing purchases across the customer population.
Expand All @@ -742,5 +732,4 @@ def distribution_customer_population(
return self._distribution_new_customers(
random_seed=random_seed,
var_names=["customer_population"],
shape_kwargs=shape_kwargs,
)["customer_population"]
7 changes: 7 additions & 0 deletions tests/clv/models/test_pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def setup_class(cls):

# Instantiate model with CDNOW data for testing
cls.model = ParetoNBDModel(cls.data)
# TODO: This can be removed after build_model() is called internally with __init__
cls.model.build_model()

# Also instantiate lifetimes model for comparison
cls.lifetimes_model = ParetoNBDFitter()
Expand Down Expand Up @@ -88,6 +90,9 @@ def test_model(self, model_config, default_model_config):
for config in (model_config, default_model_config):
model = ParetoNBDModel(self.data, config)

# TODO: This can be removed after build_model() is called internally with __init__
model.build_model()

assert isinstance(
model.model["r"].owner.op,
pm.Weibull
Expand Down Expand Up @@ -178,6 +183,8 @@ def test_model_convergence(self, fit_method, rtol, sample_kwargs):
model = ParetoNBDModel(
data=self.data,
)
# TODO: This can be removed after build_model() is called internally with __init__
model.build_model()

if sample_kwargs is None:
sample_kwargs = {}
Expand Down

0 comments on commit ef45da2

Please sign in to comment.