Skip to content

Commit

Permalink
update media transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Sep 11, 2024
1 parent 562be70 commit aac68a9
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@
"""

import warnings
from collections.abc import Iterable
from copy import deepcopy
from inspect import signature
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pymc as pm
import xarray as xr
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from pymc.distributions.shape_utils import Dims
from pytensor import tensor as pt
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.mmm.plot import (
SelToString,
plot_curve,
plot_hdi,
plot_samples,
Expand Down Expand Up @@ -299,12 +303,10 @@ def variable_mapping(self) -> dict[str, str]:

def _create_distributions(
self, dims: Dims | None = None
) -> dict[str, pt.TensorVariable]:
) -> dict[str, TensorVariable]:
dim_handler: DimHandler = create_dim_handler(dims)

def create_variable(
parameter_name: str, variable_name: str
) -> pt.TensorVariable:
def create_variable(parameter_name: str, variable_name: str) -> TensorVariable:
dist = self.function_priors[parameter_name]
var = dist.create_variable(variable_name)
return dim_handler(var, dist.dims)
Expand Down Expand Up @@ -344,7 +346,12 @@ def plot_curve(
subplot_kwargs: dict | None = None,
sample_kwargs: dict | None = None,
hdi_kwargs: dict | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
same_axes: bool = False,
colors: Iterable[str] | None = None,
legend: bool | None = None,
sel_to_string: SelToString | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot curve HDI and samples.
Parameters
Expand All @@ -357,6 +364,16 @@ def plot_curve(
Keyword arguments for the plot_curve_sample function. Defaults to None.
hdi_kwargs : dict, optional
Keyword arguments for the plot_curve_hdi function. Defaults to None.
axes : npt.NDArray[plt.Axes], optional
The exact axes to plot on. Overrides any subplot_kwargs
same_axes : bool, optional
If the axes should be the same for all plots. Defaults to False.
colors : Iterable[str], optional
The colors to use for the plot. Defaults to None.
legend : bool, optional
If the legend should be shown. Defaults to None.
sel_to_string : SelToString, optional
The function to convert the selection to a string. Defaults to None.
Returns
-------
Expand All @@ -369,6 +386,11 @@ def plot_curve(
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
axes=axes,
same_axes=same_axes,
colors=colors,
legend=legend,
sel_to_string=sel_to_string,
)

def _sample_curve(
Expand Down Expand Up @@ -424,8 +446,8 @@ def plot_curve_samples(
rng: np.random.Generator | None = None,
plot_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot samples from the curve.
Parameters
Expand Down Expand Up @@ -466,8 +488,8 @@ def plot_curve_hdi(
hdi_kwargs: dict | None = None,
plot_kwargs: dict | None = None,
subplot_kwargs: dict | None = None,
axes: npt.NDArray[plt.Axes] | None = None,
) -> tuple[plt.Figure, npt.NDArray[plt.Axes]]:
axes: npt.NDArray[Axes] | None = None,
) -> tuple[Figure, npt.NDArray[Axes]]:
"""Plot the HDI of the curve.
Parameters
Expand All @@ -494,9 +516,10 @@ def plot_curve_hdi(
axes=axes,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
hdi_kwargs=hdi_kwargs,
)

def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> pt.TensorVariable:
def apply(self, x: pt.TensorLike, dims: Dims | None = None) -> TensorVariable:
"""Call within a model context.
Used internally of the MMM to apply the transformation to the data.
Expand Down

0 comments on commit aac68a9

Please sign in to comment.