Skip to content

Commit

Permalink
Add box constraint to optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
kazewong committed May 4, 2024
1 parent 6665bc2 commit 190820c
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/flowMC/strategy/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ class optimization_Adam(Strategy):
n_steps: int = 100
learning_rate: float = 1e-2
noise_level: float = 10
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]])

@property
def __name__(self):
return "AdamOptimization"

def __init__(
self,
bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]),
**kwargs,
):
class_keys = list(self.__class__.__annotations__.keys())
Expand All @@ -45,6 +47,8 @@ def __init__(
optax.adam(learning_rate=self.learning_rate),
)

self.bounds = bounds

def __call__(
self,
rng_key: PRNGKeyArray,
Expand All @@ -67,6 +71,7 @@ def _kernel(carry, data):
grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
updates, opt_state = self.solver.update(grad, opt_state, params)
params = optax.apply_updates(params, updates)
params = optax.projections.projection_box(params, self.bounds[:, 0], self.bounds[:, 1])
return (key, params, opt_state), None

def _single_optimize(
Expand Down

0 comments on commit 190820c

Please sign in to comment.