From cd80120ee594b7cb8817c7714e4fb0db6c30eaa6 Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Mon, 15 Apr 2024 18:20:59 +0200 Subject: [PATCH] add test to check for inference method names --- tests/test_alternative_samplers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index 6222f3df3..f75cd4f6b 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -30,6 +30,17 @@ def data_n100(): return data +def test_inference_method_names(): + names = bmb.inference_methods.names + + # Check PyMC inference method family + assert "mcmc" in names["pymc"].keys() + assert "vi" in names["pymc"].keys() + + # Check bayeu inference method family. Currently, only MCMC methods are supported + assert "mcmc" in names["bayeux"].keys() + + def test_laplace(): data = pd.DataFrame(np.repeat((0, 1), (30, 60)), columns=["w"]) priors = {"Intercept": bmb.Prior("Uniform", lower=0, upper=1)} @@ -56,7 +67,7 @@ def test_vi(): (mode_n.item(), std_n.item()), (mode_a.item(), std_a.item()), decimal=2 ) -# + @pytest.mark.parametrize("sampler", MCMC_METHODS_FILTERED) def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli")