Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mock MMM posterior with prior #582

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d5b5176
mock the posterior with prior
wd60622 Mar 10, 2024
c4f3884
rename and mock in test_plotting
wd60622 Mar 10, 2024
f085838
remove tests that wont work with prior as posterior
wd60622 Mar 14, 2024
b30f9d3
Remove scale_preserving_logistic_saturation function (#585)
ulfaslak Mar 11, 2024
6e7b701
Bump minimum Python version in environment.yml
ricardoV94 Mar 14, 2024
29fde57
Remove useless fixtures in test_gamma_gamma
ricardoV94 Mar 4, 2024
ce80f57
Make test_save_load lighter
ricardoV94 Mar 14, 2024
6993997
Run slow tests in CI
ricardoV94 Feb 29, 2024
277c7d1
Make GammaGamma.test_model_convergence more stable
ricardoV94 Mar 14, 2024
e7e3ecb
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] Mar 12, 2024
49398fe
Add test util for setting fake data in CLV models
ricardoV94 Mar 1, 2024
b9e61d2
Implement ParetoNBD with covariates
ricardoV94 Mar 1, 2024
13e6e9b
Update version.txt
ricardoV94 Mar 15, 2024
4f33440
Make data optional in all ParetoNBD methods
ricardoV94 Mar 15, 2024
97b9807
Assign coords in ParetoNBDModel
ricardoV94 Mar 15, 2024
a6b6b94
Update version.txt
ricardoV94 Mar 15, 2024
5c70875
add test setup
wd60622 Mar 16, 2024
7d0dd9c
store commit
wd60622 Mar 17, 2024
267cef2
store commit
wd60622 Mar 18, 2024
824914a
implement slow but actual fit tests
wd60622 Mar 21, 2024
4823d39
correct after failing
wd60622 Mar 21, 2024
d8c49ce
Merge branch 'main' into mock-mmm-fit
wd60622 Apr 1, 2024
873436a
Merge branch 'main' into mock-mmm-fit
wd60622 Apr 8, 2024
8513e46
add variables that make up 500
wd60622 Apr 8, 2024
34a79c8
add an expected fail
wd60622 Apr 8, 2024
0d36fd1
Merge branch 'pymc-labs:main' into mock-mmm-fit
wd60622 Apr 11, 2024
ed279c0
add additional difference between changes by masking
wd60622 Apr 11, 2024
de5d5d2
Merge branch 'main' into mock-mmm-fit
wd60622 May 2, 2024
1c86f78
Merge branch 'main' into mock-mmm-fit
wd60622 May 21, 2024
90ff335
Merge branch 'main' into mock-mmm-fit
wd60622 Jun 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 152 additions & 38 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict, List, Optional, Union

import arviz as az
import numpy as np
Expand Down Expand Up @@ -119,12 +120,34 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM:
)


def mock_fit(model, X: pd.DataFrame, y: np.ndarray, **kwargs):
model.build_model(X=X, y=y)

with model.model:
idata = pm.sample_prior_predictive(random_seed=rng, **kwargs)

model.preprocess("X", X)
model.preprocess("y", y)

idata.add_groups(
{
"posterior": idata.prior,
"fit_data": pd.concat(
[X, pd.Series(y, index=X.index, name="y")], axis=1
).to_xarray(),
}
)
model.idata = idata
model.set_idata_attrs(idata=idata)

return model


@pytest.fixture(scope="module")
def mmm_fitted(
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series
) -> DelayedSaturatedMMM:
mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2, random_seed=rng)
return mmm
return mock_fit(mmm, toy_X, toy_y.to_numpy())


@pytest.fixture(scope="module")
Expand All @@ -142,10 +165,7 @@ def mmm_fitted_with_fourier_features(
toy_X: pd.DataFrame,
toy_y: pd.Series,
) -> DelayedSaturatedMMM:
mmm_with_fourier_features.fit(
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2, random_seed=rng
)
return mmm_with_fourier_features
return mock_fit(mmm_with_fourier_features, toy_X, toy_y.to_numpy())


class TestDelayedSaturatedMMM:
Expand Down Expand Up @@ -173,9 +193,7 @@ def deep_equal(dict1, dict2):
adstock_max_lag=4,
model_config=model_config_requiring_serialization,
)
model.fit(
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng
)
model = mock_fit(model, toy_X, toy_y.to_numpy())
model.save("test_save_load")
model2 = DelayedSaturatedMMM.load("test_save_load")
assert model.date_column == model2.date_column
Expand Down Expand Up @@ -296,9 +314,6 @@ def test_init(
)

def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
draws: int = 100
chains: int = 2

mmm = BaseDelayedSaturatedMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
Expand All @@ -311,35 +326,32 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None:
assert mmm.model_config is not None
n_channel: int = len(mmm.channel_columns)
n_control: int = len(mmm.control_columns)
mmm.fit(
X=toy_X,
y=toy_y,
target_accept=0.81,
draws=draws,
chains=chains,
random_seed=rng,
)
fourier_terms: int = 2 * mmm.yearly_seasonality
mmm = mock_fit(mmm, toy_X, toy_y.to_numpy())
idata: az.InferenceData = mmm.fit_result

chains = 1
draws = 500
assert (
az.extract(data=idata, var_names=["intercept"], combined=True)
.to_numpy()
.size
== draws * chains
== chains * draws
)
assert az.extract(
data=idata, var_names=["beta_channel"], combined=True
).to_numpy().shape == (n_channel, draws * chains)
).to_numpy().shape == (n_channel, chains * draws)
assert az.extract(
data=idata, var_names=["alpha"], combined=True
).to_numpy().shape == (n_channel, draws * chains)
).to_numpy().shape == (n_channel, chains * draws)
assert az.extract(
data=idata, var_names=["lam"], combined=True
).to_numpy().shape == (n_channel, draws * chains)
).to_numpy().shape == (n_channel, chains * draws)
assert az.extract(
data=idata, var_names=["gamma_control"], combined=True
).to_numpy().shape == (
n_channel,
draws * chains,
chains * draws,
)

