Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/blackjax-devs/blackjax into…
Browse files Browse the repository at this point in the history
… streaming-pr
  • Loading branch information
reubenharry committed Aug 6, 2024
2 parents 17f597d + 441412a commit 9b216da
Show file tree
Hide file tree
Showing 15 changed files with 72 additions and 125 deletions.
48 changes: 0 additions & 48 deletions .github/workflows/nightly.yml

This file was deleted.

6 changes: 0 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ or via conda-forge:
conda install -c conda-forge blackjax
```

Nightly builds (bleeding edge) of Blackjax can also be installed using `pip`:

```bash
pip install blackjax-nightly
```

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
Expand Down
4 changes: 2 additions & 2 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class MEADSAdaptationState(NamedTuple):
alpha
Value of the alpha parameter of the generalized HMC algorithm.
delta
Value of the alpha parameter of the generalized HMC algorithm.
Value of the delta parameter of the generalized HMC algorithm.
"""

Expand All @@ -60,7 +60,7 @@ def base():
with shape.
This is an implementation of Algorithm 3 of :cite:p:`hoffman2022tuning` using cross-chain
adaptation instead of parallel ensample chain adaptation.
adaptation instead of parallel ensemble chain adaptation.
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion blackjax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class SamplingAlgorithm(NamedTuple):
"""A pair of functions that represents a MCMC sampling algorithm.
Blackjax sampling algorithms are implemented as a pair of pure functions: a
kernel, that takes a new samples starting from the current state, and an
kernel, that generates a new sample from the current state, and an
initialization function that creates a kernel state from a chain position.
As they represent Markov kernels, the kernel functions are pure functions
Expand Down
10 changes: 2 additions & 8 deletions blackjax/mcmc/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,10 @@ def _leaf_idx_to_ckpt_idxs(n):
"""Find the checkpoint id from a step number."""
# computes the number of non-zero bits except the last bit
# e.g. 6 -> 2, 7 -> 2, 13 -> 2
_, idx_max = jax.lax.while_loop(
lambda nc: nc[0] > 0,
lambda nc: (nc[0] >> 1, nc[1] + (nc[0] & 1)),
(n >> 1, 0),
)
idx_max = jnp.bitwise_count(n >> 1).astype(jnp.int32)
# computes the number of contiguous last non-zero bits
# e.g. 6 -> 0, 7 -> 3, 13 -> 1
_, num_subtrees = jax.lax.while_loop(
lambda nc: (nc[0] & 1) != 0, lambda nc: (nc[0] >> 1, nc[1] + 1), (n, 0)
)
num_subtrees = jnp.bitwise_count((~n & (n + 1)) - 1).astype(jnp.int32)
idx_min = idx_max - num_subtrees + 1
return idx_min, idx_max

Expand Down
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
target_ess
Expand Down
7 changes: 2 additions & 5 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,9 @@ def step(
)


def extend_params(n_particles, params):
def extend_params(params):
"""Given a dictionary of params, repeats them for every single particle. The expected
usage is in cases where the aim is to repeat the same parameters for all chains within SMC.
"""

def extend(param):
return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0)

return jax.tree.map(extend, params)
return jax.tree.map(lambda x: jnp.asarray(x)[None, ...], params)
21 changes: 18 additions & 3 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple

import jax
Expand Down Expand Up @@ -108,6 +109,9 @@ def kernel(
Current state of the tempered SMC algorithm
lmbda
Current value of the tempering parameter
mcmc_parameters
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
Returns
-------
Expand All @@ -119,6 +123,14 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -127,11 +139,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info
Expand All @@ -142,7 +156,7 @@ def body_fn(state, rng_key):

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, mcmc_parameters),
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
jax.vmap(log_weights_fn),
resampling_fn,
Expand Down Expand Up @@ -178,7 +192,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
num_mcmc_steps
Expand Down
2 changes: 1 addition & 1 deletion blackjax/vi/svgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def as_top_level_api(
kernel: Callable = rbf_kernel,
update_kernel_parameters: Callable = update_median_heuristic,
):
"""Implements the (basic) user interface for the svgd algorithm.
"""Implements the (basic) user interface for the svgd algorithm :cite:p:`liu2016stein`.
Parameters
----------
Expand Down
13 changes: 6 additions & 7 deletions docs/examples/howto_custom_gradients.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ Functions can be defined as the minimum of another one, $f(x) = min_{y} g(x,y)$.
Our example is taken from the theory of [convex conjugates](https://en.wikipedia.org/wiki/Convex_conjugate), used for example in optimal transport. Let's consider the following function:

$$
\begin{align*}
g(x, y) &= h(y) - \langle x, y\rangle\\
h(x) &= \frac{1}{p}|x|^p,\qquad p > 1.\\
\end{align*}
\begin{equation*}
g(x, y) = h(y) - \langle x, y\rangle,\qquad h(x) = \frac{1}{p}|x|^p,\qquad p > 1.
\end{equation*}
$$

And define the function $f$ as $f(x) = -min_y g(x, y)$ which we can be implemented as:
Expand Down Expand Up @@ -69,7 +68,7 @@ Note the we also return the value of $y$ where the minimum of $g$ is achieved (t

### Trying to differentate the function with `jax.grad`

The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops, and trying to compute it directly raises an error:
The gradient of the function $f$ is undefined for JAX, which cannot differentiate through `while` loops used in BFGS, and trying to compute it directly raises an error:

```{code-cell} ipython3
# We only want the gradient with respect to `x`
Expand Down Expand Up @@ -97,15 +96,15 @@ The first order optimality criterion
\end{equation*}
```

Ensures that:
ensures that

```{math}
\begin{equation*}
\frac{df}{dx} = y(x).
\end{equation*}
```

i.e. the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.
In other words, the value of the derivative at $x$ is the value $y(x)$ at which the minimum of the function $g$ is achieved.


### Telling JAX to use a custom gradient
Expand Down
8 changes: 1 addition & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ rng_key = jax.random.key(0)
step = jax.jit(nuts.step)
for i in range(1_000):
nuts_key = jax.random.fold_in(rng_key, i)
state, _ = nuts.step(nuts_key, state)
state, _ = step(nuts_key, state)
```

