diff --git a/tests/mmm/test_plotting.py b/tests/mmm/test_plotting.py index 4d4a8cf6..b94961e6 100644 --- a/tests/mmm/test_plotting.py +++ b/tests/mmm/test_plotting.py @@ -75,8 +75,11 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget): mmm.fit( X=toy_X, y=toy_y, - prior_predictive=True, - posterior_predictive=True, + sample_kwargs={ + "prior_kwargs": {"samples": 100}, + "prior_predictive": True, + "posterior_predictive": True, + }, ) mmm._prior_predictive = mmm.prior_predictive mmm._fit_result = mmm.fit_result