From 9b0a2470ba8f807d9b58dd12b457a4a903db07d2 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Fri, 18 Aug 2023 21:29:35 +0200 Subject: [PATCH] make tests run faster --- tests/mmm/test_plotting.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/mmm/test_plotting.py b/tests/mmm/test_plotting.py index 1a88d1c8..dd1aef73 100644 --- a/tests/mmm/test_plotting.py +++ b/tests/mmm/test_plotting.py @@ -13,7 +13,7 @@ @pytest.fixture(scope="module") def toy_X() -> pd.DataFrame: date_data: pd.DatetimeIndex = pd.date_range( - start="2019-06-01", end="2021-12-31", freq="W-MON" + start="2020-06-01", end="2021-12-31", freq="W-MON" ) n: int = date_data.size @@ -62,12 +62,12 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget): mmm = ToyMMM( date_column="date", channel_columns=["channel_1", "channel_2"], - adstock_max_lag=4, + adstock_max_lag=2, ) elif control == "with_controls": mmm = ToyMMM( date_column="date", - adstock_max_lag=4, + adstock_max_lag=2, control_columns=["control_1", "control_2"], channel_columns=["channel_1", "channel_2"], ) @@ -75,6 +75,8 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget): mmm.fit( X=toy_X, y=toy_y, + chains=2, + draws=20, ) mmm._prior_predictive = mmm.prior_predictive mmm._fit_result = mmm.fit_result