Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support New Data for MMM model #444

Closed
wants to merge 15 commits into from
55 changes: 42 additions & 13 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,10 @@ def _generate_and_preprocess_model_data( # type: ignore
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
coords: Dict[str, Any] = {
self.coords_mutable: Dict[str, Any] = {
"date": date_data,
}
coords: Dict[str, Any] = {
"channel": self.channel_columns,
}

Expand Down Expand Up @@ -174,7 +176,9 @@ def build_model(
"""
model_config = self.model_config
self._generate_and_preprocess_model_data(X, y)
with pm.Model(coords=self.model_coords) as self.model:
with pm.Model(
coords=self.model_coords, coords_mutable=self.coords_mutable
) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
value=self.preprocessed_data["X"][self.channel_columns].to_numpy(),
Expand Down Expand Up @@ -484,19 +488,41 @@ def _data_setter(
-------
None
"""
if not isinstance(X, pd.DataFrame):
raise TypeError(
"X must be a pandas DataFrame in order to access the columns"
)
Comment on lines +618 to +621
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm viewing this as required in order to access the columns. Changing the type hint makes mypy mad

If there are any other suggestions, let me know

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 99% of the cases people using this package use pandas, so is ok for me.

new_channel_data: Optional[np.ndarray] = None
coords = {"date": X[self.date_column].to_numpy()}

if isinstance(X, pd.DataFrame):
try:
new_channel_data = X[self.channel_columns].to_numpy()
except KeyError as e:
raise RuntimeError("New data must contain channel_data!", e)
elif isinstance(X, np.ndarray):
new_channel_data = X
else:
raise TypeError("X must be either a pandas DataFrame or a numpy array")
try:
new_channel_data = X[self.channel_columns].to_numpy()
except KeyError as e:
raise RuntimeError("New data must contain channel_data!", e)
Comment on lines +627 to +628
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the error raises by any missing keys of the dataframe? For instance, date_column or control_columns

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so! I suggest that, for this iteration, we keep this strict and then wait for feedback.


data: Dict[str, Union[np.ndarray, Any]] = {"channel_data": new_channel_data}
def identity(x):
return x

channel_transformation = (
identity
if not hasattr(self, "channel_transformer")
else self.channel_transformer.transform
)
data: Dict[str, Union[np.ndarray, Any]] = {
"channel_data": channel_transformation(new_channel_data)
}

if self.control_columns is not None:
control_data = X[self.control_columns].to_numpy()
control_transformation = (
identity
if not hasattr(self, "control_transformer")
else self.control_transformer.transform
)
data["control_data"] = control_transformation(control_data)

if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if y is not None:
if isinstance(y, pd.Series):
Expand All @@ -507,9 +533,12 @@ def _data_setter(
data["target"] = y
else:
raise TypeError("y must be either a pandas Series or a numpy array")
else:
dtype = self.preprocessed_data["y"].dtype # type: ignore
data["target"] = np.zeros(X.shape[0], dtype=dtype) # type: ignore

with self.model:
pm.set_data(data)
pm.set_data(data, coords=coords)

@classmethod
def _model_config_formatting(cls, model_config: Dict) -> Dict:
Expand Down
91 changes: 87 additions & 4 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ def mmm() -> DelayedSaturatedMMM:
)


@pytest.fixture(scope="class")
def mmm_with_fourier_features() -> DelayedSaturatedMMM:
return DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
yearly_seasonality=2,
)


@pytest.fixture(scope="class")
def mmm_fitted(
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series
Expand All @@ -88,6 +99,16 @@ def mmm_fitted(
return mmm


@pytest.fixture(scope="class")
def mmm_fitted_with_fourier_features(
mmm_with_fourier_features, toy_X: pd.DataFrame, toy_y: pd.Series
) -> DelayedSaturatedMMM:
mmm_with_fourier_features.fit(
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2
)
return mmm_with_fourier_features


class TestDelayedSaturatedMMM:
def test_save_load_with_not_serializable_model_config(
self, model_config_requiring_serialization, toy_X, toy_y
Expand Down Expand Up @@ -448,9 +469,21 @@ def test_data_setter(self, toy_X, toy_y):
with pytest.raises(TypeError):
base_delayed_saturated_mmm._data_setter(toy_X, y_incorrect)

# Missing the date column
with pytest.raises(KeyError):
X_wrong_df = pd.DataFrame(
{"column_1": np.random.rand(135), "column_2": np.random.rand(135)}
)
base_delayed_saturated_mmm._data_setter(X_wrong_df, toy_y)

# Missing a channel column (and not date)
with pytest.raises(RuntimeError):
X_wrong_df = pd.DataFrame(
{"column1": np.random.rand(135), "column2": np.random.rand(135)}
{
"date": pd.to_datetime("2023-01-01"),
"column1": np.random.rand(135),
"column2": np.random.rand(135),
}
)
base_delayed_saturated_mmm._data_setter(X_wrong_df, toy_y)

Expand All @@ -459,12 +492,10 @@ def test_data_setter(self, toy_X, toy_y):
except Exception as e:
pytest.fail(f"_data_setter failed with error {e}")

try:
with pytest.raises(TypeError, match="X must be a pandas DataFrame"):
base_delayed_saturated_mmm._data_setter(
X_correct_ndarray, y_correct_ndarray
)
except Exception as e:
pytest.fail(f"_data_setter failed with error {e}")

def test_save_load(self, mmm_fitted):
model = mmm_fitted
Expand Down Expand Up @@ -506,3 +537,55 @@ def mock_property(self):
):
DelayedSaturatedMMM.load("test_model")
os.remove("test_model")

@pytest.mark.parametrize(
"model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"]
)
@pytest.mark.parametrize(
"new_dates",
[
# 2021-12-31 is the last date in the toy data
# Old and New dates
pd.date_range(start="2021-11-01", end="2022-03-01", freq="W-MON"),
# Only Old dates
pd.date_range(start="2019-06-01", end="2021-12-31", freq="W-MON"),
# Only New dates
pd.date_range(start="2022-01-01", end="2022-03-01", freq="W-MON"),
# Less than the adstock_max_lag (4) of the model
pd.date_range(start="2022-01-01", freq="W-MON", periods=1),
],
)
def test_new_data_predictions(
self,
model_name: str,
mmm_fitted: DelayedSaturatedMMM,
new_dates: pd.DatetimeIndex,
request,
) -> None:
mmm = request.getfixturevalue(model_name)
n = new_dates.size
new_X = pd.DataFrame(
{
"date": new_dates,
"channel_1": rng.integers(low=0, high=400, size=n),
"channel_2": rng.integers(low=0, high=50, size=n),
"control_1": rng.gamma(shape=1000, scale=500, size=n),
"control_2": rng.gamma(shape=100, scale=5, size=n),
"other_column_1": rng.integers(low=0, high=100, size=n),
"other_column_2": rng.normal(loc=0, scale=1, size=n),
}
)

with pytest.raises(
TypeError,
match=r"The DType <class 'numpy.dtype\[datetime64\]'> could not be promoted by",
):
mmm.predict_posterior(X_pred=new_X)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as noted in the comment, the ModelBuilder method will need an override if this is to work with dates in the input DataFrame


posterior_predictive = mmm.sample_posterior_predictive(
X_pred=new_X, extend_idata=False, combined=True
)
pd.testing.assert_index_equal(
pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates
)
assert posterior_predictive["likelihood"].shape[0] == new_dates.size