Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support New Data for MMM model #444

Closed
wants to merge 15 commits into from
107 changes: 93 additions & 14 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
@property
def output_var(self):
"""Defines target variable for the model"""
return "y"
return "likelihood"

def _generate_and_preprocess_model_data( # type: ignore
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]
Expand All @@ -100,8 +100,10 @@
"""
date_data = X[self.date_column]
channel_data = X[self.channel_columns]
coords: Dict[str, Any] = {
self.coords_mutable: Dict[str, Any] = {
"date": date_data,
}
coords: Dict[str, Any] = {
"channel": self.channel_columns,
}

Expand Down Expand Up @@ -315,7 +317,9 @@
)

self._generate_and_preprocess_model_data(X, y)
with pm.Model(coords=self.model_coords) as self.model:
with pm.Model(
coords=self.model_coords, coords_mutable=self.coords_mutable
) as self.model:
channel_data_ = pm.MutableData(
name="channel_data",
value=self.preprocessed_data["X"][self.channel_columns],
Expand Down Expand Up @@ -611,19 +615,41 @@
-------
None
"""
if not isinstance(X, pd.DataFrame):
raise TypeError(
"X must be a pandas DataFrame in order to access the columns"
)
Comment on lines +618 to +621
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm viewing this as required in order to access the columns. Changing the type hint makes mypy mad

If there are any other suggestions, let me know

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think 99% of the cases people using this package use pandas, so is ok for me.

new_channel_data: Optional[np.ndarray] = None
coords = {"date": X[self.date_column].to_numpy()}

if isinstance(X, pd.DataFrame):
try:
new_channel_data = X[self.channel_columns].to_numpy()
except KeyError as e:
raise RuntimeError("New data must contain channel_data!", e)
elif isinstance(X, np.ndarray):
new_channel_data = X
else:
raise TypeError("X must be either a pandas DataFrame or a numpy array")
try:
new_channel_data = X[self.channel_columns].to_numpy()
except KeyError as e:
raise RuntimeError("New data must contain channel_data!", e)

Check warning on line 628 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L627-L628

Added lines #L627 - L628 were not covered by tests
Comment on lines +627 to +628
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be the error raises by any missing keys of the dataframe? For instance, date_column or control_columns

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so! I suggest that, for this iteration, we keep this strict and then wait for feedback.


def identity(x):
return x

channel_transformation = (
identity
if not hasattr(self, "channel_transformer")
else self.channel_transformer.transform
)
data: Dict[str, Union[np.ndarray, Any]] = {
"channel_data": channel_transformation(new_channel_data)
}

if self.control_columns is not None:
control_data = X[self.control_columns].to_numpy()
control_transformation = (
identity
if not hasattr(self, "control_transformer")
else self.control_transformer.transform
)
data["control_data"] = control_transformation(control_data)

data: Dict[str, Union[np.ndarray, Any]] = {"channel_data": new_channel_data}
if hasattr(self, "fourier_columns"):
data["fourier_data"] = self._get_fourier_models_data(X)

if y is not None:
if isinstance(y, pd.Series):
Expand All @@ -634,9 +660,12 @@
data["target"] = y
else:
raise TypeError("y must be either a pandas Series or a numpy array")
else:
dtype = self.preprocessed_data["y"].dtype # type: ignore
data["target"] = np.zeros(X.shape[0], dtype=dtype) # type: ignore

with self.model:
pm.set_data(data)
pm.set_data(data, coords=coords)

@classmethod
def _model_config_formatting(cls, model_config: Dict) -> Dict:
Expand Down Expand Up @@ -814,3 +843,53 @@
ylabel="contribution",
)
return fig

def predict_posterior(
self,
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series],
extend_idata: bool = True,
combined: bool = True,
include_last_observations: bool = False,
**kwargs,
) -> DataArray:
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
"""
Generate posterior predictive samples on unseen data.

Parameters
---------
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features)
The input data used for prediction.
extend_idata : Boolean determining whether the predictions should be added to inference data object.
Defaults to True.
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists.
Defaults to True.
include_last_observations: Whether to include last observed data for carryover adstock and saturation effect.
Defaults to False.
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive

Returns
-------
y_pred : DataArray, shape (n_pred, chains * draws) if combined is True, otherwise (chains, draws, n_pred)
Posterior predictive samples for each input X_pred
"""
if not isinstance(X_pred, pd.DataFrame):
raise ValueError("X_pred must be a pandas DataFrame")

Check warning on line 876 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L876

Added line #L876 was not covered by tests

if include_last_observations:
X_pred = pd.concat([self.X.iloc[-self.adstock_max_lag :], X_pred])
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

