Skip to content

Commit

Permalink
fix: init_ask detection and state.first_step
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Aug 6, 2024
1 parent 02f00e7 commit 011d8f0
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions src/evox/workflows/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,16 @@
import jax.numpy as jnp
import ray

from evox import Algorithm, Problem, State, Workflow, use_state
from evox.utils import algorithm_has_init_ask, parse_opt_direction
from evox import (
Algorithm,
Problem,
State,
Workflow,
use_state,
has_init_ask,
has_init_tell,
)
from evox.utils import parse_opt_direction


class WorkerWorkflow(Workflow):
Expand Down Expand Up @@ -36,7 +44,7 @@ def __init__(
self.fit_transforms = fit_transforms

def setup(self, key):
return State(generation=0)
return State(generation=0, first_step=True)

def _get_slice(self, pop_size):
slice_per_worker = pop_size // self.num_workers
Expand All @@ -45,20 +53,23 @@ def _get_slice(self, pop_size):
end = start + slice_per_worker + (self.worker_index < remainder)
return start, end

def _ask(self, state):
if has_init_ask(self.algorithm) and state.first_step:
ask = self.algorithm.init_ask
else:
ask = self.algorithm.ask

# candidate: individuals that need to be evaluated (may differ from population)
# Note: num_cands can be different from init_ask() and ask()
cands, state = use_state(ask)(state)

return cands, state

def step1(self, state: State):
if "pre_ask" in self.non_empty_hooks:
ray.get(self.monitor_actor.push.remote("pre_ask", state))

if state.generation == 0:
is_init = algorithm_has_init_ask(self.algorithm, state)
else:
is_init = False

if is_init:
cand_sol, state = use_state(self.algorithm.init_ask)(state)
else:
cand_sol, state = use_state(self.algorithm.ask)(state)

cand_sol, state = self._ask(state)
if "post_ask" in self.non_empty_hooks:
ray.get(self.monitor_actor.push.remote("post_ask", None, cand_sol))

Expand All @@ -82,12 +93,17 @@ def step1(self, state: State):

return partial_fitness, state

def step2(self, state: State, fitness: List[jax.Array]):
if state.generation == 0:
is_init = algorithm_has_init_ask(self.algorithm, state)
def _tell(self, state, transformed_fitness):
if has_init_tell(self.algorithm) and state.first_step:
tell = self.algorithm.init_tell
else:
is_init = False
tell = self.algorithm.tell

state = use_state(tell)(state, transformed_fitness)

return state

def step2(self, state: State, fitness: List[jax.Array]):
fitness = jnp.concatenate(fitness, axis=0)
fitness = fitness * self.opt_direction

Expand All @@ -112,15 +128,19 @@ def step2(self, state: State, fitness: List[jax.Array]):
)
)

if is_init:
state = use_state(self.algorithm.init_tell)(state, fitness)
else:
state = use_state(self.algorithm.tell)(state, fitness)
state = self._tell(state, fitness)

if "post_tell" in self.non_empty_hooks:
ray.get(self.monitor_actor.push.remote("post_tell", state))

return state.update(generation=state.generation + 1)


if has_init_ask(self.algorithm) and state.first_step:
# this ensures that _step() will be re-jitted
state = state.replace(generation=state.generation + 1, first_step=False)
else:
state = state.replace(generation=state.generation + 1)

return state

def valid(self, state: State, metric: str):
new_state = use_state(self.problem.valid)(state, metric=metric)
Expand Down

0 comments on commit 011d8f0

Please sign in to comment.