:::{note}
Expand All @@ -57,13 +57,7 @@ If you want to use Blackjax with a model implemented with a PPL, go to the relat
```{code-block} bash
pip install blackjax
```
:::

:::{tab-item} Nightly
```{code-block} bash
pip install blackjax-nightly
```
:::

:::{tab-item} Conda
```{code-block} bash
Expand Down
10 changes: 3 additions & 7 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def smc_inner_kernel_tuning_test_case(
proposal_factory.return_value = 100

def mcmc_parameter_update_fn(state, info):
return extend_params(1000, {"mean": 100})
return extend_params({"mean": 100})

prior = lambda x: stats.norm.logpdf(x)

Expand All @@ -114,7 +114,7 @@ def wrapped_kernel(rng_key, state, logdensity, mean):
resampling_fn=resampling.systematic,
smc_algorithm=smc_algorithm,
mcmc_parameter_update_fn=mcmc_parameter_update_fn,
initial_parameter_value=extend_params(1000, {"mean": 1.0}),
initial_parameter_value=extend_params({"mean": 1.0}),
**smc_parameters,
)

Expand Down Expand Up @@ -281,7 +281,6 @@ def test_with_adaptive_tempered(self):

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -298,7 +297,6 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down Expand Up @@ -326,7 +324,7 @@ def body(carry):

_, state = inference_loop(smc_kernel, self.key, init_state)

assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2)
assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2)
self.assert_linear_regression_test_case(state.sampler_state)

@chex.all_variants(with_pmap=False)
Expand All @@ -340,7 +338,6 @@ def test_with_tempered_smc(self):

def parameter_update(state, info):
return extend_params(
100,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -357,7 +354,6 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down
10 changes: 4 additions & 6 deletions tests/smc/test_kernel_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def kernel(rng_key, state, logdensity_fn, proposal_mean):
self.check_compatible(
kernel,
blackjax.additive_step_random_walk.init,
extend_params(self.n_particles, {"proposal_mean": 1.0}),
extend_params({"proposal_mean": 1.0}),
)

def test_compatible_with_rmh(self):
Expand All @@ -70,15 +70,14 @@ def kernel(
self.check_compatible(
kernel,
blackjax.rmh.init,
extend_params(self.n_particles, {"proposal_mean": 1.0}),
extend_params({"proposal_mean": 1.0}),
)

def test_compatible_with_hmc(self):
self.check_compatible(
blackjax.hmc.build_kernel(),
blackjax.hmc.init,
extend_params(
self.n_particles,
{
"step_size": 0.3,
"inverse_mass_matrix": jnp.array([1.0]),
Expand All @@ -100,15 +99,14 @@ def kernel(rng_key, state, logdensity_fn, mean, proposal_logdensity_fn=None):
self.check_compatible(
kernel,
blackjax.irmh.init,
extend_params(self.n_particles, {"mean": jnp.array([1.0, 1.0])}),
extend_params({"mean": jnp.array([1.0, 1.0])}),
)

def test_compatible_with_nuts(self):
self.check_compatible(
blackjax.nuts.build_kernel(),
blackjax.nuts.init,
extend_params(
self.n_particles,
{"step_size": 1e-10, "inverse_mass_matrix": jnp.eye(2)},
),
)
Expand All @@ -117,7 +115,7 @@ def test_compatible_with_mala(self):
self.check_compatible(
blackjax.mala.build_kernel(),
blackjax.mala.init,
extend_params(self.n_particles, {"step_size": 1e-10}),
extend_params({"step_size": 1e-10}),
)

@staticmethod
Expand Down
Loading

0 comments on commit 9b216da

Please sign in to comment.