Skip to content

Commit

Permalink
Merge pull request #168 from kazewong/165-add-probability-floor-to-no…
Browse files Browse the repository at this point in the history
…rmalizing-flow-model

165 add probability floor to normalizing flow model -> updating Adam optimizer to optimize with constraint
  • Loading branch information
kazewong authored May 6, 2024
2 parents aae17e1 + 63c3b0b commit b4301c5
Showing 1 changed file with 82 additions and 5 deletions.
87 changes: 82 additions & 5 deletions src/flowMC/strategy/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,31 @@


class optimization_Adam(Strategy):

"""
Optimize a set of chains using Adam optimization.
Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
Args:
n_steps: int = 100
Number of optimization steps.
learning_rate: float = 1e-2
Learning rate for the optimization.
noise_level: float = 10
Variance of the noise added to the gradients.
"""

n_steps: int = 100
learning_rate: float = 1e-2
noise_level: float = 10
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])

@property
def __name__(self):
return "AdamOptimization"

def __init__(
self,
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
**kwargs,
):
class_keys = list(self.__class__.__annotations__.keys())
Expand All @@ -39,6 +47,8 @@ def __init__(
optax.adam(learning_rate=self.learning_rate),
)

self.bounds = bounds

def __call__(
self,
rng_key: PRNGKeyArray,
Expand All @@ -61,6 +71,7 @@ def _kernel(carry, data):
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
updates, opt_state = self.solver.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
params = optax.projections.projection_box(params, self.bounds[:, 0], self.bounds[:, 1])
return (key, params, opt_state), None

def _single_optimize(
Expand Down Expand Up @@ -91,11 +102,77 @@ def _single_optimize(
summary["final_positions"] = optimized_positions
summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)

if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
if (
jnp.isinf(summary["final_log_prob"]).any()
or jnp.isnan(summary["final_log_prob"]).any()
):
print("Warning: Optimization accessed infinite or NaN log-probabilities.")

return rng_key, optimized_positions, local_sampler, global_sampler, summary

def optimize(
self,
rng_key: PRNGKeyArray,
objective: Callable,
initial_position: Float[Array, " n_chain n_dim"],
):
"""
Standalone optimization function that takes an objective function and returns the optimized positions.
Args:
rng_key: PRNGKeyArray
Random key for the optimization.
objective: Callable
Objective function to optimize.
initial_position: Float[Array, " n_chain n_dim"]
Initial positions for the optimization.
"""
grad_fn = jax.jit(jax.grad(objective))

def _kernel(carry, data):
key, params, opt_state = carry

key, subkey = jax.random.split(key)
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
updates, opt_state = self.solver.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
return (key, params, opt_state), None

def _single_optimize(
key: PRNGKeyArray,
initial_position: Float[Array, " n_dim"],
) -> Float[Array, " n_dim"]:

opt_state = self.solver.init(initial_position)

(key, params, opt_state), _ = jax.lax.scan(
_kernel,
(key, initial_position, opt_state),
jnp.arange(self.n_steps),
)

return params # type: ignore

print("Using Adam optimization")
rng_key, subkey = jax.random.split(rng_key)
keys = jax.random.split(subkey, initial_position.shape[0])
optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
keys, initial_position
)

summary = {}
summary["initial_positions"] = initial_position
summary["initial_log_prob"] = jax.jit(jax.vmap(objective))(initial_position)
summary["final_positions"] = optimized_positions
summary["final_log_prob"] = jax.jit(jax.vmap(objective))(optimized_positions)

if (
jnp.isinf(summary["final_log_prob"]).any()
or jnp.isnan(summary["final_log_prob"]).any()
):
print("Warning: Optimization accessed infinite or NaN log-probabilities.")

return rng_key, optimized_positions, summary

class Evosax_CMA_ES(Strategy):

Expand Down

0 comments on commit b4301c5

Please sign in to comment.