Skip to content

Commit

Permalink
remove noqa from plots (#761)
Browse files Browse the repository at this point in the history
* remove noqa

* fix escape

* empty

* empty
  • Loading branch information
juanitorduz authored and twiecki committed Sep 10, 2024
1 parent fb81614 commit 697bd99
Show file tree
Hide file tree
Showing 6 changed files with 618 additions and 409 deletions.
712 changes: 394 additions & 318 deletions docs/source/notebooks/mmm/mmm_tvp_example.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def plot_posterior_predictive(
y2=likelihood_hdi[:, 1],
color="C0",
alpha=alpha,
label=f"${100 * hdi_prob}\%$ HDI", # noqa: W605
label=f"${100 * hdi_prob}\\%$ HDI",
)

ax.plot(
Expand Down Expand Up @@ -507,7 +507,7 @@ def plot_errors(
y2=errors_hdi["errors"].sel(hdi="higher"),
color="C3",
alpha=alpha,
label=f"${100 * hdi_prob}\%$ HDI", # noqa: W605
label=f"${100 * hdi_prob}\\%$ HDI",
)

ax.plot(
Expand Down Expand Up @@ -591,7 +591,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
y2=hdi.isel(hdi=1),
color=f"C{i}",
alpha=0.25,
label=f"$94\%$ HDI ({var_contribution})", # noqa: W605
label=f"$94\\%$ HDI ({var_contribution})",
)
ax.plot(
np.asarray(self.X[self.date_column]),
Expand Down Expand Up @@ -625,7 +625,7 @@ def plot_components_contributions(self, **plt_kwargs: Any) -> plt.Figure:
y2=intercept_hdi[:, 1],
color=f"C{i + 1}",
alpha=0.25,
label="$94\%$ HDI (intercept)", # noqa: W605
label="$94\\%$ HDI (intercept)",
)
ax.plot(
np.asarray(self.X[self.date_column]),
Expand Down
142 changes: 93 additions & 49 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@
_get_saturation_function,
)
from pymc_marketing.mmm.lift_test import (
add_lift_measurements_to_likelihood,
add_lift_measurements_to_likelihood_from_saturation,
scale_lift_measurements,
)
from pymc_marketing.mmm.preprocessing import MaxAbsScaleChannels, MaxAbsScaleTarget
from pymc_marketing.mmm.tvp import create_time_varying_intercept, infer_time_index
from pymc_marketing.mmm.tvp import create_time_varying_gp_multiplier, infer_time_index
from pymc_marketing.mmm.utils import (
apply_sklearn_transformer_across_dim,
create_new_spend_data,
Expand Down Expand Up @@ -81,6 +81,7 @@ def __init__(
adstock: str | AdstockTransformation,
saturation: str | SaturationTransformation,
time_varying_intercept: bool = False,
time_varying_media: bool = False,
model_config: dict | None = None,
sampler_config: dict | None = None,
validate_data: bool = True,
Expand All @@ -105,6 +106,13 @@ def __init__(
Type of saturation transformation to apply.
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
Because the `time-varying` variable is centered around 1 and acts as a multiplier,
the variable `baseline_intercept` now represents the mean of the time-varying intercept.
time_varying_media : bool, optional
Whether to consider time-varying media contributions, by default False.
The `time-varying-media` creates a time media variable centered around 1,
this variable acts as a global multiplier (scaling factor) for all channels,
meaning all media channels share the same latent fluctiation.
model_config : Dictionary, optional
dictionary of parameters that initialise model configuration.
Class-default defined by the user default_model_config method.
Expand All @@ -121,6 +129,7 @@ def __init__(
self.control_columns = control_columns
self.adstock_max_lag = adstock_max_lag
self.time_varying_intercept = time_varying_intercept
self.time_varying_media = time_varying_media
self.yearly_seasonality = yearly_seasonality
self.date_column = date_column
self.validate_data = validate_data
Expand Down Expand Up @@ -220,7 +229,7 @@ def _generate_and_preprocess_model_data( # type: ignore
self.X: pd.DataFrame = X_data
self.y: pd.Series | np.ndarray = y

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:
self._time_index = np.arange(0, X.shape[0])
self._time_index_mid = X.shape[0] // 2
self._time_resolution = (
Expand Down Expand Up @@ -344,33 +353,67 @@ def build_model(
dims="date",
mutable=True,
)

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:
time_index = pm.Data(
"time_index",
self._time_index,
dims="date",
)
intercept_dist = get_distribution(

if self.time_varying_intercept:
intercept_distribution = get_distribution(
name=self.model_config["intercept"]["dist"]
)
intercept = create_time_varying_intercept(
time_index,
self._time_index_mid,
self._time_resolution,
intercept_dist,
self.model_config,
baseline_intercept = intercept_distribution(
name="baseline_intercept",
**self.model_config["intercept"]["kwargs"],
)

intercept_latent_process = create_time_varying_gp_multiplier(
name="intercept",
dims="date",
time_index=time_index,
time_index_mid=self._time_index_mid,
time_resolution=self._time_resolution,
model_config=self.model_config,
)
intercept = pm.Deterministic(
name="intercept",
var=baseline_intercept * intercept_latent_process,
dims="date",
)
else:
intercept = create_distribution_from_config(
name="intercept", config=self.model_config
)

channel_contributions = pm.Deterministic(
name="channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)
if self.time_varying_media:
baseline_channel_contributions = pm.Deterministic(
name="baseline_channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)

media_latent_process = create_time_varying_gp_multiplier(
name="media",
dims="date",
time_index=time_index,
time_index_mid=self._time_index_mid,
time_resolution=self._time_resolution,
model_config=self.model_config,
)
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=baseline_channel_contributions * media_latent_process[:, None],
dims=("date", "channel"),
)

else:
channel_contributions = pm.Deterministic(
name="channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)

mu_var = intercept + channel_contributions.sum(axis=-1)

Expand All @@ -383,11 +426,6 @@ def build_model(
)
):
if self.model_config["gamma_control"].get("dims") != "control":
msg = (
"The 'dims' key in gamma_control must be 'control'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_control"]["dims"] = "control"

gamma_control = create_distribution_from_config(
Expand Down Expand Up @@ -427,11 +465,6 @@ def build_model(
)

if self.model_config["gamma_fourier"].get("dims") != "fourier_mode":
msg = (
"The 'dims' key in gamma_fourier must be 'fourier_mode'."
" This will be fixed automatically."
)
warnings.warn(msg, stacklevel=2)
self.model_config["gamma_fourier"]["dims"] = "fourier_mode"

gamma_fourier = create_distribution_from_config(
Expand Down Expand Up @@ -465,7 +498,7 @@ def build_model(

@property
def default_model_config(self) -> dict:
base_config = {
base_config: dict[str, Any] = {
"intercept": {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 2}},
"likelihood": {
"dist": "Normal",
Expand All @@ -483,25 +516,30 @@ def default_model_config(self) -> dict:
"kwargs": {"mu": 0, "b": 1},
"dims": "fourier_mode",
},
"intercept_tvp_kwargs": {
}

if self.time_varying_intercept:
base_config["intercept_tvp_config"] = {
"m": 200,
"L": None,
"eta_lam": 1,
"ls_mu": None,
"ls_sigma": 10,
"cov_func": None,
},
}
}
if self.time_varying_media:
base_config["media_tvp_config"] = {
"m": 200,
"L": None,
"eta_lam": 1,
"ls_mu": None,
"ls_sigma": 10,
"cov_func": None,
}

for media_transform in [self.adstock, self.saturation]:
for param, config in media_transform.function_priors.items():
for config in media_transform.function_priors.values():
if "dims" not in config:
msg = (
f"{param} doesn't have a 'dims' key in config. Setting to channel."
f" Set priors explicitly in {media_transform.__class__.__name__}"
" to avoid this warning."
)
warnings.warn(msg, stacklevel=2)
config["dims"] = "channel"

return {
Expand Down Expand Up @@ -712,7 +750,7 @@ def identity(x):
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if self.time_varying_intercept:
if self.time_varying_intercept | self.time_varying_media:
data["time_index"] = infer_time_index(
X[self.date_column], self.X[self.date_column], self._time_resolution
)
Expand Down Expand Up @@ -1057,7 +1095,7 @@ def plot_channel_contributions_grid(
y1=hdi_contribution[:, 0],
y2=hdi_contribution[:, 1],
color=f"C{i}",
label=f"{channel} $94\%$ HDI contribution", # noqa: W605
label=f"{channel} $94\\%$ HDI contribution",
alpha=0.4,
)

Expand Down Expand Up @@ -1671,7 +1709,7 @@ def sample_posterior_predictive(
def add_lift_test_measurements(
self,
df_lift_test: pd.DataFrame,
dist: pm.Distribution = pm.Gamma,
dist: type[pm.Distribution] = pm.Gamma,
name: str = "lift_measurements",
) -> None:
"""Add lift tests to the model.
Expand Down Expand Up @@ -1770,14 +1808,18 @@ def add_lift_test_measurements(
channel_transform=self.channel_transformer.transform,
target_transform=self.target_transformer.transform,
)
with self.model:
add_lift_measurements_to_likelihood(
df_lift_test=df_lift_test_scaled,
variable_mapping=self.saturation.variable_mapping,
saturation_function=self.saturation.function,
dist=dist,
name=name,
)
# This is coupled with the name of the
# latent process Deterministic
time_varying_var_name = (
"media_temporal_latent_multiplier" if self.time_varying_media else None
)
add_lift_measurements_to_likelihood_from_saturation(
df_lift_test=df_lift_test_scaled,
saturation=self.saturation,
time_varying_var_name=time_varying_var_name,
model=self.model,
dist=dist,
)

def _create_synth_dataset(
self,
Expand Down Expand Up @@ -2150,6 +2192,7 @@ def __init__(
channel_columns: list[str],
adstock_max_lag: int,
time_varying_intercept: bool = False,
time_varying_media: bool = False,
model_config: dict | None = None,
sampler_config: dict | None = None,
validate_data: bool = True,
Expand All @@ -2175,6 +2218,7 @@ def __init__(
channel_columns=channel_columns,
adstock_max_lag=adstock_max_lag,
time_varying_intercept=time_varying_intercept,
time_varying_media=time_varying_media,
model_config=model_config,
sampler_config=sampler_config,
validate_data=validate_data,
Expand Down
Loading

0 comments on commit 697bd99

Please sign in to comment.