mean_model_contributions_ts = mmm.compute_mean_contributions_over_time(
Expand Down Expand Up @@ -522,12 +534,12 @@ def test_get_channel_contributions_forward_pass_grid_shapes(
) -> None:
n_channels = len(mmm_fitted.channel_columns)
data_range = mmm_fitted.X.shape[0]
draws = 3
chains = 2
grid_size = 2
contributions = mmm_fitted.get_channel_contributions_forward_pass_grid(
start=0, stop=1.5, num=grid_size
)
chains = 1
draws = 500
assert contributions.shape == (
grid_size,
chains,
Expand Down Expand Up @@ -566,8 +578,10 @@ def test_data_setter(self, toy_X, toy_y):
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
)
base_delayed_saturated_mmm.fit(
X=toy_X, y=toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng
base_delayed_saturated_mmm = mock_fit(
base_delayed_saturated_mmm,
toy_X,
toy_y.to_numpy(),
)

X_correct_ndarray = np.random.randint(low=0, high=100, size=(135, 2))
Expand Down Expand Up @@ -626,9 +640,7 @@ def mock_property(self):
)

# Check that the property returns the new value
DSMMM.fit(
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng
)
DSMMM = mock_fit(DSMMM, toy_X, toy_y.to_numpy())
DSMMM.save("test_model")
# Apply the monkeypatch for the property
monkeypatch.setattr(DelayedSaturatedMMM, "id", property(mock_property))
Expand Down Expand Up @@ -806,12 +818,6 @@ def test_new_data_predict_method(

assert isinstance(posterior_predictive_mean, np.ndarray)
assert posterior_predictive_mean.shape[0] == new_dates.size
# Original scale constraint
assert np.all(posterior_predictive_mean >= 0)

# Domain kept close
lower, upper = np.quantile(a=posterior_predictive_mean, q=[0.025, 0.975], axis=0)
assert lower < toy_y.mean() < upper
Comment on lines -809 to -814
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense for these tests not to work with the changes? They seem like rather useful checks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented this above. I will add these to the tests from the actual fit.
Using the normal likelihood doesn't guarantee the prior will meet these constraints. They were geared toward a fit model. Not a prior predictive



def test_get_valid_distribution(mmm):
Expand Down Expand Up @@ -883,6 +889,11 @@ def test_new_spend_contributions(mmm_fitted) -> None:


def test_new_spend_contributions_prior_error(mmm) -> None:
prior_index = [i for i, group in enumerate(mmm.idata._groups) if group == "prior"][
0
]
mmm.idata._groups.pop(prior_index)

new_spend = np.ones(len(mmm.channel_columns))
match = "sample_prior_predictive"
with pytest.raises(RuntimeError, match=match):
Expand Down Expand Up @@ -956,6 +967,109 @@ def test_plot_new_spend_contributions_prior_select_channels(
assert isinstance(ax, plt.Axes)


@pytest.fixture(scope="module")
def fixed_model_parameters() -> dict[str, Union[float, list[float]]]:
return {
"intercept": 5.0,
"beta_channel": [0.15, 0.5],
"alpha": [0.5, 0.5],
"lam": [0.5, 0.5],
"likelihood_sigma": 0.25,
"gamma_control": [0.0001, 0.005],
}


def random_mask(df: pd.DataFrame, mask_value: float = 0.0) -> pd.DataFrame:
shape = df.shape

mask = rng.choice([0, 1], size=shape, p=[0.75, 0.25])
return df.mul(mask)


@pytest.fixture(scope="module")
def masked_toy_X(toy_X) -> pd.DataFrame:
return toy_X.set_index("date").pipe(random_mask).reset_index()
Comment on lines +982 to +991
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rational here is that since all of the toy_X columns are pretty much white noise, the generated y is also pretty much a white noise and the model doesn't fit well

Here there are some clear stops in toy_X making the y notably depend on the covariates



@pytest.fixture(scope="module")
def model_generated_y(mmm, masked_toy_X, fixed_model_parameters) -> np.ndarray:
fake_y = np.ones(len(masked_toy_X))
mmm.build_model(masked_toy_X, fake_y)

fixed_model = pm.do(mmm.model, fixed_model_parameters)
return pm.draw(fixed_model["y"], random_seed=rng)


@pytest.fixture(scope="module")
def actually_fit_mmm(mmm, masked_toy_X, model_generated_y) -> DelayedSaturatedMMM:
mmm.fit(masked_toy_X, model_generated_y, random_seed=rng)
return mmm


@pytest.mark.slow
def test_mmm_sampling_stats(actually_fit_mmm) -> None:
idata = actually_fit_mmm.idata

assert idata.sample_stats.diverging.sum() == 0


@pytest.mark.slow
def test_mmm_channel_contributions_positive(actually_fit_mmm) -> None:
contributions = actually_fit_mmm.fit_result["channel_contributions"]

assert (contributions >= 0).all()


@pytest.mark.slow
def test_mmm_mean_predictions_positive(actually_fit_mmm) -> None:
"""Not required technically, but based on the model parameters."""
mean_predictions = actually_fit_mmm.fit_result["mu"]

assert (mean_predictions >= 0).all()


@pytest.mark.xfail(reason="Constantly failing")
@pytest.mark.slow
def test_mmm_fit_posterior_close_to_actual_parameters(
actually_fit_mmm, fixed_model_parameters
) -> None:
posterior = actually_fit_mmm.fit_result

assert isinstance(posterior, xr.Dataset)

hdi = az.hdi(posterior)

for parameter, actual in fixed_model_parameters.items():
hdi_parameter = hdi[parameter]

lower = hdi_parameter.sel(hdi="lower").values
upper = hdi_parameter.sel(hdi="higher").values

if isinstance(actual, float):
assert lower < actual < upper
else:
assert (lower < actual).all() and (actual < upper).all()


@pytest.mark.slow
def test_mmm_fit_better_than_naive_model(actually_fit_mmm, toy_X, toy_y) -> None:
preprocessed_y = actually_fit_mmm.preprocessed_data["y"]

preprocessed_y_mean = preprocessed_y.mean()

def mse(y_pred, *args, **kwargs):
return ((preprocessed_y - y_pred) ** 2).mean(*args, **kwargs)

posterior = actually_fit_mmm.fit_result

mse_mean_model = mse(preprocessed_y_mean)
mse_mmm_model = mse(posterior["mu"], "date")

mmm_sample_is_better = mse_mmm_model < mse_mean_model

assert mmm_sample_is_better.all()


@pytest.fixture
def df_lift_test() -> pd.DataFrame:
return pd.DataFrame(
Expand Down
25 changes: 21 additions & 4 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import numpy as np
import pandas as pd
import pymc as pm
import pytest
from matplotlib import pyplot as plt

Expand Down Expand Up @@ -49,6 +50,25 @@ def toy_y(toy_X) -> pd.Series:
return pd.Series(rng.integers(low=0, high=100, size=toy_X.shape[0]))


def mock_fit(model, X: pd.DataFrame, y: np.ndarray, **kwargs):
model.build_model(X=X, y=y)
with model.model:
idata = pm.sample_prior_predictive(random_seed=rng, **kwargs)

idata.add_groups(
{
"posterior": idata.prior,
"fit_data": pd.concat(
[X, pd.Series(y, index=X.index, name="y")], axis=1
).to_xarray(),
}
)
model.idata = idata
model.set_idata_attrs(idata=idata)

return model


class TestBasePlotting:
@pytest.fixture(
scope="module",
Expand Down Expand Up @@ -85,10 +105,7 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
channel_columns=["channel_1", "channel_2"],
)
# fit the model
mmm.fit(
X=toy_X,
y=toy_y,
)
mmm = mock_fit(mmm, toy_X, toy_y.to_numpy())
mmm.sample_prior_predictive(toy_X, toy_y, extend_idata=True, combined=True)
mmm.sample_posterior_predictive(toy_X, extend_idata=True, combined=True)
mmm._prior_predictive = mmm.prior_predictive
Expand Down
Loading