Skip to content

Commit

Permalink
refactor: use ctx manager; move code to .
Browse files Browse the repository at this point in the history
  • Loading branch information
aaraney committed May 2, 2024
1 parent 4ea3419 commit d2bab85
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions python/ngen_cal/src/ngen/cal/_optimizers/grey_wolf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,12 @@ def __init__(
self.name = __name__

def optimize(
self, objective_func, iters, n_processes=None, verbose=True, **kwargs
self,
objective_func,
iters: int,
n_processes: Optional[int] = None,
verbose: bool = True,
**kwargs,
):
"""Optimize the swarm for a number of iterations
Expand All @@ -131,7 +136,7 @@ def optimize(
objective function to be evaluated
iters : int
number of iterations
n_processes : int
n_processes : int, optional
number of processes to use for parallel particle evaluation (default: None = no parallelization)
verbose : bool
enable or disable the logs and progress bar (default: True = enable logs)
Expand All @@ -143,7 +148,23 @@ def optimize(
tuple
the global best cost and the global best position.
"""
if n_processes is None:
return self._optimize(objective_func, iters, verbose, pool=None)
else:
with mp.Pool(n_processes) as pool:
return self._optimize(objective_func, iters, verbose, pool=pool)

def _optimize(
self,
objective_func,
iters: int,
verbose: bool = True,
pool: Optional[mp.Pool] = None,
**kwargs,
):
"""
`pool` lifecycle is managed by `optimize` method. DO NOT CLOSE IT HERE.
"""
# Apply verbosity
if self.start_iter>0:
verbose = False
Expand All @@ -158,9 +179,7 @@ def optimize(
lvl=log_level,
)

# Setup Pool of processes for parallel evaluation
pool = None if n_processes is None else mp.Pool(n_processes)

# TODO: @hellkite500, ftol_history is unused. should it be? or can we remove it?
ftol_history = deque(maxlen=self.ftol_iter)

# Compute cost of initial swarm
Expand Down Expand Up @@ -204,6 +223,8 @@ def optimize(
X2 = beta - A2 * Dbeta
X3 = delta - A3 * Ddelta
self.swarm.position = (X1 + X2 + X3) / 3
# TODO: @hellkite500, is this right?
assert self.bounds is not None
self.swarm.position = np.clip(self.swarm.position, self.bounds[0], self.bounds[1])
# Compute current cost and update local best
self.swarm.current_cost = compute_objective_function(self.swarm, objective_func, pool=pool, **kwargs)
Expand All @@ -217,13 +238,12 @@ def optimize(
self.update_history(i+2)

# Obtain the final best_cost and the final best_position
# TODO: @hellkite500, `best_cost` should be `float` here... so no copy
# method.
final_best_cost = self.swarm.best_cost.copy()
final_best_pos = self.swarm.best_pos.copy()
# Write report in log and return final cost and position
self.rep.log("Optimization finished | best cost: {}, best pos: {}".format(final_best_cost, final_best_pos), lvl=log_level)
# Close Pool of Processes
if n_processes is not None:
pool.close()
return (final_best_cost, final_best_pos)

def _hist_to_csv(self, i: int, name: str, index: List, key: str, label: Optional[str]=None) -> None:
Expand Down

0 comments on commit d2bab85

Please sign in to comment.