From 9cdb0cff6e95ebdf6d22f3d3abfe92fa1b4de1bb Mon Sep 17 00:00:00 2001 From: Zhenyu2Liang <1370096263@qq.com> Date: Thu, 3 Aug 2023 20:30:42 +0800 Subject: [PATCH] update tests --- tests/test_crowding_distance.py | 9 +++-- tests/test_multi_objective_algorithms.py | 45 ++++++++++++------------ tests/test_non_dominated_sort.py | 2 +- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/tests/test_crowding_distance.py b/tests/test_crowding_distance.py index 51c328a9..d1cf45c6 100644 --- a/tests/test_crowding_distance.py +++ b/tests/test_crowding_distance.py @@ -2,12 +2,11 @@ import jax.numpy as jnp import chex import evox as ex -from evox.operators.selection import crowding_distance, non_dominated_sort def test_crowding_distance1(): x = jnp.array([[0.5, 4], [1, 2.5], [2, 2], [3, 1], [4, 0.8]]) - distance = crowding_distance(x) + distance = ex.operators.crowding_distance(x) chex.assert_trees_all_close( distance, jnp.array( @@ -25,9 +24,9 @@ def test_crowding_distance1(): def test_crowding_distance2(): key = jax.random.PRNGKey(314) x = jax.random.normal(key, (128, 8)) - rank = non_dominated_sort(x) + rank = ex.operators.non_dominated_sort(x) pareto_front = x[rank == 0] - distance = crowding_distance(pareto_front) + distance = ex.operators.crowding_distance(pareto_front) ground_truth = jnp.array( [ jnp.inf, @@ -128,7 +127,7 @@ def test_crowding_distance2(): def test_masked_crowding_distance1(): x = jnp.array([[-1, -1], [0.5, 4], [1, 2.5], [2, 2], [-2, -2], [3, 1], [4, 0.8], [-3, -3], [-3, -4]]) mask = jnp.array([False, True, True, True, False, True, True, False, False]) - distance = crowding_distance(x, mask) + distance = ex.operators.crowding_distance(x, mask) chex.assert_trees_all_close( distance, jnp.array( diff --git a/tests/test_multi_objective_algorithms.py b/tests/test_multi_objective_algorithms.py index b67b84cb..59cf887c 100644 --- a/tests/test_multi_objective_algorithms.py +++ b/tests/test_multi_objective_algorithms.py @@ -6,9 +6,9 @@ def run_moea(algorithm, problem=problems.classic.DTLZ1(m=3)): - key = jax.random.PRNGKey(123) + key = jax.random.PRNGKey(42) monitor = StdMOMonitor(record_pf=False) - problem = problems.classic.DTLZ1(m=3) + problem = problems.classic.DTLZ2(m=3) pipeline = pipelines.StdPipeline( algorithm=algorithm, problem=problem, @@ -27,8 +27,8 @@ def run_moea(algorithm, problem=problems.classic.DTLZ1(m=3)): def test_ibea(): algorithm = algorithms.IBEA( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -37,8 +37,8 @@ def test_ibea(): def test_moead(): algorithm = algorithms.MOEAD( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, type=1, @@ -48,8 +48,8 @@ def test_moead(): def test_nsga2(): algorithm = algorithms.NSGA2( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -58,18 +58,17 @@ def test_nsga2(): def test_rvea(): algorithm = algorithms.RVEA( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) run_moea(algorithm) - - + def test_nsga3(): algorithm = algorithms.NSGA3( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -78,8 +77,8 @@ def test_nsga3(): def test_eagmoead(): algorithm = algorithms.EAGMOEAD( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -88,8 +87,8 @@ def test_eagmoead(): def test_hype(): algorithm = algorithms.HypE( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -98,8 +97,8 @@ def test_hype(): def test_moeaddra(): algorithm = algorithms.MOEADDRA( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) @@ -108,9 +107,9 @@ def test_moeaddra(): def test_spea2(): algorithm = algorithms.SPEA2( - lb=jnp.full(shape=(3,), fill_value=0), - ub=jnp.full(shape=(3,), fill_value=1), + lb=jnp.full(shape=(12,), fill_value=0), + ub=jnp.full(shape=(12,), fill_value=1), n_objs=3, pop_size=100, ) - run_moea(algorithm) + run_moea(algorithm) \ No newline at end of file diff --git a/tests/test_non_dominated_sort.py b/tests/test_non_dominated_sort.py index 74db773f..8cd19b00 100644 --- a/tests/test_non_dominated_sort.py +++ b/tests/test_non_dominated_sort.py @@ -1,5 +1,5 @@ import evox -from evox.operators.selection import non_dominated_sort +from evox.operators import non_dominated_sort import jax import jax.numpy as jnp import chex