Skip to content

Commit

Permalink
Remove flooring
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed May 4, 2024
1 parent 190820c commit 63c3b0b
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/flowMC/proposal/NF_proposal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
@jax.tree_util.register_pytree_node_class
class NFProposal(ProposalBase):
model: NFModel
prob_floor: float

def __init__(
self, logpdf: Callable, jit: bool, model: NFModel, n_flow_sample: int = 10000, prob_floor: float = -13.0
self, logpdf: Callable, jit: bool, model: NFModel, n_flow_sample: int = 10000
):
super().__init__(logpdf, jit)
self.model = model
self.prob_floor = prob_floor
self.n_flow_sample = n_flow_sample
self.update_vmap = jax.vmap(self.update, in_axes=(None, (0)))
if self.jit is True:
Expand Down Expand Up @@ -127,12 +125,11 @@ def sample(
n_chains = initial_position.shape[0]
n_dim = initial_position.shape[-1]
log_prob_initial = self.logpdf_vmap(initial_position, data)[:, None]
log_prob_nf_initial = jnp.maximum(self.model.log_prob(initial_position)[:, None], self.prob_floor)
log_prob_nf_initial = self.model.log_prob(initial_position)[:, None]

proposal_position, log_prob_proposal, log_prob_nf_proposal = self.sample_flow(
subkeys[0], initial_position, data, n_steps
)
log_prob_nf_proposal = jnp.maximum(log_prob_nf_proposal, self.prob_floor)

state = (
jax.random.split(subkeys[1], n_chains),
Expand Down

0 comments on commit 63c3b0b

Please sign in to comment.