Skip to content

Commit

Permalink
Merge pull request #57 from EMI-Group/dev2-lzy
Browse files Browse the repository at this point in the history
Update MO, Operators and Problems
  • Loading branch information
Zhenyu2Liang committed Aug 5, 2023
2 parents 54f357f + 01b67dc commit afca7f2
Show file tree
Hide file tree
Showing 26 changed files with 1,194 additions and 212 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __pycache__
/benchmarks
/build
/src/evox.egg-info
/.idea

logs/
evox/algorithms/so/cso_*.py
Expand Down
5 changes: 5 additions & 0 deletions src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
from .moead import MOEAD
from .ibea import IBEA
from .nsga3 import NSGA3
from .eagmoead import EAGMOEAD
from .moeaddra import MOEADDRA
from .spea2 import SPEA2
from .hype import HypE
from .moeadm2m import MOEADM2M
194 changes: 194 additions & 0 deletions src/evox/algorithms/mo/eagmoead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import jax
import jax.numpy as jnp
from functools import partial

from evox import jit_class, Algorithm, State
from evox.operators import selection, mutation, crossover, non_dominated_sort, crowding_distance
from evox.operators.sampling import UniformSampling, LatinHypercubeSampling
from evox.utils import euclidean_dis


@partial(jax.jit, static_argnums=[1])
def environmental_selection(fitness, n):
rank = non_dominated_sort(fitness)
order = jnp.argsort(rank)
worst_rank = rank[order[n - 1]]
mask = rank == worst_rank
crowding_dis = crowding_distance(fitness, mask)
combined_indices = jnp.lexsort((-crowding_dis, rank))[: n]

return combined_indices


@jit_class
class EAGMOEAD(Algorithm):
"""EAG-MOEA/D algorithm
link: https://ieeexplore.ieee.org/abstract/document/6882229
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
LGs=8,
selection_op=None,
mutation_op=None,
crossover_op=None,
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.LGs = LGs
self.T = jnp.ceil(self.pop_size / 10).astype(int)

self.selection = selection_op
self.mutation = mutation_op
self.crossover = crossover_op

if self.selection is None:
self.selection = selection.RouletteWheelSelection(self.pop_size)
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary(type=2)
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
ext_archive = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
fitness = jnp.zeros((self.pop_size, self.n_objs))

w = self.sample(subkey2)[0]
B = euclidean_dis(w, w)
B = jnp.argsort(B, axis=1)
B = B[:, : self.T]
return State(
population=ext_archive,
fitness=fitness,
inner_pop=ext_archive,
inner_obj=fitness,
next_generation=ext_archive,
weight_vector=w,
B=B,
s=jnp.zeros((self.pop_size, self.LGs)),
parent=jnp.zeros((self.pop_size, self.T)).astype(int),
offspring_loc=jnp.zeros((self.pop_size, )).astype(int),
gen=0,
is_init=True,
key=key,
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
key, per_key, sel_key, x_key, mut_key = jax.random.split(state.key, 5)
B = state.B
population = state.inner_pop
n, t = jnp.shape(B)
s = jnp.sum(state.s, axis=1) + 1e-6
d = s / jnp.sum(s) + 0.002
d = d / jnp.sum(d)

_, offspring_loc = self.selection(sel_key, population, 1./d)
parent = jnp.zeros((n, 2)).astype(int)
B = jax.random.permutation(
per_key, B, axis=1, independent=True
).astype(int)

def body_fun(i, val):
val = val.at[i, 0].set(B[offspring_loc[i], 0])
val = val.at[i, 1].set(B[offspring_loc[i], 1])
return val.astype(int)

parent = jax.lax.fori_loop(0, n, body_fun, parent)

selected_p = jnp.r_[population[parent[:, 0]], population[parent[:, 1]]]

crossovered = self.crossover(x_key, selected_p)
next_generation = self.mutation(mut_key, crossovered)

return next_generation, state.update(
next_generation=next_generation, offspring_loc=offspring_loc, key=key
)

def _tell_init(self, state, fitness):
state = state.update(fitness=fitness, inner_obj=fitness, is_init=False)
return state

def _tell_normal(self, state, fitness):
gen = state.gen + 1
ext_archive = state.population
ext_obj = state.fitness
inner_pop = state.inner_pop
inner_obj = state.inner_obj

offspring = state.next_generation
offspring_obj = fitness
B = state.B
w = state.weight_vector
s = state.s

offspring_loc = state.offspring_loc
vals = (inner_pop, inner_obj)

def body_fun(i, vals):
population, pop_obj = vals
g_old = jnp.sum(pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1)
g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i])
idx = B[offspring_loc[i]]
g_new = g_new[:, jnp.newaxis]
g_old = g_old[:, jnp.newaxis]
population = population.at[idx].set(
jnp.where(g_old >= g_new, offspring[i], population[idx])
)
pop_obj = pop_obj.at[idx].set(
jnp.where(g_old >= g_new, offspring_obj[i], pop_obj[idx])
)
return (population, pop_obj)

inner_pop, inner_obj = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)

merged_pop = jnp.concatenate([ext_archive, offspring], axis=0)
merged_fitness = jnp.concatenate([ext_obj, offspring_obj], axis=0)

combined_order = environmental_selection(merged_fitness, self.pop_size)
survivor = merged_pop[combined_order]
survivor_fitness = merged_fitness[combined_order]
mask = combined_order >= self.pop_size
num_valid = jnp.sum(mask)
sucessful = jnp.where(mask, size=self.pop_size)

def update_s(s):
h = offspring_loc[combined_order[sucessful]-self.pop_size]
head = h[0]
h = jnp.where(h == head, -1, h)
h = h.at[0].set(head)
hist, _ = jnp.histogram(h, self.pop_size, range=(0, self.pop_size))
s = s.at[:, gen % self.LGs+1].set(hist)
return s

def no_update(s):
return s

s = jax.lax.cond(num_valid != 0, update_s, no_update, s)
state = state.update(population=survivor, fitness=survivor_fitness, inner_pop=inner_pop, inner_obj=inner_obj,
s=s, gen=gen)
return state
154 changes: 154 additions & 0 deletions src/evox/algorithms/mo/hype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import jax
import jax.numpy as jnp
from functools import partial

from evox import jit_class, Algorithm, State
from evox.operators import selection, mutation, crossover, non_dominated_sort


@partial(jax.jit, static_argnums=[0, 1])
def calculate_alpha(N, k):
alpha = jnp.zeros(N)

for i in range(1, k + 1):
num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i)))
alpha = alpha.at[i-1].set(num / i)
return alpha


@partial(jax.jit, static_argnums=[2, 3])
def cal_hv(points, ref, k, n_sample, key):
n, m = jnp.shape(points)
alpha = calculate_alpha(n, k)

f_min = jnp.min(points, axis=0)

s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref)

pds = jnp.zeros((n, n_sample), dtype=bool)
ds = jnp.zeros((n_sample, ))

def body_fun1(i, vals):
pds, ds = vals
x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m
pds = pds.at[i].set(jnp.where(x, True, pds[i]))
ds = jnp.where(x, ds+1, ds)
return pds, ds

pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds))
ds = ds - 1

f = jnp.zeros((n,))

def body_fun2(i, val):
temp = jnp.where(pds[i, :], ds, -1).astype(int)
value = jnp.where(temp!=-1, alpha[temp], 0)
value = jnp.sum(value)
val = val.at[i].set(value)
return val

f = jax.lax.fori_loop(0, n, body_fun2, f)
f = f * jnp.prod(ref - f_min) / n_sample

return f

@jit_class
class HypE(Algorithm):
"""HypE algorithm
link: https://direct.mit.edu/evco/article-abstract/19/1/45/1363/HypE-An-Algorithm-for-Fast-Hypervolume-Based-Many
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
n_sample=10000,
mutation_op=None,
crossover_op=None,
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.n_sample = n_sample

