Skip to content

Commit

Permalink
Merge pull request #126 from EMI-Group/core_improve
Browse files Browse the repository at this point in the history
Core EvoX improvement
  • Loading branch information
BillHuang2001 committed Apr 14, 2024
2 parents 94815a1 + 528871f commit 0731898
Show file tree
Hide file tree
Showing 37 changed files with 663 additions and 585 deletions.
6 changes: 0 additions & 6 deletions docs/source/api/algorithms/so/cpso_s.rst

This file was deleted.

20 changes: 6 additions & 14 deletions docs/source/guide/advanced/2-jit-able.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,9 @@ And since `bar` uses the dynamic index, which is not compatible with `jax.jit`,

## Solution

To solve is problem, it is common practice to jit-compile low-level components, thus giving high-level components more freedom.
In EvoX, we have some general rules on whether a function should be jit-able or not.

| Component | jit-able |
| ----------- | -------- |
| `Workflow` | Optional |
| `Algorithm` | Yes |
| `Problem` | Optional |
| `Operators` | Yes |
| `Monitor` | No |

For standard workflow, one can jit compile when not using monitors and working with jit-able problems.
But even though the workflow can be compiled, there isn't much performance gain.
For problems, it depends on the task.
1. jit-compile low-level components, and give high-level components more freedom.
2. Use [`host callback`](https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html) to call a function on CPU in a jit context.

In EvoX, we almost guarantee that all low-level components are jit-compiled (all operators), and high-level components (`Workflow`) can have both jit-compiled variants (e.g. {doc}`StdWorkflow <api/workflows/standard>`) and non-jit-compiled variants (e.g. {doc}`StdWorkflow <api/workflows/non_jit>`).

Please be aware that using callbacks to jump out of the jit context is not free. Data needs to be transferred between CPU and GPU, which can be an overhead.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ classifiers = [
dependencies = [
"jax >= 0.4.16",
"jaxlib >= 0.3.0",
"jax_dataclasses >= 1.6.0",
"optax >= 0.1.0",
"pyarrow >= 10.0.0",
]
Expand Down
2 changes: 1 addition & 1 deletion src/evox/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .core.workflow import Workflow
from .core.algorithm import Algorithm
from .core.module import *
from .core.operator import Operator
from .core.problem import Problem
from .core.state import State
from .core.monitor import Monitor
Expand Down
2 changes: 1 addition & 1 deletion src/evox/algorithms/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .clustered_algorithm import ClusterdAlgorithm, RandomMaskAlgorithm
from .tree_algorithm import TreeAlgorithm
from .coevolution import VectorizedCoevolution, Coevolution
from .coevolution import VectorizedCoevolution, coevolution
206 changes: 110 additions & 96 deletions src/evox/algorithms/containers/coevolution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import field
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from evox import Algorithm, State, jit_class
from jax import vmap
from jax.tree_util import tree_map

from evox import Algorithm, State, Static, Stack, dataclass, jit_class, use_state


@jit_class
class VectorizedCoevolution(Algorithm):
Expand Down Expand Up @@ -97,6 +99,112 @@ def tell(self, state: State, fitness: jax.Array) -> State:
)


def coevolution(
base_algorithms,
dim,
num_subpops,
subpop_size,
random_subpop=True,
dtype=jnp.float32,
):
subproblem_dim = dim // num_subpops
algorithm_class = base_algorithms.__class__

@jit_class
@dataclass
class Coevolution(Algorithm):
base_algorithms: Stack[Algorithm]
dim: Static[int]
num_subpops: Static[int]
subpop_size: Static[int]
random_subpop: Static[bool]

def setup(self, key: jax.Array) -> State:
if self.random_subpop:
key, subkey = jax.random.split(key)
permutation = jax.random.permutation(subkey, self.dim)
else:
permutation = None

best_dec = jnp.empty((self.dim,), dtype=dtype)
best_fit = jnp.full((self.num_subpops,), jnp.inf) # fitness

