Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update MO, Operators and Problems #57

Merged
merged 14 commits into from
Aug 5, 2023
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 64 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions .idea/evox.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/webResources.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Empty file added src/__init__.py
Zhenyu2Liang marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
6 changes: 6 additions & 0 deletions src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@
from .moead import MOEAD
from .ibea import IBEA
from .nsga3 import NSGA3
from .eagmoead import EAGMOEAD
from .moeaddra import MOEADDRA
from .spea2 import SPEA2
from .hype import HypE
# from .rveaa import RVEAa
Zhenyu2Liang marked this conversation as resolved.
Show resolved Hide resolved

188 changes: 188 additions & 0 deletions src/evox/algorithms/mo/eagmoead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import jax
import jax.numpy as jnp
from functools import partial

from evox import jit_class, Algorithm, State
from evox.operators.selection import RouletteWheelSelection, non_dominated_sort, crowding_distance_sort, crowding_distance
from evox.operators.mutation import PmMutation
from evox.operators.crossover import SimulatedBinaryCrossover
from evox.operators.sampling import UniformSampling, LatinHypercubeSampling
from evox.utils import euclidean_dis
from jax.experimental.host_callback import id_print
Zhenyu2Liang marked this conversation as resolved.
Show resolved Hide resolved


@partial(jax.jit, static_argnames=['n'])
def environmental_selection(fitness, n):
rank = non_dominated_sort(fitness)
order = jnp.argsort(rank)
worst_rank = rank[order[n - 1]]
mask = rank == worst_rank
crowding_dis = crowding_distance(fitness, mask)
combined_indices = jnp.lexsort((-crowding_dis, rank))[: n]

return combined_indices


@jit_class
class EAGMOEAD(Algorithm):
"""EAG-MOEA/D algorithm

link: https://ieeexplore.ieee.org/abstract/document/6882229
"""

def __init__(
self,
lb,
ub,
n_objs,
pop_size,
LGs=8,
mutation=PmMutation(),
crossover=SimulatedBinaryCrossover(type=2),
):
self.lb = lb
self.ub = ub
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.LGs = LGs
self.T = jnp.ceil(self.pop_size / 10).astype(int)

self.selection = RouletteWheelSelection(self.pop_size)
self.mutation = mutation
self.crossover = crossover

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
ext_archive = (
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
+ self.lb
)
fitness = jnp.zeros((self.pop_size, self.n_objs))
# w = UniformSampling(self.pop_size, self.n_objs).random()[0]
w = LatinHypercubeSampling(self.pop_size, self.n_objs).random(subkey2)[0]
B = euclidean_dis(w, w)
B = jnp.argsort(B, axis=1)
B = B[:, : self.T]
return State(
population=ext_archive,
fitness=fitness,
inner_pop=ext_archive,
inner_obj=fitness,
next_generation=ext_archive,
weight_vector=w,
B=B,
s=jnp.zeros((self.pop_size, self.LGs)),
parent=jnp.zeros((self.pop_size, self.T)).astype(int),
offspring_loc=jnp.zeros((self.pop_size, )).astype(int),
gen=0,
is_init=True,
key=key,
)

def ask(self, state):
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state)

def tell(self, state, fitness):
return jax.lax.cond(
state.is_init, self._tell_init, self._tell_normal, state, fitness
)

def _ask_init(self, state):
return state.population, state

def _ask_normal(self, state):
key, subkey = jax.random.split(state.key)
B = state.B
population = state.inner_pop
n, t = jnp.shape(B)
s = jnp.sum(state.s, axis=1) + 1e-6
d = s / jnp.sum(s) + 0.002
d = d / jnp.sum(d)

offspring_loc, state = self.selection(state, 1./d)
parent = jnp.zeros((n, 2)).astype(int)
B = jax.random.permutation(
subkey, B, axis=1, independent=True
).astype(int)

def body_fun(i, val):
val = val.at[i, 0].set(B[offspring_loc[i], 0])
val = val.at[i, 1].set(B[offspring_loc[i], 1])
return val.astype(int)

parent = jax.lax.fori_loop(0, n, body_fun, parent)

selected_p = jnp.r_[population[parent[:, 0]], population[parent[:, 1]]]

crossovered, state = self.crossover(state, selected_p)
next_generation, state = self.mutation(state, crossovered, (self.lb, self.ub))

return next_generation, state.update(
next_generation=next_generation, offspring_loc=offspring_loc, key=key
)

def _tell_init(self, state, fitness):
state = state.update(fitness=fitness, inner_obj=fitness, is_init=False)
return state

def _tell_normal(self, state, fitness):
gen = state.gen + 1
ext_archive = state.population
ext_obj = state.fitness
inner_pop = state.inner_pop
inner_obj = state.inner_obj

offspring = state.next_generation
offspring_obj = fitness
B = state.B
w = state.weight_vector
s = state.s

offspring_loc = state.offspring_loc
vals = (inner_pop, inner_obj)

def body_fun(i, vals):
population, pop_obj = vals
g_old = jnp.sum(pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1)
g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i])
idx = B[offspring_loc[i]]
g_new = g_new[:, jnp.newaxis]
g_old = g_old[:, jnp.newaxis]
population = population.at[idx].set(
jnp.where(g_old >= g_new, offspring[i], population[idx])
)
pop_obj = pop_obj.at[idx].set(
jnp.where(g_old >= g_new, offspring_obj[i], pop_obj[idx])
)
return (population, pop_obj)

inner_pop, inner_obj = jax.lax.fori_loop(0, self.pop_size, body_fun, vals)

merged_pop = jnp.concatenate([ext_archive, offspring], axis=0)
merged_fitness = jnp.concatenate([ext_obj, offspring_obj], axis=0)

combined_order = environmental_selection(merged_fitness, self.pop_size)
survivor = merged_pop[combined_order]
survivor_fitness = merged_fitness[combined_order]
mask = combined_order >= self.pop_size
num_valid = jnp.sum(mask)
sucessful = jnp.where(mask, size=self.pop_size)

def update_s(s):
h = offspring_loc[combined_order[sucessful]-self.pop_size]
head = h[0]
h = jnp.where(h == head, -1, h)
h = h.at[0].set(head)
hist, _ = jnp.histogram(h, self.pop_size, range=(0, self.pop_size))
s = s.at[:, gen % self.LGs+1].set(hist)
return s

def no_update(s):
return s

s = jax.lax.cond(num_valid != 0, update_s, no_update, s)
state = state.update(population=survivor, fitness=survivor_fitness, inner_pop=inner_pop, inner_obj=inner_obj,
s=s, gen=gen)
return state
Loading