Skip to content

Commit

Permalink
Update EMO algorithms (#125)
Browse files Browse the repository at this point in the history
* fix bugs of lmocso

* recover ibea selection

* modify the sample operators

* delete comment
  • Loading branch information
Zhenyu2Liang committed Mar 25, 2024
1 parent ffa5d3e commit 262409d
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 48 deletions.
3 changes: 2 additions & 1 deletion src/evox/algorithms/mo/eagmoead.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
non_dominated_sort,
crowding_distance,
)
from evox.operators.sampling import UniformSampling, LatinHypercubeSampling
from evox.operators.sampling import LatinHypercubeSampling
from evox.utils import pairwise_euclidean_dist


Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)

ext_archive = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
Expand Down
36 changes: 16 additions & 20 deletions src/evox/algorithms/mo/ibea.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,26 +107,22 @@ def tell(self, state, fitness):

merged_fitness, I, C = cal_fitness(merged_obj, self.kappa)

# Different from the original paper, the selection here is directly through fitness.
next_ind = jnp.argsort(-merged_fitness)[0: self.pop_size]

# The following code is from the original paper's implementation
# and is kept for reference purposes but is not being used in this version.
# n = jnp.shape(merged_pop)[0]
# next_ind = jnp.arange(n)
# vals = (next_ind, merged_fitness)
# def body_fun(i, vals):
# next_ind, merged_fitness = vals
# x = jnp.argmin(merged_fitness)
# merged_fitness += jnp.exp(-I[x, :] / C[x] / self.kappa)
# merged_fitness = merged_fitness.at[x].set(jnp.max(merged_fitness))
# next_ind = next_ind.at[x].set(-1)
# return (next_ind, merged_fitness)
#
# next_ind, merged_fitness = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)
#
# next_ind = jnp.where(next_ind != -1, size=n, fill_value=-1)[0]
# next_ind = next_ind[0: self.pop_size]
n = jnp.shape(merged_pop)[0]
next_ind = jnp.arange(n)
vals = (next_ind, merged_fitness)

def body_fun(i, vals):
next_ind, merged_fitness = vals
x = jnp.argmin(merged_fitness)
merged_fitness += jnp.exp(-I[x, :] / C[x] / self.kappa)
merged_fitness = merged_fitness.at[x].set(jnp.max(merged_fitness))
next_ind = next_ind.at[x].set(-1)
return (next_ind, merged_fitness)

next_ind, merged_fitness = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)

next_ind = jnp.where(next_ind != -1, size=n, fill_value=-1)[0]
next_ind = next_ind[0 : self.pop_size]

survivor = merged_pop[next_ind]
survivor_fitness = merged_obj[next_ind]
Expand Down
16 changes: 9 additions & 7 deletions src/evox/algorithms/mo/lmocso.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import jax.numpy as jnp

from evox.operators import mutation, selection
from evox.operators.sampling import LatinHypercubeSampling
from evox.operators.sampling import UniformSampling
from evox import Algorithm, State, jit_class


Expand Down Expand Up @@ -80,22 +80,21 @@ def __init__(
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))

self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs)
self.sampling = UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
state_key, init_key, vector_key = jax.random.split(key, 3)
v = self.sampling(vector_key)[0]
self.pop_size = v.shape[0]

