diff --git a/src/flowMC/proposal/NF_proposal.py b/src/flowMC/proposal/NF_proposal.py index f4bf52d..bcb98dd 100644 --- a/src/flowMC/proposal/NF_proposal.py +++ b/src/flowMC/proposal/NF_proposal.py @@ -14,14 +14,12 @@ @jax.tree_util.register_pytree_node_class class NFProposal(ProposalBase): model: NFModel - prob_floor: float def __init__( - self, logpdf: Callable, jit: bool, model: NFModel, n_flow_sample: int = 10000, prob_floor: float = -13.0 + self, logpdf: Callable, jit: bool, model: NFModel, n_flow_sample: int = 10000 ): super().__init__(logpdf, jit) self.model = model - self.prob_floor = prob_floor self.n_flow_sample = n_flow_sample self.update_vmap = jax.vmap(self.update, in_axes=(None, (0))) if self.jit is True: @@ -127,12 +125,11 @@ def sample( n_chains = initial_position.shape[0] n_dim = initial_position.shape[-1] log_prob_initial = self.logpdf_vmap(initial_position, data)[:, None] - log_prob_nf_initial = jnp.maximum(self.model.log_prob(initial_position)[:, None], self.prob_floor) + log_prob_nf_initial = self.model.log_prob(initial_position)[:, None] proposal_position, log_prob_proposal, log_prob_nf_proposal = self.sample_flow( subkeys[0], initial_position, data, n_steps ) - log_prob_nf_proposal = jnp.maximum(log_prob_nf_proposal, self.prob_floor) state = ( jax.random.split(subkeys[1], n_chains),