diff --git a/src/evox/algorithms/so/pso_variants/cpso_s.py b/src/evox/algorithms/so/pso_variants/cpso_s.py index 9dc619ef..e62f27a9 100644 --- a/src/evox/algorithms/so/pso_variants/cpso_s.py +++ b/src/evox/algorithms/so/pso_variants/cpso_s.py @@ -10,146 +10,45 @@ from evox.utils import * from evox import Algorithm, State, jit_class +from evox.algorithms.containers.coevolution import Coevolution +from evox.algorithms.so.pso_variants.pso import PSO # CPSO-S: Cooperative PSO @jit_class -class CPSOS(Algorithm): +class CPSOS(Coevolution): + """Cooperative particle swarm optimizer. + Implemented using EvoX's built-in coevolution framework. + CPSOS essentially a wrapper around PSO and Coevolution. + + https://ieeexplore.ieee.org/document/1304845 + """ + def __init__( self, lb, # lower bound of problem ub, # upper bound of problem - pop_size, # population size for one swarm of a single dimension - inertia_weight, # w - pbest_coefficient, # c_pbest - gbest_coefficient, # c_gbest + subpop_size: int, + inertia_weight: float, # w + cognitive_coefficient: float, # c_pbest + social_coefficient: float, # c_gbest ): - self.dim = lb.shape[0] - self.lb = lb - self.ub = ub - self.pop_size = pop_size - self.w = inertia_weight - self.c_pbest = pbest_coefficient - self.c_gbest = gbest_coefficient - - def setup(self, key): - state_key, init_pop_key, init_v_key = jax.random.split(key, 3) - ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size)) - lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size)) - length = ub - lb - _population = jax.random.uniform(init_pop_key, shape=(self.dim, self.pop_size)) - _population = _population * length + lb - - context_vector = _population[:, 0] # b - broadcast_context_vector = jnp.broadcast_to( - context_vector[:, None], (self.dim, self.pop_size) - ) - cond = jnp.broadcast_to( - jnp.arange(self.dim)[:, None] == jnp.arange(self.dim), - shape=(self.dim, self.dim), - )[:, :, None] - population = jnp.where( - cond, _population[None, :], broadcast_context_vector[None, :] + assert jnp.all(lb[0] == lb) and jnp.all( + ub[0] == ub + ), "Currently the coevolution framewrok restricts that the upper/lower bound should be the same across dimensions" + pso = PSO( + lb[:1], + ub[:1], + subpop_size, + inertia_weight, + cognitive_coefficient, + social_coefficient, ) - population = jnp.transpose(population, axes=(0, 2, 1)) - - velocity = jax.random.uniform(init_v_key, shape=(self.dim, self.pop_size)) - velocity = velocity * length * 2 - length - - return State( - # _population/velocity: shape:(dim, pop_size) - # _population has dim different swarms, each swarm has pop_size particles - # each particle in a swarm only records one single number, which is the position of this particle in this dimension - _population=_population, - velocity=velocity, - # population: shape:(dim, pop_size, dim) - # population has dim different swarms, each swarm has pop_size particles - # population is the combination of _population and context_vector - # it is used to calculate the fitness of each particle - population=population, - # pbest: like other algorithms, pbest is the best position of each particle - pbest_position=_population, # shape:(dim, pop_size) - pbest_fitness=jnp.full( - ( - self.dim, - self.pop_size, - ), - jnp.inf, - ), # shape:(dim, pop_size) - # gbest: !!! gbest is the best position of the swarm, not the whole population !!! - # Because each swarm only focuses on one dimension, so the gbest in cpso_s should only record one single number, - # But for the convenience of coding, we still use a array with shape(dim, dim) to represent the gbest position. - # which represents the best position of this swarm. - # therefore, the shape of gbest position is (dim, dim) and gbest fitness is (dim,) - gbest_position=_population[:, 0], # shape:(dim,) - gbest_fitness=jnp.full((self.dim,), jnp.inf), # shape:(dim,) - # in fact, gbest_position is the same as context_vector - # but we still use gbest_position to represent the best position of one swarm - context_vector=context_vector, - key=state_key, - ) - - def ask(self, state): - return state.population, state - - def tell(self, state, fitness): - state_key, rand_key_gbest, rand_key_pbest = jax.random.split(state.key, num=3) - - # ----------------- Update pbest ----------------- - compare = state.pbest_fitness > fitness - pbest_position = jnp.where(compare, state._population, state.pbest_position) - pbest_fitness = jnp.minimum(state.pbest_fitness, fitness) - - # ----------------- Update gbest ----------------- - gbest_fitness = jnp.amin(pbest_fitness, axis=1) - gbest_index = jnp.argmin(pbest_fitness, axis=1) - gbest_position = pbest_position[ - jnp.arange(pbest_position.shape[0]), gbest_index - ] - - # ------------------------------------------------------ - - rand_pbest = jax.random.uniform(rand_key_pbest, shape=(self.dim, self.pop_size)) - rand_gbest = jax.random.uniform(rand_key_gbest, shape=(self.dim, self.pop_size)) - velocity = ( - self.w * state.velocity - + self.c_pbest * rand_pbest * (pbest_position - state._population) - + self.c_gbest - * rand_gbest - * ( - jnp.broadcast_to( - gbest_position[:, None], shape=(self.dim, self.pop_size) - ) - - state._population - ) - ) - _population = state._population + velocity - ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size)) - lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size)) - _population = jnp.clip(_population, lb, ub) - - # ----------------- Update population ----------------- - context_vector = gbest_position - broadcast_context_vector = jnp.broadcast_to( - context_vector[:, None], (self.dim, self.pop_size) - ) - cond = jnp.broadcast_to( - jnp.arange(self.dim)[:, None] == jnp.arange(self.dim), - shape=(self.dim, self.dim), - )[:, :, None] - population = jnp.where( - cond, _population[None, :], broadcast_context_vector[None, :] - ) - population = jnp.transpose(population, axes=(0, 2, 1)) - - return state.update( - _population=_population, - velocity=velocity, - population=population, - pbest_position=pbest_position, - pbest_fitness=pbest_fitness, - gbest_position=gbest_position, - gbest_fitness=gbest_fitness, - context_vector=context_vector, - key=state_key, + super().__init__( + base_algorithm=pso, + dim=lb.shape[0], + num_subpops=lb.shape[0], + subpop_size=subpop_size, + num_subpop_iter=1, + random_subpop=False, ) diff --git a/tests/test_single_objective_algorithms.py b/tests/test_single_objective_algorithms.py index 32312c56..2c4de31e 100644 --- a/tests/test_single_objective_algorithms.py +++ b/tests/test_single_objective_algorithms.py @@ -5,6 +5,7 @@ from evox.algorithms import ( CMAES, SepCMAES, + CPSOS, CSO, DE, PGPE, @@ -46,6 +47,18 @@ def run_single_objective_algorithm( return monitor.get_best_fitness() +def test_cpso_s(): + lb = jnp.full((5,), -32.0) + ub = jnp.full((5,), 32.0) + algorithm = CPSOS(lb, ub, 100, + inertia_weight=0.6, + cognitive_coefficient=2.5, + social_coefficient=0.8 + ) + fitness = run_single_objective_algorithm(algorithm) + assert fitness < 0.1 + + def test_cso(): lb = jnp.full((5,), -32.0) ub = jnp.full((5,), 32.0)