diff --git a/pymc_marketing/mmm/linear_trend.py b/pymc_marketing/mmm/linear_trend.py index e1aac9ea..d9321d02 100644 --- a/pymc_marketing/mmm/linear_trend.py +++ b/pymc_marketing/mmm/linear_trend.py @@ -52,19 +52,22 @@ """ +from collections.abc import Iterable from typing import Any, cast -import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import pymc as pm import pytensor.tensor as pt import xarray as xr +from matplotlib.axes import Axes +from matplotlib.figure import Figure from pydantic import BaseModel, Field, InstanceOf, model_validator from pymc.distributions.shape_utils import Dims +from pytensor.tensor.variable import TensorVariable from typing_extensions import Self -from pymc_marketing.mmm.plot import plot_curve +from pymc_marketing.mmm.plot import SelToString, plot_curve from pymc_marketing.prior import Prior, create_dim_handler @@ -278,7 +281,7 @@ def default_priors(self) -> dict[str, Prior]: return priors - def apply(self, t: pt.TensorLike) -> pt.TensorVariable: + def apply(self, t: pt.TensorLike) -> TensorVariable: """Create the linear trend for the given x values. Parameters @@ -409,7 +412,12 @@ def plot_curve( sample_kwargs: dict | None = None, hdi_kwargs: dict | None = None, include_changepoints: bool = True, - ) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]: + axes: npt.NDArray[Axes] | None = None, + same_axes: bool = False, + colors: Iterable[str] | None = None, + legend: bool | None = None, + sel_to_string: SelToString | None = None, + ) -> tuple[Figure, npt.NDArray[Axes]]: """Plot the curve samples from the trend. Parameters @@ -424,6 +432,16 @@ def plot_curve( Keyword arguments for the HDI, by default None. include_changepoints : bool, optional Include the change points in the plot, by default True. + axes : npt.NDArray[plt.Axes], optional + Axes to plot the curve, by default None. + same_axes : bool, optional + Use the same axes for the samples, by default False. + colors : Iterable[str], optional + Colors for the samples, by default None. + legend : bool, optional + Include a legend in the plot, by default None. + sel_to_string : SelToString, optional + Function to convert the selection to a string, by default None. Returns ------- @@ -437,6 +455,11 @@ def plot_curve( subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, + axes=axes, + same_axes=same_axes, + colors=colors, + legend=legend, + sel_to_string=sel_to_string, ) if not include_changepoints: