Skip to content

Commit

Permalink
add parameters to linear_trend
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Sep 11, 2024
1 parent aac68a9 commit 28b10d6
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions pymc_marketing/mmm/linear_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,22 @@
"""

from collections.abc import Iterable
from typing import Any, cast

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pydantic import BaseModel, Field, InstanceOf, model_validator
from pymc.distributions.shape_utils import Dims
from pytensor.tensor.variable import TensorVariable
from typing_extensions import Self

from pymc_marketing.mmm.plot import plot_curve
from pymc_marketing.mmm.plot import SelToString, plot_curve
from pymc_marketing.prior import Prior, create_dim_handler


Expand Down Expand Up @@ -278,7 +281,7 @@ def default_priors(self) -> dict[str, Prior]:

return priors

def apply(self, t: pt.TensorLike) -> pt.TensorVariable:
def apply(self, t: pt.TensorLike) -> TensorVariable:
"""Create the linear trend for the given x values.
Parameters
Expand Down Expand Up @@ -409,7 +412,12 @@ def plot_curve(
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
include_changepoints: bool = True,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot the curve samples from the trend.
Parameters
Expand All @@ -424,6 +432,16 @@ def plot_curve(
Keyword arguments for the HDI, by default None.
include_changepoints : bool, optional
Include the change points in the plot, by default True.
axes : npt.NDArray[plt.Axes], optional
Axes to plot the curve, by default None.
same_axes : bool, optional
Use the same axes for the samples, by default False.
colors : Iterable[str], optional
Colors for the samples, by default None.
legend : bool, optional
Include a legend in the plot, by default None.
sel_to_string : SelToString, optional
Function to convert the selection to a string, by default None.
Returns
-------
Expand All @@ -437,6 +455,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

if not include_changepoints:
Expand Down

0 comments on commit 28b10d6

Please sign in to comment.