diff --git a/blackjax/adaptation/base.py b/blackjax/adaptation/base.py index e0a01e596..c510abaf4 100644 --- a/blackjax/adaptation/base.py +++ b/blackjax/adaptation/base.py @@ -11,7 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import NamedTuple +from typing import NamedTuple, Set + +import jax from blackjax.types import ArrayTree @@ -25,3 +27,34 @@ class AdaptationInfo(NamedTuple): state: NamedTuple info: NamedTuple adaptation_state: NamedTuple + + +def return_all_adapt_info(state, info, adaptation_state): + """Return fully populated AdaptationInfo. Used for adaptation_info_fn + parameters of the adaptation algorithms. + """ + return AdaptationInfo(state, info, adaptation_state) + + +def get_filter_adapt_info_fn( + state_keys: Set[str] = set(), + info_keys: Set[str] = set(), + adapt_state_keys: Set[str] = set(), +): + """Generate a function to filter what is saved in AdaptationInfo. Used + for adptation_info_fn parameters of the adaptation algorithms. + adaptation_info_fn=get_filter_adapt_info_fn() saves no auxiliary information + """ + + def filter_tuple(tup, key_set): + mapfn = lambda key, val: None if key not in key_set else val + return jax.tree.map(mapfn, type(tup)(*tup._fields), tup) + + def filter_fn(state, info, adaptation_state): + sample_state = filter_tuple(state, state_keys) + new_info = filter_tuple(info, info_keys) + new_adapt_state = filter_tuple(adaptation_state, adapt_state_keys) + + return AdaptationInfo(sample_state, new_info, new_adapt_state) + + return filter_fn diff --git a/blackjax/adaptation/chees_adaptation.py b/blackjax/adaptation/chees_adaptation.py index e81bbeef8..60b3e719f 100644 --- a/blackjax/adaptation/chees_adaptation.py +++ b/blackjax/adaptation/chees_adaptation.py @@ -10,7 +10,7 @@ import blackjax.mcmc.dynamic_hmc as dynamic_hmc import blackjax.optimizers.dual_averaging as dual_averaging -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -278,6 +278,7 @@ def chees_adaptation( jitter_amount: float = 1.0, target_acceptance_rate: float = OPTIMAL_TARGET_ACCEPTANCE_RATE, decay_rate: float = 0.5, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the step size and trajectory length (number of integration steps / step size) parameters of the jittered HMC algorthm. @@ -337,6 +338,11 @@ def chees_adaptation( Float representing how much to favor recent iterations over earlier ones in the optimization of step size and trajectory length. A value of 1 gives equal weight to all history. A value of 0 gives weight only to the most recent iteration. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -411,10 +417,8 @@ def one_step(carry, rng_key): info.is_divergent, ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) batch_init = jax.vmap( diff --git a/blackjax/adaptation/meads_adaptation.py b/blackjax/adaptation/meads_adaptation.py index 8ed135fb5..a431a591d 100644 --- a/blackjax/adaptation/meads_adaptation.py +++ b/blackjax/adaptation/meads_adaptation.py @@ -17,7 +17,7 @@ import jax.numpy as jnp import blackjax.mcmc as mcmc -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.base import AdaptationAlgorithm from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -165,6 +165,7 @@ def update( def meads_adaptation( logdensity_fn: Callable, num_chains: int, + adaptation_info_fn: Callable = return_all_adapt_info, ) -> AdaptationAlgorithm: """Adapt the parameters of the Generalized HMC algorithm. @@ -194,6 +195,11 @@ def meads_adaptation( The log density probability density function from which we wish to sample. num_chains Number of chains used for cross-chain warm-up training. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. Returns ------- @@ -227,10 +233,8 @@ def one_step(carry, rng_key): adaptation_state, new_states.position, new_states.logdensity_grad ) - return (new_states, new_adaptation_state), AdaptationInfo( - new_states, - info, - new_adaptation_state, + return (new_states, new_adaptation_state), adaptation_info_fn( + new_states, info, new_adaptation_state ) def run(rng_key: PRNGKey, positions: ArrayLikeTree, num_steps: int = 1000): diff --git a/blackjax/adaptation/pathfinder_adaptation.py b/blackjax/adaptation/pathfinder_adaptation.py index efcc55741..c0b4ebc50 100644 --- a/blackjax/adaptation/pathfinder_adaptation.py +++ b/blackjax/adaptation/pathfinder_adaptation.py @@ -18,7 +18,7 @@ import jax.numpy as jnp import blackjax.vi as vi -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.step_size import ( DualAveragingAdaptationState, dual_averaging_adaptation, @@ -141,6 +141,7 @@ def pathfinder_adaptation( logdensity_fn: Callable, initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, + adaptation_info_fn: Callable = return_all_adapt_info, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -156,6 +157,11 @@ def pathfinder_adaptation( The initial step size used in the algorithm. target_acceptance_rate The acceptance rate that we target during step size adaptation. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -188,7 +194,7 @@ def one_step(carry, rng_key): ) return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 400): diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index e15121dc5..dd3e7b282 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -17,7 +17,7 @@ import jax import jax.numpy as jnp -from blackjax.adaptation.base import AdaptationInfo, AdaptationResults +from blackjax.adaptation.base import AdaptationResults, return_all_adapt_info from blackjax.adaptation.mass_matrix import ( MassMatrixAdaptationState, mass_matrix_adaptation, @@ -248,6 +248,7 @@ def window_adaptation( initial_step_size: float = 1.0, target_acceptance_rate: float = 0.80, progress_bar: bool = False, + adaptation_info_fn: Callable = return_all_adapt_info, **extra_parameters, ) -> AdaptationAlgorithm: """Adapt the value of the inverse mass matrix and step size parameters of @@ -278,6 +279,11 @@ def window_adaptation( The acceptance rate that we target during step size adaptation. progress_bar Whether we should display a progress bar. + adaptation_info_fn + Function to select the adaptation info returned. See return_all_adapt_info + and get_filter_adapt_info_fn in blackjax.adaptation.base. By default all + information is saved - this can result in excessive memory usage if the + information is unused. **extra_parameters The extra parameters to pass to the algorithm, e.g. the number of integration steps for HMC. @@ -316,7 +322,7 @@ def one_step(carry, xs): return ( (new_state, new_adaptation_state), - AdaptationInfo(new_state, info, new_adaptation_state), + adaptation_info_fn(new_state, info, new_adaptation_state), ) def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 4450e61f9..68751bee8 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -6,6 +6,7 @@ import blackjax from blackjax.adaptation import window_adaptation +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.util import run_inference_algorithm @@ -34,7 +35,32 @@ def test_adaptation_schedule(num_steps, expected_schedule): assert np.array_equal(adaptation_schedule, expected_schedule) -def test_chees_adaptation(): +@pytest.mark.parametrize( + "adaptation_filters", + [ + { + "filter_fn": return_all_adapt_info, + "return_sets": None, + }, + { + "filter_fn": get_filter_adapt_info_fn(), + "return_sets": (set(), set(), set()), + }, + { + "filter_fn": get_filter_adapt_info_fn( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + "return_sets": ( + {"logdensity"}, + {"proposal"}, + {"random_generator_arg", "step", "da_state"}, + ), + }, + ], +) +def test_chees_adaptation(adaptation_filters): logprob_fn = lambda x: jax.scipy.stats.norm.logpdf( x, loc=0.0, scale=jnp.array([1.0, 10.0]) ).sum() @@ -47,7 +73,10 @@ def test_chees_adaptation(): init_key, warmup_key, inference_key = jax.random.split(jax.random.key(346), 3) warmup = blackjax.chees_adaptation( - logprob_fn, num_chains=num_chains, target_acceptance_rate=0.75 + logprob_fn, + num_chains=num_chains, + target_acceptance_rate=0.75, + adaptation_info_fn=adaptation_filters["filter_fn"], ) initial_positions = jax.random.normal(init_key, (num_chains, 2)) @@ -71,6 +100,25 @@ def test_chees_adaptation(): )(chain_keys, last_states) harmonic_mean = 1.0 / jnp.mean(1.0 / infos.acceptance_rate) + + def check_attrs(attribute, keyset): + for name, param in getattr(warmup_info, attribute)._asdict().items(): + print(name, param) + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = adaptation_filters["return_sets"] + if keysets is None: + keysets = ( + warmup_info.state._fields, + warmup_info.info._fields, + warmup_info.adaptation_state._fields, + ) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + np.testing.assert_allclose(harmonic_mean, 0.75, atol=1e-1) np.testing.assert_allclose(parameters["step_size"], 1.5, rtol=2e-1) np.testing.assert_array_less(infos.num_integration_steps.mean(), 15.0) diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index f334206e6..cccd34c98 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -13,6 +13,7 @@ import blackjax import blackjax.diagnostics as diagnostics import blackjax.mcmc.random_walk +from blackjax.adaptation.base import get_filter_adapt_info_fn, return_all_adapt_info from blackjax.util import run_inference_algorithm @@ -56,6 +57,27 @@ def rmh_proposal_distribution(rng_key, position): }, ] +window_adaptation_filters = [ + { + "filter_fn": return_all_adapt_info, + "return_sets": None, + }, + { + "filter_fn": get_filter_adapt_info_fn(), + "return_sets": (set(), set(), set()), + }, + { + "filter_fn": get_filter_adapt_info_fn( + {"position"}, {"is_divergent"}, {"ss_state", "inverse_mass_matrix"} + ), + "return_sets": ( + {"position"}, + {"is_divergent"}, + {"ss_state", "inverse_mass_matrix"}, + ), + }, +] + class LinearRegressionTest(chex.TestCase): """Test sampling of a linear regression model.""" @@ -112,8 +134,14 @@ def run_mclmc(self, logdensity_fn, num_steps, initial_position, key): return samples - @parameterized.parameters(itertools.product(regression_test_cases, [True, False])) - def test_window_adaptation(self, case, is_mass_matrix_diagonal): + @parameterized.parameters( + itertools.product( + regression_test_cases, [True, False], window_adaptation_filters + ) + ) + def test_window_adaptation( + self, case, is_mass_matrix_diagonal, window_adapt_config + ): """Test the HMC kernel and the Stan warmup.""" rng_key, init_key0, init_key1 = jax.random.split(self.key, 3) x_data = jax.random.normal(init_key0, shape=(1000, 1)) @@ -131,15 +159,33 @@ def test_window_adaptation(self, case, is_mass_matrix_diagonal): logposterior_fn, is_mass_matrix_diagonal, progress_bar=True, + adaptation_info_fn=window_adapt_config["filter_fn"], **case["parameters"], ) - (state, parameters), _ = warmup.run( + (state, parameters), info = warmup.run( warmup_key, case["initial_position"], case["num_warmup_steps"], ) inference_algorithm = case["algorithm"](logposterior_fn, **parameters) + def check_attrs(attribute, keyset): + for name, param in getattr(info, attribute)._asdict().items(): + if name in keyset: + assert param is not None + else: + assert param is None + + keysets = window_adapt_config["return_sets"] + if keysets is None: + keysets = ( + info.state._fields, + info.info._fields, + info.adaptation_state._fields, + ) + for i, attribute in enumerate(["state", "info", "adaptation_state"]): + check_attrs(attribute, keysets[i]) + _, states, _ = run_inference_algorithm( rng_key=inference_key, initial_state=state,