diff --git a/README.md b/README.md index a8d847cf9..06d5b46cf 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,10 @@ state = nuts.init(initial_position) # Iterate rng_key = jax.random.key(0) -for step in range(100): - nuts_key = jax.random.fold_in(rng_key, step) - state, _ = nuts.step(nuts_key, state) +step = jax.jit(nuts.step) +for i in range(100): + nuts_key = jax.random.fold_in(rng_key, i) + state, _ = step(nuts_key, state) ``` See [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc. diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 76a016242..7645a890b 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average_update +from blackjax.util import incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -199,9 +199,9 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( - current_value=jnp.array([x, jnp.square(x)]), - previous_weight_and_average=streaming_avg, + streaming_avg = incremental_value_update( + expectation=jnp.array([x, jnp.square(x)]), + incremental_val=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask, ) diff --git a/blackjax/adaptation/window_adaptation.py b/blackjax/adaptation/window_adaptation.py index 63c54bad0..69a098325 100644 --- a/blackjax/adaptation/window_adaptation.py +++ b/blackjax/adaptation/window_adaptation.py @@ -28,7 +28,7 @@ dual_averaging_adaptation, ) from blackjax.base import AdaptationAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, PRNGKey from blackjax.util import pytree_size @@ -333,17 +333,16 @@ def run(rng_key: PRNGKey, position: ArrayLikeTree, num_steps: int = 1000): if progress_bar: print("Running window adaptation") - one_step_ = jax.jit(progress_bar_scan(num_steps)(one_step)) - else: - one_step_ = jax.jit(one_step) - + scan_fn = gen_scan_fn(num_steps, progress_bar=progress_bar) + start_state = (init_state, init_adaptation_state) keys = jax.random.split(rng_key, num_steps) schedule = build_schedule(num_steps) - last_state, info = jax.lax.scan( - one_step_, - (init_state, init_adaptation_state), + last_state, info = scan_fn( + one_step, + start_state, (jnp.arange(num_steps), keys, schedule), ) + last_chain_state, last_warmup_state, *_ = last_state step_size, inverse_mass_matrix = adapt_final(last_warmup_state) diff --git a/blackjax/progress_bar.py b/blackjax/progress_bar.py index ac509b9b6..a1425df88 100644 --- a/blackjax/progress_bar.py +++ b/blackjax/progress_bar.py @@ -14,14 +14,19 @@ """Progress bar decorators for use with step functions. Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`. """ +from threading import Lock + from fastprogress.fastprogress import progress_bar from jax import lax from jax.experimental import io_callback +from jax.numpy import array def progress_bar_scan(num_samples, print_rate=None): "Progress bar for a JAX scan" progress_bars = {} + idx_counter = 0 + lock = Lock() if print_rate is None: if num_samples > 20: @@ -29,41 +34,44 @@ def progress_bar_scan(num_samples, print_rate=None): else: print_rate = 1 # if you run the sampler for less than 20 iterations - def _define_bar(arg): - del arg - progress_bars[0] = progress_bar(range(num_samples)) - progress_bars[0].update(0) + def _calc_chain_idx(iter_num): + nonlocal idx_counter + with lock: + idx = idx_counter + idx_counter += 1 + return idx + + def _update_bar(arg, chain_id): + chain_id = int(chain_id) + if arg == 0: + chain_id = _calc_chain_idx(arg) + progress_bars[chain_id] = progress_bar(range(num_samples)) + progress_bars[chain_id].update(0) - def _update_bar(arg): - progress_bars[0].update_bar(arg + 1) + progress_bars[chain_id].update_bar(arg + 1) + return chain_id - def _close_bar(arg): - del arg - progress_bars[0].on_iter_end() + def _close_bar(arg, chain_id): + progress_bars[int(chain_id)].on_iter_end() - def _update_progress_bar(iter_num): + def _update_progress_bar(iter_num, chain_id): "Updates progress bar of a JAX scan or loop" - _ = lax.cond( - iter_num == 0, - lambda _: io_callback(_define_bar, None, iter_num), - lambda _: None, - operand=None, - ) - _ = lax.cond( + chain_id = lax.cond( # update every multiple of `print_rate` except at the end (iter_num % print_rate == 0) | (iter_num == (num_samples - 1)), - lambda _: io_callback(_update_bar, None, iter_num), - lambda _: None, + lambda _: io_callback(_update_bar, array(0), iter_num, chain_id), + lambda _: chain_id, operand=None, ) _ = lax.cond( iter_num == num_samples - 1, - lambda _: io_callback(_close_bar, None, None), + lambda _: io_callback(_close_bar, None, iter_num + 1, chain_id), lambda _: None, operand=None, ) + return chain_id def _progress_bar_scan(func): """Decorator that adds a progress bar to `body_fun` used in `lax.scan`. @@ -77,9 +85,26 @@ def wrapper_progress_bar(carry, x): iter_num, *_ = x else: iter_num = x - _update_progress_bar(iter_num) - return func(carry, x) + subcarry, chain_id = carry + chain_id = _update_progress_bar(iter_num, chain_id) + subcarry, y = func(subcarry, x) + + return (subcarry, chain_id), y return wrapper_progress_bar return _progress_bar_scan + + +def gen_scan_fn(num_samples, progress_bar, print_rate=None): + if progress_bar: + + def scan_wrap(f, init, *args, **kwargs): + func = progress_bar_scan(num_samples, print_rate)(f) + carry = (init, -1) + (last_state, _), output = lax.scan(func, carry, *args, **kwargs) + return last_state, output + + return scan_wrap + else: + return lax.scan diff --git a/blackjax/util.py b/blackjax/util.py index cdb9f4c91..b6c5367b5 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -3,15 +3,14 @@ from functools import partial from typing import Callable, Union -import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree from jax.random import normal, split -from jax.tree_util import tree_leaves +from jax.tree_util import tree_leaves, tree_map from blackjax.base import SamplingAlgorithm, VIAlgorithm -from blackjax.progress_bar import progress_bar_scan +from blackjax.progress_bar import gen_scan_fn from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey @@ -149,9 +148,7 @@ def run_inference_algorithm( initial_state: ArrayLikeTree = None, initial_position: ArrayLikeTree = None, progress_bar: bool = False, - transform: Callable = lambda x: x, - return_state_history=True, - expectation: Callable = lambda x: x, + transform: Callable = lambda state, info: (state, info), ) -> tuple: """Wrapper to run an inference algorithm. @@ -166,8 +163,7 @@ def run_inference_algorithm( initial_state The initial state of the inference algorithm. initial_position - The initial position of the inference algorithm. This is used when the initial - state is not provided. + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -175,26 +171,14 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transformation of the trace of states (and info) to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done - incrementally, so doesn't require storing all the states. - return_state_history - if False, `run_inference_algorithm` will only return an expectation of the value - of transform, and return that average instead of the full set of samples. This - is useful when memory is a bottleneck. Returns ------- - If return_state_history is True: 1. The final state. - 2. The trace of the state. - 3. The trace of the info of the inference algorithm for diagnostics. - If return_state_history is False: - 1. This is the expectation of state over the chain. Otherwise the final state. - 2. The final state of the inference algorithm. + 2. The history of states. """ if initial_state is None and initial_position is None: @@ -212,58 +196,116 @@ def run_inference_algorithm( keys = split(rng_key, num_steps) - def one_step(average_and_state, xs, return_state): + def one_step(state, xs): _, rng_key = xs - average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average_update(expectation(transform(state)), average) - if return_state: - return (average, state), (transform(state), info) - else: - return (average, state), None + return state, transform(state, info) - one_step = jax.jit(partial(one_step, return_state=return_state_history)) + scan_fn = gen_scan_fn(num_steps, progress_bar) - if progress_bar: - one_step = progress_bar_scan(num_steps)(one_step) + xs = jnp.arange(num_steps), keys + final_state, history = scan_fn(one_step, initial_state, xs) - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + return final_state, history - if not return_state_history: - return average, transform(final_state) - else: - state_history, info_history = history - return transform(final_state), state_history, info_history +def store_only_expectation_values( + sampling_algorithm, + state_transform=lambda x: x, + incremental_value_transform=lambda x: x, + burn_in=0, +): + """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same + kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. + + It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample. + + Example: + + .. code:: + + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3) + model = StandardNormal(2) + initial_position = model.sample_init(init_key) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key + ) + integrator_type = "mclachlan" + L = 1.0 + step_size = 0.1 + num_steps = 4 + + integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] + state_transform = lambda state: state.position + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, + state_transform=state_transform) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + """ + + def init_fn(state): + averaging_state = (0.0, state_transform(state)) + return (state, averaging_state) + + def update_fn(rng_key, state_and_incremental_val): + state, averaging_state = state_and_incremental_val + state, info = sampling_algorithm.step( + rng_key, state + ) # update the state with the sampling algorithm + averaging_state = incremental_value_update( + state_transform(state), + averaging_state, + weight=( + averaging_state[0] >= burn_in + ), # If we want to eliminate some number of steps as a burn-in + zero_prevention=1e-10 * (burn_in > 0), + ) + # update the expectation value with the running average + return (state, averaging_state), info + + def transform(state_and_incremental_val, info): + (state, (_, incremental_value)) = state_and_incremental_val + return incremental_value_transform(incremental_value), info + + return SamplingAlgorithm(init_fn, update_fn), transform -def streaming_average_update( - current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0 + +def incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - current_value - the current value of the function that we want to take average of - previous_weight_and_average - tuple of (previous_weight, previous_average) where previous_weight is the - sum of weights and average is the current estimated average + expectation + the value of the expectation at the current timestep + incremental_val + tuple of (total, average) where total is the sum of weights and average is the current average weight weight of the current state zero_prevention small value to prevent division by zero Returns: ---------- - new total weight and streaming average + new streaming average """ - previous_weight, previous_average = previous_weight_and_average - current_weight = previous_weight + weight - current_average = jax.tree.map( - lambda x, avg: (previous_weight * avg + weight * x) - / (current_weight + zero_prevention), - current_value, - previous_average, + + total, average = incremental_val + average = tree_map( + lambda exp, av: (total * av + weight * exp) + / (total + weight + zero_prevention), + expectation, + average, ) - return current_weight, current_average + total += weight + return total, average diff --git a/docs/index.md b/docs/index.md index edc02631c..fca4787c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,7 @@ rng_key = jax.random.key(0) step = jax.jit(nuts.step) for i in range(1_000): nuts_key = jax.random.fold_in(rng_key, i) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` :::{note} diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 68751bee8..4b34511be 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters): algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, _, infos = jax.vmap( + _, (_, infos) = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 18a07625b..c399929da 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -135,12 +135,12 @@ def run_mclmc( sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) - _, samples, _ = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) return samples @@ -197,7 +197,7 @@ def check_attrs(attribute, keyset): for i, attribute in enumerate(["state", "info", "adaptation_state"]): check_attrs(attribute, keysets[i]) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, @@ -223,15 +223,16 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=mala, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -375,15 +376,16 @@ def test_pathfinder_adaptation( ) inference_algorithm = algorithm(logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + transform=lambda state, info: state.position, ) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -418,17 +420,18 @@ def test_meads(self): inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -465,17 +468,18 @@ def test_chees(self, jitter_generator): inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) - coefs_samples = states.position["coefs"] - scale_samples = np.exp(states.position["log_scale"]) + coefs_samples = states["coefs"] + scale_samples = np.exp(states["log_scale"]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-1) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-1) @@ -494,15 +498,16 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=barker, + transform=lambda state, info: state.position, num_steps=10_000, ) - coefs_samples = states.position["coefs"][3000:] - scale_samples = np.exp(states.position["log_scale"][3000:]) + coefs_samples = states["coefs"][3000:] + scale_samples = np.exp(states["log_scale"][3000:]) np.testing.assert_allclose(np.mean(scale_samples), 1.0, atol=1e-2) np.testing.assert_allclose(np.mean(coefs_samples), 3.0, atol=1e-2) @@ -679,19 +684,20 @@ def test_latent_gaussian(self): initial_state = inference_algorithm.init(jnp.zeros((1,))) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=self.sampling_steps, ), )(rng_key=self.key, initial_state=initial_state) np.testing.assert_allclose( - np.var(states.position[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 + np.var(states[self.burnin :]), 1 / (1 + 0.5), rtol=1e-2, atol=1e-2 ) np.testing.assert_allclose( - np.mean(states.position[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 + np.mean(states[self.burnin :]), 2 / 3, rtol=1e-2, atol=1e-2 ) @@ -724,7 +730,7 @@ def univariate_normal_test_case( **kwargs, ): inference_key, orbit_key = jax.random.split(rng_key) - _, states, _ = self.variant( + _, (states, info) = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, @@ -855,7 +861,7 @@ def postprocess_samples(states, key): 20_000, burnin, postprocess_samples, - transform=lambda x: (x.positions, x.weights), + transform=lambda state, info: ((state.positions, state.weights), info), ) @chex.all_variants(with_pmap=False) @@ -997,14 +1003,15 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=2_000, ) ) - _, states, _ = inference_loop_multiple_chains( + _, states = inference_loop_multiple_chains( rng_key=multi_chain_sample_key, initial_state=initial_states ) - posterior_samples = states.position[:, -1000:] + posterior_samples = states[:, -1000:] posterior_delta = posterior_samples - true_loc posterior_variance = posterior_delta**2.0 posterior_correlation = jnp.prod(posterior_delta, axis=-1, keepdims=True) / ( diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c2295e7e2..2d108a48d 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters): ) inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, diff --git a/tests/test_util.py b/tests/test_util.py index 1f03498dd..78198f013 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized import blackjax -from blackjax.util import run_inference_algorithm +from blackjax.util import run_inference_algorithm, store_only_expectation_values class RunInferenceAlgorithmTest(chex.TestCase): @@ -30,7 +30,7 @@ def check_compatible(self, initial_state, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) def test_streaming(self): @@ -41,37 +41,49 @@ def logdensity_fn(x): 10, ) - init_key, run_key = jax.random.split(self.key, 2) - + init_key, state_key, run_key = jax.random.split(self.key, 3) initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key + ) + L = 1.0 + step_size = 0.1 + num_steps = 4 + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=L, + step_size=step_size, ) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + state_transform = lambda x: x.position - _, states, info = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=True, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda state, info: state_transform(state), + progress_bar=True, + ) + + print("average of steps (slow way):", samples.mean(axis=0)) + + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, state_transform=state_transform ) - average, _ = run_inference_algorithm( + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=False, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, ) - assert jnp.allclose(states.mean(axis=0), average) + assert jnp.allclose(trace_at_every_step[0][-1], samples.mean(axis=0)) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar): @@ -81,7 +93,7 @@ def test_compatible_with_initial_pos(self, progress_bar): inference_algorithm=self.algorithm, num_steps=self.num_steps, progress_bar=progress_bar, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) @parameterized.parameters([True, False])