Skip to content

Commit

Permalink
Merge branch 'main' into fourier_base_date_not_index
Browse files Browse the repository at this point in the history
  • Loading branch information
Ishaanjolly authored Sep 30, 2024
2 parents b4f5cf5 + 7a2b13f commit 3d21431
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 110 deletions.
188 changes: 106 additions & 82 deletions docs/source/notebooks/mmm/mmm_budget_allocation_example.ipynb

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion pymc_marketing/clv/models/pareto_nbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ def distribution_new_customer(
"purchase_rate",
"recency_frequency",
),
n_samples: int = 1000,
) -> xarray.Dataset:
"""Compute posterior predictive samples of dropout, purchase rate and frequency/recency of new customers.
Expand All @@ -896,6 +897,8 @@ def distribution_new_customer(
Random state to use for sampling.
var_names : sequence of str, optional
Names of the variables to sample from. Defaults to ["dropout", "purchase_rate", "recency_frequency"].
n_samples : int, optional
Number of samples to generate. Defaults to 1000
"""
if data is None:
Expand All @@ -911,7 +914,7 @@ def distribution_new_customer(

if dataset.sizes["chain"] == 1 and dataset.sizes["draw"] == 1:
# For map fit add a dummy draw dimension
dataset = dataset.squeeze("draw").expand_dims(draw=range(1000))
dataset = dataset.squeeze("draw").expand_dims(draw=range(n_samples))

coords = self.model.coords.copy() # type: ignore
coords["customer_id"] = data["customer_id"]
Expand Down Expand Up @@ -1032,6 +1035,7 @@ def distribution_new_customer_recency_frequency(
*,
T: int | np.ndarray | pd.Series | None = None,
random_seed: RandomState | None = None,
n_samples: int = 1000,
) -> xarray.Dataset:
"""Pareto/NBD process representing purchases across the customer population.
Expand All @@ -1052,6 +1056,8 @@ def distribution_new_customer_recency_frequency(
Not required if `data` Dataframe contains a `T` column.
random_seed : ~numpy.random.RandomState, optional
Random state to use for sampling.
n_samples : int, optional
Number of samples to generate. Defaults to 1000.
Returns
-------
Expand All @@ -1064,4 +1070,5 @@ def distribution_new_customer_recency_frequency(
T=T,
random_seed=random_seed,
var_names=["recency_frequency"],
n_samples=n_samples,
)["recency_frequency"]
40 changes: 13 additions & 27 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,28 +2239,14 @@ def plot_budget_allocation(
The matplotlib figure object and axis containing the plot.
"""
if original_scale:
channel_contributions = (
samples["channel_contributions"]
.mean(dim=["sample"])
.mean(dim=["date"])
.values
* self.get_target_transformer()["scaler"].scale_
)
channel_contributions = (
samples["channel_contributions"].mean(dim=["date", "sample"]).to_numpy()
)

allocate_spend = (
np.array(list(self.optimal_allocation_dict.values()))
* self.channel_transformer["scaler"].scale_
)
if original_scale:
channel_contributions *= self.get_target_transformer()["scaler"].scale_

else:
channel_contributions = (
samples["channel_contributions"]
.mean(dim=["sample"])
.mean(dim=["date"])
.values
)
allocate_spend = np.array(list(self.optimal_allocation_dict.values()))
allocated_spend = np.array(list(self.optimal_allocation_dict.values()))

if ax is None:
fig, ax = plt.subplots(figsize=figsize)
Expand All @@ -2274,11 +2260,11 @@ def plot_budget_allocation(

bars1 = ax.bar(
index,
allocate_spend,
allocated_spend,
bar_width,
color="b",
color="C0",
alpha=opacity,
label="Allocate Spend",
label="Allocated Spend",
)

ax2 = ax.twinx()
Expand All @@ -2287,19 +2273,19 @@ def plot_budget_allocation(
index + bar_width,
channel_contributions,
bar_width,
color="r",
color="C1",
alpha=opacity,
label="Channel Contributions",
)

ax.set_xlabel("Channels")
ax.set_ylabel("Allocate Spend", color="b")
ax.set_ylabel("Allocate Spend", color="C0")
ax.tick_params(axis="x", rotation=90)
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(self.channel_columns)

ax.set_ylabel("Allocate Spend", color="b", labelpad=10)
ax2.set_ylabel("Channel Contributions", color="r", labelpad=10)
ax.set_ylabel("Allocate Spend", color="C0", labelpad=10)
ax2.set_ylabel("Channel Contributions", color="C1", labelpad=10)

ax.grid(False)
ax2.grid(False)
Expand Down

0 comments on commit 3d21431

Please sign in to comment.