Skip to content

Commit

Permalink
fix: some operators
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Aug 3, 2023
1 parent 3255488 commit 5a237a8
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/evox/operators/crossover/one_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def _unpair(x):
def _one_point_crossover(key, parents):
_, dim = parents.shape
point = random.choice(key, dim) + 1
mask = jnp.ones((point,))
mask = jnp.pad(mask, (0, dim - point), "constant", constant_values=(0, 0))
mask = jnp.arange(dim) < point
c1 = jnp.where(mask, parents[0], parents[1])
c2 = jnp.where(mask, parents[1], parents[0])
return jnp.stack([c1, c2])
Expand Down
14 changes: 11 additions & 3 deletions src/evox/operators/selection/non_dominate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from jax import lax, jit
import jax.numpy as jnp
from evox import jit_class
from evox.operators.non_dominated_sort import non_dominated_sort
from evox.operators.crowding_distance_sort import crowding_distance_sort
from functools import partial


Expand All @@ -9,9 +11,15 @@ def non_dominate(population, fitness, topk):
"""Selection the topk individuals besed on their ranking with non-dominated sort,
returns the selected population and the corresponding fitness.
"""
ranking = non_dominated_sort(fitness)
_, index = lax.topk(ranking, topk)
return population[index], fitness[index]
# first apply non_dominated sort
rank = non_dominated_sort(fitness)
# then find the worst rank within topk, and use crodwing_distance_sort as tiebreaker
worst_rank = -lax.top_k(-rank, topk)
mask = rank == worst_rank
crowding_distance = crowding_distance_sort(fitness, mask)

combined_order = jnp.lexsort((-crowding_distance, rank))[:topk]
return population[combined_order], fitness[combined_order]


@jit_class
Expand Down
2 changes: 1 addition & 1 deletion src/evox/operators/selection/topk_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def topk_fit(population, fitness, topk):
"""Selection the topk individuals besed on the fitness,
returns the selected population and the corresponding fitness.
"""
topk_fit, index = lax.topk(-fitness, topk)
topk_fit, index = lax.top_k(-fitness, topk)
return population[index], topk_fit


Expand Down

0 comments on commit 5a237a8

Please sign in to comment.