Skip to content

Commit

Permalink
add parameter filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdipper committed Jun 13, 2024
1 parent dd9ba03 commit 7a8d3ce
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 16 deletions.
14 changes: 12 additions & 2 deletions blackjax/smc/tempered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, NamedTuple

import jax
Expand Down Expand Up @@ -119,6 +120,13 @@ def kernel(
"""
delta = lmbda - state.lmbda

shared_mcmc_parameters = {
k: v[0, ...] for k, v in mcmc_parameters.items() if v.shape[0] == 1
}
unshared_mcmc_parameters = {
k: v for k, v in mcmc_parameters.items() if v.shape[0] != 1
}

def log_weights_fn(position: ArrayLikeTree) -> float:
return delta * loglikelihood_fn(position)

Expand All @@ -127,11 +135,13 @@ def tempered_logposterior_fn(position: ArrayLikeTree) -> float:
tempered_loglikelihood = state.lmbda * loglikelihood_fn(position)
return logprior + tempered_loglikelihood

shared_mcmc_step_fn = partial(mcmc_step_fn, **shared_mcmc_parameters)

def mcmc_kernel(rng_key, position, step_parameters):
state = mcmc_init_fn(position, tempered_logposterior_fn)

def body_fn(state, rng_key):
new_state, info = mcmc_step_fn(
new_state, info = shared_mcmc_step_fn(
rng_key, state, tempered_logposterior_fn, **step_parameters
)
return new_state, info
Expand All @@ -142,7 +152,7 @@ def body_fn(state, rng_key):

smc_state, info = smc.base.step(
rng_key,
SMCState(state.particles, state.weights, mcmc_parameters),
SMCState(state.particles, state.weights, unshared_mcmc_parameters),
jax.vmap(mcmc_kernel),
jax.vmap(log_weights_fn),
resampling_fn,
Expand Down
10 changes: 5 additions & 5 deletions tests/smc/test_inner_kernel_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def test_with_adaptive_tempered(self):

def parameter_update(state, info):
return extend_params(
100,
1,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -298,7 +298,7 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
1,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down Expand Up @@ -326,7 +326,7 @@ def body(carry):

_, state = inference_loop(smc_kernel, self.key, init_state)

assert state.parameter_override["inverse_mass_matrix"].shape == (100, 2, 2)
assert state.parameter_override["inverse_mass_matrix"].shape == (1, 2, 2)
self.assert_linear_regression_test_case(state.sampler_state)

@chex.all_variants(with_pmap=False)
Expand All @@ -340,7 +340,7 @@ def test_with_tempered_smc(self):

def parameter_update(state, info):
return extend_params(
100,
1,
{
"inverse_mass_matrix": mass_matrix_from_particles(state.particles),
"step_size": 10e-2,
Expand All @@ -357,7 +357,7 @@ def parameter_update(state, info):
resampling.systematic,
mcmc_parameter_update_fn=parameter_update,
initial_parameter_value=extend_params(
100,
1,
dict(
inverse_mass_matrix=jnp.eye(2),
step_size=10e-2,
Expand Down
27 changes: 18 additions & 9 deletions tests/smc/test_tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,25 @@ def logprior_fn(x):

hmc_kernel = blackjax.hmc.build_kernel()
hmc_init = blackjax.hmc.init
hmc_parameters = extend_params(
num_particles,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
"num_integration_steps": 50,
},
hmc_parameters_list = [
extend_params(
num_particles if extend else 1,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
"num_integration_steps": 50,
},
)
for extend in [True, False]
]
hmc_parameters_list.append(
extend_params(
num_particles, {"step_size": 10e-2, "num_integration_steps": 50}
)
| extend_params(num_particles, {"inverse_mass_matrix": jnp.eye(2)})
)

for target_ess in [0.5, 0.75]:
for target_ess, hmc_parameters in zip([0.5, 0.5, 0.75], hmc_parameters_list):
tempering = adaptive_tempered_smc(
logprior_fn,
loglikelihood_fn,
Expand Down Expand Up @@ -115,7 +124,7 @@ def test_fixed_schedule_tempered_smc(self):
hmc_init = blackjax.hmc.init
hmc_kernel = blackjax.hmc.build_kernel()
hmc_parameters = extend_params(
100,
1,
{
"step_size": 10e-2,
"inverse_mass_matrix": jnp.eye(2),
Expand Down

0 comments on commit 7a8d3ce

Please sign in to comment.