Skip to content

Commit

Permalink
minor clean up (#635)
Browse files Browse the repository at this point in the history
* minor clean up

* formatting fix
  • Loading branch information
junpenglao authored Jan 19, 2024
1 parent 2c1d779 commit 5402335
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
3 changes: 2 additions & 1 deletion blackjax/mcmc/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp

from blackjax.mcmc.metrics import CheckTurning
from blackjax.types import Array


Expand All @@ -27,7 +28,7 @@ class IterativeUTurnState(NamedTuple):
idx_max: int


def iterative_uturn_numpyro(is_turning):
def iterative_uturn_numpyro(is_turning: CheckTurning):
"""Numpyro style dynamic U-Turn criterion."""

def new_state(chain_state, max_num_doublings) -> IterativeUTurnState:
Expand Down
27 changes: 16 additions & 11 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Procedures to build trajectories for algorithms in the HMC family.
To propose a new state, algorithms in the HMC family generally proceed by :cite:p:`betancourt2017conceptual`:
To propose a new state, algorithms in the HMC family generally proceed by
:cite:p:`betancourt2017conceptual`:
1. Sampling a trajectory starting from the initial point;
2. Sampling a new state from this sampled trajectory.
Expand Down Expand Up @@ -299,10 +300,11 @@ def dynamic_recursive_integration(
"""Integrate a trajectory and update the proposal recursively in Python
until the termination criterion is met.
This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with multinomial sampling.
The implemenation here is mostly for validating the progressive implementation
to make sure the two are equivalent. The recursive implementation should not
be used for actually sampling as it cannot be jitted and thus likely slow.
This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with
multinomial sampling. The implemenation here is mostly for validating the
progressive implementation to make sure the two are equivalent. The recursive
implementation should not be used for actually sampling as it cannot be jitted and
thus likely slow.
Parameters
----------
Expand All @@ -313,9 +315,11 @@ def dynamic_recursive_integration(
uturn_check_fn
Determines whether the termination criterion has been met.
divergence_threshold
Value of the difference of energy between two consecutive states above which we say a transition is divergent.
Value of the difference of energy between two consecutive states above which we
say a transition is divergent.
use_robust_uturn_check
Bool to indicate whether to perform additional U turn check between two trajectory.
Bool to indicate whether to perform additional U turn check between two
trajectory.
"""
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
Expand Down Expand Up @@ -348,7 +352,8 @@ def buildtree_integrate(
step_size
The step size of the symplectic integrator.
initial_energy
Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree)
Initial energy H0 of the HMC step (not to confused with the initial energy
of the subtree)
"""
if tree_depth == 0:
Expand Down Expand Up @@ -561,9 +566,9 @@ def expand_once(loop_state):
# Update the proposal
#
# We do not accept proposals that come from diverging or turning
# subtrajectories. However the definition of the acceptance
# probability is such that the acceptance probability needs to be
# computed across the entire trajectory.
# subtrajectories. However the definition of the acceptance probability is
# such that the acceptance probability needs to be computed across the
# entire trajectory.
def update_sum_log_p_accept(inputs):
_, proposal, new_proposal = inputs
return Proposal(
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ myst_nb>=1.0.0
numba
numpyro
optax
oryx @ git+https://github.com/jax-ml/oryx.git@main # remove after oryx release
oryx
pymc
scikit-learn
sphinx
Expand Down

0 comments on commit 5402335

Please sign in to comment.