Skip to content

Commit

Permalink
#1035 Distribution new customer enhancements (#1061)
Browse files Browse the repository at this point in the history
* feat: test.txt added for commit check

* feat: replaced plot_curve with plot_samples within ./mmm/plot.py

* feat: n_samples added to distributions_new_customers

* revert the plot.py changes

---------

Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
Ishaanjolly and juanitorduz authored Sep 23, 2024
1 parent 1a1703c commit d05c2d8
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pymc_marketing/clv/models/beta_geo_beta_binom.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def _distribution_new_customers(
"purchase_rate",
"recency_frequency",
),
n_samples: int = 1000,
) -> xarray.Dataset:
"""Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers.
Expand All @@ -542,6 +543,8 @@ def _distribution_new_customers(
Random state to use for sampling.
var_names : sequence of str, optional
Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"].
n_samples : int, optional
Number of posterior predictive samples to generate. Defaults to 1000
"""
if data is None:
Expand All @@ -557,7 +560,7 @@ def _distribution_new_customers(

if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1:
# For map fit add a dummy draw dimension
dataset = dataset.squeeze("draw").expand_dims(draw=range(1000))
dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples))

coords = self.model.coords.copy() # type: ignore
coords["customer_id"] = data["customer_id"]
Expand Down Expand Up @@ -668,6 +671,7 @@ def distribution_new_customer_recency_frequency(
*,
T: int | np.ndarray | pd.Series | None = None,
random_seed: RandomState | None = None,
n_samples: int = 1,
) -> xarray.Dataset:
"""BG/BB process representing purchases across the customer population.
Expand All @@ -687,6 +691,8 @@ def distribution_new_customer_recency_frequency(
Not required if `data` Dataframe contains a `T` column.
random_seed : ~numpy.random.RandomState, optional
Random state to use for sampling.
n_samples : int, optional
Number of samples to generate. Defaults to 1.
Returns
-------
Expand All @@ -698,4 +704,5 @@ def distribution_new_customer_recency_frequency(
T=T,
random_seed=random_seed,
var_names=["recency_frequency"],
n_samples=n_samples,
)["recency_frequency"]

0 comments on commit d05c2d8

Please sign in to comment.