-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support New Data for MMM model #444
Changes from all commits
27f8ce2
72af239
6af2fda
1fbb966
b3484fc
3b292e3
4626348
2cd0b31
f99600b
c49ab8e
f85a375
f8f2a8e
91a15c6
f41a1f2
ba62ce1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,7 +83,7 @@ | |
@property | ||
def output_var(self): | ||
"""Defines target variable for the model""" | ||
return "y" | ||
return "likelihood" | ||
|
||
def _generate_and_preprocess_model_data( # type: ignore | ||
self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] | ||
|
@@ -100,8 +100,10 @@ | |
""" | ||
date_data = X[self.date_column] | ||
channel_data = X[self.channel_columns] | ||
coords: Dict[str, Any] = { | ||
self.coords_mutable: Dict[str, Any] = { | ||
"date": date_data, | ||
} | ||
coords: Dict[str, Any] = { | ||
"channel": self.channel_columns, | ||
} | ||
|
||
|
@@ -315,7 +317,9 @@ | |
) | ||
|
||
self._generate_and_preprocess_model_data(X, y) | ||
with pm.Model(coords=self.model_coords) as self.model: | ||
with pm.Model( | ||
coords=self.model_coords, coords_mutable=self.coords_mutable | ||
) as self.model: | ||
channel_data_ = pm.MutableData( | ||
name="channel_data", | ||
value=self.preprocessed_data["X"][self.channel_columns], | ||
|
@@ -611,19 +615,41 @@ | |
------- | ||
None | ||
""" | ||
if not isinstance(X, pd.DataFrame): | ||
raise TypeError( | ||
"X must be a pandas DataFrame in order to access the columns" | ||
) | ||
new_channel_data: Optional[np.ndarray] = None | ||
coords = {"date": X[self.date_column].to_numpy()} | ||
|
||
if isinstance(X, pd.DataFrame): | ||
try: | ||
new_channel_data = X[self.channel_columns].to_numpy() | ||
except KeyError as e: | ||
raise RuntimeError("New data must contain channel_data!", e) | ||
elif isinstance(X, np.ndarray): | ||
new_channel_data = X | ||
else: | ||
raise TypeError("X must be either a pandas DataFrame or a numpy array") | ||
try: | ||
new_channel_data = X[self.channel_columns].to_numpy() | ||
except KeyError as e: | ||
raise RuntimeError("New data must contain channel_data!", e) | ||
Comment on lines
+627
to
+628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be the error raises by any missing keys of the dataframe? For instance, date_column or control_columns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so! I suggest that, for this iteration, we keep this strict and then wait for feedback. |
||
|
||
def identity(x): | ||
return x | ||
|
||
channel_transformation = ( | ||
identity | ||
if not hasattr(self, "channel_transformer") | ||
else self.channel_transformer.transform | ||
) | ||
data: Dict[str, Union[np.ndarray, Any]] = { | ||
"channel_data": channel_transformation(new_channel_data) | ||
} | ||
|
||
if self.control_columns is not None: | ||
control_data = X[self.control_columns].to_numpy() | ||
control_transformation = ( | ||
identity | ||
if not hasattr(self, "control_transformer") | ||
else self.control_transformer.transform | ||
) | ||
data["control_data"] = control_transformation(control_data) | ||
|
||
data: Dict[str, Union[np.ndarray, Any]] = {"channel_data": new_channel_data} | ||
if hasattr(self, "fourier_columns"): | ||
data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
||
if y is not None: | ||
if isinstance(y, pd.Series): | ||
|
@@ -634,9 +660,12 @@ | |
data["target"] = y | ||
else: | ||
raise TypeError("y must be either a pandas Series or a numpy array") | ||
else: | ||
dtype = self.preprocessed_data["y"].dtype # type: ignore | ||
data["target"] = np.zeros(X.shape[0], dtype=dtype) # type: ignore | ||
|
||
with self.model: | ||
pm.set_data(data) | ||
pm.set_data(data, coords=coords) | ||
|
||
@classmethod | ||
def _model_config_formatting(cls, model_config: Dict) -> Dict: | ||
|
@@ -814,3 +843,53 @@ | |
ylabel="contribution", | ||
) | ||
return fig | ||
|
||
def predict_posterior( | ||
self, | ||
X_pred: Union[np.ndarray, pd.DataFrame, pd.Series], | ||
extend_idata: bool = True, | ||
combined: bool = True, | ||
include_last_observations: bool = False, | ||
**kwargs, | ||
) -> DataArray: | ||
wd60622 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Generate posterior predictive samples on unseen data. | ||
|
||
Parameters | ||
--------- | ||
X_pred : array-like if sklearn is available, otherwise array, shape (n_pred, n_features) | ||
The input data used for prediction. | ||
extend_idata : Boolean determining whether the predictions should be added to inference data object. | ||
Defaults to True. | ||
combined: Combine chain and draw dims into sample. Won't work if a dim named sample already exists. | ||
Defaults to True. | ||
include_last_observations: Whether to include last observed data for carryover adstock and saturation effect. | ||
Defaults to False. | ||
**kwargs: Additional arguments to pass to pymc.sample_posterior_predictive | ||
|
||
Returns | ||
------- | ||
y_pred : DataArray, shape (n_pred, chains * draws) if combined is True, otherwise (chains, draws, n_pred) | ||
Posterior predictive samples for each input X_pred | ||
""" | ||
if not isinstance(X_pred, pd.DataFrame): | ||
raise ValueError("X_pred must be a pandas DataFrame") | ||
|
||
if include_last_observations: | ||
X_pred = pd.concat([self.X.iloc[-self.adstock_max_lag :], X_pred]) | ||
wd60622 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
posterior_predictive_samples = self.sample_posterior_predictive( | ||
X_pred, extend_idata, combined, **kwargs | ||
) | ||
|
||
if self.output_var not in posterior_predictive_samples: | ||
raise KeyError( | ||
f"Output variable {self.output_var} not found in posterior predictive samples." | ||
) | ||
|
||
if include_last_observations: | ||
posterior_predictive_samples = posterior_predictive_samples.isel( | ||
date=slice(self.adstock_max_lag, None) | ||
) | ||
|
||
return posterior_predictive_samples[self.output_var] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,17 @@ def mmm() -> DelayedSaturatedMMM: | |
) | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_with_fourier_features() -> DelayedSaturatedMMM: | ||
return DelayedSaturatedMMM( | ||
date_column="date", | ||
channel_columns=["channel_1", "channel_2"], | ||
adstock_max_lag=4, | ||
control_columns=["control_1", "control_2"], | ||
yearly_seasonality=2, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_fitted( | ||
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series | ||
|
@@ -96,6 +107,16 @@ def mmm_fitted( | |
return mmm | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_fitted_with_fourier_features( | ||
mmm_with_fourier_features, toy_X: pd.DataFrame, toy_y: pd.Series | ||
) -> DelayedSaturatedMMM: | ||
mmm_with_fourier_features.fit( | ||
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 | ||
) | ||
return mmm_with_fourier_features | ||
|
||
|
||
class TestDelayedSaturatedMMM: | ||
def test_save_load_with_not_serializable_model_config( | ||
self, model_config_requiring_serialization, toy_X, toy_y | ||
|
@@ -456,7 +477,8 @@ def test_data_setter(self, toy_X, toy_y): | |
with pytest.raises(TypeError): | ||
base_delayed_saturated_mmm._data_setter(toy_X, y_incorrect) | ||
|
||
with pytest.raises(RuntimeError): | ||
# Missing the date column | ||
with pytest.raises(KeyError): | ||
X_wrong_df = pd.DataFrame( | ||
{"column1": np.random.rand(135), "column2": np.random.rand(135)} | ||
) | ||
|
@@ -467,12 +489,10 @@ def test_data_setter(self, toy_X, toy_y): | |
except Exception as e: | ||
pytest.fail(f"_data_setter failed with error {e}") | ||
|
||
try: | ||
with pytest.raises(TypeError, match="X must be a pandas DataFrame"): | ||
base_delayed_saturated_mmm._data_setter( | ||
X_correct_ndarray, y_correct_ndarray | ||
) | ||
except Exception as e: | ||
pytest.fail(f"_data_setter failed with error {e}") | ||
|
||
def test_save_load(self, mmm_fitted): | ||
model = mmm_fitted | ||
|
@@ -515,6 +535,61 @@ def mock_property(self): | |
DelayedSaturatedMMM.load("test_model") | ||
os.remove("test_model") | ||
|
||
@pytest.mark.parametrize( | ||
"model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"] | ||
) | ||
@pytest.mark.parametrize( | ||
"new_dates", | ||
[ | ||
# 2021-12-31 is the last date in the toy data | ||
# Old and New dates | ||
pd.date_range(start="2021-11-01", end="2022-03-01", freq="W-MON"), | ||
# Only Old dates | ||
pd.date_range(start="2019-06-01", end="2021-12-31", freq="W-MON"), | ||
# Only New dates | ||
pd.date_range(start="2022-01-01", end="2022-03-01", freq="W-MON"), | ||
# Less than the adstock_max_lag (4) of the model | ||
pd.date_range(start="2022-01-01", freq="W-MON", periods=1), | ||
], | ||
) | ||
def test_new_data_predictions( | ||
self, | ||
model_name: str, | ||
new_dates: pd.DatetimeIndex, | ||
request, | ||
) -> None: | ||
mmm = request.getfixturevalue(model_name) | ||
n = new_dates.size | ||
X_pred = pd.DataFrame( | ||
{ | ||
"date": new_dates, | ||
"channel_1": rng.integers(low=0, high=400, size=n), | ||
"channel_2": rng.integers(low=0, high=50, size=n), | ||
"control_1": rng.gamma(shape=1000, scale=500, size=n), | ||
"control_2": rng.gamma(shape=100, scale=5, size=n), | ||
"other_column_1": rng.integers(low=0, high=100, size=n), | ||
"other_column_2": rng.normal(loc=0, scale=1, size=n), | ||
} | ||
) | ||
Comment on lines
+563
to
+573
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could have this as a fixture and split this tests as it has many |
||
|
||
pp_without = mmm.predict_posterior( | ||
X_pred=X_pred, include_last_observations=False | ||
) | ||
pp_with = mmm.predict_posterior(X_pred=X_pred, include_last_observations=True) | ||
|
||
assert pp_without.coords.equals(pp_with.coords) | ||
|
||
posterior_predictive = mmm.sample_posterior_predictive( | ||
X_pred=X_pred, extend_idata=False, combined=True | ||
) | ||
pd.testing.assert_index_equal( | ||
pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates | ||
) | ||
assert posterior_predictive["likelihood"].shape[0] == new_dates.size | ||
|
||
posterior_predictive_mean = mmm.predict(X_pred=X_pred) | ||
assert posterior_predictive_mean.shape[0] == new_dates.size | ||
|
||
@pytest.mark.parametrize( | ||
argnames="model_config", | ||
argvalues=[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm viewing this as required in order to access the columns. Changing the type hint makes mypy mad
If there are any other suggestions, let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 99% of the cases people using this package use pandas, so is ok for me.