Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Waste Free SMC available for adaptive tempered and tempered SMC. #721

Merged
merged 19 commits into from
Aug 26, 2024
Merged
4 changes: 4 additions & 0 deletions blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def build_kernel(
resampling_fn: Callable,
target_ess: float,
root_solver: Callable = solver.dichotomy,
**extra_parameters,
) -> Callable:
r"""Build a Tempered SMC step using an adaptive schedule.

Expand Down Expand Up @@ -88,6 +89,7 @@ def compute_delta(state: tempered.TemperedSMCState) -> float:
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
**extra_parameters,
)

def kernel(
Expand Down Expand Up @@ -116,6 +118,7 @@ def as_top_level_api(
target_ess: float,
root_solver: Callable = solver.dichotomy,
num_mcmc_steps: int = 10,
**extra_parameters,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

Expand Down Expand Up @@ -155,6 +158,7 @@ def as_top_level_api(
resampling_fn,
target_ess,
root_solver,
**extra_parameters,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
Expand Down
60 changes: 45 additions & 15 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple
from typing import Callable, NamedTuple, Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -48,12 +48,42 @@ def init(particles: ArrayLikeTree):
return TemperedSMCState(particles, weights, 0.0)


def update_and_take_last(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
num_mcmc_steps,
n_particles,
):
"""
Given N particles, runs num_mcmc_steps of a kernel starting at each particle, and
returns the last values, waisting the previous num_mcmc_steps-1
ciguaran marked this conversation as resolved.
Show resolved Hide resolved
samples per chain.
"""

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

return jax.vmap(mcmc_kernel), n_particles


def build_kernel(
logprior_fn: Callable,
loglikelihood_fn: Callable,
mcmc_step_fn: Callable,
mcmc_init_fn: Callable,
resampling_fn: Callable,
update_strategy: Callable = update_and_take_last,
) -> Callable:
"""Build the base Tempered SMC kernel.

Expand Down Expand Up @@ -141,26 +171,23 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info
update_fn, num_resampled = update_strategy(
mcmc_init_fn,
tempered_logposterior_fn,
shared_mcmc_step_fn,
n_particles=state.weights.shape[0],
num_mcmc_steps=num_mcmc_steps,
)

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
update_fn,
jax.vmap(log_weights_fn),
resampling_fn,
num_resampled,
)

tempered_state = TemperedSMCState(
smc_state.particles, smc_state.weights, state.lmbda + delta
)
Expand All @@ -177,7 +204,8 @@ def as_top_level_api(
mcmc_init_fn: Callable,
mcmc_parameters: dict,
resampling_fn: Callable,
num_mcmc_steps: int = 10,
num_mcmc_steps: Optional[int] = 10,
update_strategy=update_and_take_last,
) -> SamplingAlgorithm:
"""Implements the (basic) user interface for the Adaptive Tempered SMC kernel.

Expand All @@ -204,12 +232,14 @@ def as_top_level_api(
A ``SamplingAlgorithm``.

"""

kernel = build_kernel(
logprior_fn,
loglikelihood_fn,
mcmc_step_fn,
mcmc_init_fn,
resampling_fn,
update_strategy,
)

def init_fn(position: ArrayLikeTree, rng_key=None):
Expand Down
70 changes: 70 additions & 0 deletions blackjax/smc/waste_free.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools

import jax
import jax.lax
import jax.numpy as jnp


def update_waste_free(
mcmc_init_fn,
logposterior_fn,
mcmc_step_fn,
n_particles: int,
p: int,
num_resampled,
num_mcmc_steps=None,
):
"""
Given M particles, mutates them using p-1 steps. Returns M*P-1 particles,
consistent of the initial plus all the intermediate steps, thus implementing a
waste-free update function
See Algorithm 2: https://arxiv.org/abs/2011.02328
"""
if num_mcmc_steps is not None:
raise ValueError(
"Can't use waste free SMC with a num_mcmc_steps parameter, set num_mcmc_steps = None"
)

num_mcmc_steps = p - 1

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
rng_key, state, logposterior_fn, **step_parameters
)
return new_state, (new_state, info)

_, (states, infos) = jax.lax.scan(
body_fn, state, jax.random.split(rng_key, num_mcmc_steps)
)
return states, infos

def update(rng_key, position, step_parameters):
"""
Given the initial particles, runs a chain starting at each.
The combines the initial particles with all the particles generated
at each step of each chain.
"""
states, infos = jax.vmap(mcmc_kernel)(rng_key, position, step_parameters)

# step particles is num_resmapled, num_mcmc_steps, dimension_of_variable
# want to transformed into num_resampled * num_mcmc_steps, dimension of variable
def reshape_step_particles(x):
_num_resampled, num_mcmc_steps, *dimension_of_variable = x.shape
return x.reshape((_num_resampled * num_mcmc_steps, *dimension_of_variable))

