Skip to content

Commit

Permalink
Changes!
Browse files Browse the repository at this point in the history
  • Loading branch information
cetagostini committed Jun 21, 2024
1 parent 4f695a7 commit 96c1092
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 34 deletions.
15 changes: 8 additions & 7 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
time_varying_intercept : bool, optional
Whether to consider time-varying intercept, by default False.
Because the `time-varying` variable is centered around 1 and acts as a multiplier,
the variable `base_intercept` now represents the mean of the time-varying intercept.
the variable `baseline_intercept` now represents the mean of the time-varying intercept.
time_varying_media : bool, optional
Whether to consider time-varying media contributions, by default False.
The `time-varying-media` creates a time media variable centered around 1,
Expand Down Expand Up @@ -364,8 +364,9 @@ def build_model(
intercept_distribution = get_distribution(

Check warning on line 364 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L363-L364

Added lines #L363 - L364 were not covered by tests
name=self.model_config["intercept"]["dist"]
)
base_intercept = intercept_distribution(
name="base_intercept", **self.model_config["intercept"]["kwargs"]
baseline_intercept = intercept_distribution(

Check warning on line 367 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L367

Added line #L367 was not covered by tests
name="baseline_intercept",
**self.model_config["intercept"]["kwargs"],
)

intercept_latent_process = create_time_varying_gp_multiplier(

Check warning on line 372 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L372

Added line #L372 was not covered by tests
Expand All @@ -378,7 +379,7 @@ def build_model(
)
intercept = pm.Deterministic(

Check warning on line 380 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L380

Added line #L380 was not covered by tests
name="intercept",
var=base_intercept * intercept_latent_process,
var=baseline_intercept * intercept_latent_process,
dims="date",
)
else:
Expand All @@ -387,8 +388,8 @@ def build_model(
)

if self.time_varying_media:
base_channel_contributions = pm.Deterministic(
name="base_channel_contributions",
baseline_channel_contributions = pm.Deterministic(

Check warning on line 391 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L390-L391

Added lines #L390 - L391 were not covered by tests
name="baseline_channel_contributions",
var=self.forward_pass(x=channel_data_),
dims=("date", "channel"),
)
Expand All @@ -403,7 +404,7 @@ def build_model(
)
channel_contributions = pm.Deterministic(

Check warning on line 405 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L405

Added line #L405 was not covered by tests
name="channel_contributions",
var=base_channel_contributions * media_latent_process[:, None],
var=baseline_channel_contributions * media_latent_process[:, None],
dims=("date", "channel"),
)

Expand Down
38 changes: 11 additions & 27 deletions pymc_marketing/mmm/tvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,31 +90,15 @@
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
from pymc.distributions.shape_utils import Dims

from pymc_marketing.constants import DAYS_IN_YEAR


def _softplus(x: pt.TensorVariable) -> pt.TensorVariable:
"""
Compute the softplus function element-wise on the input tensor.
Parameters
----------
x : pt.TensorVariable
Input tensor.
Returns
-------
pt.TensorVariable
Output tensor after applying the softplus function element-wise.
"""
return pm.math.log(1 + pm.math.exp(x))


def time_varying_prior(
name: str,
X: pt.sharedvar.TensorSharedVariable,
dims: tuple[str, str] | str,
dims: Dims,
X_mid: int | float | None = None,
m: int = 200,
L: int | float | None = None,
Expand Down Expand Up @@ -190,14 +174,14 @@ def time_varying_prior(
phi, sqrt_psd = gp.prior_linearized(Xs=X[:, None] - X_mid)
hsgp_coefs = pm.Normal(f"{name}_hsgp_coefs", dims=hsgp_dims)
f = phi @ (hsgp_coefs * sqrt_psd).T
f = _softplus(f)
f = pt.softplus(f)

Check warning on line 177 in pymc_marketing/mmm/tvp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/tvp.py#L177

Added line #L177 was not covered by tests
centered_f = f - f.mean(axis=0) + 1
return pm.Deterministic(name, centered_f, dims=dims)


def create_time_varying_gp_multiplier(
name: str,
dims: tuple[str, str] | str,
dims: Dims,
time_index: pt.sharedvar.TensorSharedVariable,
time_index_mid: int,
time_resolution: int,
Expand Down Expand Up @@ -228,19 +212,19 @@ def create_time_varying_gp_multiplier(
Time-varying Gaussian Process multiplier for a given variable.
"""

if model_config[f"{name}_tvp_config"]["L"] is None:
model_config[f"{name}_tvp_config"]["L"] = (
time_index_mid + DAYS_IN_YEAR / time_resolution
)
if model_config[f"{name}_tvp_config"]["ls_mu"] is None:
model_config[f"{name}_tvp_config"]["ls_mu"] = DAYS_IN_YEAR / time_resolution * 2
tvp_config = model_config[f"{name}_tvp_config"]

Check warning on line 215 in pymc_marketing/mmm/tvp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/tvp.py#L215

Added line #L215 was not covered by tests

if tvp_config["L"] is None:
tvp_config["L"] = time_index_mid + DAYS_IN_YEAR / time_resolution
if tvp_config["ls_mu"] is None:
tvp_config["ls_mu"] = DAYS_IN_YEAR / time_resolution * 2

Check warning on line 220 in pymc_marketing/mmm/tvp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/tvp.py#L217-L220

Added lines #L217 - L220 were not covered by tests

multiplier = time_varying_prior(

Check warning on line 222 in pymc_marketing/mmm/tvp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/tvp.py#L222

Added line #L222 was not covered by tests
name=f"{name}_temporal_latent_multiplier",
X=time_index,
X_mid=time_index_mid,
dims=dims,
**model_config[f"{name}_tvp_config"],
**tvp_config,
)
return multiplier

Check warning on line 229 in pymc_marketing/mmm/tvp.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/tvp.py#L229

Added line #L229 was not covered by tests

Expand Down

0 comments on commit 96c1092

Please sign in to comment.