diff --git a/pymc_marketing/mmm/components/base.py b/pymc_marketing/mmm/components/base.py index 0ed73876..b03cfbae 100644 --- a/pymc_marketing/mmm/components/base.py +++ b/pymc_marketing/mmm/components/base.py @@ -21,19 +21,23 @@ """ import warnings +from collections.abc import Iterable from copy import deepcopy from inspect import signature from typing import Any -import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pymc as pm import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure from pymc.distributions.shape_utils import Dims from pytensor import tensor as pt +from pytensor.tensor.variable import TensorVariable from pymc_marketing.mmm.plot import ( + SelToString, plot_curve, plot_hdi, plot_samples, @@ -299,12 +303,10 @@ def variable_mapping(self) -> dict[str, str]: def _create_distributions( self, dims: Dims | None = None - ) -> dict[str, pt.TensorVariable]: + ) -> dict[str, TensorVariable]: dim_handler: DimHandler = create_dim_handler(dims) - def create_variable( - parameter_name: str, variable_name: str - ) -> pt.TensorVariable: + def create_variable(parameter_name: str, variable_name: str) -> TensorVariable: dist = self.function_priors[parameter_name] var = dist.create_variable(variable_name) return dim_handler(var, dist.dims) @@ -344,7 +346,12 @@ def plot_curve( subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot curve HDI and samples. Parameters @@ -357,6 +364,16 @@ def plot_curve( Keyword arguments for the plot_curve_sample function. Defaults to None. hdi_kwargs : dict, optional Keyword arguments for the plot_curve_hdi function. Defaults to None. + axes : npt.NDArray[plt.Axes], optional + The exact axes to plot on. Overrides any subplot_kwargs + same_axes : bool, optional + If the axes should be the same for all plots. Defaults to False. + colors : Iterable[str], optional + The colors to use for the plot. Defaults to None. + legend : bool, optional + If the legend should be shown. Defaults to None. + sel_to_string : SelToString, optional + The function to convert the selection to a string. Defaults to None. Returns ------- @@ -369,6 +386,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) def _sample_curve( @@ -424,8 +446,8 @@ def plot_curve_samples( rng: np.random.Generator | None = None, plot_kwargs: dict | None = None, subplot_kwargs: dict | None = None, - axes: npt.NDArray[plt.Axes] | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot samples from the curve. Parameters @@ -466,8 +488,8 @@ def plot_curve_hdi( hdi_kwargs: dict | None = None, plot_kwargs: dict | None = None, subplot_kwargs: dict | None = None, - axes: npt.NDArray[plt.Axes] | None = None, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot the HDI of the curve. Parameters @@ -494,9 +516,10 @@ def plot_curve_hdi( axes=axes, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, + hdi_kwargs=hdi_kwargs, ) - def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable: + def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable: """Call within a model context. Used internally of the MMM to apply the transformation to the data.