step_particles = jax.tree.map(reshape_step_particles, states.position)
new_particles = jax.tree.map(
lambda x, y: jnp.concatenate([x, y]), position, step_particles
)
return new_particles, infos

return update, num_resampled


def waste_free_smc(n_particles, p):
if not n_particles % p == 0:
raise ValueError("p must be a divider of n_particles ")
return functools.partial(update_waste_free, num_resampled=int(n_particles / p), p=p)
94 changes: 41 additions & 53 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Test the generic SMC sampler"""
import functools

import chex
import jax
import jax.numpy as jnp
Expand All @@ -9,6 +11,8 @@
import blackjax
import blackjax.smc.resampling as resampling
from blackjax.smc.base import extend_params, init, step
from blackjax.smc.tempered import update_and_take_last
from blackjax.smc.waste_free import update_waste_free


def logdensity_fn(position):
Expand All @@ -29,82 +33,66 @@ def setUp(self):
@chex.variants(with_jit=True)
def test_smc(self):
num_mcmc_steps = 20
num_particles = 1000

def update_fn(rng_key, position, update_params):
hmc = blackjax.hmc(logdensity_fn, **update_params)
state = hmc.init(position)

def body_fn(state, rng_key):
new_state, info = hmc.step(rng_key, state)
return new_state, info

keys = jax.random.split(rng_key, num_mcmc_steps)
last_state, info = jax.lax.scan(body_fn, state, keys)
return last_state.position, info

init_key, sample_key = jax.random.split(self.key)
num_particles = 5000

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
)
hmc_kernel = functools.partial(
blackjax.hmc.build_kernel(), **same_for_all_params
)
hmc_init = blackjax.hmc.init

state = init(
init_particles,
same_for_all_params,
update_fn, _ = update_and_take_last(
hmc_init, logdensity_fn, hmc_kernel, num_mcmc_steps, num_particles
)
init_key, sample_key = jax.random.split(self.key)

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
state = init(init_particles, {})
# Run the SMC sampler once
new_state, info = self.variant(step, static_argnums=(2, 3, 4))(
sample_key,
state,
jax.vmap(update_fn, in_axes=(0, 0, None)),
update_fn,
jax.vmap(logdensity_fn),
resampling.systematic,
)

assert new_state.particles.shape == (num_particles,)
mean, std = _weighted_avg_and_std(new_state.particles, state.weights)
np.testing.assert_allclose(0.0, mean, atol=1e-1)
np.testing.assert_allclose(1.0, std, atol=1e-1)
np.testing.assert_allclose(mean, 0.0, atol=1e-1)
np.testing.assert_allclose(std, 1.0, atol=1e-1)

@chex.variants(with_jit=True)
def test_smc_waste_free(self):
num_mcmc_steps = 10
p = 500
num_particles = 1000
num_resampled = num_particles // num_mcmc_steps

def waste_free_update_fn(keys, particles, update_params):
def one_particle_fn(rng_key, position, particle_update_params):
hmc = blackjax.hmc(logdensity_fn, **particle_update_params)
state = hmc.init(position)

def body_fn(state, rng_key):
new_state, info = hmc.step(rng_key, state)
return new_state, (state, info)

keys = jax.random.split(rng_key, num_mcmc_steps)
_, (states, info) = jax.lax.scan(body_fn, state, keys)
return states.position, info

particles, info = jax.vmap(one_particle_fn, in_axes=(0, 0, None))(
keys, particles, update_params
)
particles = particles.reshape((num_particles,))
return particles, info

num_resampled = num_particles // p
init_key, sample_key = jax.random.split(self.key)

# Initialize the state of the SMC sampler
init_particles = 0.25 + jax.random.normal(init_key, shape=(num_particles,))
state = init(
init_particles,
dict(
step_size=1e-2,
inverse_mass_matrix=jnp.eye(1),
num_integration_steps=100,
),
{},
)
same_for_all_params = dict(
step_size=1e-2, inverse_mass_matrix=jnp.eye(1), num_integration_steps=50
)
hmc_kernel = functools.partial(
blackjax.hmc.build_kernel(), **same_for_all_params
)
hmc_init = blackjax.hmc.init

waste_free_update_fn, _ = update_waste_free(
hmc_init,
logdensity_fn,
hmc_kernel,
num_particles,
p=p,
num_resampled=num_resampled,
)

# Run the SMC sampler once
Expand All @@ -116,10 +104,10 @@ def body_fn(state, rng_key):
resampling.systematic,
num_resampled,
)

assert new_state.particles.shape == (num_particles,)
mean, std = _weighted_avg_and_std(new_state.particles, state.weights)
np.testing.assert_allclose(0.0, mean, atol=1e-1)
np.testing.assert_allclose(1.0, std, atol=1e-1)
np.testing.assert_allclose(mean, 0.0, atol=1e-1)
np.testing.assert_allclose(std, 1.0, atol=1e-1)


class ExtendParamsTest(chex.TestCase):
Expand Down
Loading
Loading