Skip to content

Commit

Permalink
update evosax
Browse files Browse the repository at this point in the history
  • Loading branch information
Thibeau Wouters committed Apr 4, 2024
1 parent cac6386 commit 84cdf38
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
11 changes: 3 additions & 8 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,8 @@ def sample(self, initial_position: Array, data: dict):

self.local_sampler_tuning(initial_position, data)
last_step = initial_position
# Run a single iteration, to get the runtime of the compilation
start_time = time.time()
_, _ = self.sampling_loop(last_step, data)
end_time = time.time()
runtime = end_time - start_time
with open(self.outdir + "runtime_compilation.txt", "w") as f:
f.write(str(runtime))

# Training loop
start_time = time.time()
if self.use_global == True:
last_step = self.global_sampler_tuning(last_step, data)
Expand All @@ -170,6 +164,7 @@ def sample(self, initial_position: Array, data: dict):
with open(self.outdir + "runtime_training.txt", "w") as f:
f.write(str(runtime))

# Production loop
start_time = time.time()
last_step = self.production_run(last_step, data)
end_time = time.time()
Expand Down Expand Up @@ -355,7 +350,7 @@ def global_sampler_tuning(
"""
print("Training normalizing flow")
last_step = initial_position
for _ in tqdm(
for i in tqdm(
range(self.n_loop_training),
desc="Tuning global sampler",
):
Expand Down
14 changes: 13 additions & 1 deletion src/flowMC/utils/EvolutionaryOptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, ndims, popsize=100, verbose=False):
self.history = []
self.state = None

def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_step = 0):
def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_step = 0, early_stopping_patience: int = 50):
"""
Optimize the objective function.
Expand All @@ -68,6 +68,8 @@ def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_st
progress_bar = tqdm.tqdm(range(n_loops), "Generation: ") if self.verbose else range(n_loops)
self.bound = bound
self.state = self.strategy.initialize(key, self.es_params)
best_fitness = 1e10
patience = 0
if keep_history_step > 0:
self.history = []
for i in progress_bar:
Expand All @@ -78,6 +80,16 @@ def optimize(self, objective, bound, n_loops = 100, seed = 9527, keep_history_st
else:
for i in progress_bar:
key, self.state, _ = self.optimize_step(subkey, self.state, objective, bound)
# Early stopping
if early_stopping_patience > 0:
if self.state.best_fitness < best_fitness:
best_fitness = self.state.best_fitness
patience = 0
else:
patience += 1
if patience == early_stopping_patience:
print("Exiting early due to early stopping!")
break
if self.verbose: progress_bar.set_description(f"Generation: {i}, Fitness: {self.state.best_fitness:.4f}")

def optimize_step(self, key: jax.random.PRNGKey, state, objective: callable, bound):
Expand Down

0 comments on commit 84cdf38

Please sign in to comment.