diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index 0ed73876..b03cfbae 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -21,19 +21,23 @@ """ import warnings +from collections.abc import Iterable from copy import deepcopy from inspect import signature from typing import Any -import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pymc as pm import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure from pymc.distributions.shape_utils import Dims from pytensor import tensor as pt +from pytensor.tensor.variable import TensorVariable from pymc_marketing.mmm.plot import ( + SelToString, plot_curve, plot_hdi, plot_samples, @@ -299,12 +303,10 @@ def variable_mapping(self) -> dict[str, str]: def _create_distributions( self, dims: Dims | None = None - ) -> dict[str, pt.TensorVariable]: + ) -> dict[str, TensorVariable]: dim_handler: DimHandler = create_dim_handler(dims) - def create_variable( - parameter_name: str, variable_name: str - ) -> pt.TensorVariable: + def create_variable(parameter_name: str, variable_name: str) -> TensorVariable: dist = self.function_priors[parameter_name] var = dist.create_variable(variable_name) return dim_handler(var, dist.dims) @@ -344,7 +346,12 @@ def plot_curve( subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot curve HDI and samples. Parameters @@ -357,6 +364,16 @@ def plot_curve( Keyword arguments for the plot_curve_sample function. Defaults to None. hdi_kwargs : dict, optional Keyword arguments for the plot_curve_hdi function. Defaults to None. + axes : npt.NDArray[plt.Axes], optional + The exact axes to plot on. Overrides any subplot_kwargs + same_axes : bool, optional + If the axes should be the same for all plots. Defaults to False. + colors : Iterable[str], optional + The colors to use for the plot. Defaults to None. + legend : bool, optional + If the legend should be shown. Defaults to None. + sel_to_string : SelToString, optional + The function to convert the selection to a string. Defaults to None. Returns ------- @@ -369,6 +386,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) def _sample_curve( @@ -424,8 +446,8 @@ def plot_curve_samples( rng: np.random.Generator | None = None, plot_kwargs: dict | None = None, subplot_kwargs: dict | None = None, - axes: npt.NDArray[plt.Axes] | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot samples from the curve. Parameters @@ -466,8 +488,8 @@ def plot_curve_hdi( hdi_kwargs: dict | None = None, plot_kwargs: dict | None = None, subplot_kwargs: dict | None = None, - axes: npt.NDArray[plt.Axes] | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot the HDI of the curve. Parameters @@ -494,9 +516,10 @@ def plot_curve_hdi( axes=axes, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, + hdi_kwargs=hdi_kwargs, ) - def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable: + def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable: """Call within a model context. Used internally of the MMM to apply the transformation to the data. diff --git a/pymc_marketing/mmm/fourier.py b/pymc_marketing/mmm/fourier.py index 23bac226..ba806060 100644 --- a/pymc_marketing/mmm/fourier.py +++ b/pymc_marketing/mmm/fourier.py @@ -205,7 +205,7 @@ """ -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import arviz as az @@ -219,7 +219,7 @@ from typing_extensions import Self from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR -from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples +from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples from pymc_marketing.prior import Prior, create_dim_handler X_NAME: str = "day" @@ -465,6 +465,11 @@ def plot_curve( subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, + axes: npt.NDArray[plt.Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot the seasonality for one full period. @@ -478,6 +483,16 @@ def plot_curve( Keyword arguments for the plot_full_period_samples method, by default None hdi_kwargs : dict, optional Keyword arguments for the plot_full_period_hdi method, by default None + axes : npt.NDArray[plt.Axes], optional + Matplotlib axes, by default None + same_axes : bool, optional + Use the same axes for all plots, by default False + colors : Iterable[str], optional + Colors for the different plots, by default None + legend : bool, optional + Show the legend, by default None + sel_to_string : SelToString, optional + Function to convert the selection to a string, by default None Returns ------- @@ -491,6 +506,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) def plot_curve_hdi( @@ -596,9 +616,9 @@ class YearlyFourier(FourierBase): dist = Prior("Laplace", mu=mu, b=b, dims="fourier") yearly = YearlyFourier(n_order=2, prior=dist) prior = yearly.sample_prior(random_seed=rng) - curve = yearly.sample_full_period(prior) + curve = yearly.sample_curve(prior) - _, axes = yearly.plot_full_period(curve) + _, axes = yearly.plot_curve(curve) axes[0].set(title="Yearly Fourier Seasonality") plt.show() @@ -643,9 +663,9 @@ class MonthlyFourier(FourierBase): dist = Prior("Laplace", mu=mu, b=b, dims="fourier") yearly = MonthlyFourier(n_order=2, prior=dist) prior = yearly.sample_prior(samples=100) - curve = yearly.sample_full_period(prior) + curve = yearly.sample_curve(prior) - _, axes = yearly.plot_full_period(curve) + _, axes = yearly.plot_curve(curve) axes[0].set(title="Monthly Fourier Seasonality") plt.show() diff --git a/pymc_marketing/mmm/linear_trend.py b/pymc_marketing/mmm/linear_trend.py index e1aac9ea..d9321d02 100644 --- a/pymc_marketing/mmm/linear_trend.py +++ b/pymc_marketing/mmm/linear_trend.py @@ -52,19 +52,22 @@ """ +from collections.abc import Iterable from typing import Any, cast -import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pymc as pm import pytensor.tensor as pt import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure from pydantic import BaseModel, Field, InstanceOf, model_validator from pymc.distributions.shape_utils import Dims +from pytensor.tensor.variable import TensorVariable from typing_extensions import Self -from pymc_marketing.mmm.plot import plot_curve +from pymc_marketing.mmm.plot import SelToString, plot_curve from pymc_marketing.prior import Prior, create_dim_handler @@ -278,7 +281,7 @@ def default_priors(self) -> dict[str, Prior]: return priors - def apply(self, t: pt.TensorLike) -> pt.TensorVariable: + def apply(self, t: pt.TensorLike) -> TensorVariable: """Create the linear trend for the given x values. Parameters @@ -409,7 +412,12 @@ def plot_curve( sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, include_changepoints: bool = True, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot the curve samples from the trend. Parameters @@ -424,6 +432,16 @@ def plot_curve( Keyword arguments for the HDI, by default None. include_changepoints : bool, optional Include the change points in the plot, by default True. + axes : npt.NDArray[plt.Axes], optional + Axes to plot the curve, by default None. + same_axes : bool, optional + Use the same axes for the samples, by default False. + colors : Iterable[str], optional + Colors for the samples, by default None. + legend : bool, optional + Include a legend in the plot, by default None. + sel_to_string : SelToString, optional + Function to convert the selection to a string, by default None. Returns ------- @@ -437,6 +455,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) if not include_changepoints: diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 84f4eaef..6efdba7d 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -11,18 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Plotting functions for the MMM.""" +"""Plot distributions stored in xarray.DataArray across coordinates. + +Used to plot the prior and posterior of the various MMM components. + +See the :func:`plot_curve` function for more information. + +""" import warnings -from collections.abc import Generator, MutableMapping, Sequence -from itertools import product -from typing import Any +from collections.abc import Callable, Generator, Iterable, MutableMapping, Sequence +from itertools import product, repeat +from typing import Any, Concatenate, ParamSpec, cast import arviz as az import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pandas as pd import xarray as xr +from matplotlib.axes import Axes +from matplotlib.lines import Line2D +from matplotlib.patches import Patch from pymc_marketing.mmm.utils import drop_scalar_coords @@ -73,6 +83,48 @@ def get_total_coord_size(coords: Coords) -> int: return total_size +def create_legend_handles( + colors: Iterable[str], + alpha: float = 0.5, + line: bool = True, + patch: bool = True, +) -> list[Line2D | Patch | tuple[Line2D, Patch]]: + """Create the legend handles for the given colors. + + Parameters + ---------- + colors : Iterable[str] + The colors to create the legend handles. + alpha : float, optional + The alpha value for the patches, by default 0.5. + line : bool, optional + Whether to include the line, by default True. + patch : bool, optional + Whether to include the patch, by default True. + + Returns + ------- + list[Line2D | Patch | tuple[Line2D, Patch]] + The legend handles. + + """ + if not line and not patch: + raise ValueError("At least one of line or patch must be True") + + def create_handle( + color: str, alpha: float + ) -> Line2D | Patch | tuple[Line2D, Patch]: + if line and patch: + return Line2D([0], [0], color=color), Patch(color=color, alpha=alpha) + + if line: + return Line2D([0], [0], color=color) + + return Patch(color=color, alpha=alpha) + + return [create_handle(color, alpha) for color in colors] + + def set_subplot_kwargs_defaults( subplot_kwargs: MutableMapping[str, Any], total_size: int, @@ -104,9 +156,12 @@ def set_subplot_kwargs_defaults( subplot_kwargs["ncols"] = total_size // subplot_kwargs["nrows"] +Selection = dict[str, Any] + + def selections( coords: Coords, -) -> Generator[dict[str, Any], None, None]: +) -> Generator[Selection, None, None]: """Create generator of selections. Parameters @@ -125,82 +180,69 @@ def selections( yield {name: value for name, value in zip(coord_names, values, strict=True)} -def plot_hdi( - curve: xr.DataArray, - non_grid_names: set[str], - hdi_kwargs: dict | None = None, - subplot_kwargs: dict[str, Any] | None = None, - plot_kwargs: dict[str, Any] | None = None, - axes: npt.NDArray[plt.Axes] | None = None, -) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: - """Plot hdi of the curve across coords. +P = ParamSpec("P") +GetPlotData = Callable[[xr.DataArray], xr.DataArray] +MakeSelection = Callable[[xr.DataArray, Selection], pd.DataFrame] +PlotSelection = Callable[Concatenate[pd.DataFrame, Axes, str, P], Axes] - Parameters - ---------- - curve : xr.DataArray - Curve to plot - non_grid_names : set[str] - The names to exclude from the grid. chain and draw are - excluded automatically - n : int, optional - Number of samples to plot - rng : np.random.Generator, optional - Random number generator - axes : npt.NDArray[plt.Axes], optional - Axes to plot on - subplot_kwargs : dict, optional - Additional kwargs to while creating the fig and axes - plot_kwargs : dict, optional - Kwargs for the plot function - Returns - ------- - tuple[plt.Figure, npt.NDArray[plt.Axes]] - Figure and the axes +def _get_sample_plot_data(data): + return data - """ - curve = drop_scalar_coords(curve) - hdi_kwargs = hdi_kwargs or {} - conf = az.hdi(curve, **hdi_kwargs)[curve.name] - - plot_coords = get_plot_coords( - conf.coords, - non_grid_names=non_grid_names.union({"hdi"}), +def _create_make_sample_selection( + rng, + n: int, + n_chains: int, + n_draws: int, +) -> MakeSelection: + rng = rng or np.random.default_rng() + idx = random_samples( + rng, + n=n, + n_chains=n_chains, + n_draws=n_draws, ) - total_size = get_total_coord_size(plot_coords) - if axes is None: - subplot_kwargs = subplot_kwargs or {} - subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} - set_subplot_kwargs_defaults(subplot_kwargs, total_size) - fig, axes = plt.subplots(**subplot_kwargs) - else: - fig = plt.gcf() + def make_sample_selection(data, sel): + return data.sel(sel).to_series().unstack().loc[idx, :].T - plot_kwargs = plot_kwargs or {} - plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} + return make_sample_selection - for i, (ax, sel) in enumerate( - zip(np.ravel(axes), selections(plot_coords), strict=False) - ): - color = f"C{i}" - df_conf = conf.sel(sel).to_series().unstack() - ax.fill_between( - x=df_conf.index, - y1=df_conf["lower"], - y2=df_conf["higher"], - color=color, - **plot_kwargs, - ) - title = ", ".join(f"{name}={value}" for name, value in sel.items()) - ax.set_title(title) +def _plot_sample_selection(df, ax: Axes, color: str, **plot_kwargs) -> Axes: + return df.plot(ax=ax, color=color, **plot_kwargs) - if not isinstance(axes, np.ndarray): - axes = np.array([axes]) - return fig, axes +def _create_get_hdi_plot_data(hdi_kwargs) -> GetPlotData: + def get_plot_data(data: xr.DataArray) -> xr.DataArray: + hdi: xr.Dataset = az.hdi(data, **hdi_kwargs) + return hdi[data.name] + + return get_plot_data + + +def _make_hdi_selection(data: xr.DataArray, sel: dict[str, Any]) -> pd.DataFrame: + return data.sel(sel).to_series().unstack() + + +def _plot_hdi_selection( + df: pd.DataFrame, + ax: Axes, + color: str, + **plot_kwargs, +) -> Axes: + ax.fill_between( + x=df.index, + y1=df["lower"], + y2=df["higher"], + color=color, + **plot_kwargs, + ) + return ax + + +SelToString = Callable[[Selection], str] def random_samples( @@ -235,15 +277,211 @@ def random_samples( ] +def generate_colors(n: int, start: int = 0) -> list[str]: + """Generate list of colors. + + Parameters + ---------- + n : int + Number of colors to generate + start : int, optional + Starting index, by default 0 + + Returns + ------- + list[str] + List of colors + + Examples + -------- + Generate 5 colors starting from index 1 + + .. code-block:: python + + colors = generate_colors(5, start=1) + print(colors) + # ['C1', 'C2', 'C3', 'C4', 'C5'] + + """ + return [f"C{i}" for i in range(start, start + n)] + + +def _plot_across_coord( + curve: xr.DataArray, + non_grid_names: set[str], + get_plot_data: GetPlotData, + make_selection: MakeSelection, + plot_selection: PlotSelection, + subplot_kwargs: dict | None = None, + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool = False, + plot_kwargs: dict[str, Any] | None = None, + patch: bool = True, + line: bool = True, + sel_to_string: SelToString | None = None, +) -> tuple[plt.Figure, npt.NDArray[Axes]]: + """Plot data array across coords. + + Commonality used for the `plot_samples` and `plot_hdi` functions. + Differences depending on the `get_plot_data`, `make_selection` and + `plot_selection` functions passed. + + Allows for plotting each coordinate combination on a separate axis + or on the same axis. + + """ + if sel_to_string is None: + + def sel_to_string(sel): + return ", ".join(f"{key}={value}" for key, value in sel.items()) + + curve = drop_scalar_coords(curve) + + data = get_plot_data(curve) + + plot_coords = get_plot_coords( + data.coords, + non_grid_names=non_grid_names.union({"chain", "draw", "hdi"}), + ) + total_size = get_total_coord_size(plot_coords) + + if axes is None and not same_axes: + subplot_kwargs = subplot_kwargs or {} + subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} + set_subplot_kwargs_defaults(subplot_kwargs, total_size) + fig, axes = plt.subplots(**subplot_kwargs) + axes_iter = np.ravel(axes) + return_axes = axes + + create_title = sel_to_string + + create_legend_label = None + elif axes is not None and same_axes: + fig = plt.gcf() + axes_iter = repeat(axes[0], total_size) # type: ignore + return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes + + def create_title(sel): + return "" + + create_legend_label = sel_to_string + + elif axes is None and same_axes: + fig, ax = plt.subplots(ncols=1, nrows=1) + axes_iter = repeat(ax, total_size) # type: ignore + return_axes = np.array([ax]) + + def create_title(sel): + return "" + + create_legend_label = sel_to_string + else: + fig = plt.gcf() + axes_iter = np.ravel(axes) # type: ignore + return_axes = np.array([axes]) if not isinstance(axes, np.ndarray) else axes + + create_title = sel_to_string # type: ignore + + create_legend_label = None + + colors = cast(Iterable[str], colors or generate_colors(n=total_size, start=0)) + + for color, ax, sel in zip(colors, axes_iter, selections(plot_coords), strict=False): + ax = data.pipe(make_selection, sel=sel).pipe( + plot_selection, + ax=ax, + color=color, + **plot_kwargs, + ) + title = create_title(sel) + ax.set_title(title) + + if same_axes and legend and create_legend_label is not None: + handles = create_legend_handles(colors, patch=patch, line=line) + labels = [create_legend_label(sel) for sel in selections(plot_coords)] + ax.legend(handles=handles, labels=labels) + + return fig, return_axes + + +def plot_hdi( + curve: xr.DataArray, + non_grid_names: set[str], + hdi_kwargs: dict | None = None, + subplot_kwargs: dict[str, Any] | None = None, + plot_kwargs: dict[str, Any] | None = None, + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool = False, + sel_to_string: SelToString | None = None, +) -> tuple[plt.Figure, npt.NDArray[Axes]]: + """Plot hdi of the curve across coords. + + Parameters + ---------- + curve : xr.DataArray + Curve to plot + non_grid_names : set[str] + The names to exclude from the grid. chain and draw are + excluded automatically + n : int, optional + Number of samples to plot + rng : np.random.Generator, optional + Random number generator + axes : npt.NDArray[plt.Axes], optional + Axes to plot on + subplot_kwargs : dict, optional + Additional kwargs to while creating the fig and axes + plot_kwargs : dict, optional + Kwargs for the plot function + + Returns + ------- + tuple[plt.Figure, npt.NDArray[plt.Axes]] + Figure and the axes + + """ + get_plot_data = _create_get_hdi_plot_data(hdi_kwargs or {}) + make_selection = _make_hdi_selection + plot_selection = _plot_hdi_selection + + plot_kwargs = plot_kwargs or {} + plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} + + return _plot_across_coord( + curve=curve, + non_grid_names=non_grid_names, + get_plot_data=get_plot_data, + make_selection=make_selection, + plot_selection=plot_selection, + subplot_kwargs=subplot_kwargs, + same_axes=same_axes, + axes=axes, + colors=colors, + legend=legend, + plot_kwargs=plot_kwargs, + patch=True, + line=False, + sel_to_string=sel_to_string, + ) + + def plot_samples( curve: xr.DataArray, non_grid_names: set[str], n: int = 10, rng: np.random.Generator | None = None, - axes: npt.NDArray[plt.Axes] | None = None, + axes: npt.NDArray[Axes] | None = None, subplot_kwargs: dict[str, Any] | None = None, plot_kwargs: dict[str, Any] | None = None, -) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool = False, + sel_to_string: SelToString | None = None, +) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot n samples of the curve across coords. Parameters @@ -263,6 +501,8 @@ def plot_samples( Additional kwargs to while creating the fig and axes plot_kwargs : dict, optional Kwargs for the plot function + same_axes : bool + All of the plots in the same axis Returns ------- @@ -270,21 +510,17 @@ def plot_samples( Figure and the axes """ - curve = drop_scalar_coords(curve) - - plot_coords = get_plot_coords( - curve.coords, - non_grid_names=non_grid_names.union({"chain", "draw"}), + get_plot_data = _get_sample_plot_data + + n_chains = curve.sizes["chain"] + n_draws = curve.sizes["draw"] + make_selection = _create_make_sample_selection( + rng=rng, + n=n, + n_chains=n_chains, + n_draws=n_draws, ) - total_size = get_total_coord_size(plot_coords) - - if axes is None: - subplot_kwargs = subplot_kwargs or {} - subplot_kwargs = {**{"sharey": True, "sharex": True}, **subplot_kwargs} - set_subplot_kwargs_defaults(subplot_kwargs, total_size) - fig, axes = plt.subplots(**subplot_kwargs) - else: - fig = plt.gcf() + plot_selection = _plot_sample_selection plot_kwargs = plot_kwargs or {} plot_kwargs = { @@ -292,28 +528,23 @@ def plot_samples( **plot_kwargs, } - rng = rng or np.random.default_rng() - idx = random_samples( - rng, n=n, n_chains=curve.sizes["chain"], n_draws=curve.sizes["draw"] + return _plot_across_coord( + curve=curve, + non_grid_names=non_grid_names, + get_plot_data=get_plot_data, + make_selection=make_selection, + plot_selection=plot_selection, + subplot_kwargs=subplot_kwargs, + plot_kwargs=plot_kwargs, + same_axes=same_axes, + axes=axes, + colors=colors, + legend=legend, + patch=False, + line=True, + sel_to_string=sel_to_string, ) - for i, (ax, sel) in enumerate( - zip(np.ravel(axes), selections(plot_coords), strict=False) - ): - color = f"C{i}" - - df_curve = curve.sel(sel).to_series().unstack() - df_sample = df_curve.loc[idx, :] - - df_sample.T.plot(ax=ax, color=color, **plot_kwargs) - title = ", ".join(f"{name}={value}" for name, value in sel.items()) - ax.set_title(title) - - if not isinstance(axes, np.ndarray): - axes = np.array([axes]) - - return fig, axes - def plot_curve( curve: xr.DataArray, @@ -321,7 +552,12 @@ def plot_curve( subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, -) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, +) -> tuple[plt.Figure, npt.NDArray[Axes]]: """Plot HDI with samples of the curve across coords. Parameters @@ -337,12 +573,83 @@ def plot_curve( Kwargs for the :func:`plot_curve` function hdi_kwargs : dict, optional Kwargs for the :func:`plot_hdi` function + same_axes : bool + If all of the plots are on the same axis + colors : Iterable[str], optional + Colors for the plots + legend : bool, optional + If to include a legend. Defaults to True if same_axes + sel_to_string : Callable[[Selection], str], optional + Function to convert selection to a string. Defaults to + ", ".join(f"{key}={value}" for key, value in sel.items()) Returns ------- tuple[plt.Figure, npt.NDArray[plt.Axes]] Figure and the axes + Examples + -------- + Plot prior for arbitrary Deterministic in PyMC model + + .. plot:: + :include-source: True + :context: reset + + import numpy as np + import pandas as pd + + import pymc as pm + + import matplotlib.pyplot as plt + + from pymc_marketing.mmm.plot import plot_curve + + seed = sum(map(ord, "Arbitrary curve")) + rng = np.random.default_rng(seed) + + dates = pd.date_range("2024-01-01", periods=52, freq="W") + + coords = {"date": dates, "product": ["A", "B"]} + with pm.Model(coords=coords) as model: + data = pm.Normal( + "data", + mu=[-0.5, 0.5], + sigma=1, + dims=("date", "product"), + ) + cumsum = pm.Deterministic( + "cumsum", + data.cumsum(axis=0), + dims=("date", "product"), + ) + idata = pm.sample_prior_predictive(random_seed=rng) + + curve = idata.prior["cumsum"] + + fig, axes = plot_curve( + curve, + non_grid_names={"date"}, + subplot_kwargs={"figsize": (15, 5)}, + ) + plt.show() + + Plot same curve on same axes with custom colors + + .. plot:: + :include-source: True + :context: close-figs + + colors = ["red", "blue"] + fig, axes = plot_curve( + curve, + non_grid_names={"date"}, + same_axes=True, + colors=colors, + ) + axes[0].set(title="Same data but on same axes and custom colors") + plt.show() + """ curve = drop_scalar_coords(curve) @@ -352,6 +659,23 @@ def plot_curve( if "subplot_kwargs" not in sample_kwargs: sample_kwargs["subplot_kwargs"] = subplot_kwargs + if "axes" not in sample_kwargs: + sample_kwargs["axes"] = axes + + if same_axes: + sample_kwargs["same_axes"] = True + sample_kwargs["legend"] = False + hdi_kwargs["same_axes"] = True + hdi_kwargs["legend"] = legend if isinstance(legend, bool) else True + + if colors is not None: + sample_kwargs["colors"] = colors + hdi_kwargs["colors"] = colors + + if sel_to_string is not None: + sample_kwargs["sel_to_string"] = sel_to_string + hdi_kwargs["sel_to_string"] = sel_to_string + fig, axes = plot_samples( curve, non_grid_names=non_grid_names, diff --git a/tests/mmm/test_plot.py b/tests/mmm/test_plot.py index 4fa24051..ed942bbd 100644 --- a/tests/mmm/test_plot.py +++ b/tests/mmm/test_plot.py @@ -18,6 +18,7 @@ import xarray as xr from pymc_marketing.mmm.plot import ( + plot_curve, plot_hdi, plot_samples, random_samples, @@ -77,7 +78,7 @@ def test_random_samples(sample_frame) -> None: assert len(df_sub) == n -@pytest.fixture +@pytest.fixture(scope="module") def mock_curve() -> xr.DataArray: coords = { "chain": np.arange(1), @@ -91,15 +92,77 @@ def mock_curve() -> xr.DataArray: ) -def test_plot_samples(mock_curve) -> None: - fig, axes = plot_samples(mock_curve, non_grid_names={"chain", "draw", "day"}) +@pytest.mark.parametrize("plot_func", [plot_samples, plot_hdi]) +@pytest.mark.parametrize( + "same_axes", [True, False], ids=["same_axes", "different_axes"] +) +@pytest.mark.parametrize("legend", [True, False], ids=["legend", "no_legend"]) +def test_plot_functions(mock_curve, plot_func, same_axes: bool, legend: bool) -> None: + fig, axes = plot_func( + mock_curve, + non_grid_names={"day"}, + same_axes=same_axes, + legend=legend, + ) - assert axes.size == 5 + assert axes.size == (1 if same_axes else mock_curve.sizes["geo"]) assert isinstance(fig, plt.Figure) + plt.close(fig) -def test_plot_hdi(mock_curve) -> None: - fig, axes = plot_hdi(mock_curve, non_grid_names={"day"}) +def test_plot_curve(mock_curve) -> None: + fig, axes = plot_curve(mock_curve, non_grid_names={"day"}) - assert axes.size == 5 + assert axes.size == mock_curve.sizes["geo"] assert isinstance(fig, plt.Figure) + plt.close(fig) + + +def test_plot_curve_supply_axes_same_axes(mock_curve) -> None: + _, ax = plt.subplots() + axes = np.array([ax]) + + fig, modified_axes = plot_curve( + mock_curve, + non_grid_names={"day"}, + axes=axes, + same_axes=True, + ) + + np.testing.assert_equal(axes, modified_axes) + plt.close(fig) + + +def test_plot_curve_custom_colors(mock_curve) -> None: + colors = ["red", "blue", "green", "yellow", "purple"] + + fig, axes = plot_curve(mock_curve, non_grid_names={"day"}, colors=colors) + + for ax, color in zip(axes, colors, strict=True): + for line in ax.get_lines(): + assert line.get_color() == color + + plt.close(fig) + + +def test_plot_curve_custom_sel_to_string(mock_curve) -> None: + def custom_sel_to_string(sel): + return ", ".join(f"{key}: {value}" for key, value in sel.items()) + + fig, axes = plot_curve( + mock_curve, + non_grid_names={"day"}, + sel_to_string=custom_sel_to_string, + ) + + titles = [ax.get_title() for ax in axes] + + assert titles == [ + "geo: 0", + "geo: 1", + "geo: 2", + "geo: 3", + "geo: 4", + ] + + plt.close(fig)