-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #57 from EMI-Group/dev2-lzy
Update MO, Operators and Problems
- Loading branch information
Showing
26 changed files
with
1,194 additions
and
212 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ __pycache__ | |
/benchmarks | ||
/build | ||
/src/evox.egg-info | ||
/.idea | ||
|
||
logs/ | ||
evox/algorithms/so/cso_*.py | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
from functools import partial | ||
|
||
from evox import jit_class, Algorithm, State | ||
from evox.operators import selection, mutation, crossover, non_dominated_sort, crowding_distance | ||
from evox.operators.sampling import UniformSampling, LatinHypercubeSampling | ||
from evox.utils import euclidean_dis | ||
|
||
|
||
@partial(jax.jit, static_argnums=[1]) | ||
def environmental_selection(fitness, n): | ||
rank = non_dominated_sort(fitness) | ||
order = jnp.argsort(rank) | ||
worst_rank = rank[order[n - 1]] | ||
mask = rank == worst_rank | ||
crowding_dis = crowding_distance(fitness, mask) | ||
combined_indices = jnp.lexsort((-crowding_dis, rank))[: n] | ||
|
||
return combined_indices | ||
|
||
|
||
@jit_class | ||
class EAGMOEAD(Algorithm): | ||
"""EAG-MOEA/D algorithm | ||
link: https://ieeexplore.ieee.org/abstract/document/6882229 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lb, | ||
ub, | ||
n_objs, | ||
pop_size, | ||
LGs=8, | ||
selection_op=None, | ||
mutation_op=None, | ||
crossover_op=None, | ||
): | ||
self.lb = lb | ||
self.ub = ub | ||
self.n_objs = n_objs | ||
self.dim = lb.shape[0] | ||
self.pop_size = pop_size | ||
self.LGs = LGs | ||
self.T = jnp.ceil(self.pop_size / 10).astype(int) | ||
|
||
self.selection = selection_op | ||
self.mutation = mutation_op | ||
self.crossover = crossover_op | ||
|
||
if self.selection is None: | ||
self.selection = selection.RouletteWheelSelection(self.pop_size) | ||
if self.mutation is None: | ||
self.mutation = mutation.Polynomial((lb, ub)) | ||
if self.crossover is None: | ||
self.crossover = crossover.SimulatedBinary(type=2) | ||
self.sample = LatinHypercubeSampling(self.pop_size, self.n_objs) | ||
|
||
def setup(self, key): | ||
key, subkey1, subkey2 = jax.random.split(key, 3) | ||
ext_archive = ( | ||
jax.random.uniform(subkey1, shape=(self.pop_size, self.dim)) | ||
* (self.ub - self.lb) | ||
+ self.lb | ||
) | ||
fitness = jnp.zeros((self.pop_size, self.n_objs)) | ||
|
||
w = self.sample(subkey2)[0] | ||
B = euclidean_dis(w, w) | ||
B = jnp.argsort(B, axis=1) | ||
B = B[:, : self.T] | ||
return State( | ||
population=ext_archive, | ||
fitness=fitness, | ||
inner_pop=ext_archive, | ||
inner_obj=fitness, | ||
next_generation=ext_archive, | ||
weight_vector=w, | ||
B=B, | ||
s=jnp.zeros((self.pop_size, self.LGs)), | ||
parent=jnp.zeros((self.pop_size, self.T)).astype(int), | ||
offspring_loc=jnp.zeros((self.pop_size, )).astype(int), | ||
gen=0, | ||
is_init=True, | ||
key=key, | ||
) | ||
|
||
def ask(self, state): | ||
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) | ||
|
||
def tell(self, state, fitness): | ||
return jax.lax.cond( | ||
state.is_init, self._tell_init, self._tell_normal, state, fitness | ||
) | ||
|
||
def _ask_init(self, state): | ||
return state.population, state | ||
|
||
def _ask_normal(self, state): | ||
key, per_key, sel_key, x_key, mut_key = jax.random.split(state.key, 5) | ||
B = state.B | ||
population = state.inner_pop | ||
n, t = jnp.shape(B) | ||
s = jnp.sum(state.s, axis=1) + 1e-6 | ||
d = s / jnp.sum(s) + 0.002 | ||
d = d / jnp.sum(d) | ||
|
||
_, offspring_loc = self.selection(sel_key, population, 1./d) | ||
parent = jnp.zeros((n, 2)).astype(int) | ||
B = jax.random.permutation( | ||
per_key, B, axis=1, independent=True | ||
).astype(int) | ||
|
||
def body_fun(i, val): | ||
val = val.at[i, 0].set(B[offspring_loc[i], 0]) | ||
val = val.at[i, 1].set(B[offspring_loc[i], 1]) | ||
return val.astype(int) | ||
|
||
parent = jax.lax.fori_loop(0, n, body_fun, parent) | ||
|
||
selected_p = jnp.r_[population[parent[:, 0]], population[parent[:, 1]]] | ||
|
||
crossovered = self.crossover(x_key, selected_p) | ||
next_generation = self.mutation(mut_key, crossovered) | ||
|
||
return next_generation, state.update( | ||
next_generation=next_generation, offspring_loc=offspring_loc, key=key | ||
) | ||
|
||
def _tell_init(self, state, fitness): | ||
state = state.update(fitness=fitness, inner_obj=fitness, is_init=False) | ||
return state | ||
|
||
def _tell_normal(self, state, fitness): | ||
gen = state.gen + 1 | ||
ext_archive = state.population | ||
ext_obj = state.fitness | ||
inner_pop = state.inner_pop | ||
inner_obj = state.inner_obj | ||
|
||
offspring = state.next_generation | ||
offspring_obj = fitness | ||
B = state.B | ||
w = state.weight_vector | ||
s = state.s | ||
|
||
offspring_loc = state.offspring_loc | ||
vals = (inner_pop, inner_obj) | ||
|
||
def body_fun(i, vals): | ||
population, pop_obj = vals | ||
g_old = jnp.sum(pop_obj[B[offspring_loc[i], :]] * w[B[offspring_loc[i], :]], axis=1) | ||
g_new = w[B[offspring_loc[i], :]] @ jnp.transpose(offspring_obj[i]) | ||
idx = B[offspring_loc[i]] | ||
g_new = g_new[:, jnp.newaxis] | ||
g_old = g_old[:, jnp.newaxis] | ||
population = population.at[idx].set( | ||
jnp.where(g_old >= g_new, offspring[i], population[idx]) | ||
) | ||
pop_obj = pop_obj.at[idx].set( | ||
jnp.where(g_old >= g_new, offspring_obj[i], pop_obj[idx]) | ||
) | ||
return (population, pop_obj) | ||
|
||
inner_pop, inner_obj = jax.lax.fori_loop(0, self.pop_size, body_fun, vals) | ||
|
||
merged_pop = jnp.concatenate([ext_archive, offspring], axis=0) | ||
merged_fitness = jnp.concatenate([ext_obj, offspring_obj], axis=0) | ||
|
||
combined_order = environmental_selection(merged_fitness, self.pop_size) | ||
survivor = merged_pop[combined_order] | ||
survivor_fitness = merged_fitness[combined_order] | ||
mask = combined_order >= self.pop_size | ||
num_valid = jnp.sum(mask) | ||
sucessful = jnp.where(mask, size=self.pop_size) | ||
|
||
def update_s(s): | ||
h = offspring_loc[combined_order[sucessful]-self.pop_size] | ||
head = h[0] | ||
h = jnp.where(h == head, -1, h) | ||
h = h.at[0].set(head) | ||
hist, _ = jnp.histogram(h, self.pop_size, range=(0, self.pop_size)) | ||
s = s.at[:, gen % self.LGs+1].set(hist) | ||
return s | ||
|
||
def no_update(s): | ||
return s | ||
|
||
s = jax.lax.cond(num_valid != 0, update_s, no_update, s) | ||
state = state.update(population=survivor, fitness=survivor_fitness, inner_pop=inner_pop, inner_obj=inner_obj, | ||
s=s, gen=gen) | ||
return state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
from functools import partial | ||
|
||
from evox import jit_class, Algorithm, State | ||
from evox.operators import selection, mutation, crossover, non_dominated_sort | ||
|
||
|
||
@partial(jax.jit, static_argnums=[0, 1]) | ||
def calculate_alpha(N, k): | ||
alpha = jnp.zeros(N) | ||
|
||
for i in range(1, k + 1): | ||
num = jnp.prod((k - jnp.arange(1, i)) / (N - jnp.arange(1, i))) | ||
alpha = alpha.at[i-1].set(num / i) | ||
return alpha | ||
|
||
|
||
@partial(jax.jit, static_argnums=[2, 3]) | ||
def cal_hv(points, ref, k, n_sample, key): | ||
n, m = jnp.shape(points) | ||
alpha = calculate_alpha(n, k) | ||
|
||
f_min = jnp.min(points, axis=0) | ||
|
||
s = jax.random.uniform(key, shape=(n_sample, m), minval=f_min, maxval=ref) | ||
|
||
pds = jnp.zeros((n, n_sample), dtype=bool) | ||
ds = jnp.zeros((n_sample, )) | ||
|
||
def body_fun1(i, vals): | ||
pds, ds = vals | ||
x = jnp.sum((jnp.tile(points[i, :], (n_sample, 1)) - s) <= 0, axis=1) == m | ||
pds = pds.at[i].set(jnp.where(x, True, pds[i])) | ||
ds = jnp.where(x, ds+1, ds) | ||
return pds, ds | ||
|
||
pds, ds = jax.lax.fori_loop(0, n, body_fun1, (pds, ds)) | ||
ds = ds - 1 | ||
|
||
f = jnp.zeros((n,)) | ||
|
||
def body_fun2(i, val): | ||
temp = jnp.where(pds[i, :], ds, -1).astype(int) | ||
value = jnp.where(temp!=-1, alpha[temp], 0) | ||
value = jnp.sum(value) | ||
val = val.at[i].set(value) | ||
return val | ||
|
||
f = jax.lax.fori_loop(0, n, body_fun2, f) | ||
f = f * jnp.prod(ref - f_min) / n_sample | ||
|
||
return f | ||
|
||
@jit_class | ||
class HypE(Algorithm): | ||
"""HypE algorithm | ||
link: https://direct.mit.edu/evco/article-abstract/19/1/45/1363/HypE-An-Algorithm-for-Fast-Hypervolume-Based-Many | ||
""" | ||
|
||
def __init__( | ||
self, | ||
lb, | ||
ub, | ||
n_objs, | ||
pop_size, | ||
n_sample=10000, | ||
mutation_op=None, | ||
crossover_op=None, | ||
): | ||
self.lb = lb | ||
self.ub = ub | ||
self.n_objs = n_objs | ||
self.dim = lb.shape[0] | ||
self.pop_size = pop_size | ||
self.n_sample = n_sample | ||
|
||
self.mutation = mutation_op | ||
self.crossover = crossover_op | ||
self.selection = selection.Tournament(n_round=self.pop_size) | ||
if self.mutation is None: | ||
self.mutation = mutation.Polynomial((lb, ub)) | ||
if self.crossover is None: | ||
self.crossover = crossover.SimulatedBinary() | ||
|
||
def setup(self, key): | ||
key, subkey = jax.random.split(key) | ||
population = ( | ||
jax.random.uniform(subkey, shape=(self.pop_size, self.dim)) | ||
* (self.ub - self.lb) | ||
+ self.lb | ||
) | ||
return State( | ||
population=population, | ||
fitness=jnp.zeros((self.pop_size, self.n_objs)), | ||
next_generation=population, | ||
ref_point=jnp.zeros((self.n_objs, )), | ||
key=key, | ||
is_init=True, | ||
) | ||
|
||
def ask(self, state): | ||
return jax.lax.cond(state.is_init, self._ask_init, self._ask_normal, state) | ||
|
||
def tell(self, state, fitness): | ||
return jax.lax.cond( | ||
state.is_init, self._tell_init, self._tell_normal, state, fitness | ||
) | ||
|
||
def _ask_init(self, state): | ||
return state.population, state | ||
|
||
def _ask_normal(self, state): | ||
population = state.population | ||
pop_obj = state.fitness | ||
key, subkey, sel_key, x_key, mut_key = jax.random.split(state.key, 5) | ||
hv = cal_hv(pop_obj, state.ref_point, self.pop_size, self.n_sample, subkey) | ||
|
||
selected, _ = self.selection(sel_key, population, -hv) | ||
crossovered = self.crossover(x_key, selected) | ||
next_generation = self.mutation(mut_key, crossovered) | ||
|
||
return next_generation, state.update(next_generation=next_generation) | ||
|
||
def _tell_init(self, state, fitness): | ||
ref_point = jnp.zeros((self.n_objs, )) + jnp.max(fitness)*1.2 | ||
state = state.update(fitness=fitness, ref_point=ref_point, is_init=False) | ||
return state | ||
|
||
def _tell_normal(self, state, fitness): | ||
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) | ||
merged_obj = jnp.concatenate([state.fitness, fitness], axis=0) | ||
|
||
n = jnp.shape(merged_pop)[0] | ||
|
||
rank = non_dominated_sort(merged_obj) | ||
order = jnp.argsort(rank) | ||
worst_rank = rank[order[n-1]] | ||
mask = rank == worst_rank | ||
|
||
key, subkey = jax.random.split(state.key) | ||
hv = cal_hv(merged_obj, state.ref_point, n, self.n_sample, subkey) | ||
|
||
dis = jnp.where(mask, hv, -jnp.inf) | ||
|
||
combined_indices = jnp.lexsort((-dis, rank))[: self.pop_size] | ||
|
||
survivor = merged_pop[combined_indices] | ||
survivor_fitness = merged_obj[combined_indices] | ||
|
||
state = state.update(population=survivor, fitness=survivor_fitness, key=key) | ||
|
||
return state |
Oops, something went wrong.