From 3353209738fd2028e0a8b4b92201679cb697c544 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Fri, 14 Jun 2024 23:06:24 -0700 Subject: [PATCH 1/5] Enable shared mcmc parameters with tempered smc (#694) * add parameter filtering * fix parameter split + docstring * change extend_paramss --- blackjax/smc/adaptive_tempered.py | 3 ++- blackjax/smc/base.py | 7 ++---- blackjax/smc/tempered.py | 21 ++++++++++++++--- tests/smc/test_inner_kernel_tuning.py | 10 +++------ tests/smc/test_kernel_compatibility.py | 10 ++++----- tests/smc/test_smc.py | 31 ++++++++++++-------------- tests/smc/test_tempered_smc.py | 22 +++++++++++++----- 7 files changed, 59 insertions(+), 45 deletions(-) diff --git a/blackjax/smc/adaptive_tempered.py b/blackjax/smc/adaptive_tempered.py index b8a611606..10fb194fa 100644 --- a/blackjax/smc/adaptive_tempered.py +++ b/blackjax/smc/adaptive_tempered.py @@ -130,7 +130,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. target_ess diff --git a/blackjax/smc/base.py b/blackjax/smc/base.py index 21d8e12f4..5093cf06b 100644 --- a/blackjax/smc/base.py +++ b/blackjax/smc/base.py @@ -150,12 +150,9 @@ def step( ) -def extend_params(n_particles, params): +def extend_params(params): """Given a dictionary of params, repeats them for every single particle. The expected usage is in cases where the aim is to repeat the same parameters for all chains within SMC. """ - def extend(param): - return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0) - - return jax.tree.map(extend, params) + return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params) diff --git a/blackjax/smc/tempered.py b/blackjax/smc/tempered.py index b373d062f..43b83d034 100644 --- a/blackjax/smc/tempered.py +++ b/blackjax/smc/tempered.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import Callable, NamedTuple import jax @@ -108,6 +109,9 @@ def kernel( Current state of the tempered SMC algorithm lmbda Current value of the tempering parameter + mcmc_parameters + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. Returns ------- @@ -119,6 +123,14 @@ def kernel( """ delta = lmbda - state.lmbda + shared_mcmc_parameters = {} + unshared_mcmc_parameters = {} + for k, v in mcmc_parameters.items(): + if v.shape[0] == 1: + shared_mcmc_parameters[k] = v[0, ...] + else: + unshared_mcmc_parameters[k] = v + def log_weights_fn(position: ArrayLikeTree) -> float: return delta * loglikelihood_fn(position) @@ -127,11 +139,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float: tempered_loglikelihood = state.lmbda * loglikelihood_fn(position) return logprior + tempered_loglikelihood + shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters) + def mcmc_kernel(rng_key, position, step_parameters): state = mcmc_init_fn(position, tempered_logposterior_fn) def body_fn(state, rng_key): - new_state, info = mcmc_step_fn( + new_state, info = shared_mcmc_step_fn( rng_key, state, tempered_logposterior_fn, **step_parameters ) return new_state, info @@ -142,7 +156,7 @@ def body_fn(state, rng_key): smc_state, info = smc.base.step( rng_key, - SMCState(state.particles, state.weights, mcmc_parameters), + SMCState(state.particles, state.weights, unshared_mcmc_parameters), jax.vmap(mcmc_kernel), jax.vmap(log_weights_fn), resampling_fn, @@ -178,7 +192,8 @@ def as_top_level_api( mcmc_init_fn The MCMC init function used to build a MCMC state from a particle position. mcmc_parameters - The parameters of the MCMC step function. + The parameters of the MCMC step function. Parameters with leading dimension + length of 1 are shared amongst the particles. resampling_fn The function used to resample the particles. num_mcmc_steps diff --git a/tests/smc/test_inner_kernel_tuning.py b/tests/smc/test_inner_kernel_tuning.py index bf970ae47..7d6190af5 100644 --- a/tests/smc/test_inner_kernel_tuning.py +++ b/tests/smc/test_inner_kernel_tuning.py @@ -94,7 +94,7 @@ def smc_inner_kernel_tuning_test_case( proposal_factory.return_value = 100 def mcmc_parameter_update_fn(state, info): - return extend_params(1000, {"mean": 100}) + return extend_params({"mean": 100}) prior = lambda x: stats.norm.logpdf(x) @@ -114,7 +114,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean): resampling_fn=resampling.systematic, smc_algorithm=smc_algorithm, mcmc_parameter_update_fn=mcmc_parameter_update_fn, - initial_parameter_value=extend_params(1000, {"mean": 1.0}), + initial_parameter_value=extend_params({"mean": 1.0}), **smc_parameters, ) @@ -281,7 +281,6 @@ def test_with_adaptive_tempered(self): def parameter_update(state, info): return extend_params( - 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -298,7 +297,6 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, @@ -326,7 +324,7 @@ def body(carry): _, state = inference_loop(smc_kernel, self.key, init_state) - assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2) + assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2) self.assert_linear_regression_test_case(state.sampler_state) @chex.all_variants(with_pmap=False) @@ -340,7 +338,6 @@ def test_with_tempered_smc(self): def parameter_update(state, info): return extend_params( - 100, { "inverse_mass_matrix": mass_matrix_from_particles(state.particles), "step_size": 10e-2, @@ -357,7 +354,6 @@ def parameter_update(state, info): resampling.systematic, mcmc_parameter_update_fn=parameter_update, initial_parameter_value=extend_params( - 100, dict( inverse_mass_matrix=jnp.eye(2), step_size=10e-2, diff --git a/tests/smc/test_kernel_compatibility.py b/tests/smc/test_kernel_compatibility.py index 3e675c2cc..fdda30b3a 100644 --- a/tests/smc/test_kernel_compatibility.py +++ b/tests/smc/test_kernel_compatibility.py @@ -50,7 +50,7 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean): self.check_compatible( kernel, blackjax.additive_step_random_walk.init, - extend_params(self.n_particles, {"proposal_mean": 1.0}), + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_rmh(self): @@ -70,7 +70,7 @@ def kernel( self.check_compatible( kernel, blackjax.rmh.init, - extend_params(self.n_particles, {"proposal_mean": 1.0}), + extend_params({"proposal_mean": 1.0}), ) def test_compatible_with_hmc(self): @@ -78,7 +78,6 @@ def test_compatible_with_hmc(self): blackjax.hmc.build_kernel(), blackjax.hmc.init, extend_params( - self.n_particles, { "step_size": 0.3, "inverse_mass_matrix": jnp.array([1.0]), @@ -100,7 +99,7 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None): self.check_compatible( kernel, blackjax.irmh.init, - extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}), + extend_params({"mean": jnp.array([1.0, 1.0])}), ) def test_compatible_with_nuts(self): @@ -108,7 +107,6 @@ def test_compatible_with_nuts(self): blackjax.nuts.build_kernel(), blackjax.nuts.init, extend_params( - self.n_particles, {"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)}, ), ) @@ -117,7 +115,7 @@ def test_compatible_with_mala(self): self.check_compatible( blackjax.mala.build_kernel(), blackjax.mala.init, - extend_params(self.n_particles, {"step_size": 1e-10}), + extend_params({"step_size": 1e-10}), ) @staticmethod diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 2838e984f..6366182a8 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -50,16 +50,17 @@ def body_fn(state, rng_key): same_for_all_params = dict( step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50 ) + state = init( init_particles, - extend_params(num_particles, same_for_all_params), + same_for_all_params, ) # Run the SMC sampler once new_state, info = self.variant(step, static_argnums=(2, 3, 4))( sample_key, state, - jax.vmap(update_fn), + jax.vmap(update_fn, in_axes=(0, 0, None)), jax.vmap(logdensity_fn), resampling.systematic, ) @@ -87,7 +88,9 @@ def body_fn(state, rng_key): _, (states, info) = jax.lax.scan(body_fn, state, keys) return states.position, info - particles, info = jax.vmap(one_particle_fn)(keys, particles, update_params) + particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))( + keys, particles, update_params + ) particles = particles.reshape((num_particles,)) return particles, info @@ -97,13 +100,10 @@ def body_fn(state, rng_key): init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,)) state = init( init_particles, - extend_params( - num_resampled, - dict( - step_size=1e-2, - inverse_mass_matrix=jnp.eye(1), - num_integration_steps=100, - ), + dict( + step_size=1e-2, + inverse_mass_matrix=jnp.eye(1), + num_integration_steps=100, ), ) @@ -125,7 +125,6 @@ def body_fn(state, rng_key): class ExtendParamsTest(chex.TestCase): def test_extend_params(self): extended = extend_params( - 3, { "a": 50, "b": np.array([50]), @@ -133,14 +132,12 @@ def test_extend_params(self): "d": np.array([[1, 2], [3, 4]]), }, ) - np.testing.assert_allclose(extended["a"], np.ones((3,)) * 50) - np.testing.assert_allclose(extended["b"], np.array([[50], [50], [50]])) - np.testing.assert_allclose( - extended["c"], np.array([[50, 60], [50, 60], [50, 60]]) - ) + np.testing.assert_allclose(extended["a"], np.ones((1,)) * 50) + np.testing.assert_allclose(extended["b"], np.array([[50]])) + np.testing.assert_allclose(extended["c"], np.array([[50, 60]])) np.testing.assert_allclose( extended["d"], - np.array([[[1, 2], [3, 4]], [[1, 2], [3, 4]], [[1, 2], [3, 4]]]), + np.array([[[1, 2], [3, 4]]]), ) diff --git a/tests/smc/test_tempered_smc.py b/tests/smc/test_tempered_smc.py index a7d9acdd8..527457d62 100644 --- a/tests/smc/test_tempered_smc.py +++ b/tests/smc/test_tempered_smc.py @@ -65,16 +65,28 @@ def logprior_fn(x): hmc_kernel = blackjax.hmc.build_kernel() hmc_init = blackjax.hmc.init - hmc_parameters = extend_params( - num_particles, + + base_params = extend_params( { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), "num_integration_steps": 50, - }, + } ) - for target_ess in [0.5, 0.75]: + # verify results are equivalent with all shared, all unshared, and mixed params + hmc_parameters_list = [ + base_params, + jax.tree.map(lambda x: jnp.repeat(x, num_particles, axis=0), base_params), + jax.tree_util.tree_map_with_path( + lambda path, x: jnp.repeat(x, num_particles, axis=0) + if path[0].key == "step_size" + else x, + base_params, + ), + ] + + for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list): tempering = adaptive_tempered_smc( logprior_fn, loglikelihood_fn, @@ -115,7 +127,6 @@ def test_fixed_schedule_tempered_smc(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() hmc_parameters = extend_params( - 100, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(2), @@ -182,7 +193,6 @@ def test_normalizing_constant(self): hmc_init = blackjax.hmc.init hmc_kernel = blackjax.hmc.build_kernel() hmc_parameters = extend_params( - num_particles, { "step_size": 10e-2, "inverse_mass_matrix": jnp.eye(num_dim), From eca35abc1fed16c3d1174482b0cbf16e084c72ae Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Wed, 19 Jun 2024 22:18:50 -0700 Subject: [PATCH 2/5] convert to bit twiddling (#696) --- blackjax/mcmc/termination.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index 24e17c3a5..eb1276da3 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -64,16 +64,10 @@ def _leaf_idx_to_ckpt_idxs(n): """Find the checkpoint id from a step number.""" # computes the number of non-zero bits except the last bit # e.g. 6 -> 2, 7 -> 2, 13 -> 2 - _, idx_max = jax.lax.while_loop( - lambda nc: nc[0] > 0, - lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)), - (n >> 1, 0), - ) + idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32) # computes the number of contiguous last non-zero bits # e.g. 6 -> 0, 7 -> 3, 13 -> 1 - _, num_subtrees = jax.lax.while_loop( - lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0) - ) + num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32) idx_min = idx_max - num_subtrees + 1 return idx_min, idx_max From 5764a2b4aff803dff45d903370cc70f2000aa7ef Mon Sep 17 00:00:00 2001 From: Junpeng Lao Date: Mon, 24 Jun 2024 07:01:50 +0200 Subject: [PATCH 3/5] Remove nightly release (#699) --- .github/workflows/nightly.yml | 48 ----------------------------------- README.md | 6 ----- docs/index.md | 6 ----- 3 files changed, 60 deletions(-) delete mode 100644 .github/workflows/nightly.yml diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml deleted file mode 100644 index 6472e4421..000000000 --- a/.github/workflows/nightly.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: Nightly - -on: - push: - branches: [main] - -jobs: - build_and_publish: - name: Build and publish on PyPi - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - uses: actions/setup-python@v4 - with: - python-version: 3.11 - - name: Update pyproject.toml - # Taken from https://github.com/aesara-devs/aesara/pull/1375 - run: | - curl -sSLf https://github.com/TomWright/dasel/releases/download/v2.0.2/dasel_linux_amd64 \ - -L -o /tmp/dasel && chmod +x /tmp/dasel - /tmp/dasel put -f pyproject.toml project.name -v blackjax-nightly - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.version_scheme -v post-release - /tmp/dasel put -f pyproject.toml tool.setuptools_scm.local_scheme -v no-local-version - - name: Build the sdist and wheel - run: | - python -m pip install -U pip - python -m pip install build - python -m build - - name: Check sdist install and imports - run: | - mkdir -p test-sdist - cd test-sdist - python -m venv venv-sdist - venv-sdist/bin/python -m pip install ../dist/blackjax-nightly-*.tar.gz - venv-sdist/bin/python -c "import blackjax" - - name: Check wheel install and imports - run: | - mkdir -p test-wheel - cd test-wheel - python -m venv venv-wheel - venv-wheel/bin/python -m pip install ../dist/blackjax_nightly-*.whl - - name: Publish to PyPi - uses: pypa/gh-action-pypi-publish@v1.4.2 - with: - user: __token__ - password: ${{ secrets.PYPI_NIGHTLY_TOKEN }} diff --git a/README.md b/README.md index d7d78b15f..a8d847cf9 100644 --- a/README.md +++ b/README.md @@ -41,12 +41,6 @@ or via conda-forge: conda install -c conda-forge blackjax ``` -Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`: - -```bash -pip install blackjax-nightly -``` - BlackJAX is written in pure Python but depends on XLA via JAX. By default, the version of JAX that will be installed along with BlackJAX will make your code run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow diff --git a/docs/index.md b/docs/index.md index 0fd84d860..edc02631c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -57,13 +57,7 @@ If you want to use Blackjax with a model implemented with a PPL, go to the relat ```{code-block} bash pip install blackjax ``` -::: -:::{tab-item} Nightly -```{code-block} bash -pip install blackjax-nightly -``` -::: :::{tab-item} Conda ```{code-block} bash From f8db9aa04d83fed4b3bde0f37ddf6194b229308f Mon Sep 17 00:00:00 2001 From: Gilad Turok <36947659+gil2rok@users.noreply.github.com> Date: Mon, 24 Jun 2024 01:07:30 -0400 Subject: [PATCH 4/5] Fix doc mistakes (#701) * Fix equation formatting * Clarify JAX gradient error * Fix punctuation + capitalization * Fix grammar Should not begin sentence with "i.e." in English. * Fix math formatting error * Fix typo Change parallel _ensample_ chain adaptation to parallel _ensemble_ chain adaptation. * Add SVGD citation to appear in doc Currently the SVGD paper is only cited in the `kernel` function, which is defined _within_ the `build_kernel` function. Because of this nested function format, the SVGD paper is _not_ cited in the documentation. To fix this, I added a citation to the SVGD paper in the `as_top_level_api` docstring. * Fix grammar + clarify doc * Fix typo --------- Co-authored-by: Junpeng Lao --- blackjax/adaptation/meads_adaptation.py | 4 ++-- blackjax/base.py | 2 +- blackjax/vi/svgd.py | 2 +- docs/examples/howto_custom_gradients.md | 13 ++++++------- 4 files changed, 10 insertions(+), 11 deletions(-) diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index a431a591d..b383653e8 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple): alpha Value of the alpha parameter of the generalized HMC algorithm. delta - Value of the alpha parameter of the generalized HMC algorithm. + Value of the delta parameter of the generalized HMC algorithm. """ @@ -60,7 +60,7 @@ def base(): with shape. This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain - adaptation instead of parallel ensample chain adaptation. + adaptation instead of parallel ensemble chain adaptation. Returns ------- diff --git a/blackjax/base.py b/blackjax/base.py index f766e98b5..8ea24cd70 100644 --- a/blackjax/base.py +++ b/blackjax/base.py @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple): """A pair of functions that represents a MCMC sampling algorithm. Blackjax sampling algorithms are implemented as a pair of pure functions: a - kernel, that takes a new samples starting from the current state, and an + kernel, that generates a new sample from the current state, and an initialization function that creates a kernel state from a chain position. As they represent Markov kernels, the kernel functions are pure functions diff --git a/blackjax/vi/svgd.py b/blackjax/vi/svgd.py index 881de77e6..e287b813f 100644 --- a/blackjax/vi/svgd.py +++ b/blackjax/vi/svgd.py @@ -135,7 +135,7 @@ def as_top_level_api( kernel: Callable = rbf_kernel, update_kernel_parameters: Callable = update_median_heuristic, ): - """Implements the (basic) user interface for the svgd algorithm. + """Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`. Parameters ---------- diff --git a/docs/examples/howto_custom_gradients.md b/docs/examples/howto_custom_gradients.md index 731de0eea..653e6d393 100644 --- a/docs/examples/howto_custom_gradients.md +++ b/docs/examples/howto_custom_gradients.md @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$. Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function: $$ -\begin{align*} -g(x, y) &= h(y) - \langle x, y\rangle\\ -h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\ -\end{align*} +\begin{equation*} +g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1. +\end{equation*} $$ And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as: @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t ### Trying to differentate the function with `jax.grad` -The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error: +The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error: ```{code-cell} ipython3 # We only want the gradient with respect to `x` @@ -97,7 +96,7 @@ The first order optimality criterion \end{equation*} ``` -Ensures that: +ensures that ```{math} \begin{equation*} @@ -105,7 +104,7 @@ Ensures that: \end{equation*} ``` -i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. +In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved. ### Telling JAX to use a custom gradient From 441412a09e39f514189be84813f812d95709365c Mon Sep 17 00:00:00 2001 From: johannahaffner <38662446+johannahaffner@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:32:54 +0200 Subject: [PATCH 5/5] Update index.md (#711) The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index edc02631c..fca4787c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,7 @@ rng_key = jax.random.key(0) step = jax.jit(nuts.step) for i in range(1_000): nuts_key = jax.random.fold_in(rng_key, i) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` :::{note}