Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RVEAa algorithm #132

Merged
merged 4 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/api/algorithms/mo/im_moea.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
======
IMMOEA
======

.. autoclass:: evox.algorithms.IMMOEA
:members:
2 changes: 2 additions & 0 deletions docs/source/api/algorithms/mo/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ Multi-objective
spea2
sra
tdea
rveaa
im_moea
6 changes: 6 additions & 0 deletions docs/source/api/algorithms/mo/rveaa.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
=====
RVEAa
=====

.. autoclass:: evox.algorithms.RVEAa
:members:
1 change: 1 addition & 0 deletions src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .bce_ibea import BCEIBEA
from .lmocso import LMOCSO
from .im_moea import IMMOEA
from .rveaa import RVEAa
3 changes: 2 additions & 1 deletion src/evox/algorithms/mo/eagmoead.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import jax
import jax.numpy as jnp
from functools import partial
import math

from evox import jit_class, Algorithm, State
from evox.operators import (
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(
self.dim = lb.shape[0]
self.pop_size = pop_size
self.LGs = LGs
self.T = jnp.ceil(self.pop_size / 10).astype(int)
self.T = math.ceil(self.pop_size / 10)

self.selection = selection_op
self.mutation = mutation_op
Expand Down
21 changes: 6 additions & 15 deletions src/evox/algorithms/mo/rvea.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,11 @@ def init_tell(self, state, fitness):
return state

def ask(self, state):
key, subkey, x_key, x1_key, x2_key, mut_key = jax.random.split(state.key, 6)
population = state.population
key, subkey, x_key, mut_key = jax.random.split(state.key, 4)

population = state.population
no_nan_pop = ~jnp.isnan(population).all(axis=1)
max_idx = jnp.sum(no_nan_pop).astype(int)

pop = population[jnp.where(no_nan_pop, size=self.pop_size, fill_value=-1)]

mating_pool = jax.random.randint(subkey, (self.pop_size,), 0, max_idx)
Expand All @@ -117,26 +116,18 @@ def tell(self, state, fitness):
merged_pop, merged_fitness, v, (current_gen / self.max_gen) ** self.alpha
)

def rv_adaptation(pop_obj, v):
v_temp = v * jnp.tile(
(jnp.nanmax(pop_obj, axis=0) - jnp.nanmin(pop_obj, axis=0)), (len(v), 1)
)

next_v = v_temp / jnp.tile(
jnp.sqrt(jnp.sum(v_temp**2, axis=1)).reshape(len(v), 1),
(1, jnp.shape(v)[1]),
)

return next_v
def rv_adaptation(pop_obj, v, v0):
return v0 * (jnp.nanmax(pop_obj, axis=0) - jnp.nanmin(pop_obj, axis=0))

def no_update(_pop_obj, v):
def no_update(_pop_obj, v, v0):
return v

v = jax.lax.cond(
current_gen % (1 / self.fr) == 0,
rv_adaptation,
no_update,
survivor_fitness,
v,
state.init_v,
)

Expand Down
215 changes: 215 additions & 0 deletions src/evox/algorithms/mo/rveaa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# --------------------------------------------------------------------------------------
# 1. RVEA algorithm is described in the following papers:
#
# Title: A Reference Vector Guided Evolutionary Algorithm for Many-Objective Optimization
# Link: https://ieeexplore.ieee.org/document/7386636
# --------------------------------------------------------------------------------------

import jax
import jax.numpy as jnp

from evox.operators import mutation, crossover, selection
from evox.operators.sampling import UniformSampling
from evox import Algorithm, State, jit_class
from evox.utils import cos_dist
from evox.operators import non_dominated_sort


@jax.jit
def rv_regeneration(pop_obj, v, key):
"""
Regenerate reference vectors regenerate strategy.
"""
pop_obj = pop_obj - jnp.nanmin(pop_obj, axis=0)
cosine = cos_dist(pop_obj, v)

associate = jnp.nanargmax(cosine, axis=1)

invalid = jnp.sum(associate[:, jnp.newaxis] == jnp.arange(v.shape[0]), axis=0)
rand = jax.random.uniform(key, (v.shape[0], v.shape[1])) * jnp.nanmax(
pop_obj, axis=0
)
v = jnp.where(invalid[:, jnp.newaxis] == 0, rand, v)

return v


@jax.jit
def batch_truncation(pop, obj):
"""
Use the batch truncation operator to select the best n solutions.
"""
n = jnp.shape(pop)[0] // 2
cosine = cos_dist(obj, obj)
not_all_nan_rows = ~jnp.isnan(cosine).all(axis=1)
mask = jnp.eye(jnp.shape(cosine)[0], dtype=bool) & not_all_nan_rows[:, None]
cosine = jnp.where(mask, 0, cosine)

sorted_indices = jnp.sort(-cosine, axis=1)
rank = jnp.argsort(
jnp.where(jnp.isnan(sorted_indices[:, 0]), -jnp.inf, sorted_indices[:, 0])
)

mask = jnp.ones(jnp.shape(rank)[0], dtype=bool)
mask = mask.at[rank[:n]].set(False)[:, jnp.newaxis]

new_pop = jnp.where(mask, pop, jnp.nan)
new_obj = jnp.where(mask, obj, jnp.nan)

return new_pop, new_obj


@jit_class
class RVEAa(Algorithm):
"""RVEAa algorithms (RVEA embedded with the reference vector regeneration strategy)

link: https://ieeexplore.ieee.org/document/7386636

Args:
alpha : The parameter controlling the rate of change of penalty. Defaults to 2.
fr : The frequency of reference vector adaptation. Defaults to 0.1.
max_gen : The maximum number of generations. Defaults to 100.
If the number of iterations is not 100, change the value based on the actual value.
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
alpha=2,
fr=0.1,
max_gen=100,
selection_op=None,
mutation_op=None,
crossover_op=None,
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.alpha = alpha
self.fr = fr
self.max_gen = max_gen

self.selection = selection_op
self.mutation = mutation_op
self.crossover = crossover_op

if self.selection is None:
self.selection = selection.ReferenceVectorGuided()
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

self.sampling = UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2, subkey3 = jax.random.split(key, 4)

v = self.sampling(subkey1)[0]
v0 = v
self.pop_size = v.shape[0]

population0 = (
jax.random.uniform(subkey2, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
population = jnp.concatenate(
[
population0,
jnp.full(shape=(self.pop_size, self.dim), fill_value=jnp.nan),
],
axis=0,
)
v = jnp.concatenate(
[v, jax.random.uniform(subkey3, shape=(self.pop_size, self.n_objs))], axis=0
)

return State(
population=population,
fitness=jnp.zeros((self.pop_size * 2, self.n_objs)),
next_generation=population0,
reference_vector=v,
init_v=v0,
key=key,
gen=0,
)

def init_ask(self, state):
return state.population, state

def init_tell(self, state, fitness):
state = state.update(fitness=fitness)
return state

def ask(self, state):
key, subkey, x_key, mut_key = jax.random.split(state.key, 4)

population = state.population
no_nan_pop = ~jnp.isnan(population).all(axis=1)
max_idx = jnp.sum(no_nan_pop).astype(int)
pop = population[jnp.where(no_nan_pop, size=self.pop_size, fill_value=-1)]

mating_pool = jax.random.randint(subkey, (self.pop_size,), 0, max_idx)
crossovered = self.crossover(x_key, pop[mating_pool])
next_generation = self.mutation(mut_key, crossovered)
next_generation = jnp.clip(next_generation, self.lb, self.ub)

return next_generation, state.update(next_generation=next_generation, key=key)

def tell(self, state, fitness):
key, subkey = jax.random.split(state.key, 2)
current_gen = state.gen + 1

v = state.reference_vector
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)

merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0)

rank = non_dominated_sort(merged_fitness)
merged_fitness = jnp.where(rank[:, jnp.newaxis] == 0, merged_fitness, jnp.nan)
merged_pop = jnp.where(rank[:, jnp.newaxis] == 0, merged_pop, jnp.nan)

survivor, survivor_fitness = self.selection(
merged_pop, merged_fitness, v, (current_gen / self.max_gen) ** self.alpha
)

def rv_adaptation(pop_obj, v, v0):
return v0 * (jnp.nanmax(pop_obj, axis=0) - jnp.nanmin(pop_obj, axis=0))

def no_update(_pop_obj, v, v0):
return v

v_adapt = jax.lax.cond(
current_gen % (1 / self.fr) == 0,
rv_adaptation,
no_update,
survivor_fitness,
v[: self.pop_size],
state.init_v,
)

v_regen = rv_regeneration(survivor_fitness, v[self.pop_size :], subkey)
v = jnp.concatenate([v_adapt, v_regen], axis=0)

survivor, survivor_fitness = jax.lax.cond(
current_gen + 1 == self.max_gen,
batch_truncation,
lambda x, y: (x, y),
survivor,
survivor_fitness,
)

state = state.update(
population=survivor,
fitness=survivor_fitness,
reference_vector=v,
gen=current_gen,
key=key,
)
return state
11 changes: 11 additions & 0 deletions tests/test_multi_objective_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def test_lmocso():
)
run_moea(algorithm)


def test_im_moea():
algorithm = algorithms.IMMOEA(
lb=jnp.full(shape=(N,), fill_value=LB),
Expand All @@ -208,3 +209,13 @@ def test_im_moea():
pop_size=POP_SIZE,
)
run_moea(algorithm)


def test_rveaa():
algorithm = algorithms.RVEAa(
lb=jnp.full(shape=(N,), fill_value=LB),
ub=jnp.full(shape=(N,), fill_value=UB),
n_objs=M,
pop_size=POP_SIZE,
)
run_moea(algorithm)
Loading