Skip to content

Commit

Permalink
refactor: Move plot_prior_vs_posterior into MMM module
Browse files Browse the repository at this point in the history
  • Loading branch information
louismagowan committed Oct 17, 2024
1 parent 80cb3a6 commit 651dacb
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions pymc_marketing/mmm/mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,148 @@ def plot_channel_parameter(self, param_name: str, **plt_kwargs: Any) -> plt.Figu
)
return fig

def plot_prior_vs_posterior(
self,
var_name: str,
alphabetical_sort: bool = True,
figsize: tuple[int, int] | None = None,
) -> plt.Figure:
"""
Plot the prior vs posterior distribution for a specified variable in a 3 columngrid layout.
This function generates KDE plots for each MMM channel, showing the prior predictive
and posterior distributions with their respective means highlighted.
It sorts the plots either alphabetically or based on the difference between the
posterior and prior means, with the largest difference (posterior - prior) at the top.
Parameters
----------
var_name: str
The variable to analyze (e.g., 'adstock_alpha').
alphabetical_sort: bool, optional
Whether to sort the channels alphabetically (True) or by the difference
between the posterior and prior means (False). Default is True.
figsize : tuple of int, optional
Figure size in inches. If None, it will be calculated based on the number of channels.
Returns
-------
fig : plt.Figure
The matplotlib figure object
Raises
------
ValueError
If the required attributes (prior, posterior) were not found.
ValueError
If var_name is not a string.
"""
if not hasattr(self, "fit_result") or not hasattr(self, "prior"):
raise ValueError(

Check warning on line 1096 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1095-L1096

Added lines #L1095 - L1096 were not covered by tests
"Required attributes (fit_result, prior) not found. "
"Ensure you've called model.fit() and model.sample_prior_predictive()"
)

if not isinstance(var_name, str):
raise ValueError(

Check warning on line 1102 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1101-L1102

Added lines #L1101 - L1102 were not covered by tests
"var_name must be a string. Please provide a single variable name."
)

# Determine the number of channels and set up the grid
num_channels = len(self.channel_columns)
num_cols = 3
num_rows = (num_channels + num_cols - 1) // num_cols # Calculate rows needed

Check warning on line 1109 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1107-L1109

Added lines #L1107 - L1109 were not covered by tests

if figsize is None:
figsize = (25, 5 * num_rows)

Check warning on line 1112 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1111-L1112

Added lines #L1111 - L1112 were not covered by tests

# Calculate prior and posterior means for sorting
channel_means = []
for channel in self.channel_columns:
prior_mean = self.prior[var_name].sel(channel=channel).mean().values
posterior_mean = (

Check warning on line 1118 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1115-L1118

Added lines #L1115 - L1118 were not covered by tests
self.fit_result[var_name].sel(channel=channel).mean().values
)
difference = posterior_mean - prior_mean
channel_means.append((channel, prior_mean, posterior_mean, difference))

Check warning on line 1122 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1121-L1122

Added lines #L1121 - L1122 were not covered by tests

# Choose how to sort the channels
if alphabetical_sort:
sorted_channels = sorted(channel_means, key=lambda x: x[0])

Check warning on line 1126 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1125-L1126

Added lines #L1125 - L1126 were not covered by tests
else:
# Otherwise, sort on difference between posterior and prior means
sorted_channels = sorted(channel_means, key=lambda x: x[3], reverse=True)

Check warning on line 1129 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1129

Added line #L1129 was not covered by tests

fig, axs = plt.subplots(num_rows, num_cols, figsize=figsize)
axs = axs.flatten() # Flatten the array for easy iteration

Check warning on line 1132 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1131-L1132

Added lines #L1131 - L1132 were not covered by tests

# Plot for each channel
for i, (channel, prior_mean, posterior_mean, difference) in enumerate(

Check warning on line 1135 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1135

Added line #L1135 was not covered by tests
sorted_channels
):
# Extract prior samples for the current channel
prior_samples = self.prior[var_name].sel(channel=channel).values.flatten()

Check warning on line 1139 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1139

Added line #L1139 was not covered by tests

# Plot the prior predictive distribution
sns.kdeplot(

Check warning on line 1142 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1142

Added line #L1142 was not covered by tests
prior_samples,
ax=axs[i],
label="Prior Predictive",
color="blue",
fill=True,
)

# Add a vertical line for the mean of the prior distribution
axs[i].axvline(

Check warning on line 1151 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1151

Added line #L1151 was not covered by tests
prior_mean,
color="blue",
linestyle="--",
linewidth=2,
label=f"Prior Mean: {prior_mean:.2f}",
)

# Extract posterior samples for the current channel
posterior_samples = (

Check warning on line 1160 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1160

Added line #L1160 was not covered by tests
self.fit_result[var_name].sel(channel=channel).values.flatten()
)

# Plot the prior predictive distribution
sns.kdeplot(

Check warning on line 1165 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1165

Added line #L1165 was not covered by tests
posterior_samples,
ax=axs[i],
label="Posterior Predictive",
color="red",
fill=True,
alpha=0.15,
)

# Add a vertical line for the mean of the posterior distribution
axs[i].axvline(

Check warning on line 1175 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1175

Added line #L1175 was not covered by tests
posterior_mean,
color="red",
linestyle="--",
linewidth=2,
label=f"Posterior Mean: {posterior_mean:.2f} (Diff: {difference:.2f})",
)

# Set titles and labels
axs[i].set_title(channel) # Subplot title is just the channel name
axs[i].set_xlabel(var_name.capitalize())
axs[i].set_ylabel("Density")
axs[i].legend(loc="upper right")

Check warning on line 1187 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1184-L1187

Added lines #L1184 - L1187 were not covered by tests

# Set the overall figure title
fig.suptitle(f"Prior vs Posterior Distributions | {var_name}", fontsize=16)

Check warning on line 1190 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1190

Added line #L1190 was not covered by tests

# Hide any unused subplots
for j in range(i + 1, len(axs)):
fig.delaxes(axs[j])

Check warning on line 1194 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1193-L1194

Added lines #L1193 - L1194 were not covered by tests

# Adjust layout
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to fit the title

Check warning on line 1197 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1197

Added line #L1197 was not covered by tests

return fig

Check warning on line 1199 in pymc_marketing/mmm/mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/mmm.py#L1199

Added line #L1199 was not covered by tests

def get_ts_contribution_posterior(
self, var_contribution: str, original_scale: bool = False
) -> DataArray:
Expand Down

0 comments on commit 651dacb

Please sign in to comment.