Skip to content

Commit

Permalink
add test to check for inference method names
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Apr 15, 2024
1 parent 5e059c2 commit cd80120
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion tests/test_alternative_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,17 @@ def data_n100():
return data


def test_inference_method_names():
names = bmb.inference_methods.names

# Check PyMC inference method family
assert "mcmc" in names["pymc"].keys()
assert "vi" in names["pymc"].keys()

# Check bayeu inference method family. Currently, only MCMC methods are supported
assert "mcmc" in names["bayeux"].keys()


def test_laplace():
data = pd.DataFrame(np.repeat((0, 1), (30, 60)), columns=["w"])
priors = {"Intercept": bmb.Prior("Uniform", lower=0, upper=1)}
Expand All @@ -56,7 +67,7 @@ def test_vi():
(mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2
)

#

@pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED)
def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler):
model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli")
Expand Down

0 comments on commit cd80120

Please sign in to comment.