From cdcf1041e6c655a93ac6cac733361098342866be Mon Sep 17 00:00:00 2001 From: GStechschulte Date: Fri, 1 Mar 2024 16:17:01 +0100 Subject: [PATCH] tests for JAX based samplers except TFP --- tests/test_alternative_samplers.py | 30 ++++++++++-------------------- 1 file changed, 10 insertions(+), 20 deletions(-) diff --git a/tests/test_alternative_samplers.py b/tests/test_alternative_samplers.py index 816755648..a8cfc6805 100644 --- a/tests/test_alternative_samplers.py +++ b/tests/test_alternative_samplers.py @@ -1,10 +1,14 @@ import bambi as bmb +import bayeux as bx import numpy as np import pandas as pd import pytest +# Tensorflow probability based samplers do not work with Bambi models yet. +MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if "tfp" not in getattr(bx.mcmc, k).name ] + @pytest.fixture(scope="module") def data_n100(): size = 100 @@ -52,27 +56,13 @@ def test_vi(): ) -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("numpyro_nuts", {"chain_method": "vectorized"}), - ("blackjax_nuts", {"chain_method": "vectorized"}), - ], -) -def test_logistic_regression_categoric_alternative_samplers(data_n100, args): +@pytest.mark.parametrize("sampler", MCMC_METHODS) +def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler): model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli") - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler) -@pytest.mark.parametrize( - "args", - [ - ("mcmc", {}), - ("numpyro_nuts", {"chain_method": "vectorized"}), - ("blackjax_nuts", {"chain_method": "vectorized"}), - ], -) -def test_regression_alternative_samplers(data_n100, args): +@pytest.mark.parametrize("sampler", MCMC_METHODS) +def test_regression_alternative_samplers(data_n100, sampler): model = bmb.Model("n1 ~ n2", data_n100) - model.fit(tune=50, draws=50, inference_method=args[0], **args[1]) + model.fit(inference_method=sampler)