Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use bayeux to access a wide range of samplers #775

Merged
merged 34 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
44966d3
use bayeux to access a wide range of samplers
GStechschulte Feb 4, 2024
061a1b0
use bayeux to access a wide range of samplers
GStechschulte Feb 4, 2024
8afe534
add notebook links to family table (#774)
GStechschulte Feb 4, 2024
9f1d9d1
access methods programatically
GStechschulte Feb 5, 2024
9b42fc2
clean bayeux idata to be consistent with pymc model coords
GStechschulte Feb 10, 2024
91ce2a0
rename alternative sampler args in tests
GStechschulte Feb 10, 2024
89a2aee
change docstring to reflect bayeux sampler names
GStechschulte Feb 10, 2024
d6058ad
bayeux dependencies are numpyro/jax/jaxlib/blackjax
GStechschulte Feb 10, 2024
722c8b5
rename idata coords and dims to PyMC model
GStechschulte Feb 19, 2024
ccc2877
add JAX based sampler dependencies
GStechschulte Feb 19, 2024
74b4e8b
Update code of conduct (#783)
tomicapretto Feb 21, 2024
47bb161
[WIP] Fix HSGP predictions (#780)
tomicapretto Feb 29, 2024
9f6fc2a
bayeux 0.1.9 updates
GStechschulte Mar 1, 2024
10bb508
bump bayeux version
GStechschulte Mar 1, 2024
f7bf97f
remove TFP methods, optimizers, and resolve pylint errors
GStechschulte Mar 1, 2024
1147d96
alternative backends docs
GStechschulte Mar 1, 2024
cdcf104
tests for JAX based samplers except TFP
GStechschulte Mar 1, 2024
bf1e478
add TFP backend example
GStechschulte Mar 1, 2024
27a41e6
add TFP MCMC methods
GStechschulte Mar 1, 2024
98f7da8
don't use flowmc, chees, meads for categorical model
GStechschulte Mar 3, 2024
4ae1092
call model.backend.inference_methods to show list of samplers
GStechschulte Mar 3, 2024
81936a2
docstring changes
GStechschulte Mar 3, 2024
f6d8894
inference_methods attribute and change JAX random seed
GStechschulte Mar 3, 2024
02d1df6
Add FutureWarning to inference_method parameter
GStechschulte Mar 4, 2024
dd278d4
black formatting and resolve pylint errors
GStechschulte Mar 4, 2024
b0e94a4
fix package name
GStechschulte Mar 4, 2024
65fd945
drop 3.9 and add 3.12 to testing matrix
GStechschulte Mar 19, 2024
4712f1a
change Python versions in requires-python and target-version
GStechschulte Mar 19, 2024
d508214
remove python 3.11 black target-version
GStechschulte Mar 19, 2024
1d05684
pin requires-python to <3.13
GStechschulte Mar 19, 2024
f06715e
pip upgrade setuptools
GStechschulte Mar 19, 2024
ef575d3
Bump PyMC to 5.12
tomicapretto Mar 28, 2024
9bf90a6
Upgrade black and pylint
tomicapretto Mar 28, 2024
9f9d769
remove upgrading of setup tools
GStechschulte Mar 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions bambi/backend/pymc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
import re
import traceback


Expand All @@ -23,6 +24,23 @@
__version__ = version("bambi")


PYMC_SAMPLERS = ["mcmc"]
BAYEUX_SAMPLERS = [
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
"blackjax_hmc",
"blackjax_chees_hmc",
"blackjax_meads_hmc",
"blackjax_nuts",
"blackjax_hmc_pathfinder",
"blackjax_nuts_pathfinder",
"flowmc_rqspline_hmc",
"flowmc_rqspline_mala",
"flowmc_realnvp_hmc",
"flowmc_realnvp_mala",
"numpyro_hmc",
"numpyro_nuts",
]


class PyMCModel:
"""PyMC model-fitting backend."""

Expand Down Expand Up @@ -95,7 +113,7 @@ def run(
"""Run PyMC sampler."""
inference_method = inference_method.lower()
# NOTE: Methods return different types of objects (idata, approximation, and dictionary)
if inference_method in ["mcmc", "nuts_numpyro", "nuts_blackjax"]:
if inference_method in (PYMC_SAMPLERS + BAYEUX_SAMPLERS):
result = self._run_mcmc(
draws,
tune,
Expand Down Expand Up @@ -169,8 +187,8 @@ def _run_mcmc(
sampler_backend="mcmc",
**kwargs,
):
with self.model:
if sampler_backend == "mcmc":
if sampler_backend in PYMC_SAMPLERS:
with self.model:
try:
idata = pm.sample(
draws=draws,
Expand Down Expand Up @@ -203,43 +221,27 @@ def _run_mcmc(
random_seed=random_seed,
**kwargs,
)
idata_from = "pymc"
else:
raise
elif sampler_backend == "nuts_numpyro":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

if not chains:
# sample_numpyro_nuts does not handle chains = None like pm.sample does
chains = 4
idata = pymc.sampling_jax.sample_numpyro_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
elif sampler_backend == "nuts_blackjax":
import pymc.sampling_jax # pylint: disable=import-outside-toplevel

# sample_blackjax_nuts does not handle chains = None like pm.sample does
if not chains:
chains = 4
idata = pymc.sampling_jax.sample_blackjax_nuts(
draws=draws,
tune=tune,
chains=chains,
random_seed=random_seed,
**kwargs,
)
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f"'mcmc', 'nuts_numpyro' or 'nuts_blackjax'"
)
idata = self._clean_results(idata, omit_offsets, include_mean)
elif sampler_backend in BAYEUX_SAMPLERS:
import bayeux as bx
import jax

bx_model = bx.Model.from_pymc(self.model)
ColCarroll marked this conversation as resolved.
Show resolved Hide resolved
bx_sampler = getattr(bx_model.mcmc, sampler_backend)
idata = bx_sampler(seed=jax.random.key(0), **kwargs)
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
idata_from = "bayeux"
else:
raise ValueError(
f"sampler_backend value {sampler_backend} is not valid. Please choose one of"
f"{PYMC_SAMPLERS + BAYEUX_SAMPLERS}"
)

idata = self._clean_results(idata, omit_offsets, include_mean, idata_from)
return idata

def _clean_results(self, idata, omit_offsets, include_mean):
def _clean_results(self, idata, omit_offsets, include_mean, idata_from):
for group in idata.groups():

getattr(idata, group).attrs["modeling_interface"] = "bambi"
Expand All @@ -258,6 +260,13 @@ def _clean_results(self, idata, omit_offsets, include_mean):

dims_original = list(self.model.coords)

# Identify bayeux idata and rename dims and coordinates to match PyMC model
if idata_from == "bayeux":
pymc_model_dims = [dim for dim in dims_original if "_obs" not in dim]
bayeux_dims = [dim for dim in idata.posterior.dims if not dim.startswith(("chain", "draw"))]
cleaned_dims = dict(zip(bayeux_dims, pymc_model_dims))
idata = idata.rename(cleaned_dims)

# Discard dims that are in the model but unused in the posterior
dims_original = [dim for dim in dims_original if dim in idata.posterior.dims]

Expand Down
4 changes: 2 additions & 2 deletions bambi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def fit(
using the ``fit`` function.
Finally, ``"laplace"``, in which case a Laplace approximation is used and is not
recommended other than for pedagogical use.
To use the PyMC numpyro and blackjax samplers, use ``nuts_numpyro`` or ``nuts_blackjax``
To use the PyMC numpyro and blackjax samplers, use ``numpyro_nuts`` or ``blackjax_nuts``
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
respectively. Both methods will only work if you can use NUTS sampling, so your model
must be differentiable.
init : str
Expand Down Expand Up @@ -306,7 +306,7 @@ def fit(
Returns
-------
An ArviZ ``InferenceData`` instance if inference_method is ``"mcmc"`` (default),
"nuts_numpyro", "nuts_blackjax" or "laplace".
"numpyro_nuts", "blackjax_nuts" or "laplace".
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
An ``Approximation`` object if ``"vi"``.
"""
method = kwargs.pop("method", None)
Expand Down
58 changes: 29 additions & 29 deletions docs/notebooks/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -970,35 +970,35 @@
"\n",
"<center>\n",
"\n",
"|Family name |Response distribution | Default link |\n",
"|:----------------------------- |:------------------------------- |:--------------- |\n",
"asymmetriclaplace | AsymmetricLaplace | identity |\n",
"bernoulli | Bernoulli | logit |\n",
"beta | Beta | logit |\n",
"beta_binomial | BetaBinomial | logit |\n",
"binomial | Binomial | logit | \n",
"categorical | Categorical | softmax | \n",
"cumulative | Cumulative | logit | \n",
"dirichlet_multinomial | DirichletMultinomial | logit |\n",
"exponential | Exponential | log | \n",
"gamma | Gamma | inverse |\n",
"gaussian | Normal | identity |\n",
"hurdle_gamma | HurdleGamma | log |\n",
"hurdle_lognormal | HurdleLogNormal | identity |\n",
"hurdle_negativebinomial | HurdleNegativeBinomial | log |\n",
"hurdle_poisson | HurdlePoisson | log |\n",
"multinomial | Multinomial | softmax |\n",
"negativebinomial | NegativeBinomial | log |\n",
"laplace | Laplace | identity |\n",
"poisson | Poisson | log |\n",
"sratio | StoppingRatio | logit |\n",
"t | StudentT | identity |\n",
"vonmises | VonMises | tan(x / 2) |\n",
"wald | InverseGaussian | inverse squared |\n",
"weibull | Weibull | log |\n",
"zero_inflated_binomial | ZeroInflatedBinomial | logit |\n",
"zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log |\n",
"zero_inflated_poisson | ZeroInflatedPoisson | log |\n",
"|Family name |Response distribution | Default link | Example notebook |\n",
"|:----------------------------- |:------------------------------- |:--------------- |:-----------------|\n",
"asymmetriclaplace | AsymmetricLaplace | identity | [Quantile Regression](https://bambinos.github.io/bambi/notebooks/quantile_regression.html#quantile-regression) |\n",
"bernoulli | Bernoulli | logit | [Logistic Regression](https://bambinos.github.io/bambi/notebooks/logistic_regression.html) |\n",
"beta | Beta | logit | [Beta Regression](https://bambinos.github.io/bambi/notebooks/beta_regression.html) |\n",
"beta_binomial | BetaBinomial | logit | _To be added_ |\n",
"binomial | Binomial | logit | [Hierarchical Logistic Regression](https://bambinos.github.io/bambi/notebooks/hierarchical_binomial_bambi.html) | \n",
"categorical | Categorical | softmax | [Categorical Regression](https://bambinos.github.io/bambi/notebooks/categorical_regression.html) | \n",
"cumulative | Cumulative | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#cumulative-model) | \n",
"dirichlet_multinomial | DirichletMultinomial | logit | _To be added_ |\n",
"exponential | Exponential | log | [Survival Models](https://bambinos.github.io/bambi/notebooks/survival_model.html#survival-models) | \n",
"gamma | Gamma | inverse | [Gamma Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n",
"gaussian | Normal | identity | [Multiple Linear Regression](https://bambinos.github.io/bambi/notebooks/ESCS_multiple_regression.html) |\n",
"hurdle_gamma | HurdleGamma | log | _To be added_ |\n",
"hurdle_lognormal | HurdleLogNormal | identity | _To be added_ |\n",
"hurdle_negativebinomial | HurdleNegativeBinomial | log | _To be added_ |\n",
"hurdle_poisson | HurdlePoisson | log | [Hurdle Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#hurdle-poisson) |\n",
"multinomial | Multinomial | softmax | _To be added_ |\n",
"negativebinomial | NegativeBinomial | log | [Negative Binomial Regression](https://bambinos.github.io/bambi/notebooks/negative_binomial.html) |\n",
"laplace | Laplace | identity | _To be added_ |\n",
"poisson | Poisson | log | [Gaussian Processes with a Poisson likelihood](https://bambinos.github.io/bambi/notebooks/hsgp_2d.html#a-more-complex-example-poisson-likelihood-with-group-specific-effects) |\n",
"sratio | StoppingRatio | logit | [Ordinal Models](https://bambinos.github.io/bambi/notebooks/ordinal_regression.html#sequential-model) |\n",
"t | StudentT | identity | [Robust Linear Regression](https://bambinos.github.io/bambi/notebooks/t_regression.html) |\n",
"vonmises | VonMises | tan(x / 2) | [Circular Regression](https://bambinos.github.io/bambi/notebooks/circular_regression.html#circular-regression) |\n",
"wald | InverseGaussian | inverse squared | [Wald Regression](https://bambinos.github.io/bambi/notebooks/wald_gamma_glm.html) |\n",
"weibull | Weibull | log | _To be added_ |\n",
"zero_inflated_binomial | ZeroInflatedBinomial | logit | _To be added_ |\n",
"zero_inflated_negativebinomial | ZeroInflatedNegativeBinomial | log | _To be added_ |\n",
"zero_inflated_poisson | ZeroInflatedPoisson | log | [Zero Inflated Poisson Regression](https://bambinos.github.io/bambi/notebooks/zero_inflated_regression.html#zero-inflated-poisson)|\n",
"\n",
"\n",
"</center>\n",
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ dev = [
"seaborn>=0.9.0",
]
jax = [
"bayeux>=0.1.6",
"blackjax>=1.0.0",
"jax>=0.3.1",
"jaxlib>=0.3.1",
"numpyro>=0.9.0",
"flowMC>=0.2.4",
GStechschulte marked this conversation as resolved.
Show resolved Hide resolved
]

[project.urls]
Expand Down
8 changes: 4 additions & 4 deletions tests/test_alternative_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def test_vi():
"args",
[
("mcmc", {}),
("nuts_numpyro", {"chain_method": "vectorized"}),
("nuts_blackjax", {"chain_method": "vectorized"}),
("numpyro_nuts", {"chain_method": "vectorized"}),
("blackjax_nuts", {"chain_method": "vectorized"}),
],
)
def test_logistic_regression_categoric_alternative_samplers(data_n100, args):
Expand All @@ -69,8 +69,8 @@ def test_logistic_regression_categoric_alternative_samplers(data_n100, args):
"args",
[
("mcmc", {}),
("nuts_numpyro", {"chain_method": "vectorized"}),
("nuts_blackjax", {"chain_method": "vectorized"}),
("numpyro_nuts", {"chain_method": "vectorized"}),
("blackjax_nuts", {"chain_method": "vectorized"}),
],
)
def test_regression_alternative_samplers(data_n100, args):
Expand Down
Loading