-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support New Data for MMM model #444
Changes from 7 commits
27f8ce2
72af239
6af2fda
1fbb966
b3484fc
3b292e3
4626348
2cd0b31
f99600b
c49ab8e
f85a375
f8f2a8e
91a15c6
f41a1f2
ba62ce1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,8 +99,10 @@ def _generate_and_preprocess_model_data( # type: ignore | |
""" | ||
date_data = X[self.date_column] | ||
channel_data = X[self.channel_columns] | ||
coords: Dict[str, Any] = { | ||
self.coords_mutable: Dict[str, Any] = { | ||
"date": date_data, | ||
} | ||
coords: Dict[str, Any] = { | ||
"channel": self.channel_columns, | ||
} | ||
|
||
|
@@ -174,7 +176,9 @@ def build_model( | |
""" | ||
model_config = self.model_config | ||
self._generate_and_preprocess_model_data(X, y) | ||
with pm.Model(coords=self.model_coords) as self.model: | ||
with pm.Model( | ||
coords=self.model_coords, coords_mutable=self.coords_mutable | ||
) as self.model: | ||
channel_data_ = pm.MutableData( | ||
name="channel_data", | ||
value=self.preprocessed_data["X"][self.channel_columns].to_numpy(), | ||
|
@@ -484,19 +488,41 @@ def _data_setter( | |
------- | ||
None | ||
""" | ||
if not isinstance(X, pd.DataFrame): | ||
raise TypeError( | ||
"X must be a pandas DataFrame in order to access the columns" | ||
) | ||
new_channel_data: Optional[np.ndarray] = None | ||
coords = {"date": X[self.date_column].to_numpy()} | ||
|
||
if isinstance(X, pd.DataFrame): | ||
try: | ||
new_channel_data = X[self.channel_columns].to_numpy() | ||
except KeyError as e: | ||
raise RuntimeError("New data must contain channel_data!", e) | ||
elif isinstance(X, np.ndarray): | ||
new_channel_data = X | ||
else: | ||
raise TypeError("X must be either a pandas DataFrame or a numpy array") | ||
try: | ||
new_channel_data = X[self.channel_columns].to_numpy() | ||
except KeyError as e: | ||
raise RuntimeError("New data must contain channel_data!", e) | ||
Comment on lines
+627
to
+628
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be the error raises by any missing keys of the dataframe? For instance, date_column or control_columns There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think so! I suggest that, for this iteration, we keep this strict and then wait for feedback. |
||
|
||
data: Dict[str, Union[np.ndarray, Any]] = {"channel_data": new_channel_data} | ||
def identity(x): | ||
return x | ||
|
||
channel_transformation = ( | ||
identity | ||
if not hasattr(self, "channel_transformer") | ||
else self.channel_transformer.transform | ||
) | ||
data: Dict[str, Union[np.ndarray, Any]] = { | ||
"channel_data": channel_transformation(new_channel_data) | ||
} | ||
|
||
if self.control_columns is not None: | ||
control_data = X[self.control_columns].to_numpy() | ||
control_transformation = ( | ||
identity | ||
if not hasattr(self, "control_transformer") | ||
else self.control_transformer.transform | ||
) | ||
data["control_data"] = control_transformation(control_data) | ||
|
||
if hasattr(self, "fourier_columns"): | ||
data["fourier_data"] = self._get_fourier_models_data(X) | ||
|
||
if y is not None: | ||
if isinstance(y, pd.Series): | ||
|
@@ -507,9 +533,12 @@ def _data_setter( | |
data["target"] = y | ||
else: | ||
raise TypeError("y must be either a pandas Series or a numpy array") | ||
else: | ||
dtype = self.preprocessed_data["y"].dtype # type: ignore | ||
data["target"] = np.zeros(X.shape[0], dtype=dtype) # type: ignore | ||
|
||
with self.model: | ||
pm.set_data(data) | ||
pm.set_data(data, coords=coords) | ||
|
||
@classmethod | ||
def _model_config_formatting(cls, model_config: Dict) -> Dict: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,6 +80,17 @@ def mmm() -> DelayedSaturatedMMM: | |
) | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_with_fourier_features() -> DelayedSaturatedMMM: | ||
return DelayedSaturatedMMM( | ||
date_column="date", | ||
channel_columns=["channel_1", "channel_2"], | ||
adstock_max_lag=4, | ||
control_columns=["control_1", "control_2"], | ||
yearly_seasonality=2, | ||
) | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_fitted( | ||
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series | ||
|
@@ -88,6 +99,16 @@ def mmm_fitted( | |
return mmm | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def mmm_fitted_with_fourier_features( | ||
mmm_with_fourier_features, toy_X: pd.DataFrame, toy_y: pd.Series | ||
) -> DelayedSaturatedMMM: | ||
mmm_with_fourier_features.fit( | ||
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 | ||
) | ||
return mmm_with_fourier_features | ||
|
||
|
||
class TestDelayedSaturatedMMM: | ||
def test_save_load_with_not_serializable_model_config( | ||
self, model_config_requiring_serialization, toy_X, toy_y | ||
|
@@ -448,9 +469,21 @@ def test_data_setter(self, toy_X, toy_y): | |
with pytest.raises(TypeError): | ||
base_delayed_saturated_mmm._data_setter(toy_X, y_incorrect) | ||
|
||
# Missing the date column | ||
with pytest.raises(KeyError): | ||
X_wrong_df = pd.DataFrame( | ||
{"column_1": np.random.rand(135), "column_2": np.random.rand(135)} | ||
) | ||
base_delayed_saturated_mmm._data_setter(X_wrong_df, toy_y) | ||
|
||
# Missing a channel column (and not date) | ||
with pytest.raises(RuntimeError): | ||
X_wrong_df = pd.DataFrame( | ||
{"column1": np.random.rand(135), "column2": np.random.rand(135)} | ||
{ | ||
"date": pd.to_datetime("2023-01-01"), | ||
"column1": np.random.rand(135), | ||
"column2": np.random.rand(135), | ||
} | ||
) | ||
base_delayed_saturated_mmm._data_setter(X_wrong_df, toy_y) | ||
|
||
|
@@ -459,12 +492,10 @@ def test_data_setter(self, toy_X, toy_y): | |
except Exception as e: | ||
pytest.fail(f"_data_setter failed with error {e}") | ||
|
||
try: | ||
with pytest.raises(TypeError, match="X must be a pandas DataFrame"): | ||
base_delayed_saturated_mmm._data_setter( | ||
X_correct_ndarray, y_correct_ndarray | ||
) | ||
except Exception as e: | ||
pytest.fail(f"_data_setter failed with error {e}") | ||
|
||
def test_save_load(self, mmm_fitted): | ||
model = mmm_fitted | ||
|
@@ -506,3 +537,55 @@ def mock_property(self): | |
): | ||
DelayedSaturatedMMM.load("test_model") | ||
os.remove("test_model") | ||
|
||
@pytest.mark.parametrize( | ||
"model_name", ["mmm_fitted", "mmm_fitted_with_fourier_features"] | ||
) | ||
@pytest.mark.parametrize( | ||
"new_dates", | ||
[ | ||
# 2021-12-31 is the last date in the toy data | ||
# Old and New dates | ||
pd.date_range(start="2021-11-01", end="2022-03-01", freq="W-MON"), | ||
# Only Old dates | ||
pd.date_range(start="2019-06-01", end="2021-12-31", freq="W-MON"), | ||
# Only New dates | ||
pd.date_range(start="2022-01-01", end="2022-03-01", freq="W-MON"), | ||
# Less than the adstock_max_lag (4) of the model | ||
pd.date_range(start="2022-01-01", freq="W-MON", periods=1), | ||
], | ||
) | ||
def test_new_data_predictions( | ||
self, | ||
model_name: str, | ||
mmm_fitted: DelayedSaturatedMMM, | ||
new_dates: pd.DatetimeIndex, | ||
request, | ||
) -> None: | ||
mmm = request.getfixturevalue(model_name) | ||
n = new_dates.size | ||
new_X = pd.DataFrame( | ||
{ | ||
"date": new_dates, | ||
"channel_1": rng.integers(low=0, high=400, size=n), | ||
"channel_2": rng.integers(low=0, high=50, size=n), | ||
"control_1": rng.gamma(shape=1000, scale=500, size=n), | ||
"control_2": rng.gamma(shape=100, scale=5, size=n), | ||
"other_column_1": rng.integers(low=0, high=100, size=n), | ||
"other_column_2": rng.normal(loc=0, scale=1, size=n), | ||
} | ||
) | ||
|
||
with pytest.raises( | ||
TypeError, | ||
match=r"The DType <class 'numpy.dtype\[datetime64\]'> could not be promoted by", | ||
): | ||
mmm.predict_posterior(X_pred=new_X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as noted in the comment, the ModelBuilder method will need an override if this is to work with dates in the input DataFrame |
||
|
||
posterior_predictive = mmm.sample_posterior_predictive( | ||
X_pred=new_X, extend_idata=False, combined=True | ||
) | ||
pd.testing.assert_index_equal( | ||
pd.DatetimeIndex(posterior_predictive.coords["date"]), new_dates | ||
) | ||
assert posterior_predictive["likelihood"].shape[0] == new_dates.size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm viewing this as required in order to access the columns. Changing the type hint makes mypy mad
If there are any other suggestions, let me know
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 99% of the cases people using this package use pandas, so is ok for me.