Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MO, Operators and Problems #57

Merged
merged 14 commits into from
Aug 5, 2023
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ __pycache__
/benchmarks
/build
/src/evox.egg-info
/.idea

logs/
evox/algorithms/so/cso_*.py
Expand Down
5 changes: 5 additions & 0 deletions src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
from .moead import MOEAD
from .ibea import IBEA
from .nsga3 import NSGA3
from .eagmoead import EAGMOEAD
from .moeaddra import MOEADDRA
from .spea2 import SPEA2
from .hype import HypE
from .moeadm2m import MOEADM2M
194 changes: 194 additions & 0 deletions src/evox/algorithms/mo/eagmoead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import jax
import jax.numpy as jnp
from functools import partial

from evox import jit_class, Algorithm, State
from evox.operators import selection, mutation, crossover, non_dominated_sort, crowding_distance
from evox.operators.sampling import UniformSampling, LatinHypercubeSampling
from evox.utils import euclidean_dis


@partial(jax.jit, static_argnums=[1])
def environmental_selection(fitness, n):
rank = non_dominated_sort(fitness)
order = jnp.argsort(rank)
worst_rank = rank[order[n - 1]]
mask = rank == worst_rank
crowding_dis = crowding_distance(fitness, mask)
combined_indices = jnp.lexsort((-crowding_dis, rank))[: n]

return combined_indices


@jit_class
class EAGMOEAD(Algorithm):
"""EAG-MOEA/D algorithm

link: https://ieeexplore.ieee.org/abstract/document/6882229
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
LGs=8,
selection_op=None,
mutation_op=None,
crossover_op=None,
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.LGs = LGs
self.T = jnp.ceil(self.pop_size / 10).astype(int)

self.selection = selection_op
self.mutation = mutation_op
self.crossover = crossover_op

if self.selection is None:
self.selection = selection.RouletteWheelSelection(self.pop_size)
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary(type=2)
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
ext_archive = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
fitness = jnp.zeros((self.pop_size, self.n_objs))

w = self.sample(subkey2)[0]
B = euclidean_dis(w, w)
B = jnp.argsort(B, axis=1)
B = B[:, : self.T]
return State(
population=ext_archive,
fitness=fitness,
inner_pop=ext_archive,
inner_obj=fitness,
next_generation=ext_archive,
weight_vector=w,
B=B,
s=jnp.zeros((self.pop_size, self.LGs)),
parent=jnp.zeros((self.pop_size, self.T)).astype(int),
offspring_loc=jnp.zeros((self.pop_size, )).astype(int),
gen=0,
is_init=True,
key=key,
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
key, per_key, sel_key, x_key, mut_key = jax.random.split(state.key, 5)
B = state.B
population = state.inner_pop
n, t = jnp.shape(B)
s = jnp.sum(state.s, axis=1) + 1e-6
d = s / jnp.sum(s) + 0.002
d = d / jnp.sum(d)

_, offspring_loc = self.selection(sel_key, population, 1./d)
parent = jnp.zeros((n, 2)).astype(int)
B = jax.random.permutation(
per_key, B, axis=1, independent=True
).astype(int)

def body_fun(i, val):
val = val.at[i, 0].set(B[offspring_loc[i], 0])
val = val.at[i, 1].set(B[offspring_loc[i], 1])
return val.astype(int)

parent = jax.lax.fori_loop(0, n, body_fun, parent)

selected_p = jnp.r_[population[parent[:, 0]], population[parent[:, 1]]]

crossovered = self.crossover(x_key, selected_p)
next_generation = self.mutation(mut_key, crossovered)

return next_generation, state.update(
next_generation=next_generation, offspring_loc=offspring_loc, key=key
)

def _tell_init(self, state, fitness):
state = state.update(fitness=fitness, inner_obj=fitness, is_init=False)
return state

def _tell_normal(self, state, fitness):
gen = state.gen + 1
ext_archive = state.population
ext_obj = state.fitness
inner_pop = state.inner_pop
inner_obj = state.inner_obj

offspring = state.next_generation
offspring_obj = fitness
B = state.B
w = state.weight_vector
s = state.s

offspring_loc = state.offspring_loc
vals = (inner_pop, inner_obj)

def body_fun(i, vals):
population, pop_obj = vals
g_old = jnp.sum(pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1)
g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i])
idx = B[offspring_loc[i]]
g_new = g_new[:, jnp.newaxis]
g_old = g_old[:, jnp.newaxis]
population = population.at[idx].set(
jnp.where(g_old >= g_new, offspring[i], population[idx])
)
pop_obj = pop_obj.at[idx].set(
jnp.where(g_old >= g_new, offspring_obj[i], pop_obj[idx])
)
return (population, pop_obj)

inner_pop, inner_obj = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)

merged_pop = jnp.concatenate([ext_archive, offspring], axis=0)
merged_fitness = jnp.concatenate([ext_obj, offspring_obj], axis=0)

combined_order = environmental_selection(merged_fitness, self.pop_size)
survivor = merged_pop[combined_order]
survivor_fitness = merged_fitness[combined_order]
mask = combined_order >= self.pop_size
num_valid = jnp.sum(mask)
sucessful = jnp.where(mask, size=self.pop_size)

def update_s(s):
h = offspring_loc[combined_order[sucessful]-self.pop_size]
head = h[0]
h = jnp.where(h == head, -1, h)
h = h.at[0].set(head)
hist, _ = jnp.histogram(h, self.pop_size, range=(0, self.pop_size))
s = s.at[:, gen % self.LGs+1].set(hist)
return s

def no_update(s):
return s

s = jax.lax.cond(num_valid != 0, update_s, no_update, s)
state = state.update(population=survivor, fitness=survivor_fitness, inner_pop=inner_pop, inner_obj=inner_obj,
s=s, gen=gen)
return state
154 changes: 154 additions & 0 deletions src/evox/algorithms/mo/hype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import jax
import jax.numpy as jnp
from functools import partial

from evox import jit_class, Algorithm, State
from evox.operators import selection, mutation, crossover, non_dominated_sort


@partial(jax.jit, static_argnums=[0, 1])
def calculate_alpha(N, k):
alpha = jnp.zeros(N)

for i in range(1, k + 1):
num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i)))
alpha = alpha.at[i-1].set(num / i)
return alpha


@partial(jax.jit, static_argnums=[2, 3])
def cal_hv(points, ref, k, n_sample, key):
n, m = jnp.shape(points)
alpha = calculate_alpha(n, k)

f_min = jnp.min(points, axis=0)

s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref)

pds = jnp.zeros((n, n_sample), dtype=bool)
ds = jnp.zeros((n_sample, ))

def body_fun1(i, vals):
pds, ds = vals
x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m
pds = pds.at[i].set(jnp.where(x, True, pds[i]))
ds = jnp.where(x, ds+1, ds)
return pds, ds

pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds))
ds = ds - 1

f = jnp.zeros((n,))

def body_fun2(i, val):
temp = jnp.where(pds[i, :], ds, -1).astype(int)
value = jnp.where(temp!=-1, alpha[temp], 0)
value = jnp.sum(value)
val = val.at[i].set(value)
return val

f = jax.lax.fori_loop(0, n, body_fun2, f)
f = f * jnp.prod(ref - f_min) / n_sample

return f

@jit_class
class HypE(Algorithm):
"""HypE algorithm

