Skip to content

Commit

Permalink
Replace iterative RNG split and carry with jax.random.fold_in (#656)
Browse files Browse the repository at this point in the history
* Replace iterative RNG split and carry with `jax.random.fold_in`

* revert unintended change

* file formatting

* change `jax.tree_map` to `jax.tree.map`

* revert unintended file

* fiddle with rng_key

* seed again
  • Loading branch information
junpenglao committed Apr 8, 2024
1 parent 7cf4f9d commit a5f7482
Show file tree
Hide file tree
Showing 25 changed files with 129 additions and 143 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ state = nuts.init(initial_position)

# Iterate
rng_key = jax.random.key(0)
for _ in range(100):
rng_key, nuts_key = jax.random.split(rng_key)
for step in range(100):
nuts_key = jax.random.fold_in(rng_key, step)
state, _ = nuts.step(nuts_key, state)
```

Expand Down
16 changes: 7 additions & 9 deletions blackjax/adaptation/chees_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,20 +361,18 @@ def run(
), "initial `positions` leading dimension must be equal to the `num_chains`"
num_dim = pytree_size(positions) // num_chains

key_init, key_step = jax.random.split(rng_key)
next_random_arg_fn = lambda i: i + 1
init_random_arg = 0

if jitter_generator is not None:
jitter_gn = lambda key: jitter_generator(key) * jitter_amount + (
1.0 - jitter_amount
)
next_random_arg_fn = lambda key: jax.random.split(key)[1]
init_random_arg = key_init
rng_key, carry_key = jax.random.split(rng_key)
jitter_gn = lambda i: jitter_generator(
jax.random.fold_in(carry_key, i)
) * jitter_amount + (1.0 - jitter_amount)
else:
jitter_gn = lambda i: dynamic_hmc.halton_sequence(
i, np.ceil(np.log2(num_steps + max_sampling_steps))
) * jitter_amount + (1.0 - jitter_amount)
next_random_arg_fn = lambda i: i + 1
init_random_arg = 0

def integration_steps_fn(random_generator_arg, trajectory_length_adjusted):
return jnp.asarray(
Expand Down Expand Up @@ -425,7 +423,7 @@ def one_step(carry, rng_key):
init_states = batch_init(positions)
init_adaptation_state = init(init_random_arg, step_size)

keys_step = jax.random.split(key_step, num_steps)
keys_step = jax.random.split(rng_key, num_steps)
(last_states, last_adaptation_state), info = jax.lax.scan(
one_step, (init_states, init_adaptation_state), keys_step
)
Expand Down
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
kalman_state = (0.0, jnp.zeros(dim), jnp.zeros(dim))

# run the steps
kalman_state = jax.lax.scan(
kalman_state, *_ = jax.lax.scan(
step,
init=(state, params, adap0, kalman_state),
xs=(outer_weights, L_step_size_adaptation_keys),
length=num_steps1 + num_steps2,
)[0]
)
state, params, _, kalman_state_output = kalman_state

L = params.L
Expand Down
8 changes: 4 additions & 4 deletions blackjax/adaptation/meads_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,16 @@ def compute_parameters(
of the generalized HMC algorithm.
"""
mean_position = jax.tree_map(lambda p: p.mean(axis=0), positions)
sd_position = jax.tree_map(lambda p: p.std(axis=0), positions)
normalized_positions = jax.tree_map(
mean_position = jax.tree.map(lambda p: p.mean(axis=0), positions)
sd_position = jax.tree.map(lambda p: p.std(axis=0), positions)
normalized_positions = jax.tree.map(
lambda p, mu, sd: (p - mu) / sd,
positions,
mean_position,
sd_position,
)

batch_grad_scaled = jax.tree_map(
batch_grad_scaled = jax.tree.map(
lambda grad, sd: grad * sd, logdensity_grad, sd_position
)

Expand Down
14 changes: 7 additions & 7 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def final(da_state: DualAveragingAdaptationState) -> float:
class ReasonableStepSizeState(NamedTuple):
"""State carried through the search for a reasonable first step size.
rng_key
Key used by JAX's random number generator.
step
The current iteration step.
direction: {-1, 1}
Determines whether the step size should be increased or decreased during the
previous step search. If direction = 1 it will be increased, otherwise decreased.
Expand All @@ -171,7 +171,7 @@ class ReasonableStepSizeState(NamedTuple):
"""

rng_key: PRNGKey
step: int
direction: int
previous_direction: int
step_size: float
Expand Down Expand Up @@ -243,17 +243,17 @@ def do_continue(rss_state: ReasonableStepSizeState) -> bool:

def update(rss_state: ReasonableStepSizeState) -> ReasonableStepSizeState:
"""Perform one step of the step size search."""
rng_key, direction, _, step_size = rss_state
rng_key, subkey = jax.random.split(rng_key)
i, direction, _, step_size = rss_state
subkey = jax.random.fold_in(rng_key, i)

step_size = (2.0**direction) * step_size
kernel = kernel_generator(step_size)
_, info = kernel(subkey, reference_state)

new_direction = jnp.where(target_accept < info.acceptance_rate, 1, -1)
return ReasonableStepSizeState(rng_key, new_direction, direction, step_size)
return ReasonableStepSizeState(i + 1, new_direction, direction, step_size)

rss_state = ReasonableStepSizeState(rng_key, 0, 0, initial_step_size)
rss_state = ReasonableStepSizeState(0, 0, 0, initial_step_size)
rss_state = jax.lax.while_loop(do_continue, update, rss_state)

return rss_state.step_size
14 changes: 7 additions & 7 deletions blackjax/mcmc/elliptical_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def generate(
rng_key: PRNGKey, state: EllipSliceState
) -> tuple[EllipSliceState, EllipSliceInfo]:
position, logdensity = state
key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 3)
key_slice, key_momentum, key_uniform, key_theta = jax.random.split(rng_key, 4)
# step 1: sample momentum
momentum = momentum_generator(key_momentum, position)
# step 2: get slice (y)
Expand All @@ -235,20 +235,20 @@ def slice_fn(vals):
likelihood is continuous with respect to the parameter being sampled.
"""
rng, _, subiter, theta, theta_min, theta_max, *_ = vals
rng, thetak = jax.random.split(rng)
_, subiter, theta, theta_min, theta_max, *_ = vals
thetak = jax.random.fold_in(key_slice, subiter)
theta = jax.random.uniform(thetak, minval=theta_min, maxval=theta_max)
p, m = ellipsis(position, momentum, theta, mean)
logdensity = logdensity_fn(p)
theta_min = jnp.where(theta < 0, theta, theta_min)
theta_max = jnp.where(theta > 0, theta, theta_max)
subiter += 1
return rng, logdensity, subiter, theta, theta_min, theta_max, p, m
return logdensity, subiter, theta, theta_min, theta_max, p, m

_, logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop(
lambda vals: vals[1] <= logy,
logdensity, subiter, theta, *_, position, momentum = jax.lax.while_loop(
lambda vals: vals[0] <= logy,
slice_fn,
(rng_key, logdensity, 1, theta, theta_min, theta_max, p, m),
(logdensity, 1, theta, theta_min, theta_max, p, m),
)
return (
EllipSliceState(position, logdensity),
Expand Down
2 changes: 1 addition & 1 deletion blackjax/mcmc/ghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def update_momentum(rng_key, state, alpha, momentum_generator):
"""
position, momentum, *_ = state

momentum = jax.tree_map(
momentum = jax.tree.map(
lambda prev_momentum, shifted_momentum: prev_momentum * jnp.sqrt(1.0 - alpha)
+ jnp.sqrt(alpha) * shifted_momentum,
momentum,
Expand Down
32 changes: 15 additions & 17 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,17 @@ def integrate(

def do_keep_integrating(loop_state):
"""Decide whether we should continue integrating the trajectory"""
_, integration_state, (is_diverging, has_terminated) = loop_state
integration_state, (is_diverging, has_terminated) = loop_state
return (
(integration_state.step < max_num_steps)
& ~has_terminated
& ~is_diverging
)

def add_one_state(loop_state):
rng_key, integration_state, _ = loop_state
integration_state, _ = loop_state
step, proposal, trajectory, termination_state = integration_state
rng_key, proposal_key = jax.random.split(rng_key)
proposal_key = jax.random.fold_in(rng_key, step)

new_state = integrator(trajectory.rightmost_state, direction * step_size)
new_proposal = generate_proposal(initial_energy, new_state)
Expand Down Expand Up @@ -246,7 +246,7 @@ def add_one_state(loop_state):
new_termination_state,
)

return (rng_key, new_integration_state, (is_diverging, has_terminated))
return (new_integration_state, (is_diverging, has_terminated))

proposal_placeholder = generate_proposal(initial_energy, initial_state)
trajectory_placeholder = Trajectory(
Expand All @@ -259,12 +259,12 @@ def add_one_state(loop_state):
termination_state,
)

_, integration_state, (is_diverging, has_terminated) = jax.lax.while_loop(
new_integration_state, (is_diverging, has_terminated) = jax.lax.while_loop(
do_keep_integrating,
add_one_state,
(rng_key, integration_state_placeholder, (False, False)),
(integration_state_placeholder, (False, False)),
)
step, proposal, trajectory, termination_state = integration_state
_, proposal, trajectory, termination_state = new_integration_state

# In the while_loop we always extend on the right most direction.
new_trajectory = jax.lax.cond(
Expand Down Expand Up @@ -496,8 +496,7 @@ def dynamic_multiplicative_expansion(
step_size
The step size used by the symplectic integrator.
max_num_expansions
The maximum number of trajectory expansions until the proposal is
returned.
The maximum number of trajectory expansions until the proposal is returned.
rate
The rate of the geometrical expansion. Typically 2 in NUTS, this is why
the literature often refers to "tree doubling".
Expand All @@ -513,7 +512,7 @@ def expand(
):
def do_keep_expanding(loop_state) -> bool:
"""Determine whether we need to keep expanding the trajectory."""
_, expansion_state, (is_diverging, is_turning) = loop_state
expansion_state, (is_diverging, is_turning) = loop_state
return (
(expansion_state.step < max_num_expansions)
& ~is_diverging
Expand All @@ -531,12 +530,11 @@ def expand_once(loop_state):
the subtrajectory.
"""
rng_key, expansion_state, _ = loop_state
expansion_state, _ = loop_state
step, proposal, trajectory, termination_state = expansion_state

rng_key, direction_key, trajectory_key, proposal_key = jax.random.split(
rng_key, 4
)
subkey = jax.random.fold_in(rng_key, step)
direction_key, trajectory_key, proposal_key = jax.random.split(subkey, 3)

# create new subtrajectory that is twice as long as the current
# trajectory.
Expand Down Expand Up @@ -608,12 +606,12 @@ def update_sum_log_p_accept(inputs):
)
info = (is_diverging, is_turning_subtree | is_turning)

return (rng_key, new_state, info)
return (new_state, info)

_, expansion_state, (is_diverging, is_turning) = jax.lax.while_loop(
expansion_state, (is_diverging, is_turning) = jax.lax.while_loop(
do_keep_expanding,
expand_once,
(rng_key, initial_expansion_state, (False, False)),
(initial_expansion_state, (False, False)),
)

return expansion_state, (is_diverging, is_turning)
Expand Down
4 changes: 2 additions & 2 deletions blackjax/optimizers/lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def minimize_lbfgs(
f=history_raveled.f,
g=unravel_fn_mapped(history_raveled.g),
alpha=unravel_fn_mapped(history_raveled.alpha),
update_mask=jax.tree_map(
update_mask=jax.tree.map(
lambda x: x.astype(history_raveled.update_mask.dtype),
unravel_fn_mapped(history_raveled.update_mask.astype(x0_raveled.dtype)),
),
Expand Down Expand Up @@ -230,7 +230,7 @@ def scan_body(tup, it):
scan_body, ((init_step, initial_history), True), jnp.arange(maxiter)
)
# Append initial state to history.
history = jax.tree_map(
history = jax.tree.map(
lambda x, y: jnp.concatenate([x[None, ...], y], axis=0),
initial_history,
history,
Expand Down
2 changes: 1 addition & 1 deletion blackjax/sgmcmc/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def cv_grad_estimator_fn(
grad_estimate = logdensity_grad_estimator(position, minibatch)
center_grad_estimate = logdensity_grad_estimator(centering_position, minibatch)

return jax.tree_map(
return jax.tree.map(
lambda grad_est, cv_grad_est, cv_grad: cv_grad + grad_est - cv_grad_est,
grad_estimate,
center_grad_estimate,
Expand Down
4 changes: 2 additions & 2 deletions blackjax/smc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def step(
num_resampled = num_particles

resampling_idx = resample_fn(resampling_key, state.weights, num_resampled)
particles = jax.tree_map(lambda x: x[resampling_idx], state.particles)
particles = jax.tree.map(lambda x: x[resampling_idx], state.particles)

keys = jax.random.split(updating_key, num_resampled)
particles, update_info = update_fn(keys, particles, state.update_parameters)
Expand All @@ -158,4 +158,4 @@ def extend_params(n_particles, params):
def extend(param):
return jnp.repeat(jnp.asarray(param)[None, ...], n_particles, axis=0)

return jax.tree_map(extend, params)
return jax.tree.map(extend, params)
14 changes: 7 additions & 7 deletions blackjax/vi/meanfield_vi.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def init(
**optimizer_kwargs,
) -> MFVIState:
"""Initialize the mean-field VI state."""
mu = jax.tree_map(jnp.zeros_like, position)
rho = jax.tree_map(lambda x: -2.0 * jnp.ones_like(x), position)
mu = jax.tree.map(jnp.zeros_like, position)
rho = jax.tree.map(lambda x: -2.0 * jnp.ones_like(x), position)
opt_state = optimizer.init((mu, rho))
return MFVIState(mu, rho, opt_state)

Expand Down Expand Up @@ -99,7 +99,7 @@ def kl_divergence_fn(parameters):

elbo, elbo_grad = jax.value_and_grad(kl_divergence_fn)(parameters)
updates, new_opt_state = optimizer.update(elbo_grad, state.opt_state, parameters)
new_parameters = jax.tree_map(lambda p, u: p + u, parameters, updates)
new_parameters = jax.tree.map(lambda p, u: p + u, parameters, updates)
new_state = MFVIState(new_parameters[0], new_parameters[1], new_opt_state)
return new_state, MFVIInfo(elbo)

Expand Down Expand Up @@ -151,7 +151,7 @@ def sample_fn(rng_key: PRNGKey, state: MFVIState, num_samples: int):


def _sample(rng_key, mu, rho, num_samples):
sigma = jax.tree_map(jnp.exp, rho)
sigma = jax.tree.map(jnp.exp, rho)
mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu)
sigma_flat, _ = jax.flatten_util.ravel_pytree(sigma)
flatten_sample = (
Expand All @@ -162,11 +162,11 @@ def _sample(rng_key, mu, rho, num_samples):


def generate_meanfield_logdensity(mu, rho):
sigma_param = jax.tree_map(jnp.exp, rho)
sigma_param = jax.tree.map(jnp.exp, rho)

def meanfield_logdensity(position):
logq_pytree = jax.tree_map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree_map(jnp.sum, logq_pytree)
logq_pytree = jax.tree.map(jsp.stats.norm.logpdf, position, mu, sigma_param)
logq = jax.tree.map(jnp.sum, logq_pytree)
return jax.tree_util.tree_reduce(jnp.add, logq)

return meanfield_logdensity
2 changes: 1 addition & 1 deletion blackjax/vi/pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def path_finder_body_fn(rng_key, S, Z, alpha_l, theta, theta_grad):
)

max_elbo_idx = jnp.argmax(elbo)
return jax.tree_map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo(
return jax.tree.map(lambda x: x[max_elbo_idx], pathfinder_result), PathfinderInfo(
pathfinder_result
)

Expand Down
Loading

0 comments on commit a5f7482

Please sign in to comment.