Skip to content

Commit

Permalink
Merge pull request #102 from EMI-Group/dev1-lzy
Browse files Browse the repository at this point in the history
Fix multiobjective algorithms and DTLZ problems
  • Loading branch information
BillHuang2001 committed Jan 14, 2024
2 parents 731a745 + 13dc751 commit 0620e56
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 30 deletions.
8 changes: 3 additions & 5 deletions src/evox/algorithms/mo/nsga3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def __init__(
ub,
n_objs,
pop_size,
ref=None,
selection_op=None,
mutation_op=None,
crossover_op=None,
Expand All @@ -46,22 +45,19 @@ def __init__(
self.n_objs = n_objs
self.dim = lb.shape[0]
self.pop_size = pop_size
self.ref = ref

self.selection = selection_op
self.mutation = mutation_op
self.crossover = crossover_op

if self.ref is None:
self.ref = sampling.UniformSampling(pop_size, n_objs)()[0]
if self.selection is None:
self.selection = selection.UniformRand(0.5)
if self.mutation is None:
self.mutation = mutation.Gaussian()
if self.crossover is None:
self.crossover = crossover.UniformRand()

self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None]
self.sampling = sampling.UniformSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey = jax.random.split(key)
Expand All @@ -70,6 +66,8 @@ def setup(self, key):
* (self.ub - self.lb)
+ self.lb
)
self.ref = self.sampling(subkey)[0]
self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None]
return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
Expand Down
15 changes: 10 additions & 5 deletions src/evox/algorithms/mo/rvea.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import jax.numpy as jnp

from evox.operators import mutation, crossover, selection
from evox.operators.sampling import LatinHypercubeSampling
from evox.operators.sampling import LatinHypercubeSampling, UniformSampling
from evox import Algorithm, State, jit_class


Expand Down Expand Up @@ -59,7 +59,10 @@ def __init__(
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs)
if self.n_objs == 2:
self.sampling = UniformSampling(self.pop_size, self.n_objs)
else:
self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs)

def setup(self, key):
key, subkey1, subkey2 = jax.random.split(key, 3)
Expand All @@ -69,13 +72,14 @@ def setup(self, key):
+ self.lb
)
v = self.sampling(subkey2)[0]
v = v / jnp.linalg.norm(v, axis=0)
v0 = v

return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
next_generation=population,
reference_vector=v,
init_v=v0,
key=key,
gen=0,
)
Expand All @@ -88,7 +92,7 @@ def init_tell(self, state, fitness):
return state

def ask(self, state):
key, subkey, x_key, mut_key = jax.random.split(state.key, 4)
key, subkey, x_key, x1_key, x2_key, mut_key = jax.random.split(state.key, 6)
population = state.population

no_nan_pop = ~jnp.isnan(population).all(axis=1)
Expand All @@ -99,6 +103,7 @@ def ask(self, state):
mating_pool = jax.random.randint(subkey, (self.pop_size,), 0, max_idx)
crossovered = self.crossover(x_key, pop[mating_pool])
next_generation = self.mutation(mut_key, crossovered)
next_generation = jnp.clip(next_generation, self.lb, self.ub)

return next_generation, state.update(next_generation=next_generation, key=key)

Expand Down Expand Up @@ -132,7 +137,7 @@ def no_update(_pop_obj, v):
rv_adaptation,
no_update,
survivor_fitness,
v,
state.init_v,
)

state = state.update(
Expand Down
2 changes: 1 addition & 1 deletion src/evox/operators/sampling/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, n=None, m=None):
self.n = n
self.m = m

def __call__(self):
def __call__(self, key=None):
h1 = 1
while comb(h1 + self.m, self.m - 1) <= self.n:
h1 += 1
Expand Down
8 changes: 4 additions & 4 deletions src/evox/operators/selection/rvea_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def ref_vec_guided(x, f, v, theta):
nan_mask = jnp.isnan(obj).any(axis=1)
associate = jnp.argmin(angle, axis=1)
associate = jnp.where(nan_mask, -1, associate)

partition = jax.vmap(
lambda x: jnp.where(associate == x, jnp.arange(0, n), -1), in_axes=1, out_axes=1
)(jnp.tile(jnp.arange(0, nv), (n, 1)))
associate = jnp.tile(associate[:, jnp.newaxis], (1, nv))
partition = jnp.tile(jnp.arange(0, n)[:, jnp.newaxis], (1, nv))
I = jnp.tile(jnp.arange(0, nv), (n, 1))
partition = (associate == I) * partition + (associate != I) * -1

mask = partition == -1
mask_null = jnp.sum(mask, axis=0) == n
Expand Down
20 changes: 6 additions & 14 deletions src/evox/problems/numerical/dtlz.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,9 @@ def pf(self, state):
for i in range(self.m - 2):
f = jnp.c_[f[:, 0], f]

f = (
f
/ jnp.sqrt(2)
* jnp.tile(
jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))),
(jnp.shape(f)[0], 1),
)
f = f / jnp.sqrt(2) ** jnp.tile(
jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))),
(jnp.shape(f)[0], 1),
)
return f, state

Expand Down Expand Up @@ -283,13 +279,9 @@ def pf(self, state):
for i in range(self.m - 2):
f = jnp.c_[f[:, 0], f]

f = (
f
/ jnp.sqrt(2)
* jnp.tile(
jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))),
(jnp.shape(f)[0], 1),
)
f = f / jnp.sqrt(2) ** jnp.tile(
jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))),
(jnp.shape(f)[0], 1),
)
return f, state

Expand Down
1 change: 0 additions & 1 deletion tests/test_multi_objective_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
def run_moea(algorithm, problem=problems.numerical.DTLZ1(m=M)):
key = jax.random.PRNGKey(42)
monitor = StdMOMonitor(record_pf=False)
# problem = problems.numerical.DTLZ2(m=M)
workflow = workflows.StdWorkflow(
algorithm=algorithm,
problem=problem,
Expand Down

0 comments on commit 0620e56

Please sign in to comment.