Skip to content

Commit

Permalink
tests for JAX based samplers except TFP
Browse files Browse the repository at this point in the history
  • Loading branch information
GStechschulte committed Mar 1, 2024
1 parent 1147d96 commit cdcf104
Showing 1 changed file with 10 additions and 20 deletions.
30 changes: 10 additions & 20 deletions tests/test_alternative_samplers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import bambi as bmb
import bayeux as bx
import numpy as np
import pandas as pd

import pytest


# Tensorflow probability based samplers do not work with Bambi models yet.
MCMC_METHODS = [getattr(bx.mcmc, k).name for k in bx.mcmc.__all__ if "tfp" not in getattr(bx.mcmc, k).name ]

@pytest.fixture(scope="module")
def data_n100():
size = 100
Expand Down Expand Up @@ -52,27 +56,13 @@ def test_vi():
)


@pytest.mark.parametrize(
"args",
[
("mcmc", {}),
("numpyro_nuts", {"chain_method": "vectorized"}),
("blackjax_nuts", {"chain_method": "vectorized"}),
],
)
def test_logistic_regression_categoric_alternative_samplers(data_n100, args):
@pytest.mark.parametrize("sampler", MCMC_METHODS)
def test_logistic_regression_categoric_alternative_samplers(data_n100, sampler):
model = bmb.Model("b1 ~ n1", data_n100, family="bernoulli")
model.fit(tune=50, draws=50, inference_method=args[0], **args[1])
model.fit(inference_method=sampler)


@pytest.mark.parametrize(
"args",
[
("mcmc", {}),
("numpyro_nuts", {"chain_method": "vectorized"}),
("blackjax_nuts", {"chain_method": "vectorized"}),
],
)
def test_regression_alternative_samplers(data_n100, args):
@pytest.mark.parametrize("sampler", MCMC_METHODS)
def test_regression_alternative_samplers(data_n100, sampler):
model = bmb.Model("n1 ~ n2", data_n100)
model.fit(tune=50, draws=50, inference_method=args[0], **args[1])
model.fit(inference_method=sampler)

0 comments on commit cdcf104

Please sign in to comment.