Skip to content

Commit

Permalink
dev: NSGA2 use built-in operator
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Jun 13, 2024
1 parent 6b49e1c commit 93ae83e
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/evox/algorithms/mo/nsga2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
crossover,
)
from evox import Algorithm, jit_class, State
from evox.operators.selection import NonDominate


@jit_class
Expand Down Expand Up @@ -51,6 +52,8 @@ def __init__(
self.mutation = mutation.Polynomial((self.lb, self.ub))
if self.crossover is None:
self.crossover = crossover.SimulatedBinary()

self.survivor_selection = NonDominate(self.pop_size)

def setup(self, key):
key, subkey = jax.random.split(key)
Expand Down Expand Up @@ -87,14 +90,7 @@ def tell(self, state, fitness):
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0)

rank = non_dominated_sort(merged_fitness)
order = jnp.argsort(rank)
worst_rank = rank[order[self.pop_size]]
mask = rank == worst_rank
crowding_dis = crowding_distance(merged_fitness, mask)
survivor, survivor_fitness = self.survivor_selection(merged_pop, merged_fitness)

combined_order = jnp.lexsort((-crowding_dis, rank))[: self.pop_size]
survivor = merged_pop[combined_order]
survivor_fitness = merged_fitness[combined_order]
state = state.update(population=survivor, fitness=survivor_fitness)
return state

0 comments on commit 93ae83e

Please sign in to comment.