Skip to content

Commit

Permalink
dev: introduce dedup option
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Jun 13, 2024
1 parent 916b5da commit 6b49e1c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
39 changes: 33 additions & 6 deletions src/evox/operators/selection/non_dominate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, lax, pure_callback, vmap
Expand Down Expand Up @@ -178,27 +179,53 @@ def crowding_distance_sort(x: jax.Array, mask: jax.Array = None):
return jnp.argsort(distance)


@partial(jit, static_argnums=[2])
def non_dominate(population, fitness, topk):
@partial(jit, static_argnums=[2, 3])
def non_dominate(population, fitness, topk, deduplicate):
"""Selection the topk individuals besed on their ranking with non-dominated sort,
returns the selected population and the corresponding fitness.
"""
if deduplicate:
# remove duplicated individuals by assigning their fitness to inf
_, unique_index, unique_count = jnp.unique(
population,
axis=0,
size=population.shape[0],
return_index=True,
return_counts=True,
)
population = population[unique_index]
fitness = fitness[unique_index]
count = jnp.sum(unique_count > 0)
# backup the original fitness
# so even when a duplicated individual is selected, the original fitness is used
# this will happen if the topk is larger than the number of unique individuals
fitness_bak = fitness
fitness = jnp.where(
(jnp.arange(fitness.shape[0]) < count)[:, jnp.newaxis],
fitness,
jnp.inf,
)

# first apply non_dominated sort
rank = non_dominated_sort(fitness)
# then find the worst rank within topk, and use crodwing_distance_sort as tiebreaker
order = jnp.argsort(rank)
worst_rank = rank[order[topk-1]]
worst_rank = rank[order[topk - 1]]
mask = rank == worst_rank
crowding_distance = crowding_distance_sort(fitness, mask)

combined_order = jnp.lexsort((-crowding_distance, rank))[:topk]
return population[combined_order], fitness[combined_order]
if deduplicate:
return population[combined_order], fitness_bak[combined_order]
else:
return population[combined_order], fitness[combined_order]


@jit_class
class NonDominate:
def __init__(self, topk):
def __init__(self, topk, deduplicate=False):
self.topk = topk
self.deduplicate = deduplicate

def __call__(self, population, fitness):
return non_dominate(population, fitness, self.topk)
return non_dominate(population, fitness, self.topk, self.deduplicate)
38 changes: 32 additions & 6 deletions src/evox/operators/selection/topk_fit.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,47 @@
from jax import lax, jit
import jax.numpy as jnp
from evox import jit_class
from functools import partial


@partial(jit, static_argnums=[2])
def topk_fit(population, fitness, topk):
@partial(jit, static_argnums=[2, 3])
def topk_fit(population, fitness, topk, deduplicate):
"""Selection the topk individuals besed on the fitness,
returns the selected population and the corresponding fitness.
"""
topk_fit, index = lax.top_k(-fitness, topk)
return population[index], -topk_fit

if deduplicate:
# remove duplicated individuals by assigning their fitness to inf
_, unique_index, unique_count = jnp.unique(
population,
axis=0,
size=population.shape[0],
return_index=True,
return_counts=True,
)
population = population[unique_index]
fitness = fitness[unique_index]
count = jnp.sum(unique_count > 0)
# backup the original fitness
# so even when a duplicated individual is selected, the original fitness is used
# this will happen if the topk is larger than the number of unique individuals
fitness_bak = fitness
fitness = jnp.where(jnp.arange(fitness.shape[0]) < count, fitness, jnp.inf)

index = jnp.argsort(fitness)
index = index[:topk]

if deduplicate:
return population[index], fitness_bak[index]
else:
return population[index], fitness[index]


@jit_class
class TopkFit:
def __init__(self, topk):
def __init__(self, topk, deduplicate=False):
self.topk = topk
self.deduplicate = deduplicate

def __call__(self, population, fitness):
return topk_fit(population, fitness, self.topk)
return topk_fit(population, fitness, self.topk, self.deduplicate)

0 comments on commit 6b49e1c

Please sign in to comment.