Skip to content

Commit

Permalink
add parameters to fourier plot method
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Sep 11, 2024
1 parent 358ff9a commit 562be70
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"""

from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any

import arviz as az
Expand All @@ -219,7 +219,7 @@
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples
from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, create_dim_handler

X_NAME: str = "day"
Expand Down Expand Up @@ -465,6 +465,11 @@ def plot_curve(
subplot_kwargs: dict | None = None,
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
"""Plot the seasonality for one full period.
Expand All @@ -478,6 +483,16 @@ def plot_curve(
Keyword arguments for the plot_full_period_samples method, by default None
hdi_kwargs : dict, optional
Keyword arguments for the plot_full_period_hdi method, by default None
axes : npt.NDArray[plt.Axes], optional
Matplotlib axes, by default None
same_axes : bool, optional
Use the same axes for all plots, by default False
colors : Iterable[str], optional
Colors for the different plots, by default None
legend : bool, optional
Show the legend, by default None
sel_to_string : SelToString, optional
Function to convert the selection to a string, by default None
Returns
-------
Expand All @@ -491,6 +506,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

def plot_curve_hdi(
Expand Down

0 comments on commit 562be70

Please sign in to comment.