diff --git a/pymc_marketing/mmm/base.py b/pymc_marketing/mmm/base.py index f5e8adfd..34c1e429 100644 --- a/pymc_marketing/mmm/base.py +++ b/pymc_marketing/mmm/base.py @@ -46,6 +46,7 @@ from pymc_marketing.model_builder import ModelBuilder __all__ = ["MMMModelBuilder", "BaseValidateMMM"] +from pydantic import Field, validate_call class MMMModelBuilder(ModelBuilder): @@ -53,12 +54,15 @@ class MMMModelBuilder(ModelBuilder): _model_type = "BaseMMM" version = "0.0.2" + @validate_call def __init__( self, - date_column: str, - channel_columns: list[str] | tuple[str], - model_config: dict | None = None, - sampler_config: dict | None = None, + date_column: str = Field(..., description="Column name of the date variable."), + channel_columns: list[str] = Field( + min_length=1, description="Column names of the media channel variables." + ), + model_config: dict | None = Field(None, description="Model configuration."), + sampler_config: dict | None = Field(None, description="Sampler configuration."), **kwargs, ) -> None: self.date_column: str = date_column diff --git a/pymc_marketing/mmm/delayed_saturated_mmm.py b/pymc_marketing/mmm/delayed_saturated_mmm.py index 7e894ac2..36c9a729 100644 --- a/pymc_marketing/mmm/delayed_saturated_mmm.py +++ b/pymc_marketing/mmm/delayed_saturated_mmm.py @@ -125,7 +125,7 @@ def __init__( Parameter --------- date_column : str - Column name of the date variable. + Column name of the date variable. Must be parsable using ~pandas.to_datetime. channel_columns : List[str] Column names of the media channel variables. adstock_max_lag : int, optional @@ -236,7 +236,13 @@ def _generate_and_preprocess_model_data( # type: ignore _time_resolution: int The time resolution of the date index. Used by TVP. """ - date_data = X[self.date_column] + try: + date_data = pd.to_datetime(X[self.date_column]) + except Exception as e: + raise ValueError( + f"Could not convert {self.date_column} to datetime. Please check the date format." + ) from e + channel_data = X[self.channel_columns] coords: dict[str, Any] = { diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index 63e4ed4a..d060c617 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -82,6 +82,23 @@ def toy_X(generate_data) -> pd.DataFrame: return generate_data(date_data) +@pytest.fixture(scope="module") +def toy_X_with_bad_dates() -> pd.DataFrame: + bad_date_data = ["a", "b", "c", "d", "e"] + n: int = len(bad_date_data) + return pd.DataFrame( + data={ + "date": bad_date_data, + "channel_1": rng.integers(low=0, high=400, size=n), + "channel_2": rng.integers(low=0, high=50, size=n), + "control_1": rng.gamma(shape=1000, scale=500, size=n), + "control_2": rng.gamma(shape=100, scale=5, size=n), + "other_column_1": rng.integers(low=0, high=100, size=n), + "other_column_2": rng.normal(loc=0, scale=1, size=n), + } + ) + + @pytest.fixture(scope="class") def model_config_requiring_serialization() -> dict: model_config = { @@ -206,6 +223,22 @@ def deep_equal(dict1, dict2): assert model.sampler_config == model2.sampler_config os.remove("test_save_load") + def test_bad_date_column(self, toy_X_with_bad_dates) -> None: + with pytest.raises( + ValueError, + match="Could not convert bad_date_column to datetime. Please check the date format.", + ): + my_mmm = MMM( + date_column="bad_date_column", + channel_columns=["channel_1", "channel_2"], + adstock_max_lag=4, + control_columns=["control_1", "control_2"], + adstock="geometric", + saturation="logistic", + ) + y = np.ones(toy_X_with_bad_dates.shape[0]) + my_mmm.build_model(X=toy_X_with_bad_dates, y=y) + @pytest.mark.parametrize( argnames="adstock_max_lag", argvalues=[1, 4],