population = (
jax.random.uniform(init_key, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
velocity = jnp.zeros((self.pop_size, self.dim))
velocity = jnp.zeros((self.pop_size // 2 * 2, self.dim))
fitness = jnp.full((self.pop_size, self.n_objs), jnp.inf)

v = self.sampling(vector_key)[0]
v = v / jnp.linalg.norm(v, axis=0)

return State(
population=population,
next_generation=population,
Expand All @@ -117,10 +116,13 @@ def ask(self, state):
no_nan_pop = ~jnp.isnan(population).all(axis=1)
max_idx = jnp.sum(no_nan_pop).astype(int)
pop = population[jnp.where(no_nan_pop, size=self.pop_size, fill_value=-1)]

mating_pool = jax.random.randint(mating_key, (self.pop_size,), 0, max_idx)
population = pop[mating_pool]

randperm = jax.random.permutation(pairing_key, self.pop_size).reshape(2, -1)
randperm = jax.random.permutation(pairing_key, self.pop_size // 2 * 2).reshape(
2, -1
)

# calculate the shift-based density estimation(SDE) fitness
sde_fitness = cal_fitness(state.fitness)
Expand Down
14 changes: 9 additions & 5 deletions src/evox/algorithms/mo/moead.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@

from evox import Algorithm, State, jit_class
from evox.operators import crossover, mutation
from evox.operators.sampling import LatinHypercubeSampling, UniformSampling
from evox.operators.sampling import UniformSampling
from evox.utils import pairwise_euclidean_dist


@jit_class
class MOEAD(Algorithm):
"""MOEA/D algorithm
"""Parallel MOEA/D algorithm
link: https://ieeexplore.ieee.org/document/4358754
"""
Expand All @@ -43,7 +43,7 @@ def __init__(
self.dim = lb.shape[0]
self.pop_size = pop_size
self.type = type
self.T = int(math.ceil(self.pop_size / 10))
self.T = 0

self.mutation = mutation_op
self.crossover = crossover_op
Expand All @@ -52,16 +52,20 @@ def __init__(
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary(type=2)
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs)
self.sample = UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
w, _ = self.sample(subkey2)
self.pop_size = w.shape[0]
self.T = int(math.ceil(self.pop_size / 10))

population = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
w, _ = self.sample(subkey2)

B = pairwise_euclidean_dist(w, w)
B = jnp.argsort(B, axis=1)
B = B[:, : self.T]
Expand Down
3 changes: 1 addition & 2 deletions src/evox/algorithms/mo/moeaddra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from evox import Algorithm, State, jit_class
from evox.operators import crossover, mutation, selection
from evox.operators.sampling import LatinHypercubeSampling, UniformSampling
from evox.operators.sampling import LatinHypercubeSampling
from evox.utils import pairwise_euclidean_dist


Expand All @@ -33,7 +33,6 @@ def __init__(
ub,
n_objs,
pop_size,
# type=1,
mutation_op=None,
crossover_op=None,
):
Expand Down
12 changes: 8 additions & 4 deletions src/evox/algorithms/mo/moeadm2m.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,25 +115,29 @@ def __init__(
self.n_objs = n_objs
self.dim = lb.shape[0]
self.k = k
self.pop_size = (jnp.ceil(pop_size / self.k) * self.k).astype(int)
self.pop_size = pop_size
self.s = int(self.pop_size / self.k)
self.max_gen = max_gen

self.mutation = Mutation()
self.crossover = Crossover()

self.sample = sampling.LatinHypercubeSampling(self.k, self.n_objs)
self.sample = sampling.UniformSampling(self.k, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)

w, k = self.sample(subkey2)
self.k = int(k)
self.pop_size = (jnp.ceil(self.pop_size / self.k) * self.k).astype(int)
self.s = int(self.pop_size / self.k)

population = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)

w = self.sample(subkey2)[0]

return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
Expand Down
14 changes: 7 additions & 7 deletions src/evox/algorithms/mo/rvea.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.numpy as jnp

from evox.operators import mutation, crossover, selection
from evox.operators.sampling import LatinHypercubeSampling, UniformSampling
from evox.operators.sampling import UniformSampling
from evox import Algorithm, State, jit_class


Expand Down Expand Up @@ -59,20 +59,20 @@ def __init__(
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

if self.n_objs == 2:
self.sampling = UniformSampling(self.pop_size, self.n_objs)
else:
self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs)
self.sampling = UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)

v = self.sampling(subkey2)[0]
v0 = v
self.pop_size = v.shape[0]

population = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
v = self.sampling(subkey2)[0]
v0 = v

return State(
population=population,
Expand Down
6 changes: 4 additions & 2 deletions src/evox/algorithms/mo/tdea.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,18 @@ def __init__(
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()
self.sampling = sampling.LatinHypercubeSampling(self.pop_size, self.n_objs)
self.sample = sampling.UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
w, _ = self.sample(subkey2)
self.pop_size = w.shape[0]

population = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
w = self.sampling(subkey2)[0]
return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
Expand Down

0 comments on commit 262409d

Please sign in to comment.