return State(
coop_pops=jnp.empty((self.subpop_size, self.dim)),
best_dec=best_dec,
best_fit=best_fit,
iter_counter=0,
permutation=permutation,
)

def init_ask(self, state: State):
init_subpops, state = use_state(vmap(algorithm_class.init_ask))(
base_algorithms, state
)
# init_subpops (num_subpops, subpop_size, sub_dim)
init_pop = init_subpops.transpose((1, 0, 2)).reshape(-1, self.dim)
return init_pop, state.update(coop_pops=init_pop)

def ask(self, state: State) -> Tuple[jax.Array, State]:
subpop_index = state.iter_counter % self.num_subpops

subpop, state = use_state(algorithm_class.ask, subpop_index)(
base_algorithms, state
)

# co-operate
tiled_best_dec = jnp.tile(state.best_dec, (self.subpop_size, 1))
coop_pops = jax.lax.dynamic_update_slice(
tiled_best_dec, subpop, (0, subpop_index * subproblem_dim)
)

# if random_subpop is set, do a inverse permutation here.
if self.random_subpop:
coop_pops = coop_pops.at[:, state.permutation].set(coop_pops)

return coop_pops, state.update(coop_pops=coop_pops)

def init_tell(self, state, fitness):
best_fit = jnp.min(fitness)
state = use_state(vmap(algorithm_class.init_tell, in_axes=(0, 0, None)))(
base_algorithms, state, fitness
)
best_dec = state.coop_pops[jnp.argmin(fitness)]
return state.update(
best_fit=jnp.tile(best_fit, self.num_subpops),
coop_pops=None,
)

def tell(self, state: State, fitness: jax.Array) -> State:
subpop_index = state.iter_counter % self.num_subpops
state = use_state(algorithm_class.tell, subpop_index)(
base_algorithms, state, fitness
)
min_fitness = jnp.min(fitness)

best_dec_this_gen = state.coop_pops[jnp.argmin(fitness)]
if self.random_subpop:
# if random_subpop is set, permutate the decision variables.
best_dec_this_gen = best_dec_this_gen[state.permutation]

best_dec = jax.lax.select(
state.best_fit[subpop_index] > min_fitness,
best_dec_this_gen,
state.best_dec,
)

best_fit = state.best_fit.at[subpop_index].min(min_fitness)

return state.update(
best_dec=best_dec,
best_fit=best_fit,
iter_counter=state.iter_counter + 1,
coop_pops=None,
)

return Coevolution(base_algorithms, dim, num_subpops, subpop_size, random_subpop)