link: https://direct.mit.edu/evco/article-abstract/19/1/45/1363/HypE-An-Algorithm-for-Fast-Hypervolume-Based-Many
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
n_sample=10000,
mutation_op=None,
crossover_op=None,
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.n_sample = n_sample

self.mutation = mutation_op
self.crossover = crossover_op
self.selection = selection.Tournament(n_round=self.pop_size)
if self.mutation is None:
self.mutation = mutation.Polynomial((lb, ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

def setup(self, key):
key, subkey = jax.random.split(key)
population = (
jax.random.uniform(subkey, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
next_generation=population,
ref_point=jnp.zeros((self.n_objs, )),
key=key,
is_init=True,
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
population = state.population
pop_obj = state.fitness
key, subkey, sel_key, x_key, mut_key = jax.random.split(state.key, 5)
hv = cal_hv(pop_obj, state.ref_point, self.pop_size, self.n_sample, subkey)

selected, _ = self.selection(sel_key, population, -hv)
crossovered = self.crossover(x_key, selected)
next_generation = self.mutation(mut_key, crossovered)

return next_generation, state.update(next_generation=next_generation)

def _tell_init(self, state, fitness):
ref_point = jnp.zeros((self.n_objs, )) + jnp.max(fitness)*1.2
state = state.update(fitness=fitness, ref_point=ref_point, is_init=False)
return state

def _tell_normal(self, state, fitness):
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
merged_obj = jnp.concatenate([state.fitness, fitness], axis=0)

n = jnp.shape(merged_pop)[0]

rank = non_dominated_sort(merged_obj)
order = jnp.argsort(rank)
worst_rank = rank[order[n-1]]
mask = rank == worst_rank

key, subkey = jax.random.split(state.key)
hv = cal_hv(merged_obj, state.ref_point, n, self.n_sample, subkey)

dis = jnp.where(mask, hv, -jnp.inf)

combined_indices = jnp.lexsort((-dis, rank))[: self.pop_size]

survivor = merged_pop[combined_indices]
survivor_fitness = merged_obj[combined_indices]

state = state.update(population=survivor, fitness=survivor_fitness, key=key)

return state
Loading