From b34b4ccdb1fd641617604ef6e8ebb03ce32854a4 Mon Sep 17 00:00:00 2001 From: andrewdipper Date: Sun, 6 Oct 2024 22:09:36 -0600 Subject: [PATCH] switch to blackjax run_inference_algorithm --- pymc/sampling/jax.py | 79 +++++++++++++++------------- tests/sampling/test_mcmc_external.py | 57 +------------------- 2 files changed, 45 insertions(+), 91 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 516b631a5c..c60fcac68c 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -278,27 +278,22 @@ def map_fn(x): assert draws % num_chunks == 0 nsteps = draws // num_chunks - # Run adaptation - adapt = blackjax.window_adaptation( - algorithm=algorithm, - logdensity_fn=logprob_fn, - target_acceptance_rate=target_accept, - adaptation_info_fn=get_filter_adapt_info_fn(), - progress_bar=progressbar, - **nuts_kwargs, - ) - + # Run adaptation for sampling parameters @map_fn def run_adaptation(seed, init_position): - return adapt.run(seed, init_position, num_steps=tune) - - (last_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) - - def _one_step(state, x, kernel): - del x - state, rng_key = state - key, _skey = jax.random.split(rng_key) - state, info = kernel(_skey, state) + return blackjax.window_adaptation( + algorithm=algorithm, + logdensity_fn=logprob_fn, + target_acceptance_rate=target_accept, + adaptation_info_fn=get_filter_adapt_info_fn(), + progress_bar=progressbar, + **nuts_kwargs, + ).run(seed, init_position, num_steps=tune) + + (adapt_state, tuned_params), _ = run_adaptation(adapt_seed, initial_points) + + # Filters output from each sampling step + def _transform_fn(state, info): position = state.position stats = { "diverging": info.is_divergent, @@ -308,42 +303,54 @@ def _one_step(state, x, kernel): "acceptance_rate": info.acceptance_rate, "lp": state.logdensity, } - return (state, key), (position, stats) + return position, stats + # Performs sampling for each chunk + # random keys are carried with state @map_fn @partial(jax.jit, donate_argnums=0) def _multi_step(state, imm, ss): - start_state, key = state - scan_fn = blackjax.progress_bar.gen_scan_fn(nsteps, progressbar) - - kernel = algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss).step - - (last_state, key), (raw_samples, stats) = scan_fn( - partial(_one_step, kernel=kernel), (start_state, key), jnp.arange(nsteps) + state, key = state + key, _skey = jax.random.split(key) + last_state, (raw_samples, stats) = blackjax.util.run_inference_algorithm( + _skey, + algorithm(logprob_fn, inverse_mass_matrix=imm, step_size=ss), + num_steps=nsteps, + initial_state=state, + progress_bar=progressbar, + transform=_transform_fn, ) samples, log_likelihoods = postprocess_fn(raw_samples) return (last_state, key), ((samples, log_likelihoods), stats) - sample_fn = partial( + chunk_sample_fn = partial( _multi_step, imm=tuned_params["inverse_mass_matrix"], ss=tuned_params["step_size"] ) + if progressbar: logger.info("Sampling chunk %d of %d:" % (1, num_chunks)) - (last_state, seed), (samples, stats) = sample_fn((last_state, sample_seed)) + + # Sample first chunk + last_state, sample_data = chunk_sample_fn((adapt_state, sample_seed)) + + # If single chunk sampling return results on device if num_chunks == 1: - return samples[0], stats, samples[1], blackjax + ((samples, log_likelihoods), stats) = sample_data + return samples, stats, log_likelihoods, blackjax + # Provision space for all samples on the cpu + save first chunk output = _set_tree( - jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), (samples, stats)), - jax.device_put((samples, stats), jax.devices("cpu")[0]), + jax.tree.map(jax.vmap(partial(_gen_arr, nchunk=num_chunks)), sample_data), + jax.device_put(sample_data, jax.devices("cpu")[0]), 0, ) - del samples, stats + del sample_data - last_state, (all_samples, all_stats) = _do_chunked_sampling( - (last_state, seed), output, num_chunks, nsteps, sample_fn, progressbar + # Sample remaining chunks + _, ((samples, log_likelihoods), stats) = _do_chunked_sampling( + last_state, output, num_chunks, nsteps, chunk_sample_fn, progressbar ) - return all_samples[0], all_stats, all_samples[1], blackjax + return samples, stats, log_likelihoods, blackjax def _numpyro_stats_to_dict(posterior): diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index 654e3c70ce..d1f4544f4c 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -75,9 +75,9 @@ def test_external_nuts_sampler(recwarn, nuts_sampler): assert idata_reference.posterior.attrs.keys() == idata1.posterior.attrs.keys() -@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) -def test_external_nuts_chunking(nuts_sampler): +def test_numpyro_external_nuts_chunking(): # chunked sampling should give exact same results as non-chunked + nuts_sampler = "numpyro" pytest.importorskip(nuts_sampler) with Model(): @@ -104,56 +104,3 @@ def test_external_nuts_chunking(nuts_sampler): np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys() - - -def test_step_args(): - with Model() as model: - a = Normal("a") - idata = sample( - nuts_sampler="numpyro", - target_accept=0.5, - nuts={"max_treedepth": 10}, - random_seed=1410, - ) - - npt.assert_almost_equal(idata.sample_stats.acceptance_rate.mean(), 0.5, decimal=1) - - -@pytest.mark.skipif(jax.default_backend() == "cpu", reason="need default backend that is not cpu") -@pytest.mark.parametrize("nuts_sampler", ["blackjax", "numpyro"]) -def test_postprocessing_backend(nuts_sampler): - pytest.importorskip(nuts_sampler) - default_backend = jax.default_backend() - - with Model(): - x = Normal("x", 100, 5) - y = Data("y", [1, 2, 3, 4]) - - Normal("L", mu=x, sigma=0.1, observed=y) - - base_kwargs = dict( - nuts_sampler=nuts_sampler, - random_seed=123, - chains=4, - tune=200, - draws=200, - progressbar=False, - initvals={"x": 0.0}, - idata_kwargs={"log_likelihood": True}, - ) - - idata1 = sample( - **base_kwargs, - nuts_sampler_kwargs={ - "postprocessing_backend": default_backend, - "chain_method": "vectorized", - }, - ) - idata2 = sample( - **base_kwargs, - nuts_sampler_kwargs={"postprocessing_backend": "cpu", "chain_method": "vectorized"}, - ) - - np.testing.assert_array_equal(idata1.posterior.x, idata2.posterior.x) - np.testing.assert_array_equal(idata1.log_likelihood.L, idata2.log_likelihood.L) - assert idata1.posterior.attrs.keys() == idata2.posterior.attrs.keys()