From 93ae83e6f3f2fc8229c9cd4ade51a3bc7f6f5f83 Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Thu, 6 Jun 2024 16:01:03 +0800 Subject: [PATCH] dev: NSGA2 use built-in operator --- src/evox/algorithms/mo/nsga2.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/evox/algorithms/mo/nsga2.py b/src/evox/algorithms/mo/nsga2.py index c794b546..e1353d9b 100644 --- a/src/evox/algorithms/mo/nsga2.py +++ b/src/evox/algorithms/mo/nsga2.py @@ -16,6 +16,7 @@ crossover, ) from evox import Algorithm, jit_class, State +from evox.operators.selection import NonDominate @jit_class @@ -51,6 +52,8 @@ def __init__( self.mutation = mutation.Polynomial((self.lb, self.ub)) if self.crossover is None: self.crossover = crossover.SimulatedBinary() + + self.survivor_selection = NonDominate(self.pop_size) def setup(self, key): key, subkey = jax.random.split(key) @@ -87,14 +90,7 @@ def tell(self, state, fitness): merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) - rank = non_dominated_sort(merged_fitness) - order = jnp.argsort(rank) - worst_rank = rank[order[self.pop_size]] - mask = rank == worst_rank - crowding_dis = crowding_distance(merged_fitness, mask) + survivor, survivor_fitness = self.survivor_selection(merged_pop, merged_fitness) - combined_order = jnp.lexsort((-crowding_dis, rank))[: self.pop_size] - survivor = merged_pop[combined_order] - survivor_fitness = merged_fitness[combined_order] state = state.update(population=survivor, fitness=survivor_fitness) return state