Skip to content

Commit

Permalink
Waste Free SMC available for adaptive tempered and tempered SMC. (#721)
Browse files Browse the repository at this point in the history
* extracting taking last

* test passing

* layering

* example

* more

* Adding another example

* tests in place

* rolling back changes

* Adding test for num_mcmc_steps

* format

* better test coverage

* linter

* Flake8

* black

* Update blackjax/smc/waste_free.py

Co-authored-by: Junpeng Lao <[email protected]>

* fixing linter

---------

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
ciguaran and junpenglao authored Aug 26, 2024
1 parent 072cc81 commit b02b60b
Show file tree
Hide file tree
Showing 5 changed files with 323 additions and 68 deletions.
4 changes: 4 additions & 0 deletions blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def build_kernel(
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = solver.dichotomy,
**extra_parameters,
) -> Callable:
r"""Build a Tempered SMC step using an adaptive schedule.
Expand Down Expand Up @@ -88,6 +89,7 @@ def compute_delta(state: tempered.TemperedSMCState) -> float:
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
**extra_parameters,
)

def kernel(
Expand Down Expand Up @@ -116,6 +118,7 @@ def as_top_level_api(
target_ess: float,
root_solver: Callable = solver.dichotomy,
num_mcmc_steps: int = 10,
**extra_parameters,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Tempered SMC kernel.
Expand Down Expand Up @@ -155,6 +158,7 @@ def as_top_level_api(
resampling_fn,
target_ess,
root_solver,
**extra_parameters,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
Expand Down
60 changes: 45 additions & 15 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,12 +48,42 @@ def init(particles: ArrayLikeTree):
return TemperedSMCState(particles, weights, 0.0)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
) -> Callable:
"""Build the base Tempered SMC kernel.
Expand Down Expand Up @@ -141,26 +171,23 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info
update_fn, num_resampled = update_strategy(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)

tempered_state = TemperedSMCState(
smc_state.particles, smc_state.weights, state.lmbda + delta
)
Expand All @@ -177,7 +204,8 @@ def as_top_level_api(
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps: int = 10,
num_mcmc_steps: Optional[int] = 10,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Tempered SMC kernel.
Expand All @@ -204,12 +232,14 @@ def as_top_level_api(
A ``SamplingAlgorithm``.
"""

kernel = build_kernel(
logprior_fn,
loglikelihood_fn,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
update_strategy,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
Expand Down
70 changes: 70 additions & 0 deletions blackjax/smc/waste_free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools

import jax
import jax.lax
import jax.numpy as jnp


def update_waste_free(
mcmc_init_fn,
logposterior_fn,
mcmc_step_fn,
n_particles: int,
p: int,
num_resampled,
num_mcmc_steps=None,
):
"""
Given M particles, mutates them using p-1 steps. Returns M*P-1 particles,
consistent of the initial plus all the intermediate steps, thus implementing a
waste-free update function
See Algorithm 2: https://arxiv.org/abs/2011.02328
"""
if num_mcmc_steps is not None:
raise ValueError(
"Can't use waste free SMC with a num_mcmc_steps parameter, set num_mcmc_steps = None"
)

num_mcmc_steps = p - 1

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
rng_key, state, logposterior_fn, **step_parameters
)
return new_state, (new_state, info)

_, (states, infos) = jax.lax.scan(
body_fn, state, jax.random.split(rng_key, num_mcmc_steps)
)
return states, infos

def update(rng_key, position, step_parameters):
"""
Given the initial particles, runs a chain starting at each.
The combines the initial particles with all the particles generated
at each step of each chain.
"""
states, infos = jax.vmap(mcmc_kernel)(rng_key, position, step_parameters)

# step particles is num_resmapled, num_mcmc_steps, dimension_of_variable
# want to transformed into num_resampled * num_mcmc_steps, dimension of variable
def reshape_step_particles(x):
_num_resampled, num_mcmc_steps, *dimension_of_variable = x.shape
return x.reshape((_num_resampled * num_mcmc_steps, *dimension_of_variable))

step_particles = jax.tree.map(reshape_step_particles, states.position)
new_particles = jax.tree.map(
lambda x, y: jnp.concatenate([x, y]), position, step_particles
)
return new_particles, infos

return update, num_resampled


def waste_free_smc(n_particles, p):
if not n_particles % p == 0:
raise ValueError("p must be a divider of n_particles ")
return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p)
94 changes: 41 additions & 53 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test the generic SMC sampler"""
import functools

import chex
import jax
import jax.numpy as jnp
Expand All @@ -9,6 +11,8 @@
import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc.base import extend_params, init, step
from blackjax.smc.tempered import update_and_take_last
from blackjax.smc.waste_free import update_waste_free


def logdensity_fn(position):
Expand All @@ -29,82 +33,66 @@ def setUp(self):
@chex.variants(with_jit=True)
def test_smc(self):
num_mcmc_steps = 20
num_particles = 1000

def update_fn(rng_key, position, update_params):
hmc = blackjax.hmc(logdensity_fn, **update_params)
state = hmc.init(position)

def body_fn(state, rng_key):
new_state, info = hmc.step(rng_key, state)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

init_key, sample_key = jax.random.split(self.key)
num_particles = 5000

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
)
hmc_kernel = functools.partial(
blackjax.hmc.build_kernel(), **same_for_all_params
)
hmc_init = blackjax.hmc.init

state = init(
init_particles,
same_for_all_params,
update_fn, _ = update_and_take_last(
hmc_init, logdensity_fn, hmc_kernel, num_mcmc_steps, num_particles
)
init_key, sample_key = jax.random.split(self.key)

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
state = init(init_particles, {})
# Run the SMC sampler once
new_state, info = self.variant(step, static_argnums=(2, 3, 4))(
sample_key,
state,
jax.vmap(update_fn, in_axes=(0, 0, None)),
update_fn,
jax.vmap(logdensity_fn),
resampling.systematic,
)

assert new_state.particles.shape == (num_particles,)
mean, std = _weighted_avg_and_std(new_state.particles, state.weights)
np.testing.assert_allclose(0.0, mean, atol=1e-1)
np.testing.assert_allclose(1.0, std, atol=1e-1)
np.testing.assert_allclose(mean, 0.0, atol=1e-1)
np.testing.assert_allclose(std, 1.0, atol=1e-1)

@chex.variants(with_jit=True)
def test_smc_waste_free(self):
num_mcmc_steps = 10
p = 500
num_particles = 1000
num_resampled = num_particles // num_mcmc_steps

def waste_free_update_fn(keys, particles, update_params):
def one_particle_fn(rng_key, position, particle_update_params):
hmc = blackjax.hmc(logdensity_fn, **particle_update_params)
state = hmc.init(position)

def body_fn(state, rng_key):
new_state, info = hmc.step(rng_key, state)
return new_state, (state, info)

keys = jax.random.split(rng_key, num_mcmc_steps)
_, (states, info) = jax.lax.scan(body_fn, state, keys)
return states.position, info

particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))(
keys, particles, update_params
)
particles = particles.reshape((num_particles,))
return particles, info

num_resampled = num_particles // p
init_key, sample_key = jax.random.split(self.key)

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
state = init(
init_particles,
dict(
step_size=1e-2,
inverse_mass_matrix=jnp.eye(1),
num_integration_steps=100,
),
{},
)
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
)
hmc_kernel = functools.partial(
blackjax.hmc.build_kernel(), **same_for_all_params
)
hmc_init = blackjax.hmc.init

waste_free_update_fn, _ = update_waste_free(
hmc_init,
logdensity_fn,
hmc_kernel,
num_particles,
p=p,
num_resampled=num_resampled,
)

# Run the SMC sampler once
Expand All @@ -116,10 +104,10 @@ def body_fn(state, rng_key):
resampling.systematic,
num_resampled,
)

assert new_state.particles.shape == (num_particles,)
mean, std = _weighted_avg_and_std(new_state.particles, state.weights)
np.testing.assert_allclose(0.0, mean, atol=1e-1)
np.testing.assert_allclose(1.0, std, atol=1e-1)
np.testing.assert_allclose(mean, 0.0, atol=1e-1)
np.testing.assert_allclose(std, 1.0, atol=1e-1)


class ExtendParamsTest(chex.TestCase):
Expand Down
Loading

0 comments on commit b02b60b

Please sign in to comment.