Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenyu2Liang committed Aug 3, 2023
1 parent e06e228 commit 9cdb0cf
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 29 deletions.
9 changes: 4 additions & 5 deletions tests/test_crowding_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import jax.numpy as jnp
import chex
import evox as ex
from evox.operators.selection import crowding_distance, non_dominated_sort


def test_crowding_distance1():
x = jnp.array([[0.5, 4], [1, 2.5], [2, 2], [3, 1], [4, 0.8]])
distance = crowding_distance(x)
distance = ex.operators.crowding_distance(x)
chex.assert_trees_all_close(
distance,
jnp.array(
Expand All @@ -25,9 +24,9 @@ def test_crowding_distance1():
def test_crowding_distance2():
key = jax.random.PRNGKey(314)
x = jax.random.normal(key, (128, 8))
rank = non_dominated_sort(x)
rank = ex.operators.non_dominated_sort(x)
pareto_front = x[rank == 0]
distance = crowding_distance(pareto_front)
distance = ex.operators.crowding_distance(pareto_front)
ground_truth = jnp.array(
[
jnp.inf,
Expand Down Expand Up @@ -128,7 +127,7 @@ def test_crowding_distance2():
def test_masked_crowding_distance1():
x = jnp.array([[-1, -1], [0.5, 4], [1, 2.5], [2, 2], [-2, -2], [3, 1], [4, 0.8], [-3, -3], [-3, -4]])
mask = jnp.array([False, True, True, True, False, True, True, False, False])
distance = crowding_distance(x, mask)
distance = ex.operators.crowding_distance(x, mask)
chex.assert_trees_all_close(
distance,
jnp.array(
Expand Down
45 changes: 22 additions & 23 deletions tests/test_multi_objective_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@


def run_moea(algorithm, problem=problems.classic.DTLZ1(m=3)):
key = jax.random.PRNGKey(123)
key = jax.random.PRNGKey(42)
monitor = StdMOMonitor(record_pf=False)
problem = problems.classic.DTLZ1(m=3)
problem = problems.classic.DTLZ2(m=3)
pipeline = pipelines.StdPipeline(
algorithm=algorithm,
problem=problem,
Expand All @@ -27,8 +27,8 @@ def run_moea(algorithm, problem=problems.classic.DTLZ1(m=3)):

def test_ibea():
algorithm = algorithms.IBEA(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -37,8 +37,8 @@ def test_ibea():

def test_moead():
algorithm = algorithms.MOEAD(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
type=1,
Expand All @@ -48,8 +48,8 @@ def test_moead():

def test_nsga2():
algorithm = algorithms.NSGA2(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -58,18 +58,17 @@ def test_nsga2():

def test_rvea():
algorithm = algorithms.RVEA(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
run_moea(algorithm)



def test_nsga3():
algorithm = algorithms.NSGA3(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -78,8 +77,8 @@ def test_nsga3():

def test_eagmoead():
algorithm = algorithms.EAGMOEAD(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -88,8 +87,8 @@ def test_eagmoead():

def test_hype():
algorithm = algorithms.HypE(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -98,8 +97,8 @@ def test_hype():

def test_moeaddra():
algorithm = algorithms.MOEADDRA(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
Expand All @@ -108,9 +107,9 @@ def test_moeaddra():

def test_spea2():
algorithm = algorithms.SPEA2(
lb=jnp.full(shape=(3,), fill_value=0),
ub=jnp.full(shape=(3,), fill_value=1),
lb=jnp.full(shape=(12,), fill_value=0),
ub=jnp.full(shape=(12,), fill_value=1),
n_objs=3,
pop_size=100,
)
run_moea(algorithm)
run_moea(algorithm)
2 changes: 1 addition & 1 deletion tests/test_non_dominated_sort.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import evox
from evox.operators.selection import non_dominated_sort
from evox.operators import non_dominated_sort
import jax
import jax.numpy as jnp
import chex
Expand Down

0 comments on commit 9cdb0cf

Please sign in to comment.