From 190820ce95b75a742018b22069c9e42a268d1e76 Mon Sep 17 00:00:00 2001 From: Kaze Wong Date: Sat, 4 May 2024 12:35:04 -0400 Subject: [PATCH] Add box constraint to optimization --- src/flowMC/strategy/optimization.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/flowMC/strategy/optimization.py b/src/flowMC/strategy/optimization.py index 11418f9..b2ac7ac 100644 --- a/src/flowMC/strategy/optimization.py +++ b/src/flowMC/strategy/optimization.py @@ -26,6 +26,7 @@ class optimization_Adam(Strategy): n_steps: int = 100 learning_rate: float = 1e-2 noise_level: float = 10 + bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]) @property def __name__(self): @@ -33,6 +34,7 @@ def __name__(self): def __init__( self, + bounds: Float[Array, "n_dim 2"] = jnp.array([[-jnp.inf, jnp.inf]]), **kwargs, ): class_keys = list(self.__class__.__annotations__.keys()) @@ -45,6 +47,8 @@ def __init__( optax.adam(learning_rate=self.learning_rate), ) + self.bounds = bounds + def __call__( self, rng_key: PRNGKeyArray, @@ -67,6 +71,7 @@ def _kernel(carry, data): grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level) updates, opt_state = self.solver.update(grad, opt_state, params) params = optax.apply_updates(params, updates) + params = optax.projections.projection_box(params, self.bounds[:, 0], self.bounds[:, 1]) return (key, params, opt_state), None def _single_optimize(