diff --git a/pymc_marketing/mmm/plot.py b/pymc_marketing/mmm/plot.py index 721cadf4..f511490d 100644 --- a/pymc_marketing/mmm/plot.py +++ b/pymc_marketing/mmm/plot.py @@ -243,69 +243,6 @@ def _plot_hdi_selection( SelToString = Callable[[Selection], str] -def plot_hdi( - curve: xr.DataArray, - non_grid_names: set[str], - hdi_kwargs: dict | None = None, - subplot_kwargs: dict[str, Any] | None = None, - plot_kwargs: dict[str, Any] | None = None, - axes: npt.NDArray[Axes] | None = None, - same_axes: bool = False, - colors: Iterable[str] | None = None, - legend: bool = False, - sel_to_string: SelToString | None = None, -) -> tuple[plt.Figure, npt.NDArray[Axes]]: - """Plot hdi of the curve across coords. - - Parameters - ---------- - curve : xr.DataArray - Curve to plot - non_grid_names : set[str] - The names to exclude from the grid. chain and draw are - excluded automatically - n : int, optional - Number of samples to plot - rng : np.random.Generator, optional - Random number generator - axes : npt.NDArray[plt.Axes], optional - Axes to plot on - subplot_kwargs : dict, optional - Additional kwargs to while creating the fig and axes - plot_kwargs : dict, optional - Kwargs for the plot function - - Returns - ------- - tuple[plt.Figure, npt.NDArray[plt.Axes]] - Figure and the axes - - """ - get_plot_data = _create_get_hdi_plot_data(hdi_kwargs or {}) - make_selection = _make_hdi_selection - plot_selection = _plot_hdi_selection - - plot_kwargs = plot_kwargs or {} - plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} - - return _plot_across_coord( - curve=curve, - non_grid_names=non_grid_names, - get_plot_data=get_plot_data, - make_selection=make_selection, - plot_selection=plot_selection, - subplot_kwargs=subplot_kwargs, - same_axes=same_axes, - axes=axes, - colors=colors, - legend=legend, - plot_kwargs=plot_kwargs, - patch=True, - line=False, - sel_to_string=sel_to_string, - ) - - def random_samples( rng: np.random.Generator, n: int, @@ -467,6 +404,69 @@ def create_title(sel): return fig, return_axes +def plot_hdi( + curve: xr.DataArray, + non_grid_names: set[str], + hdi_kwargs: dict | None = None, + subplot_kwargs: dict[str, Any] | None = None, + plot_kwargs: dict[str, Any] | None = None, + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool = False, + sel_to_string: SelToString | None = None, +) -> tuple[plt.Figure, npt.NDArray[Axes]]: + """Plot hdi of the curve across coords. + + Parameters + ---------- + curve : xr.DataArray + Curve to plot + non_grid_names : set[str] + The names to exclude from the grid. chain and draw are + excluded automatically + n : int, optional + Number of samples to plot + rng : np.random.Generator, optional + Random number generator + axes : npt.NDArray[plt.Axes], optional + Axes to plot on + subplot_kwargs : dict, optional + Additional kwargs to while creating the fig and axes + plot_kwargs : dict, optional + Kwargs for the plot function + + Returns + ------- + tuple[plt.Figure, npt.NDArray[plt.Axes]] + Figure and the axes + + """ + get_plot_data = _create_get_hdi_plot_data(hdi_kwargs or {}) + make_selection = _make_hdi_selection + plot_selection = _plot_hdi_selection + + plot_kwargs = plot_kwargs or {} + plot_kwargs = {**{"alpha": 0.25}, **plot_kwargs} + + return _plot_across_coord( + curve=curve, + non_grid_names=non_grid_names, + get_plot_data=get_plot_data, + make_selection=make_selection, + plot_selection=plot_selection, + subplot_kwargs=subplot_kwargs, + same_axes=same_axes, + axes=axes, + colors=colors, + legend=legend, + plot_kwargs=plot_kwargs, + patch=True, + line=False, + sel_to_string=sel_to_string, + ) + + def plot_samples( curve: xr.DataArray, non_grid_names: set[str],