diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 578e6a9f..526a5f8e 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -107,7 +107,7 @@ def __init__( time_varying_intercept : bool, optional Whether to consider time-varying intercept, by default False. Because the `time-varying` variable is centered around 1 and acts as a multiplier, - the variable `base_intercept` now represents the mean of the time-varying intercept. + the variable `baseline_intercept` now represents the mean of the time-varying intercept. time_varying_media : bool, optional Whether to consider time-varying media contributions, by default False. The `time-varying-media` creates a time media variable centered around 1, @@ -364,8 +364,9 @@ def build_model( intercept_distribution = get_distribution( name=self.model_config["intercept"]["dist"] ) - base_intercept = intercept_distribution( - name="base_intercept", **self.model_config["intercept"]["kwargs"] + baseline_intercept = intercept_distribution( + name="baseline_intercept", + **self.model_config["intercept"]["kwargs"], ) intercept_latent_process = create_time_varying_gp_multiplier( @@ -378,7 +379,7 @@ def build_model( ) intercept = pm.Deterministic( name="intercept", - var=base_intercept * intercept_latent_process, + var=baseline_intercept * intercept_latent_process, dims="date", ) else: @@ -387,8 +388,8 @@ def build_model( ) if self.time_varying_media: - base_channel_contributions = pm.Deterministic( - name="base_channel_contributions", + baseline_channel_contributions = pm.Deterministic( + name="baseline_channel_contributions", var=self.forward_pass(x=channel_data_), dims=("date", "channel"), ) @@ -403,7 +404,7 @@ def build_model( ) channel_contributions = pm.Deterministic( name="channel_contributions", - var=base_channel_contributions * media_latent_process[:, None], + var=baseline_channel_contributions * media_latent_process[:, None], dims=("date", "channel"), ) diff --git a/pymc_marketing/mmm/tvp.py b/pymc_marketing/mmm/tvp.py index 61c9a111..9b8c8423 100644 --- a/pymc_marketing/mmm/tvp.py +++ b/pymc_marketing/mmm/tvp.py @@ -90,31 +90,15 @@ import pandas as pd import pymc as pm import pytensor.tensor as pt +from pymc.distributions.shape_utils import Dims from pymc_marketing.constants import DAYS_IN_YEAR -def _softplus(x: pt.TensorVariable) -> pt.TensorVariable: - """ - Compute the softplus function element-wise on the input tensor. - - Parameters - ---------- - x : pt.TensorVariable - Input tensor. - - Returns - ------- - pt.TensorVariable - Output tensor after applying the softplus function element-wise. - """ - return pm.math.log(1 + pm.math.exp(x)) - - def time_varying_prior( name: str, X: pt.sharedvar.TensorSharedVariable, - dims: tuple[str, str] | str, + dims: Dims, X_mid: int | float | None = None, m: int = 200, L: int | float | None = None, @@ -190,14 +174,14 @@ def time_varying_prior( phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid) hsgp_coefs = pm.Normal(f"{name}_hsgp_coefs", dims=hsgp_dims) f = phi @ (hsgp_coefs * sqrt_psd).T - f = _softplus(f) + f = pt.softplus(f) centered_f = f - f.mean(axis=0) + 1 return pm.Deterministic(name, centered_f, dims=dims) def create_time_varying_gp_multiplier( name: str, - dims: tuple[str, str] | str, + dims: Dims, time_index: pt.sharedvar.TensorSharedVariable, time_index_mid: int, time_resolution: int, @@ -228,19 +212,19 @@ def create_time_varying_gp_multiplier( Time-varying Gaussian Process multiplier for a given variable. """ - if model_config[f"{name}_tvp_config"]["L"] is None: - model_config[f"{name}_tvp_config"]["L"] = ( - time_index_mid + DAYS_IN_YEAR / time_resolution - ) - if model_config[f"{name}_tvp_config"]["ls_mu"] is None: - model_config[f"{name}_tvp_config"]["ls_mu"] = DAYS_IN_YEAR / time_resolution * 2 + tvp_config = model_config[f"{name}_tvp_config"] + + if tvp_config["L"] is None: + tvp_config["L"] = time_index_mid + DAYS_IN_YEAR / time_resolution + if tvp_config["ls_mu"] is None: + tvp_config["ls_mu"] = DAYS_IN_YEAR / time_resolution * 2 multiplier = time_varying_prior( name=f"{name}_temporal_latent_multiplier", X=time_index, X_mid=time_index_mid, dims=dims, - **model_config[f"{name}_tvp_config"], + **tvp_config, ) return multiplier