From fcd9d4ad81bf8a3b7d72dec6f61fba9721a044ff Mon Sep 17 00:00:00 2001 From: Zhenyu2Liang <1370096263@qq.com> Date: Sun, 14 Jan 2024 03:19:16 +0800 Subject: [PATCH 1/3] fix rvea, nsga3 and dtlz --- src/evox/algorithms/mo/nsga3.py | 8 ++--- src/evox/algorithms/mo/rvea.py | 15 +++++--- src/evox/operators/sampling/uniform.py | 2 +- .../operators/selection/rvea_selection.py | 8 ++--- src/evox/problems/numerical/dtlz.py | 34 +++++++------------ tests/test_multi_objective_algorithms.py | 4 +-- 6 files changed, 33 insertions(+), 38 deletions(-) diff --git a/src/evox/algorithms/mo/nsga3.py b/src/evox/algorithms/mo/nsga3.py index 75cceceb..529b8b8f 100644 --- a/src/evox/algorithms/mo/nsga3.py +++ b/src/evox/algorithms/mo/nsga3.py @@ -36,7 +36,6 @@ def __init__( ub, n_objs, pop_size, - ref=None, selection_op=None, mutation_op=None, crossover_op=None, @@ -46,14 +45,11 @@ def __init__( self.n_objs = n_objs self.dim = lb.shape[0] self.pop_size = pop_size - self.ref = ref self.selection = selection_op self.mutation = mutation_op self.crossover = crossover_op - if self.ref is None: - self.ref = sampling.UniformSampling(pop_size, n_objs)()[0] if self.selection is None: self.selection = selection.UniformRand(0.5) if self.mutation is None: @@ -61,7 +57,7 @@ def __init__( if self.crossover is None: self.crossover = crossover.UniformRand() - self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None] + self.sampling = sampling.UniformSampling(self.pop_size, self.n_objs) def setup(self, key): key, subkey = jax.random.split(key) @@ -70,6 +66,8 @@ def setup(self, key): * (self.ub - self.lb) + self.lb ) + self.ref = self.sampling(subkey)[0] + self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None] return State( population=population, fitness=jnp.zeros((self.pop_size, self.n_objs)), diff --git a/src/evox/algorithms/mo/rvea.py b/src/evox/algorithms/mo/rvea.py index f92dfdb2..1e59819b 100644 --- a/src/evox/algorithms/mo/rvea.py +++ b/src/evox/algorithms/mo/rvea.py @@ -9,7 +9,7 @@ import jax.numpy as jnp from evox.operators import mutation, crossover, selection -from evox.operators.sampling import LatinHypercubeSampling +from evox.operators.sampling import LatinHypercubeSampling, UniformSampling from evox import Algorithm, State, jit_class @@ -59,7 +59,10 @@ def __init__( if self.crossover is None: self.crossover = crossover.SimulatedBinary() - self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs) + if self.n_objs == 2: + self.sampling = UniformSampling(self.pop_size, self.n_objs) + else: + self.sampling = LatinHypercubeSampling(self.pop_size, self.n_objs) def setup(self, key): key, subkey1, subkey2 = jax.random.split(key, 3) @@ -69,13 +72,14 @@ def setup(self, key): + self.lb ) v = self.sampling(subkey2)[0] - v = v / jnp.linalg.norm(v, axis=0) + v0 = v return State( population=population, fitness=jnp.zeros((self.pop_size, self.n_objs)), next_generation=population, reference_vector=v, + init_v=v0, key=key, gen=0, ) @@ -88,7 +92,7 @@ def init_tell(self, state, fitness): return state def ask(self, state): - key, subkey, x_key, mut_key = jax.random.split(state.key, 4) + key, subkey, x_key, x1_key, x2_key, mut_key = jax.random.split(state.key, 6) population = state.population no_nan_pop = ~jnp.isnan(population).all(axis=1) @@ -99,6 +103,7 @@ def ask(self, state): mating_pool = jax.random.randint(subkey, (self.pop_size,), 0, max_idx) crossovered = self.crossover(x_key, pop[mating_pool]) next_generation = self.mutation(mut_key, crossovered) + next_generation = jnp.clip(next_generation, self.lb, self.ub) return next_generation, state.update(next_generation=next_generation, key=key) @@ -132,7 +137,7 @@ def no_update(_pop_obj, v): rv_adaptation, no_update, survivor_fitness, - v, + state.init_v, ) state = state.update( diff --git a/src/evox/operators/sampling/uniform.py b/src/evox/operators/sampling/uniform.py index 3498eb9d..f04e813b 100644 --- a/src/evox/operators/sampling/uniform.py +++ b/src/evox/operators/sampling/uniform.py @@ -17,7 +17,7 @@ def __init__(self, n=None, m=None): self.n = n self.m = m - def __call__(self): + def __call__(self, key): h1 = 1 while comb(h1 + self.m, self.m - 1) <= self.n: h1 += 1 diff --git a/src/evox/operators/selection/rvea_selection.py b/src/evox/operators/selection/rvea_selection.py index ffb779db..0c040b7d 100644 --- a/src/evox/operators/selection/rvea_selection.py +++ b/src/evox/operators/selection/rvea_selection.py @@ -23,10 +23,10 @@ def ref_vec_guided(x, f, v, theta): nan_mask = jnp.isnan(obj).any(axis=1) associate = jnp.argmin(angle, axis=1) associate = jnp.where(nan_mask, -1, associate) - - partition = jax.vmap( - lambda x: jnp.where(associate == x, jnp.arange(0, n), -1), in_axes=1, out_axes=1 - )(jnp.tile(jnp.arange(0, nv), (n, 1))) + associate = jnp.tile(associate[:, jnp.newaxis], (1, nv)) + partition = jnp.tile(jnp.arange(0, n)[:, jnp.newaxis], (1, nv)) + I = jnp.tile(jnp.arange(0, nv), (n, 1)) + partition = (associate == I) * partition + (associate != I) * -1 mask = partition == -1 mask_null = jnp.sum(mask, axis=0) == n diff --git a/src/evox/problems/numerical/dtlz.py b/src/evox/problems/numerical/dtlz.py index c011a011..95cb29ef 100644 --- a/src/evox/problems/numerical/dtlz.py +++ b/src/evox/problems/numerical/dtlz.py @@ -29,7 +29,7 @@ def evaluate(self, state, X): return jax.jit(jax.vmap(self._dtlz))(X), state def pf(self, state): - f = self.sample()[0] / 2 + f = self.sample(state.key)[0] / 2 return f, state @@ -43,7 +43,7 @@ def __init__(self, d=None, m=None, ref_num=100): self.d = self.m + 4 else: self.d = d - super().__init__(self.d, self.m, ref_num) + super().__init__(d, m, ref_num) def evaluate(self, state, X): m = self.m @@ -79,7 +79,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(self.d, self.m, ref_num) + super().__init__(d, m, ref_num) def evaluate(self, state, X): m = self.m @@ -103,7 +103,7 @@ def evaluate(self, state, X): return f, state def pf(self, state): - f = self.sample()[0] + f = self.sample(state.key)[0] f /= jnp.tile(jnp.sqrt(jnp.sum(f**2, axis=1, keepdims=True)), (1, self.m)) return f, state @@ -183,7 +183,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(self.d, self.m, ref_num) + super().__init__(d, m, ref_num) def evaluate(self, state, X): m = self.m @@ -222,13 +222,9 @@ def pf(self, state): for i in range(self.m - 2): f = jnp.c_[f[:, 0], f] - f = ( - f - / jnp.sqrt(2) - * jnp.tile( - jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), - (jnp.shape(f)[0], 1), - ) + f = f / jnp.sqrt(2) ** jnp.tile( + jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), + (jnp.shape(f)[0], 1), ) return f, state @@ -243,7 +239,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(self.d, self.m, ref_num) + super().__init__(d, m, ref_num) def evaluate(self, state, X): m = self.m @@ -283,13 +279,9 @@ def pf(self, state): for i in range(self.m - 2): f = jnp.c_[f[:, 0], f] - f = ( - f - / jnp.sqrt(2) - * jnp.tile( - jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), - (jnp.shape(f)[0], 1), - ) + f = f / jnp.sqrt(2) ** jnp.tile( + jnp.hstack((self.m - 2, jnp.arange(self.m - 2, -1, -1))), + (jnp.shape(f)[0], 1), ) return f, state @@ -305,7 +297,7 @@ def __init__(self, d=None, m=None, ref_num=1000): else: self.d = d - super().__init__(self.d, self.m, ref_num) + super().__init__(d, m, ref_num) self.sample = GridSampling(self.ref_num * self.m, self.m - 1) def evaluate(self, state, X): diff --git a/tests/test_multi_objective_algorithms.py b/tests/test_multi_objective_algorithms.py index 9d957cb7..ee822c61 100644 --- a/tests/test_multi_objective_algorithms.py +++ b/tests/test_multi_objective_algorithms.py @@ -7,16 +7,16 @@ N = 12 M = 3 +D = 7 POP_SIZE = 100 LB = 0 UB = 1 ITER = 10 -def run_moea(algorithm, problem=problems.numerical.DTLZ1(m=M)): +def run_moea(algorithm, problem=problems.numerical.DTLZ1(d=D, m=M)): key = jax.random.PRNGKey(42) monitor = StdMOMonitor(record_pf=False) - # problem = problems.numerical.DTLZ2(m=M) workflow = workflows.StdWorkflow( algorithm=algorithm, problem=problem, From ffc38ef0cfb6f82c65ecd4a79d20014aca5cf57e Mon Sep 17 00:00:00 2001 From: Zhenyu2Liang <1370096263@qq.com> Date: Sun, 14 Jan 2024 03:29:53 +0800 Subject: [PATCH 2/3] fix dtlz --- src/evox/problems/numerical/dtlz.py | 10 +++++----- tests/test_multi_objective_algorithms.py | 3 +-- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/src/evox/problems/numerical/dtlz.py b/src/evox/problems/numerical/dtlz.py index 95cb29ef..865ecb49 100644 --- a/src/evox/problems/numerical/dtlz.py +++ b/src/evox/problems/numerical/dtlz.py @@ -43,7 +43,7 @@ def __init__(self, d=None, m=None, ref_num=100): self.d = self.m + 4 else: self.d = d - super().__init__(d, m, ref_num) + super().__init__(self.d, self.m, ref_num) def evaluate(self, state, X): m = self.m @@ -79,7 +79,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(d, m, ref_num) + super().__init__(self.d, self.m, ref_num) def evaluate(self, state, X): m = self.m @@ -183,7 +183,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(d, m, ref_num) + super().__init__(self.d, self.m, ref_num) def evaluate(self, state, X): m = self.m @@ -239,7 +239,7 @@ def __init__(self, d=None, m=None, ref_num=1000): self.d = self.m + 9 else: self.d = d - super().__init__(d, m, ref_num) + super().__init__(self.d, self.m, ref_num) def evaluate(self, state, X): m = self.m @@ -297,7 +297,7 @@ def __init__(self, d=None, m=None, ref_num=1000): else: self.d = d - super().__init__(d, m, ref_num) + super().__init__(self.d, self.m, ref_num) self.sample = GridSampling(self.ref_num * self.m, self.m - 1) def evaluate(self, state, X): diff --git a/tests/test_multi_objective_algorithms.py b/tests/test_multi_objective_algorithms.py index ee822c61..337ff2d0 100644 --- a/tests/test_multi_objective_algorithms.py +++ b/tests/test_multi_objective_algorithms.py @@ -7,14 +7,13 @@ N = 12 M = 3 -D = 7 POP_SIZE = 100 LB = 0 UB = 1 ITER = 10 -def run_moea(algorithm, problem=problems.numerical.DTLZ1(d=D, m=M)): +def run_moea(algorithm, problem=problems.numerical.DTLZ1(m=M)): key = jax.random.PRNGKey(42) monitor = StdMOMonitor(record_pf=False) workflow = workflows.StdWorkflow( From 13dc751124d8572713db86865aba83b8f18a27f1 Mon Sep 17 00:00:00 2001 From: Zhenyu2Liang <1370096263@qq.com> Date: Sun, 14 Jan 2024 03:59:04 +0800 Subject: [PATCH 3/3] fix uniform --- src/evox/operators/sampling/uniform.py | 2 +- src/evox/problems/numerical/dtlz.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/evox/operators/sampling/uniform.py b/src/evox/operators/sampling/uniform.py index f04e813b..b7378bca 100644 --- a/src/evox/operators/sampling/uniform.py +++ b/src/evox/operators/sampling/uniform.py @@ -17,7 +17,7 @@ def __init__(self, n=None, m=None): self.n = n self.m = m - def __call__(self, key): + def __call__(self, key=None): h1 = 1 while comb(h1 + self.m, self.m - 1) <= self.n: h1 += 1 diff --git a/src/evox/problems/numerical/dtlz.py b/src/evox/problems/numerical/dtlz.py index 865ecb49..f0e3fd29 100644 --- a/src/evox/problems/numerical/dtlz.py +++ b/src/evox/problems/numerical/dtlz.py @@ -29,7 +29,7 @@ def evaluate(self, state, X): return jax.jit(jax.vmap(self._dtlz))(X), state def pf(self, state): - f = self.sample(state.key)[0] / 2 + f = self.sample()[0] / 2 return f, state @@ -103,7 +103,7 @@ def evaluate(self, state, X): return f, state def pf(self, state): - f = self.sample(state.key)[0] + f = self.sample()[0] f /= jnp.tile(jnp.sqrt(jnp.sum(f**2, axis=1, keepdims=True)), (1, self.m)) return f, state