Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Nov 14, 2023
1 parent 7dd936d commit e003f73
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,20 +1216,22 @@ def test_basic(self, beetle_data):


def test_wald_family(data_n100):
model = bmb.Model("y1 ~ y2", data_n100, family="wald", link="log")
idata = model.fit(tune=DRAWS, draws=DRAWS)
data_n100["y"] = np.exp(data_n100["y1"])
priors = {"common": bmb.Prior("Normal", mu=0, sigma=1)}
model = bmb.Model("y ~ y2", data_n100, family="wald", link="log", priors=priors)
idata = model.fit(tune=DRAWS, draws=DRAWS, random_seed=1234)

model.predict(idata, kind="mean")
model.predict(idata, kind="pps")

assert (0 < idata.posterior["y1_mean"]).all()
assert (0 < idata.posterior_predictive["y1"]).all()
assert (0 < idata.posterior["y_mean"]).all()
assert (0 < idata.posterior_predictive["y"]).all()

model.predict(idata, kind="mean", data=data_n100.iloc[:20, :])
model.predict(idata, kind="pps", data=data_n100.iloc[:20, :])

assert (0 < idata.posterior["y1_mean"]).all()
assert (0 < idata.posterior_predictive["y1"]).all()
assert (0 < idata.posterior["y_mean"]).all()
assert (0 < idata.posterior_predictive["y"]).all()


def test_predict_include_group_specific():
Expand Down Expand Up @@ -1306,7 +1308,7 @@ def test_predict_new_groups_fail(sleepstudy):
pd.DataFrame({"Days": [1, 2, 3], "Subject": ["x", "y", "z"]}),
),
(
"inhaler",
"inhaler_data",
"rating ~ 1 + period + treat + (1 + treat|subject)",
"categorical",
pd.DataFrame(
Expand Down

0 comments on commit e003f73

Please sign in to comment.