Skip to content

Commit

Permalink
rolling back changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ciguaran committed Aug 15, 2024
1 parent f0ce4ba commit 3576690
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 2,330 deletions.
57 changes: 57 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

from blackjax._version import __version__

from .adaptation.chees_adaptation import chees_adaptation
from .adaptation.mclmc_adaptation import mclmc_find_L_and_step_size
from .adaptation.meads_adaptation import meads_adaptation
from .adaptation.pathfinder_adaptation import pathfinder_adaptation
from .adaptation.window_adaptation import window_adaptation
from .base import SamplingAlgorithm, VIAlgorithm
from .diagnostics import effective_sample_size as ess
from .diagnostics import potential_scale_reduction as rhat
Expand All @@ -23,9 +28,19 @@
normal_random_walk,
rmh_as_top_level_api,
)
from .optimizers import dual_averaging, lbfgs
from .sgmcmc import csgld as _csgld
from .sgmcmc import sghmc as _sghmc
from .sgmcmc import sgld as _sgld
from .sgmcmc import sgnht as _sgnht
from .smc import adaptive_tempered
from .smc import inner_kernel_tuning as _inner_kernel_tuning
from .smc import tempered
from .vi import meanfield_vi as _meanfield_vi
from .vi import pathfinder as _pathfinder
from .vi import schrodinger_follmer as _schrodinger_follmer
from .vi import svgd as _svgd
from .vi.pathfinder import PathFinderAlgorithm

"""
The above three classes exist as a backwards compatible way of exposing both the high level, differentiable
Expand Down Expand Up @@ -58,6 +73,16 @@ def __call__(self, *args, **kwargs) -> VIAlgorithm:
return self.differentiable(*args, **kwargs)


@dataclasses.dataclass
class GeneratePathfinderAPI:
differentiable: Callable
approximate: Callable
sample: Callable

def __call__(self, *args, **kwargs) -> PathFinderAlgorithm:
return self.differentiable(*args, **kwargs)


def generate_top_level_api_from(module):
return GenerateSamplingAPI(
module.as_top_level_api, module.init, module.build_kernel
Expand Down Expand Up @@ -98,9 +123,41 @@ def generate_top_level_api_from(module):
smc_family = [tempered_smc, adaptive_tempered_smc]
"Step_fn returning state has a .particles attribute"

# stochastic gradient mcmc
sgld = generate_top_level_api_from(_sgld)
sghmc = generate_top_level_api_from(_sghmc)
sgnht = generate_top_level_api_from(_sgnht)
csgld = generate_top_level_api_from(_csgld)
svgd = generate_top_level_api_from(_svgd)

# variational inference
meanfield_vi = GenerateVariationalAPI(
_meanfield_vi.as_top_level_api,
_meanfield_vi.init,
_meanfield_vi.step,
_meanfield_vi.sample,
)
schrodinger_follmer = GenerateVariationalAPI(
_schrodinger_follmer.as_top_level_api,
_schrodinger_follmer.init,
_schrodinger_follmer.step,
_schrodinger_follmer.sample,
)

pathfinder = GeneratePathfinderAPI(
_pathfinder.as_top_level_api, _pathfinder.approximate, _pathfinder.sample
)


__all__ = [
"__version__",
"dual_averaging", # optimizers
"lbfgs",
"window_adaptation", # mcmc adaptation
"meads_adaptation",
"chees_adaptation",
"pathfinder_adaptation",
"mclmc_find_L_and_step_size", # mclmc adaptation
"ess", # diagnostics
"rhat",
]
Loading

0 comments on commit 3576690

Please sign in to comment.