Skip to content

Commit

Permalink
Add user customization to plot_curve methods (#1018)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Sep 11, 2024
1 parent 40dee1d commit 9362709
Show file tree
Hide file tree
Showing 5 changed files with 588 additions and 135 deletions.
45 changes: 34 additions & 11 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@
"""

import warnings
from collections.abc import Iterable
from copy import deepcopy
from inspect import signature
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.mmm.plot import (
SelToString,
plot_curve,
plot_hdi,
plot_samples,
Expand Down Expand Up @@ -299,12 +303,10 @@ def variable_mapping(self) -> dict[str, str]:

def _create_distributions(
self, dims: Dims | None = None
) -> dict[str, pt.TensorVariable]:
) -> dict[str, TensorVariable]:
dim_handler: DimHandler = create_dim_handler(dims)

def create_variable(
parameter_name: str, variable_name: str
) -> pt.TensorVariable:
def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
dist = self.function_priors[parameter_name]
var = dist.create_variable(variable_name)
return dim_handler(var, dist.dims)
Expand Down Expand Up @@ -344,7 +346,12 @@ def plot_curve(
subplot_kwargs: dict | None = None,
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot curve HDI and samples.
Parameters
Expand All @@ -357,6 +364,16 @@ def plot_curve(
Keyword arguments for the plot_curve_sample function. Defaults to None.
hdi_kwargs : dict, optional
Keyword arguments for the plot_curve_hdi function. Defaults to None.
axes : npt.NDArray[plt.Axes], optional
The exact axes to plot on. Overrides any subplot_kwargs
same_axes : bool, optional
If the axes should be the same for all plots. Defaults to False.
colors : Iterable[str], optional
The colors to use for the plot. Defaults to None.
legend : bool, optional
If the legend should be shown. Defaults to None.
sel_to_string : SelToString, optional
The function to convert the selection to a string. Defaults to None.
Returns
-------
Expand All @@ -369,6 +386,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

def _sample_curve(
Expand Down Expand Up @@ -424,8 +446,8 @@ def plot_curve_samples(
rng: np.random.Generator | None = None,
plot_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot samples from the curve.
Parameters
Expand Down Expand Up @@ -466,8 +488,8 @@ def plot_curve_hdi(
hdi_kwargs: dict | None = None,
plot_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot the HDI of the curve.
Parameters
Expand All @@ -494,9 +516,10 @@ def plot_curve_hdi(
axes=axes,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
hdi_kwargs=hdi_kwargs,
)

def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable:
def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable:
"""Call within a model context.
Used internally of the MMM to apply the transformation to the data.
Expand Down
32 changes: 26 additions & 6 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@
"""

from collections.abc import Callable
from collections.abc import Callable, Iterable
from typing import Any

import arviz as az
Expand All @@ -219,7 +219,7 @@
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.mmm.plot import plot_curve, plot_hdi, plot_samples
from pymc_marketing.mmm.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, create_dim_handler

X_NAME: str = "day"
Expand Down Expand Up @@ -465,6 +465,11 @@ def plot_curve(
subplot_kwargs: dict | None = None,
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
"""Plot the seasonality for one full period.
Expand All @@ -478,6 +483,16 @@ def plot_curve(
Keyword arguments for the plot_full_period_samples method, by default None
hdi_kwargs : dict, optional
Keyword arguments for the plot_full_period_hdi method, by default None
axes : npt.NDArray[plt.Axes], optional
Matplotlib axes, by default None
same_axes : bool, optional
Use the same axes for all plots, by default False
colors : Iterable[str], optional
Colors for the different plots, by default None
legend : bool, optional
Show the legend, by default None
sel_to_string : SelToString, optional
Function to convert the selection to a string, by default None
Returns
-------
Expand All @@ -491,6 +506,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

def plot_curve_hdi(
Expand Down Expand Up @@ -596,9 +616,9 @@ class YearlyFourier(FourierBase):
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
yearly = YearlyFourier(n_order=2, prior=dist)
prior = yearly.sample_prior(random_seed=rng)
curve = yearly.sample_full_period(prior)
curve = yearly.sample_curve(prior)
_, axes = yearly.plot_full_period(curve)
_, axes = yearly.plot_curve(curve)
axes[0].set(title="Yearly Fourier Seasonality")
plt.show()
Expand Down Expand Up @@ -643,9 +663,9 @@ class MonthlyFourier(FourierBase):
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
yearly = MonthlyFourier(n_order=2, prior=dist)
prior = yearly.sample_prior(samples=100)
curve = yearly.sample_full_period(prior)
curve = yearly.sample_curve(prior)
_, axes = yearly.plot_full_period(curve)
_, axes = yearly.plot_curve(curve)
axes[0].set(title="Monthly Fourier Seasonality")
plt.show()
Expand Down
31 changes: 27 additions & 4 deletions pymc_marketing/mmm/linear_trend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,22 @@
"""

from collections.abc import Iterable
from typing import Any, cast

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pydantic import BaseModel, Field, InstanceOf, model_validator
from pymc.distributions.shape_utils import Dims
from pytensor.tensor.variable import TensorVariable
from typing_extensions import Self

from pymc_marketing.mmm.plot import plot_curve
from pymc_marketing.mmm.plot import SelToString, plot_curve
from pymc_marketing.prior import Prior, create_dim_handler


Expand Down Expand Up @@ -278,7 +281,7 @@ def default_priors(self) -> dict[str, Prior]:

return priors

def apply(self, t: pt.TensorLike) -> pt.TensorVariable:
def apply(self, t: pt.TensorLike) -> TensorVariable:
"""Create the linear trend for the given x values.
Parameters
Expand Down Expand Up @@ -409,7 +412,12 @@ def plot_curve(
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
include_changepoints: bool = True,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot the curve samples from the trend.
Parameters
Expand All @@ -424,6 +432,16 @@ def plot_curve(
Keyword arguments for the HDI, by default None.
include_changepoints : bool, optional
Include the change points in the plot, by default True.
axes : npt.NDArray[plt.Axes], optional
Axes to plot the curve, by default None.
same_axes : bool, optional
Use the same axes for the samples, by default False.
colors : Iterable[str], optional
Colors for the samples, by default None.
legend : bool, optional
Include a legend in the plot, by default None.
sel_to_string : SelToString, optional
Function to convert the selection to a string, by default None.
Returns
-------
Expand All @@ -437,6 +455,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

if not include_changepoints:
Expand Down
Loading

0 comments on commit 9362709

Please sign in to comment.