Skip to content

Commit

Permalink
Date Validation and MMM Model Hamonization (Pydantic) (#824)
Browse files Browse the repository at this point in the history
* validate base mmm init class

* validate dateformat

* add comment about date

* remove ()
  • Loading branch information
juanitorduz authored and twiecki committed Sep 10, 2024
1 parent 8ffe8c1 commit f17d990
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 6 deletions.
12 changes: 8 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,23 @@
from pymc_marketing.model_builder import ModelBuilder

__all__ = ["MMMModelBuilder", "BaseValidateMMM"]
from pydantic import Field, validate_call


class MMMModelBuilder(ModelBuilder):
model: pm.Model
_model_type = "BaseMMM"
version = "0.0.2"

@validate_call
def __init__(
self,
date_column: str,
channel_columns: list[str] | tuple[str],
model_config: dict | None = None,
sampler_config: dict | None = None,
date_column: str = Field(..., description="Column name of the date variable."),
channel_columns: list[str] = Field(
min_length=1, description="Column names of the media channel variables."
),
model_config: dict | None = Field(None, description="Model configuration."),
sampler_config: dict | None = Field(None, description="Sampler configuration."),
**kwargs,
) -> None:
self.date_column: str = date_column
Expand Down
10 changes: 8 additions & 2 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
Parameter
---------
date_column : str
Column name of the date variable.
Column name of the date variable. Must be parsable using ~pandas.to_datetime.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int, optional
Expand Down Expand Up @@ -236,7 +236,13 @@ def _generate_and_preprocess_model_data( # type: ignore
_time_resolution: int
The time resolution of the date index. Used by TVP.
"""
date_data = X[self.date_column]
try:
date_data = pd.to_datetime(X[self.date_column])
except Exception as e:
raise ValueError(
f"Could not convert {self.date_column} to datetime. Please check the date format."
) from e

channel_data = X[self.channel_columns]

coords: dict[str, Any] = {
Expand Down
33 changes: 33 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def toy_X(generate_data) -> pd.DataFrame:
return generate_data(date_data)


@pytest.fixture(scope="module")
def toy_X_with_bad_dates() -> pd.DataFrame:
bad_date_data = ["a", "b", "c", "d", "e"]
n: int = len(bad_date_data)
return pd.DataFrame(
data={
"date": bad_date_data,
"channel_1": rng.integers(low=0, high=400, size=n),
"channel_2": rng.integers(low=0, high=50, size=n),
"control_1": rng.gamma(shape=1000, scale=500, size=n),
"control_2": rng.gamma(shape=100, scale=5, size=n),
"other_column_1": rng.integers(low=0, high=100, size=n),
"other_column_2": rng.normal(loc=0, scale=1, size=n),
}
)


@pytest.fixture(scope="class")
def model_config_requiring_serialization() -> dict:
model_config = {
Expand Down Expand Up @@ -206,6 +223,22 @@ def deep_equal(dict1, dict2):
assert model.sampler_config == model2.sampler_config
os.remove("test_save_load")

def test_bad_date_column(self, toy_X_with_bad_dates) -> None:
with pytest.raises(
ValueError,
match="Could not convert bad_date_column to datetime. Please check the date format.",
):
my_mmm = MMM(
date_column="bad_date_column",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
)
y = np.ones(toy_X_with_bad_dates.shape[0])
my_mmm.build_model(X=toy_X_with_bad_dates, y=y)

@pytest.mark.parametrize(
argnames="adstock_max_lag",
argvalues=[1, 4],
Expand Down

0 comments on commit f17d990

Please sign in to comment.