From 562be70220b2682029a67d8acf253404ae47afe4 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 10 Sep 2024 20:52:22 -0400 Subject: [PATCH] add parameters to fourier plot method --- pymc_marketing/mmm/fourier.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/pymc_marketing/mmm/fourier.py b/pymc_marketing/mmm/fourier.py index 23bac226..c284bac3 100644 --- a/pymc_marketing/mmm/fourier.py +++ b/pymc_marketing/mmm/fourier.py @@ -205,7 +205,7 @@ """ -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any import arviz as az @@ -219,7 +219,7 @@ from typing_extensions import Self from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR -from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples +from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples from pymc_marketing.prior import Prior, create_dim_handler X_NAME: str = "day" @@ -465,6 +465,11 @@ def plot_curve( subplot_kwargs: dict | None = None, sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, + axes: npt.NDArray[plt.Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: """Plot the seasonality for one full period. @@ -478,6 +483,16 @@ def plot_curve( Keyword arguments for the plot_full_period_samples method, by default None hdi_kwargs : dict, optional Keyword arguments for the plot_full_period_hdi method, by default None + axes : npt.NDArray[plt.Axes], optional + Matplotlib axes, by default None + same_axes : bool, optional + Use the same axes for all plots, by default False + colors : Iterable[str], optional + Colors for the different plots, by default None + legend : bool, optional + Show the legend, by default None + sel_to_string : SelToString, optional + Function to convert the selection to a string, by default None Returns ------- @@ -491,6 +506,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) def plot_curve_hdi(