From ef45da25ba3a956f4c0d8a990a884326ec9ae247 Mon Sep 17 00:00:00 2001 From: ColtAllen Date: Mon, 23 Oct 2023 10:35:40 -0600 Subject: [PATCH] Revert sample_kwargs, build_model, and fit --- pymc_marketing/clv/models/pareto_nbd.py | 19 ++++--------------- tests/clv/models/test_pareto_nbd.py | 7 +++++++ 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/pymc_marketing/clv/models/pareto_nbd.py b/pymc_marketing/clv/models/pareto_nbd.py index 03cb832b..fcab7f1f 100644 --- a/pymc_marketing/clv/models/pareto_nbd.py +++ b/pymc_marketing/clv/models/pareto_nbd.py @@ -202,7 +202,7 @@ def __init__( self.r_prior, self.alpha_prior, self.s_prior, self.beta_prior ) - self.build_model() + # TODO: Add self.build_model() call here @property def default_model_config(self) -> Dict[str, Dict]: @@ -323,7 +323,7 @@ def fit(self, fit_method: str = "map", **kwargs): # type: ignore ) super().fit(fit_method, **kwargs) - return self + # TODO: return self or None? def expected_purchases( self, @@ -646,13 +646,8 @@ def _distribution_new_customers( s = pm.HalfFlat("s") beta = pm.HalfFlat("beta") - if shape_kwargs is None: - shape_kwargs = {"shape": 1000} - - pm.Gamma( - "population_purchase_rate", alpha=r, beta=1 / alpha, **shape_kwargs - ) - pm.Gamma("population_dropout", alpha=s, beta=1 / beta, **shape_kwargs) + pm.Gamma("population_purchase_rate", alpha=r, beta=1 / alpha, shape=1000) + pm.Gamma("population_dropout", alpha=s, beta=1 / beta, shape=1000) ParetoNBD( name="customer_population", @@ -672,7 +667,6 @@ def _distribution_new_customers( def distribution_new_customer_dropout( self, random_seed: Optional[RandomState] = None, - shape_kwargs: Optional[Dict] = None, ) -> xarray.Dataset: """Sample from the Gamma distribution representing dropout times for new customers. @@ -691,13 +685,11 @@ def distribution_new_customer_dropout( return self._distribution_new_customers( random_seed=random_seed, var_names=["population_dropout"], - shape_kwargs=shape_kwargs, )["population_dropout"] def distribution_new_customer_purchase_rate( self, random_seed: Optional[RandomState] = None, - shape_kwargs: Optional[Dict] = None, ) -> xarray.Dataset: """Sample from the Gamma distribution representing purchase rates for new customers. @@ -717,13 +709,11 @@ def distribution_new_customer_purchase_rate( return self._distribution_new_customers( random_seed=random_seed, var_names=["population_purchase_rate"], - shape_kwargs=shape_kwargs, )["population_purchase_rate"] def distribution_customer_population( self, random_seed: Optional[RandomState] = None, - shape_kwargs: Optional[Dict] = None, ) -> xarray.Dataset: """Pareto/NBD process representing purchases across the customer population. @@ -742,5 +732,4 @@ def distribution_customer_population( return self._distribution_new_customers( random_seed=random_seed, var_names=["customer_population"], - shape_kwargs=shape_kwargs, )["customer_population"] diff --git a/tests/clv/models/test_pareto_nbd.py b/tests/clv/models/test_pareto_nbd.py index 772dda63..57565f68 100644 --- a/tests/clv/models/test_pareto_nbd.py +++ b/tests/clv/models/test_pareto_nbd.py @@ -37,6 +37,8 @@ def setup_class(cls): # Instantiate model with CDNOW data for testing cls.model = ParetoNBDModel(cls.data) + # TODO: This can be removed after build_model() is called internally with __init__ + cls.model.build_model() # Also instantiate lifetimes model for comparison cls.lifetimes_model = ParetoNBDFitter() @@ -88,6 +90,9 @@ def test_model(self, model_config, default_model_config): for config in (model_config, default_model_config): model = ParetoNBDModel(self.data, config) + # TODO: This can be removed after build_model() is called internally with __init__ + model.build_model() + assert isinstance( model.model["r"].owner.op, pm.Weibull @@ -178,6 +183,8 @@ def test_model_convergence(self, fit_method, rtol, sample_kwargs): model = ParetoNBDModel( data=self.data, ) + # TODO: This can be removed after build_model() is called internally with __init__ + model.build_model() if sample_kwargs is None: sample_kwargs = {}