self.mutation = mutation_op
self.crossover = crossover_op
self.selection = selection.Tournament(n_round=self.pop_size)
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

def setup(self, key):
key, subkey = jax.random.split(key)
population = (
jax.random.uniform(subkey, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
next_generation=population,
ref_point=jnp.zeros((self.n_objs, )),
key=key,
is_init=True,
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
population = state.population
pop_obj = state.fitness
key, subkey, sel_key, x_key, mut_key = jax.random.split(state.key, 5)
hv = cal_hv(pop_obj, state.ref_point, self.pop_size, self.n_sample, subkey)

selected, _ = self.selection(sel_key, population, -hv)
crossovered = self.crossover(x_key, selected)
next_generation = self.mutation(mut_key, crossovered)

return next_generation, state.update(next_generation=next_generation)

def _tell_init(self, state, fitness):
ref_point = jnp.zeros((self.n_objs, )) + jnp.max(fitness)*1.2
state = state.update(fitness=fitness, ref_point=ref_point, is_init=False)
return state

def _tell_normal(self, state, fitness):
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
merged_obj = jnp.concatenate([state.fitness, fitness], axis=0)

n = jnp.shape(merged_pop)[0]

rank = non_dominated_sort(merged_obj)
order = jnp.argsort(rank)
worst_rank = rank[order[n-1]]
mask = rank == worst_rank

key, subkey = jax.random.split(state.key)
hv = cal_hv(merged_obj, state.ref_point, n, self.n_sample, subkey)

dis = jnp.where(mask, hv, -jnp.inf)

combined_indices = jnp.lexsort((-dis, rank))[: self.pop_size]

survivor = merged_pop[combined_indices]
survivor_fitness = merged_obj[combined_indices]

state = state.update(population=survivor, fitness=survivor_fitness, key=key)

return state
Loading

0 comments on commit afca7f2

Please sign in to comment.