-
Notifications
You must be signed in to change notification settings - Fork 192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
mock MMM posterior with prior #582
Closed
Closed
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
d5b5176
mock the posterior with prior
wd60622 c4f3884
rename and mock in test_plotting
wd60622 f085838
remove tests that wont work with prior as posterior
wd60622 b30f9d3
Remove scale_preserving_logistic_saturation function (#585)
ulfaslak 6e7b701
Bump minimum Python version in environment.yml
ricardoV94 29fde57
Remove useless fixtures in test_gamma_gamma
ricardoV94 ce80f57
Make test_save_load lighter
ricardoV94 6993997
Run slow tests in CI
ricardoV94 277c7d1
Make GammaGamma.test_model_convergence more stable
ricardoV94 e7e3ecb
[pre-commit.ci] pre-commit autoupdate
pre-commit-ci[bot] 49398fe
Add test util for setting fake data in CLV models
ricardoV94 b9e61d2
Implement ParetoNBD with covariates
ricardoV94 13e6e9b
Update version.txt
ricardoV94 4f33440
Make data optional in all ParetoNBD methods
ricardoV94 97b9807
Assign coords in ParetoNBDModel
ricardoV94 a6b6b94
Update version.txt
ricardoV94 5c70875
add test setup
wd60622 7d0dd9c
store commit
wd60622 267cef2
store commit
wd60622 824914a
implement slow but actual fit tests
wd60622 4823d39
correct after failing
wd60622 d8c49ce
Merge branch 'main' into mock-mmm-fit
wd60622 873436a
Merge branch 'main' into mock-mmm-fit
wd60622 8513e46
add variables that make up 500
wd60622 34a79c8
add an expected fail
wd60622 0d36fd1
Merge branch 'pymc-labs:main' into mock-mmm-fit
wd60622 ed279c0
add additional difference between changes by masking
wd60622 de5d5d2
Merge branch 'main' into mock-mmm-fit
wd60622 1c86f78
Merge branch 'main' into mock-mmm-fit
wd60622 90ff335
Merge branch 'main' into mock-mmm-fit
wd60622 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import os | ||
from typing import Dict, List, Optional, Union | ||
|
||
import arviz as az | ||
import numpy as np | ||
|
@@ -119,12 +120,34 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: | |
) | ||
|
||
|
||
def mock_fit(model, X: pd.DataFrame, y: np.ndarray, **kwargs): | ||
model.build_model(X=X, y=y) | ||
|
||
with model.model: | ||
idata = pm.sample_prior_predictive(random_seed=rng, **kwargs) | ||
|
||
model.preprocess("X", X) | ||
model.preprocess("y", y) | ||
|
||
idata.add_groups( | ||
{ | ||
"posterior": idata.prior, | ||
"fit_data": pd.concat( | ||
[X, pd.Series(y, index=X.index, name="y")], axis=1 | ||
).to_xarray(), | ||
} | ||
) | ||
model.idata = idata | ||
model.set_idata_attrs(idata=idata) | ||
|
||
return model | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def mmm_fitted( | ||
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series | ||
) -> DelayedSaturatedMMM: | ||
mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2, random_seed=rng) | ||
return mmm | ||
return mock_fit(mmm, toy_X, toy_y.to_numpy()) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
|
@@ -142,10 +165,7 @@ def mmm_fitted_with_fourier_features( | |
toy_X: pd.DataFrame, | ||
toy_y: pd.Series, | ||
) -> DelayedSaturatedMMM: | ||
mmm_with_fourier_features.fit( | ||
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2, random_seed=rng | ||
) | ||
return mmm_with_fourier_features | ||
return mock_fit(mmm_with_fourier_features, toy_X, toy_y.to_numpy()) | ||
|
||
|
||
class TestDelayedSaturatedMMM: | ||
|
@@ -173,9 +193,7 @@ def deep_equal(dict1, dict2): | |
adstock_max_lag=4, | ||
model_config=model_config_requiring_serialization, | ||
) | ||
model.fit( | ||
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng | ||
) | ||
model = mock_fit(model, toy_X, toy_y.to_numpy()) | ||
model.save("test_save_load") | ||
model2 = DelayedSaturatedMMM.load("test_save_load") | ||
assert model.date_column == model2.date_column | ||
|
@@ -296,9 +314,6 @@ def test_init( | |
) | ||
|
||
def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None: | ||
draws: int = 100 | ||
chains: int = 2 | ||
|
||
mmm = BaseDelayedSaturatedMMM( | ||
date_column="date", | ||
channel_columns=["channel_1", "channel_2"], | ||
|
@@ -311,35 +326,32 @@ def test_fit(self, toy_X: pd.DataFrame, toy_y: pd.Series) -> None: | |
assert mmm.model_config is not None | ||
n_channel: int = len(mmm.channel_columns) | ||
n_control: int = len(mmm.control_columns) | ||
mmm.fit( | ||
X=toy_X, | ||
y=toy_y, | ||
target_accept=0.81, | ||
draws=draws, | ||
chains=chains, | ||
random_seed=rng, | ||
) | ||
fourier_terms: int = 2 * mmm.yearly_seasonality | ||
mmm = mock_fit(mmm, toy_X, toy_y.to_numpy()) | ||
idata: az.InferenceData = mmm.fit_result | ||
|
||
chains = 1 | ||
draws = 500 | ||
assert ( | ||
az.extract(data=idata, var_names=["intercept"], combined=True) | ||
.to_numpy() | ||
.size | ||
== draws * chains | ||
== chains * draws | ||
) | ||
assert az.extract( | ||
data=idata, var_names=["beta_channel"], combined=True | ||
).to_numpy().shape == (n_channel, draws * chains) | ||
).to_numpy().shape == (n_channel, chains * draws) | ||
assert az.extract( | ||
data=idata, var_names=["alpha"], combined=True | ||
).to_numpy().shape == (n_channel, draws * chains) | ||
).to_numpy().shape == (n_channel, chains * draws) | ||
assert az.extract( | ||
data=idata, var_names=["lam"], combined=True | ||
).to_numpy().shape == (n_channel, draws * chains) | ||
).to_numpy().shape == (n_channel, chains * draws) | ||
assert az.extract( | ||
data=idata, var_names=["gamma_control"], combined=True | ||
).to_numpy().shape == ( | ||
n_channel, | ||
draws * chains, | ||
chains * draws, | ||
) | ||
|
||
mean_model_contributions_ts = mmm.compute_mean_contributions_over_time( | ||
|
@@ -522,12 +534,12 @@ def test_get_channel_contributions_forward_pass_grid_shapes( | |
) -> None: | ||
n_channels = len(mmm_fitted.channel_columns) | ||
data_range = mmm_fitted.X.shape[0] | ||
draws = 3 | ||
chains = 2 | ||
grid_size = 2 | ||
contributions = mmm_fitted.get_channel_contributions_forward_pass_grid( | ||
start=0, stop=1.5, num=grid_size | ||
) | ||
chains = 1 | ||
draws = 500 | ||
assert contributions.shape == ( | ||
grid_size, | ||
chains, | ||
|
@@ -566,8 +578,10 @@ def test_data_setter(self, toy_X, toy_y): | |
channel_columns=["channel_1", "channel_2"], | ||
adstock_max_lag=4, | ||
) | ||
base_delayed_saturated_mmm.fit( | ||
X=toy_X, y=toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng | ||
base_delayed_saturated_mmm = mock_fit( | ||
base_delayed_saturated_mmm, | ||
toy_X, | ||
toy_y.to_numpy(), | ||
) | ||
|
||
X_correct_ndarray = np.random.randint(low=0, high=100, size=(135, 2)) | ||
|
@@ -626,9 +640,7 @@ def mock_property(self): | |
) | ||
|
||
# Check that the property returns the new value | ||
DSMMM.fit( | ||
toy_X, toy_y, target_accept=0.81, draws=100, chains=2, random_seed=rng | ||
) | ||
DSMMM = mock_fit(DSMMM, toy_X, toy_y.to_numpy()) | ||
DSMMM.save("test_model") | ||
# Apply the monkeypatch for the property | ||
monkeypatch.setattr(DelayedSaturatedMMM, "id", property(mock_property)) | ||
|
@@ -806,12 +818,6 @@ def test_new_data_predict_method( | |
|
||
assert isinstance(posterior_predictive_mean, np.ndarray) | ||
assert posterior_predictive_mean.shape[0] == new_dates.size | ||
# Original scale constraint | ||
assert np.all(posterior_predictive_mean >= 0) | ||
|
||
# Domain kept close | ||
lower, upper = np.quantile(a=posterior_predictive_mean, q=[0.025, 0.975], axis=0) | ||
assert lower < toy_y.mean() < upper | ||
|
||
|
||
def test_get_valid_distribution(mmm): | ||
|
@@ -883,6 +889,11 @@ def test_new_spend_contributions(mmm_fitted) -> None: | |
|
||
|
||
def test_new_spend_contributions_prior_error(mmm) -> None: | ||
prior_index = [i for i, group in enumerate(mmm.idata._groups) if group == "prior"][ | ||
0 | ||
] | ||
mmm.idata._groups.pop(prior_index) | ||
|
||
new_spend = np.ones(len(mmm.channel_columns)) | ||
match = "sample_prior_predictive" | ||
with pytest.raises(RuntimeError, match=match): | ||
|
@@ -956,6 +967,109 @@ def test_plot_new_spend_contributions_prior_select_channels( | |
assert isinstance(ax, plt.Axes) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def fixed_model_parameters() -> dict[str, Union[float, list[float]]]: | ||
return { | ||
"intercept": 5.0, | ||
"beta_channel": [0.15, 0.5], | ||
"alpha": [0.5, 0.5], | ||
"lam": [0.5, 0.5], | ||
"likelihood_sigma": 0.25, | ||
"gamma_control": [0.0001, 0.005], | ||
} | ||
|
||
|
||
def random_mask(df: pd.DataFrame, mask_value: float = 0.0) -> pd.DataFrame: | ||
shape = df.shape | ||
|
||
mask = rng.choice([0, 1], size=shape, p=[0.75, 0.25]) | ||
return df.mul(mask) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def masked_toy_X(toy_X) -> pd.DataFrame: | ||
return toy_X.set_index("date").pipe(random_mask).reset_index() | ||
Comment on lines
+982
to
+991
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rational here is that since all of the toy_X columns are pretty much white noise, the generated y is also pretty much a white noise and the model doesn't fit well Here there are some clear stops in toy_X making the y notably depend on the covariates |
||
|
||
|
||
@pytest.fixture(scope="module") | ||
def model_generated_y(mmm, masked_toy_X, fixed_model_parameters) -> np.ndarray: | ||
fake_y = np.ones(len(masked_toy_X)) | ||
mmm.build_model(masked_toy_X, fake_y) | ||
|
||
fixed_model = pm.do(mmm.model, fixed_model_parameters) | ||
return pm.draw(fixed_model["y"], random_seed=rng) | ||
|
||
|
||
@pytest.fixture(scope="module") | ||
def actually_fit_mmm(mmm, masked_toy_X, model_generated_y) -> DelayedSaturatedMMM: | ||
mmm.fit(masked_toy_X, model_generated_y, random_seed=rng) | ||
return mmm | ||
|
||
|
||
@pytest.mark.slow | ||
def test_mmm_sampling_stats(actually_fit_mmm) -> None: | ||
idata = actually_fit_mmm.idata | ||
|
||
assert idata.sample_stats.diverging.sum() == 0 | ||
|
||
|
||
@pytest.mark.slow | ||
def test_mmm_channel_contributions_positive(actually_fit_mmm) -> None: | ||
contributions = actually_fit_mmm.fit_result["channel_contributions"] | ||
|
||
assert (contributions >= 0).all() | ||
|
||
|
||
@pytest.mark.slow | ||
def test_mmm_mean_predictions_positive(actually_fit_mmm) -> None: | ||
"""Not required technically, but based on the model parameters.""" | ||
mean_predictions = actually_fit_mmm.fit_result["mu"] | ||
|
||
assert (mean_predictions >= 0).all() | ||
|
||
|
||
@pytest.mark.xfail(reason="Constantly failing") | ||
@pytest.mark.slow | ||
def test_mmm_fit_posterior_close_to_actual_parameters( | ||
actually_fit_mmm, fixed_model_parameters | ||
) -> None: | ||
posterior = actually_fit_mmm.fit_result | ||
|
||
assert isinstance(posterior, xr.Dataset) | ||
|
||
hdi = az.hdi(posterior) | ||
|
||
for parameter, actual in fixed_model_parameters.items(): | ||
hdi_parameter = hdi[parameter] | ||
|
||
lower = hdi_parameter.sel(hdi="lower").values | ||
upper = hdi_parameter.sel(hdi="higher").values | ||
|
||
if isinstance(actual, float): | ||
assert lower < actual < upper | ||
else: | ||
assert (lower < actual).all() and (actual < upper).all() | ||
|
||
|
||
@pytest.mark.slow | ||
def test_mmm_fit_better_than_naive_model(actually_fit_mmm, toy_X, toy_y) -> None: | ||
preprocessed_y = actually_fit_mmm.preprocessed_data["y"] | ||
|
||
preprocessed_y_mean = preprocessed_y.mean() | ||
|
||
def mse(y_pred, *args, **kwargs): | ||
return ((preprocessed_y - y_pred) ** 2).mean(*args, **kwargs) | ||
|
||
posterior = actually_fit_mmm.fit_result | ||
|
||
mse_mean_model = mse(preprocessed_y_mean) | ||
mse_mmm_model = mse(posterior["mu"], "date") | ||
|
||
mmm_sample_is_better = mse_mmm_model < mse_mean_model | ||
|
||
assert mmm_sample_is_better.all() | ||
|
||
|
||
@pytest.fixture | ||
def df_lift_test() -> pd.DataFrame: | ||
return pd.DataFrame( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it make sense for these tests not to work with the changes? They seem like rather useful checks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented this above. I will add these to the tests from the actual fit.
Using the normal likelihood doesn't guarantee the prior will meet these constraints. They were geared toward a fit model. Not a prior predictive