Skip to content

Commit

Permalink
#1063 Pareto_nbd distribution enhancements (#1067)
Browse files Browse the repository at this point in the history
* feat: test.txt added for commit check

* feat: replaced plot_curve with plot_samples within ./mmm/plot.py

* feat: n_samples added to distributions_new_customers

* feat: added n_samples to distributions_new_customer and distribution_new_customer_frequency

* remove test.txt

* redo commit before samples from sample in ./mmm/plot.py

* feat(pareto_nbd.py): changed n_samples = 1000 from 1

* fix (pareto_nbd.py): corrected the doc string for dis_new_cust_freq with n_samples = 1000 default
  • Loading branch information
Ishaanjolly authored Sep 30, 2024
1 parent da26dc8 commit 7a2b13f
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def distribution_new_customer(
"purchase_rate",
"recency_frequency",
),
n_samples: int = 1000,
) -> xarray.Dataset:
"""Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers.
Expand All @@ -896,6 +897,8 @@ def distribution_new_customer(
Random state to use for sampling.
var_names : sequence of str, optional
Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"].
n_samples : int, optional
Number of samples to generate. Defaults to 1000
"""
if data is None:
Expand All @@ -911,7 +914,7 @@ def distribution_new_customer(

if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1:
# For map fit add a dummy draw dimension
dataset = dataset.squeeze("draw").expand_dims(draw=range(1000))
dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples))

coords = self.model.coords.copy() # type: ignore
coords["customer_id"] = data["customer_id"]
Expand Down Expand Up @@ -1032,6 +1035,7 @@ def distribution_new_customer_recency_frequency(
*,
T: int | np.ndarray | pd.Series | None = None,
random_seed: RandomState | None = None,
n_samples: int = 1000,
) -> xarray.Dataset:
"""Pareto/NBD process representing purchases across the customer population.
Expand All @@ -1052,6 +1056,8 @@ def distribution_new_customer_recency_frequency(
Not required if `data` Dataframe contains a `T` column.
random_seed : ~numpy.random.RandomState, optional
Random state to use for sampling.
n_samples : int, optional
Number of samples to generate. Defaults to 1000.
Returns
-------
Expand All @@ -1064,4 +1070,5 @@ def distribution_new_customer_recency_frequency(
T=T,
random_seed=random_seed,
var_names=["recency_frequency"],
n_samples=n_samples,
)["recency_frequency"]

0 comments on commit 7a2b13f

Please sign in to comment.