posterior_predictive_samples = self.sample_posterior_predictive(
X_pred, extend_idata, combined, **kwargs
)

if self.output_var not in posterior_predictive_samples:
raise KeyError(

Check warning on line 886 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L886

Added line #L886 was not covered by tests
f"Output variable {self.output_var} not found in posterior predictive samples."
)

if include_last_observations:
posterior_predictive_samples = posterior_predictive_samples.isel(
date=slice(self.adstock_max_lag, None)
)

return posterior_predictive_samples[self.output_var]
83 changes: 79 additions & 4 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def mmm() -> DelayedSaturatedMMM:
)


@pytest.fixture(scope="class")
def mmm_with_fourier_features() -> DelayedSaturatedMMM:
return DelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
yearly_seasonality=2,
)


@pytest.fixture(scope="class")
def mmm_fitted(
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series
Expand All @@ -96,6 +107,16 @@ def mmm_fitted(
return mmm


@pytest.fixture(scope="class")
def mmm_fitted_with_fourier_features(
mmm_with_fourier_features, toy_X: pd.DataFrame, toy_y: pd.Series
) -> DelayedSaturatedMMM:
mmm_with_fourier_features.fit(
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2
)
return mmm_with_fourier_features


class TestDelayedSaturatedMMM:
def test_save_load_with_not_serializable_model_config(
self, model_config_requiring_serialization, toy_X, toy_y
Expand Down Expand Up @@ -456,7 +477,8 @@ def test_data_setter(self, toy_X, toy_y):
with pytest.raises(TypeError):
base_delayed_saturated_mmm._data_setter(toy_X, y_incorrect)

with pytest.raises(RuntimeError):
# Missing the date column
with pytest.raises(KeyError):
X_wrong_df = pd.DataFrame(
{"column1": np.random.rand(135), "column2": np.random.rand(135)}
)
Expand All @@ -467,12 +489,10 @@ def test_data_setter(self, toy_X, toy_y):
except Exception as e:
pytest.fail(f"_data_setter failed with error {e}")

try:
with pytest.raises(TypeError, match="X must be a pandas DataFrame"):
base_delayed_saturated_mmm._data_setter(
X_correct_ndarray, y_correct_ndarray
)
except Exception as e:
pytest.fail(f"_data_setter failed with error {e}")

def test_save_load(self, mmm_fitted):
model = mmm_fitted
Expand Down Expand Up @@ -515,6 +535,61 @@ def mock_property(self):
DelayedSaturatedMMM.load("test_model")
os.remove("test_model")

@pytest.mark.parametrize(
"model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"]
)
@pytest.mark.parametrize(
"new_dates",
[
# 2021-12-31 is the last date in the toy data
# Old and New dates
pd.date_range(start="2021-11-01", end="2022-03-01", freq="W-MON"),
# Only Old dates
pd.date_range(start="2019-06-01", end="2021-12-31", freq="W-MON"),
# Only New dates
pd.date_range(start="2022-01-01", end="2022-03-01", freq="W-MON"),
# Less than the adstock_max_lag (4) of the model
pd.date_range(start="2022-01-01", freq="W-MON", periods=1),
],
)
def test_new_data_predictions(
self,
model_name: str,
new_dates: pd.DatetimeIndex,
request,
) -> None:
mmm = request.getfixturevalue(model_name)
n = new_dates.size
X_pred = pd.DataFrame(
{
"date": new_dates,
"channel_1": rng.integers(low=0, high=400, size=n),
"channel_2": rng.integers(low=0, high=50, size=n),
"control_1": rng.gamma(shape=1000, scale=500, size=n),
"control_2": rng.gamma(shape=100, scale=5, size=n),
"other_column_1": rng.integers(low=0, high=100, size=n),
"other_column_2": rng.normal(loc=0, scale=1, size=n),
}
)
Comment on lines +563 to +573
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have this as a fixture and split this tests as it has many asserts checking different functions?


pp_without = mmm.predict_posterior(
X_pred=X_pred, include_last_observations=False
)
pp_with = mmm.predict_posterior(X_pred=X_pred, include_last_observations=True)

assert pp_without.coords.equals(pp_with.coords)

posterior_predictive = mmm.sample_posterior_predictive(
X_pred=X_pred, extend_idata=False, combined=True
)
pd.testing.assert_index_equal(
pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates
)
assert posterior_predictive["likelihood"].shape[0] == new_dates.size

posterior_predictive_mean = mmm.predict(X_pred=X_pred)
assert posterior_predictive_mean.shape[0] == new_dates.size

@pytest.mark.parametrize(
argnames="model_config",
argvalues=[
Expand Down