diff --git a/pymc_marketing/clv/models/beta_geo_beta_binom.py b/pymc_marketing/clv/models/beta_geo_beta_binom.py index 45f8b68f..f9eec3f1 100644 --- a/pymc_marketing/clv/models/beta_geo_beta_binom.py +++ b/pymc_marketing/clv/models/beta_geo_beta_binom.py @@ -524,6 +524,7 @@ def _distribution_new_customers( "purchase_rate", "recency_frequency", ), + n_samples: int = 1000, ) -> xarray.Dataset: """Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers. @@ -542,6 +543,8 @@ def _distribution_new_customers( Random state to use for sampling. var_names : sequence of str, optional Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"]. + n_samples : int, optional + Number of posterior predictive samples to generate. Defaults to 1000 """ if data is None: @@ -557,7 +560,7 @@ def _distribution_new_customers( if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1: # For map fit add a dummy draw dimension - dataset = dataset.squeeze("draw").expand_dims(draw=range(1000)) + dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples)) coords = self.model.coords.copy() # type: ignore coords["customer_id"] = data["customer_id"] @@ -668,6 +671,7 @@ def distribution_new_customer_recency_frequency( *, T: int | np.ndarray | pd.Series | None = None, random_seed: RandomState | None = None, + n_samples: int = 1, ) -> xarray.Dataset: """BG/BB process representing purchases across the customer population. @@ -687,6 +691,8 @@ def distribution_new_customer_recency_frequency( Not required if `data` Dataframe contains a `T` column. random_seed : ~numpy.random.RandomState, optional Random state to use for sampling. + n_samples : int, optional + Number of samples to generate. Defaults to 1. Returns ------- @@ -698,4 +704,5 @@ def distribution_new_customer_recency_frequency( T=T, random_seed=random_seed, var_names=["recency_frequency"], + n_samples=n_samples, )["recency_frequency"]