Skip to content

Commit

Permalink
fix: coevolution
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Apr 14, 2024
1 parent bf234cd commit 7c96129
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 115 deletions.
2 changes: 1 addition & 1 deletion src/evox/algorithms/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .clustered_algorithm import ClusterdAlgorithm, RandomMaskAlgorithm
from .tree_algorithm import TreeAlgorithm
from .coevolution import VectorizedCoevolution, Coevolution
from .coevolution import VectorizedCoevolution, coevolution
206 changes: 110 additions & 96 deletions src/evox/algorithms/containers/coevolution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from dataclasses import field
from functools import partial
from typing import Optional, Tuple
from typing import List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
from evox import Algorithm, State, jit_class
from jax import vmap
from jax.tree_util import tree_map

from evox import Algorithm, State, Static, Stack, dataclass, jit_class, use_state


@jit_class
class VectorizedCoevolution(Algorithm):
Expand Down Expand Up @@ -97,6 +99,112 @@ def tell(self, state: State, fitness: jax.Array) -> State:
)


def coevolution(
base_algorithms,
dim,
num_subpops,
subpop_size,
random_subpop=True,
dtype=jnp.float32,
):
subproblem_dim = dim // num_subpops
algorithm_class = base_algorithms.__class__

@jit_class
@dataclass
class Coevolution(Algorithm):
base_algorithms: Stack[Algorithm]
dim: Static[int]
num_subpops: Static[int]
subpop_size: Static[int]
random_subpop: Static[bool]

def setup(self, key: jax.Array) -> State:
if self.random_subpop:
key, subkey = jax.random.split(key)
permutation = jax.random.permutation(subkey, self.dim)
else:
permutation = None

best_dec = jnp.empty((self.dim,), dtype=dtype)
best_fit = jnp.full((self.num_subpops,), jnp.inf) # fitness

return State(
coop_pops=jnp.empty((self.subpop_size, self.dim)),
best_dec=best_dec,
best_fit=best_fit,
iter_counter=0,
permutation=permutation,
)

def init_ask(self, state: State):
init_subpops, state = use_state(vmap(algorithm_class.init_ask))(
base_algorithms, state
)
# init_subpops (num_subpops, subpop_size, sub_dim)
init_pop = init_subpops.transpose((1, 0, 2)).reshape(-1, self.dim)
return init_pop, state.update(coop_pops=init_pop)

def ask(self, state: State) -> Tuple[jax.Array, State]:
subpop_index = state.iter_counter % self.num_subpops

subpop, state = use_state(algorithm_class.ask, subpop_index)(
base_algorithms, state
)

# co-operate
tiled_best_dec = jnp.tile(state.best_dec, (self.subpop_size, 1))
coop_pops = jax.lax.dynamic_update_slice(
tiled_best_dec, subpop, (0, subpop_index * subproblem_dim)
)

# if random_subpop is set, do a inverse permutation here.
if self.random_subpop:
coop_pops = coop_pops.at[:, state.permutation].set(coop_pops)

return coop_pops, state.update(coop_pops=coop_pops)

def init_tell(self, state, fitness):
best_fit = jnp.min(fitness)
state = use_state(vmap(algorithm_class.init_tell, in_axes=(0, 0, None)))(
base_algorithms, state, fitness
)
best_dec = state.coop_pops[jnp.argmin(fitness)]
return state.update(
best_fit=jnp.tile(best_fit, self.num_subpops),
coop_pops=None,
)

def tell(self, state: State, fitness: jax.Array) -> State:
subpop_index = state.iter_counter % self.num_subpops
state = use_state(algorithm_class.tell, subpop_index)(
base_algorithms, state, fitness
)
min_fitness = jnp.min(fitness)

best_dec_this_gen = state.coop_pops[jnp.argmin(fitness)]
if self.random_subpop:
# if random_subpop is set, permutate the decision variables.
best_dec_this_gen = best_dec_this_gen[state.permutation]

best_dec = jax.lax.select(
state.best_fit[subpop_index] > min_fitness,
best_dec_this_gen,
state.best_dec,
)

best_fit = state.best_fit.at[subpop_index].min(min_fitness)

return state.update(
best_dec=best_dec,
best_fit=best_fit,
iter_counter=state.iter_counter + 1,
coop_pops=None,
)

return Coevolution(base_algorithms, dim, num_subpops, subpop_size, random_subpop)


