Skip to content

Commit

Permalink
added plotting options to plot_posterior_predictive
Browse files Browse the repository at this point in the history
  • Loading branch information
jsnyde0 committed Sep 30, 2024
1 parent 7a2b13f commit 9c874b6
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 35 deletions.
42 changes: 37 additions & 5 deletions docs/source/notebooks/mmm/mmm_example.ipynb

Large diffs are not rendered by default.

227 changes: 197 additions & 30 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,31 +360,63 @@ def plot_prior_predictive(self, **plt_kwargs: Any) -> plt.Figure:
return fig

def plot_posterior_predictive(
self, original_scale: bool = False, ax: plt.Axes = None, **plt_kwargs: Any
self,
original_scale: bool = False,
add_hdi: bool = True,
add_mean: bool = True,
add_gradient: bool = False,
ax: plt.Axes = None,
**plt_kwargs: Any,
) -> plt.Figure:
"""Plot posterior distribution from the model fit.
"""
Plot the posterior predictive distribution from the model fit.
This function creates a visualization of the model's posterior predictive distribution,
allowing for comparison with observed data. It can include highest density intervals (HDI),
mean predictions, and a gradient representation of the full distribution.
Parameters
----------
original_scale : bool, optional
Whether to plot in the original scale.
If True, plot in the original scale of the target variable.
If False, plot in the transformed scale used for modeling. Default is False.
add_hdi : bool, optional
If True, add highest density intervals to the plot. Default is True.
add_mean : bool, optional
If True, add the mean prediction to the plot. Default is True.
add_gradient : bool, optional
If True, add a gradient representation of the full posterior distribution. Default is False.
ax : plt.Axes, optional
Matplotlib axis object.
**plt_kwargs
Keyword arguments passed to `plt.subplots`.
A matplotlib Axes object to plot on. If None, a new figure and axes will be created.
**plt_kwargs : dict
Additional keyword arguments to pass to plt.subplots() when creating a new figure.
Returns
-------
plt.Figure
The matplotlib Figure object containing the plot.
Raises
------
ValueError
If the length of the target variable doesn't match the length
of the date column in the posterior predictive data.
Notes
-----
This function visualizes the model's predictions against the observed data.
The observed data is always plotted as a black line.
Depending on the parameters, it can also show:
- HDI (Highest Density Intervals) at 94% and 50% levels
- Mean prediction line
- Gradient representation of the full posterior distribution
If predicting out-of-sample, ensure that `self.y` is overwritten with the
corresponding non-transformed target variable.
"""
try:
posterior_predictive_data: Dataset = self.posterior_predictive

except Exception as e:
raise RuntimeError(
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
) from e
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
original_scale=original_scale
)

target_to_plot = np.asarray(
self.y
Expand All @@ -404,25 +436,20 @@ def plot_posterior_predictive(
else:
fig = ax.figure

if original_scale:
posterior_predictive_data = apply_sklearn_transformer_across_dim(
data=posterior_predictive_data,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
)
if add_hdi:
for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
ax = self._add_hdi_to_plot(
ax=ax, original_scale=original_scale, hdi_prob=hdi_prob, alpha=alpha
)

for hdi_prob, alpha in zip((0.94, 0.50), (0.2, 0.4), strict=True):
likelihood_hdi: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=hdi_prob
)[self.output_var]
if add_mean:
ax = self._add_mean_to_plot(
ax=ax, original_scale=original_scale, color="blue"
)

ax.fill_between(
x=posterior_predictive_data.date,
y1=likelihood_hdi[:, 0],
y2=likelihood_hdi[:, 1],
color="C0",
alpha=alpha,
label=f"{hdi_prob:.0%} HDI",
if add_gradient:
ax = self._add_gradient_to_plot(
ax=ax, original_scale=original_scale, n_percentiles=30, palette="Blues"
)

ax.plot(
Expand All @@ -440,6 +467,146 @@ def plot_posterior_predictive(

return fig

def _get_posterior_predictive_data(self, original_scale: bool = False) -> Dataset:
"""Get the posterior predictive data."""
try:
posterior_predictive_data: Dataset = self.posterior_predictive

except Exception as e:
raise RuntimeError(
"Make sure the model has bin fitted and the posterior predictive has been sampled!"
) from e

if original_scale:
posterior_predictive_data = apply_sklearn_transformer_across_dim(
data=posterior_predictive_data,
func=self.get_target_transformer().inverse_transform,
dim_name="date",
)
return posterior_predictive_data

def _add_mean_to_plot(
self, ax, original_scale: bool = False, color="blue", linestyle="-", **kwargs
) -> plt.Axes:
"""Add mean prediction to existing plot."""
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
original_scale=original_scale
)

mean_prediction = posterior_predictive_data[self.output_var].mean(
dim=["chain", "draw"]
)

ax.plot(
np.asarray(posterior_predictive_data.date),
mean_prediction,
color=color,
linestyle=linestyle,
label="Mean Prediction",
)
return ax

def _add_hdi_to_plot(
self,
ax: plt.Axes,
original_scale: bool = False,
hdi_prob: float = 0.94,
color: str = "C0",
alpha: float = 0.2,
**kwargs,
) -> plt.Axes:
"""Add HDI to existing plot."""
posterior_predictive_data: Dataset = self._get_posterior_predictive_data(
original_scale=original_scale
)

likelihood_hdi: DataArray = az.hdi(
ary=posterior_predictive_data, hdi_prob=hdi_prob
)[self.output_var]

ax.fill_between(
x=posterior_predictive_data.date,
y1=likelihood_hdi[:, 0],
y2=likelihood_hdi[:, 1],
color=color,
alpha=alpha,
label=f"{hdi_prob:.0%} HDI",
**kwargs,
)
return ax

def _add_gradient_to_plot(
self,
ax: plt.Axes,
original_scale: bool = False,
n_percentiles: int = 30,
palette: str = "Blues",
**kwargs,
) -> plt.Axes:
"""
Add a gradient representation of the posterior predictive distribution to an existing plot.
This method creates a shaded area plot where the color intensity represents
the density of the posterior predictive distribution.
Parameters
----------
ax : plt.Axes
The matplotlib axes object to add the gradient to.
original_scale : bool, optional
If True, use the original scale of the data. Default is False.
n_percentiles : int, optional
Number of percentile ranges to use for the gradient. Default is 30.
palette : str, optional
Color palette to use for the gradient. Default is "Blues".
**kwargs
Additional keyword arguments passed to ax.fill_between().
Returns
-------
plt.Axes
The matplotlib axes object with the gradient added.
"""
# Get posterior predictive data and flatten it
posterior_predictive = self._get_posterior_predictive_data(
original_scale=original_scale
)
posterior_predictive_flattened = posterior_predictive.stack(
sample=("chain", "draw")
).to_dataarray()
dates = posterior_predictive.date.values

# Set up color map and ranges
cmap = plt.get_cmap(palette)
color_range = np.linspace(0.3, 1.0, n_percentiles // 2)
percentile_ranges = np.linspace(3, 97, n_percentiles)

# Create gradient by filling between percentile ranges
for i in range(len(percentile_ranges) - 1):
lower_percentile = np.percentile(
posterior_predictive_flattened, percentile_ranges[i], axis=2
).squeeze()
upper_percentile = np.percentile(
posterior_predictive_flattened, percentile_ranges[i + 1], axis=2
).squeeze()
if i < n_percentiles // 2:
color_val = color_range[i]
else:
color_val = color_range[n_percentiles - i - 2]
alpha_val = 0.2 + 0.8 * (
1 - abs(2 * i / n_percentiles - 1)
) # Higher alpha in the middle
ax.fill_between(
x=dates,
y1=lower_percentile,
y2=upper_percentile,
color=cmap(color_val),
alpha=alpha_val,
**kwargs,
)

return ax

def get_errors(self, original_scale: bool = False) -> DataArray:
"""Get model errors posterior distribution.
Expand Down
56 changes: 56 additions & 0 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,62 @@ class ToyMMM(BaseMMM, MaxAbsScaleTarget):
("plot_posterior_predictive", {}),
("plot_posterior_predictive", {"original_scale": True}),
("plot_posterior_predictive", {"ax": plt.subplots()[1]}),
(
"plot_posterior_predictive",
{
"add_mean": True,
"original_scale": False,
},
),
(
"plot_posterior_predictive",
{
"add_gradient": True,
"original_scale": True,
},
),
(
"plot_posterior_predictive",
{
"add_hdi": True,
"original_scale": False,
},
),
(
"plot_posterior_predictive",
{
"add_mean": True,
"add_hdi": True,
"original_scale": True,
},
),
(
"plot_posterior_predictive",
{
"add_mean": True,
"add_gradient": True,
"add_hdi": True,
"original_scale": False,
},
),
(
"plot_posterior_predictive",
{
"add_mean": True,
"add_gradient": True,
"add_hdi": True,
"original_scale": True,
},
),
(
"plot_posterior_predictive",
{
"add_mean": False,
"add_gradient": True,
"add_hdi": False,
"original_scale": False,
},
),
("plot_errors", {}),
("plot_errors", {"original_scale": True}),
("plot_errors", {"ax": plt.subplots()[1]}),
Expand Down

0 comments on commit 9c874b6

Please sign in to comment.