diff --git a/src/evox/algorithms/mo/nsga2.py b/src/evox/algorithms/mo/nsga2.py index c794b546..e1353d9b 100644 --- a/src/evox/algorithms/mo/nsga2.py +++ b/src/evox/algorithms/mo/nsga2.py @@ -16,6 +16,7 @@ crossover, ) from evox import Algorithm, jit_class, State +from evox.operators.selection import NonDominate @jit_class @@ -51,6 +52,8 @@ def __init__( self.mutation = mutation.Polynomial((self.lb, self.ub)) if self.crossover is None: self.crossover = crossover.SimulatedBinary() + + self.survivor_selection = NonDominate(self.pop_size) def setup(self, key): key, subkey = jax.random.split(key) @@ -87,14 +90,7 @@ def tell(self, state, fitness): merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0) merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0) - rank = non_dominated_sort(merged_fitness) - order = jnp.argsort(rank) - worst_rank = rank[order[self.pop_size]] - mask = rank == worst_rank - crowding_dis = crowding_distance(merged_fitness, mask) + survivor, survivor_fitness = self.survivor_selection(merged_pop, merged_fitness) - combined_order = jnp.lexsort((-crowding_dis, rank))[: self.pop_size] - survivor = merged_pop[combined_order] - survivor_fitness = merged_fitness[combined_order] state = state.update(population=survivor, fitness=survivor_fitness) return state