@jit_class
class Coevolution(Algorithm):
"""
Expand All @@ -123,97 +231,3 @@ def __init__(
self.num_subpop_iter = num_subpop_iter
self.random_subpop = random_subpop
self.sub_dim = dim // num_subpops

def setup(self, key: jax.Array) -> State:
if self.random_subpop:
key, subkey = jax.random.split(key)
self.permutation = jax.random.permutation(subkey, self.dim)

best_dec = jnp.empty((self.dim,))
best_fit = jnp.full((self.num_subpops,), jnp.inf) # fitness
keys = jax.random.split(key, self.num_subpops)
base_alg_state = vmap(self._base_algorithm.init)(keys)

return State(
iter_counter=0,
subpops=jnp.empty((self.num_subpops, self.subpop_size, self.sub_dim)),
best_dec=best_dec,
best_fit=best_fit,
base_alg_state=base_alg_state,
)

def ask(self, state: State) -> Tuple[jax.Array, State]:
subpop_index = (state.iter_counter // self.num_subpop_iter) % self.num_subpops

# Ask all the algorithms once to initialize the best solution,
def init_best_dec(state):
# in the first iteration, we don't really have a best solution
# so just pick the first solution.
init_subpops, _base_alg_state = vmap(self._base_algorithm.ask)(state.base_alg_state)
first_dec = init_subpops[:, 0, :]
return first_dec.reshape((self.dim,))

best_dec = jax.lax.cond(
state.iter_counter == 0,
init_best_dec,
lambda state: state.best_dec,
state,
)

subpop, sub_alg_state = self._base_algorithm.ask(
state.base_alg_state[subpop_index]
)
subpops = state.subpops.at[subpop_index].set(subpop)
base_alg_state = tree_map(
lambda old, new: old.at[subpop_index].set(new),
state.base_alg_state,
sub_alg_state,
)

# co-operate
tiled_best_dec = jnp.tile(best_dec, (self.subpop_size, 1))
coop_pops = jax.lax.dynamic_update_slice(
tiled_best_dec, subpop, (0, subpop_index * self.sub_dim)
)

# if random_subpop is set, do a inverse permutation here.
if self.random_subpop:
coop_pops = coop_pops.at[:, self.permutation].set(coop_pops)

return coop_pops, state.update(
subpops=subpops,
base_alg_state=base_alg_state,
coop_pops=coop_pops,
)

def tell(self, state: State, fitness: jax.Array) -> State:
subpop_index = (state.iter_counter // self.num_subpop_iter) % self.num_subpops
subpop_base_alg_state = self._base_algorithm.tell(
state.base_alg_state[subpop_index], fitness
)
base_alg_state = tree_map(
lambda old, new: old.at[subpop_index].set(new),
state.base_alg_state,
subpop_base_alg_state,
)
min_fitness = jnp.min(fitness)

best_dec_this_gen = state.coop_pops[jnp.argmin(fitness)]
if self.random_subpop:
# if random_subpop is set, permutate the decision variables.
best_dec_this_gen = best_dec_this_gen[self.permutation]

best_dec = jax.lax.select(
state.best_fit[subpop_index] > min_fitness,
best_dec_this_gen,
state.best_dec,
)

best_fit = state.best_fit.at[subpop_index].min(min_fitness)

return state.update(
base_alg_state=base_alg_state,
best_dec=best_dec,
best_fit=best_fit,
iter_counter=state.iter_counter + 1,
)
37 changes: 19 additions & 18 deletions tests/test_containers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import jax
import jax.numpy as jnp
import pytest
from evox import algorithms, workflows, problems
from evox import algorithms, workflows, problems, Stateful
from evox.monitors import StdSOMonitor


Expand Down Expand Up @@ -33,6 +33,7 @@ def test_clustered_cma_es():
assert min_fitness < 2


@pytest.mark.skip(reason="currently unsupported")
@pytest.mark.parametrize("random_subpop", [True, False])
def test_vectorized_coevolution(random_subpop):
# create a workflow
Expand Down Expand Up @@ -95,32 +96,32 @@ def test_vectorized_coevolution(random_subpop):
assert min_fitness < 1


@pytest.mark.parametrize(
"random_subpop, num_subpop_iter", [(True, 1), (False, 1), (True, 2), (False, 2)]
)
def test_coevolution(random_subpop, num_subpop_iter):
@pytest.mark.parametrize("random_subpop", [True, False])
def test_coevolution(random_subpop):
# create a workflow
monitor = StdSOMonitor()
base_algorithm = algorithms.CSO(
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=20,
)
base_algorithms = Stateful.stack([base_algorithm] * 4)
algorithm = algorithms.coevolution(
base_algorithms,
dim=40,
num_subpops=4,
subpop_size=10,
random_subpop=random_subpop,
)

workflow = workflows.StdWorkflow(
algorithms.Coevolution(
base_algorithm=algorithms.CSO(
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=20,
),
dim=40,
num_subpops=4,
subpop_size=10,
num_subpop_iter=num_subpop_iter,
random_subpop=random_subpop,
),
algorithm=algorithm,
problem=problems.numerical.Ackley(),
monitors=[monitor],
)
# init the workflow
key = jax.random.PRNGKey(42)
state = workflow.init(key)

for i in range(4 * 200):
state = workflow.step(state)

Expand Down

0 comments on commit 7c96129

Please sign in to comment.