Skip to content

Commit

Permalink
make tests run faster
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Aug 18, 2023
1 parent 4c9f4ca commit 9b0a247
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/mmm/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.fixture(scope="module")
def toy_X() -> pd.DataFrame:
date_data: pd.DatetimeIndex = pd.date_range(
start="2019-06-01", end="2021-12-31", freq="W-MON"
start="2020-06-01", end="2021-12-31", freq="W-MON"
)

n: int = date_data.size
Expand Down Expand Up @@ -62,19 +62,21 @@ class ToyMMM(BaseDelayedSaturatedMMM, MaxAbsScaleTarget):
mmm = ToyMMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
adstock_max_lag=2,
)
elif control == "with_controls":
mmm = ToyMMM(
date_column="date",
adstock_max_lag=4,
adstock_max_lag=2,
control_columns=["control_1", "control_2"],
channel_columns=["channel_1", "channel_2"],
)
# fit the model
mmm.fit(
X=toy_X,
y=toy_y,
chains=2,
draws=20,
)
mmm._prior_predictive = mmm.prior_predictive
mmm._fit_result = mmm.fit_result
Expand Down

0 comments on commit 9b0a247

Please sign in to comment.