@jit_class
class Coevolution(Algorithm):
"""
Expand All @@ -123,97 +231,3 @@ def __init__(
self.num_subpop_iter = num_subpop_iter
self.random_subpop = random_subpop
self.sub_dim = dim // num_subpops

def setup(self, key: jax.Array) -> State:
if self.random_subpop:
key, subkey = jax.random.split(key)
self.permutation = jax.random.permutation(subkey, self.dim)

best_dec = jnp.empty((self.dim,))
best_fit = jnp.full((self.num_subpops,), jnp.inf) # fitness
keys = jax.random.split(key, self.num_subpops)
base_alg_state = vmap(self._base_algorithm.init)(keys)

return State(
iter_counter=0,
subpops=jnp.empty((self.num_subpops, self.subpop_size, self.sub_dim)),
best_dec=best_dec,
best_fit=best_fit,
base_alg_state=base_alg_state,
)

def ask(self, state: State) -> Tuple[jax.Array, State]:
subpop_index = (state.iter_counter // self.num_subpop_iter) % self.num_subpops

# Ask all the algorithms once to initialize the best solution,
def init_best_dec(state):
# in the first iteration, we don't really have a best solution
# so just pick the first solution.
init_subpops, _base_alg_state = vmap(self._base_algorithm.ask)(state.base_alg_state)
first_dec = init_subpops[:, 0, :]
return first_dec.reshape((self.dim,))

best_dec = jax.lax.cond(
state.iter_counter == 0,
init_best_dec,
lambda state: state.best_dec,
state,
)

subpop, sub_alg_state = self._base_algorithm.ask(
state.base_alg_state[subpop_index]
)
subpops = state.subpops.at[subpop_index].set(subpop)
base_alg_state = tree_map(
lambda old, new: old.at[subpop_index].set(new),
state.base_alg_state,
sub_alg_state,
)

# co-operate
tiled_best_dec = jnp.tile(best_dec, (self.subpop_size, 1))
coop_pops = jax.lax.dynamic_update_slice(
tiled_best_dec, subpop, (0, subpop_index * self.sub_dim)
)

# if random_subpop is set, do a inverse permutation here.
if self.random_subpop:
coop_pops = coop_pops.at[:, self.permutation].set(coop_pops)

return coop_pops, state.update(
subpops=subpops,
base_alg_state=base_alg_state,
coop_pops=coop_pops,
)

def tell(self, state: State, fitness: jax.Array) -> State:
subpop_index = (state.iter_counter // self.num_subpop_iter) % self.num_subpops
subpop_base_alg_state = self._base_algorithm.tell(
state.base_alg_state[subpop_index], fitness
)
base_alg_state = tree_map(
lambda old, new: old.at[subpop_index].set(new),
state.base_alg_state,
subpop_base_alg_state,
)
min_fitness = jnp.min(fitness)

best_dec_this_gen = state.coop_pops[jnp.argmin(fitness)]
if self.random_subpop:
# if random_subpop is set, permutate the decision variables.
best_dec_this_gen = best_dec_this_gen[self.permutation]

best_dec = jax.lax.select(
state.best_fit[subpop_index] > min_fitness,
best_dec_this_gen,
state.best_dec,
)

best_fit = state.best_fit.at[subpop_index].min(min_fitness)

return state.update(
base_alg_state=base_alg_state,
best_dec=best_dec,
best_fit=best_fit,
iter_counter=state.iter_counter + 1,
)
2 changes: 1 addition & 1 deletion src/evox/algorithms/so/es_variants/nes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __init__(

def setup(self, key):
mean = self.init_mean
sigma = math.pow(jnp.prod(jnp.diag(self.init_covar)), 1 / self.dim)
sigma = jnp.pow(jnp.prod(jnp.diag(self.init_covar)), 1 / self.dim)
B = self.init_covar / sigma
population = jnp.empty((self.pop_size, self.dim))
noise = jnp.empty_like(population)
Expand Down
12 changes: 6 additions & 6 deletions src/evox/algorithms/so/es_variants/open_es.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import jax.numpy as jnp
import optax

import evox
from evox import Algorithm, State, jit_class, use_state, utils


@evox.jit_class
class OpenES(evox.Algorithm):
@jit_class
class OpenES(Algorithm):
def __init__(
self,
center_init,
Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(
self.mirrored_sampling = mirrored_sampling

if optimizer == "adam":
self.optimizer = evox.utils.OptaxWrapper(
self.optimizer = utils.OptaxWrapper(
optax.adam(learning_rate=learning_rate), center_init
)
else:
Expand All @@ -57,7 +57,7 @@ def setup(self, key):
# placeholder
population = jnp.tile(self.center_init, (self.pop_size, 1))
noise = jnp.tile(self.center_init, (self.pop_size, 1))
return evox.State(
return State(
population=population, center=self.center_init, noise=noise, key=key
)

Expand All @@ -77,6 +77,6 @@ def tell(self, state, fitness):
if self.optimizer is None:
center = state.center - self.learning_rate * grad
else:
updates, state = self.optimizer.update(state, state.center)
updates, state = use_state(self.optimizer.update)(state, state.center)
center = optax.apply_updates(state.center, updates)
return state.update(center=center)
Loading

0 comments on commit 0731898

Please sign in to comment.