Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable shared mcmc parameters with tempered smc #694

Merged
merged 3 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion blackjax/smc/adaptive_tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
target_ess
Expand Down
21 changes: 18 additions & 3 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple

import jax
Expand Down Expand Up @@ -108,6 +109,9 @@ def kernel(
Current state of the tempered SMC algorithm
lmbda
Current value of the tempering parameter
mcmc_parameters
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.

Returns
-------
Expand All @@ -119,6 +123,14 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {}
unshared_mcmc_parameters = {}
for k, v in mcmc_parameters.items():
if v.shape[0] == 1:
shared_mcmc_parameters[k] = v[0, ...]
else:
unshared_mcmc_parameters[k] = v

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -127,11 +139,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info
Expand All @@ -142,7 +156,7 @@ def body_fn(state, rng_key):

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, mcmc_parameters),
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
jax.vmap(log_weights_fn),
resampling_fn,
Expand Down Expand Up @@ -178,7 +192,8 @@ def as_top_level_api(
mcmc_init_fn
The MCMC init function used to build a MCMC state from a particle position.
mcmc_parameters
The parameters of the MCMC step function.
The parameters of the MCMC step function. Parameters with leading dimension
length of 1 are shared amongst the particles.
resampling_fn
The function used to resample the particles.
num_mcmc_steps
Expand Down
10 changes: 5 additions & 5 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_with_adaptive_tempered(self):

def parameter_update(state, info):
return extend_params(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling extend params shouldn't be needed anymore

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still using extend_params since it needs to convert to an array and add a leading dimension of length one. I could have added a separate helper but it'd do the same thing as extend_params but just without the jnp.repeat part

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is worth to modify extend_params into something simple without the jnp.repeat.

Would this work?

def extend_params(params: ArrayLikeTree):
    return jax.tree.map(lambda a: jnp.expand_dims(a, 0), params)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work, perhaps with potentially a jnp.asarray() included. The concern I have with modifying extend_params is I'd assume there's still a use case for it with initializing unshared parameters?

A current use case is for the remaining tests that are using duplicated parameters for testing the unshared case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ciguaran Is there any other use case besides dimension matching initially?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's none.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd assume there's still a use case for it with initializing unshared parameters -> that is a good concern although I think that in most cases we wouldn't initialize by "extending" with copies. We would do it by sampling from some probability distribution. So you can modify it as Junpeng suggested.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, that makes sense. thanks

100,
1,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -298,7 +298,7 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
1,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down Expand Up @@ -326,7 +326,7 @@ def body(carry):

_, state = inference_loop(smc_kernel, self.key, init_state)

assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2)
assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2)
self.assert_linear_regression_test_case(state.sampler_state)

@chex.all_variants(with_pmap=False)
Expand All @@ -340,7 +340,7 @@ def test_with_tempered_smc(self):

def parameter_update(state, info):
return extend_params(
100,
1,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -357,7 +357,7 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
1,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down
27 changes: 18 additions & 9 deletions tests/smc/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,25 @@ def logprior_fn(x):

hmc_kernel = blackjax.hmc.build_kernel()
hmc_init = blackjax.hmc.init
hmc_parameters = extend_params(
num_particles,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
"num_integration_steps": 50,
},
hmc_parameters_list = [
extend_params(
num_particles if extend else 1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a comment highlighting that here you are testing that extending with a copy is the same as having a 1 dimension parameter. Otherwise it may be tricky to get why this is even hapening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
"num_integration_steps": 50,
},
)
for extend in [True, False]
]
hmc_parameters_list.append(
extend_params(
num_particles, {"step_size": 10e-2, "num_integration_steps": 50}
)
| extend_params(num_particles, {"inverse_mass_matrix": jnp.eye(2)})
)

for target_ess in [0.5, 0.75]:
for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list):
tempering = adaptive_tempered_smc(
logprior_fn,
loglikelihood_fn,
Expand Down Expand Up @@ -115,7 +124,7 @@ def test_fixed_schedule_tempered_smc(self):
hmc_init = blackjax.hmc.init
hmc_kernel = blackjax.hmc.build_kernel()
hmc_parameters = extend_params(
100,
1,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
Expand Down
Loading