From 5a237a8e7699db451736738560e6da323e7b4024 Mon Sep 17 00:00:00 2001 From: Bill Huang Date: Thu, 3 Aug 2023 14:17:21 +0800 Subject: [PATCH] fix: some operators --- src/evox/operators/crossover/one_point.py | 3 +-- src/evox/operators/selection/non_dominate.py | 14 +++++++++++--- src/evox/operators/selection/topk_fit.py | 2 +- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/evox/operators/crossover/one_point.py b/src/evox/operators/crossover/one_point.py index 924fde1e..947105a9 100644 --- a/src/evox/operators/crossover/one_point.py +++ b/src/evox/operators/crossover/one_point.py @@ -17,8 +17,7 @@ def _unpair(x): def _one_point_crossover(key, parents): _, dim = parents.shape point = random.choice(key, dim) + 1 - mask = jnp.ones((point,)) - mask = jnp.pad(mask, (0, dim - point), "constant", constant_values=(0, 0)) + mask = jnp.arange(dim) < point c1 = jnp.where(mask, parents[0], parents[1]) c2 = jnp.where(mask, parents[1], parents[0]) return jnp.stack([c1, c2]) diff --git a/src/evox/operators/selection/non_dominate.py b/src/evox/operators/selection/non_dominate.py index 7a38a751..f5a0f51f 100644 --- a/src/evox/operators/selection/non_dominate.py +++ b/src/evox/operators/selection/non_dominate.py @@ -1,6 +1,8 @@ from jax import lax, jit +import jax.numpy as jnp from evox import jit_class from evox.operators.non_dominated_sort import non_dominated_sort +from evox.operators.crowding_distance_sort import crowding_distance_sort from functools import partial @@ -9,9 +11,15 @@ def non_dominate(population, fitness, topk): """Selection the topk individuals besed on their ranking with non-dominated sort, returns the selected population and the corresponding fitness. """ - ranking = non_dominated_sort(fitness) - _, index = lax.topk(ranking, topk) - return population[index], fitness[index] + # first apply non_dominated sort + rank = non_dominated_sort(fitness) + # then find the worst rank within topk, and use crodwing_distance_sort as tiebreaker + worst_rank = -lax.top_k(-rank, topk) + mask = rank == worst_rank + crowding_distance = crowding_distance_sort(fitness, mask) + + combined_order = jnp.lexsort((-crowding_distance, rank))[:topk] + return population[combined_order], fitness[combined_order] @jit_class diff --git a/src/evox/operators/selection/topk_fit.py b/src/evox/operators/selection/topk_fit.py index 01f8d2b3..61e78155 100644 --- a/src/evox/operators/selection/topk_fit.py +++ b/src/evox/operators/selection/topk_fit.py @@ -8,7 +8,7 @@ def topk_fit(population, fitness, topk): """Selection the topk individuals besed on the fitness, returns the selected population and the corresponding fitness. """ - topk_fit, index = lax.topk(-fitness, topk) + topk_fit, index = lax.top_k(-fitness, topk) return population[index], topk_fit