diff --git a/examples/algorithm/so/pso_variants/eg_cpso_s.py b/examples/algorithm/so/pso_variants/eg_cpso_s.py new file mode 100644 index 00000000..99d6ada1 --- /dev/null +++ b/examples/algorithm/so/pso_variants/eg_cpso_s.py @@ -0,0 +1,53 @@ +from evox import algorithms, problems, pipelines, monitors +import jax +import jax.numpy as jnp +from functools import partial +import evox as ex + +algorithm = algorithms.so.pso_vatients.CPSO_S( + lb=jnp.full(shape=(10,), fill_value=-32), + ub=jnp.full(shape=(10,), fill_value=32), + pop_size=15, + inertia_weight=0.4, + pbest_coefficient=2.5, + gbest_coefficient=0.8, +) + +def _ackley_func(a, b, c, x): + return ( + -a * jnp.exp(-b * jnp.sqrt(jnp.mean(x**2))) + - jnp.exp(jnp.mean(jnp.cos(c * x))) + + a + + jnp.e + ) + +@ex.jit_class +class Ackley(ex.Problem): + def __init__(self, a=20, b=0.2, c=2*jnp.pi): + self.a = a + self.b = b + self.c = c + + def evaluate(self, state, X): + return jax.vmap(jax.vmap(partial(_ackley_func, self.a, self.b, self.c)))(X), state + +problem = Ackley() + +monitor = monitors.FitnessMonitor() + +# create a pipeline + +pipeline = pipelines.StdPipeline( + algorithm=algorithm, + problem=problem, + fitness_transform=monitor.update, +) + +# init the pipeline +key = jax.random.PRNGKey(42) +state = pipeline.init(key) + +# run the pipeline for 100 steps +for i in range(100): + state = pipeline.step(state) + print(monitor.get_min_fitness()) \ No newline at end of file diff --git a/examples/algorithm/so/pso_variants/eg_sl_pso_gs.py b/examples/algorithm/so/pso_variants/eg_sl_pso_gs.py index 0e255406..8cf8a8c6 100644 --- a/examples/algorithm/so/pso_variants/eg_sl_pso_gs.py +++ b/examples/algorithm/so/pso_variants/eg_sl_pso_gs.py @@ -6,8 +6,8 @@ lb=jnp.full(shape=(10,), fill_value=-32), ub=jnp.full(shape=(10,), fill_value=32), pop_size=100, - epsilon=0.1, - theta=0.1, + social_influence_factor=0.1, + demonstrator_choice_factor=0.1, ) problem = problems.classic.Ackley() diff --git a/examples/algorithm/so/pso_variants/eg_sl_pso_us.py b/examples/algorithm/so/pso_variants/eg_sl_pso_us.py index 24a447d3..a2bc9868 100644 --- a/examples/algorithm/so/pso_variants/eg_sl_pso_us.py +++ b/examples/algorithm/so/pso_variants/eg_sl_pso_us.py @@ -6,8 +6,8 @@ lb=jnp.full(shape=(10,), fill_value=-32), ub=jnp.full(shape=(10,), fill_value=32), pop_size=100, - epsilon=0.1, - _lambda=0.1, + social_influence_factor=0.1, + demonstrator_choice_factor=0.1, ) problem = problems.classic.Ackley() diff --git a/src/evox/algorithms/so/pso_vatients/__init__.py b/src/evox/algorithms/so/pso_vatients/__init__.py index df527c6e..fecc4b18 100644 --- a/src/evox/algorithms/so/pso_vatients/__init__.py +++ b/src/evox/algorithms/so/pso_vatients/__init__.py @@ -1,3 +1,4 @@ from .sl_pso_us import SL_PSO_US from .sl_pso_gs import SL_PSO_GS -from .clpso import CLPSO \ No newline at end of file +from .clpso import CLPSO +from .cpso_s import CPSO_S \ No newline at end of file diff --git a/src/evox/algorithms/so/pso_vatients/clpso.py b/src/evox/algorithms/so/pso_vatients/clpso.py index db779c2e..618ca39e 100644 --- a/src/evox/algorithms/so/pso_vatients/clpso.py +++ b/src/evox/algorithms/so/pso_vatients/clpso.py @@ -64,7 +64,7 @@ def tell(self, state, fitness): # ----------------- Update gbest ----------------- gbest_position, gbest_fitness = min_by( [state.gbest_position[jnp.newaxis, :], state.population], - [state.gbest_position, fitness], + [state.gbest_fitness, fitness], ) gbest_fitness = jnp.atleast_1d(gbest_fitness) diff --git a/src/evox/algorithms/so/pso_vatients/cpso_s.py b/src/evox/algorithms/so/pso_vatients/cpso_s.py new file mode 100644 index 00000000..42746a2c --- /dev/null +++ b/src/evox/algorithms/so/pso_vatients/cpso_s.py @@ -0,0 +1,129 @@ +import jax +import jax.numpy as jnp + +import evox as ex +from evox.utils import * + +# SL-PSO: Social Learning PSO +# SL-PSO-GS: Using Gaussian Sampling for Demonstator Choice +# https://ieeexplore.ieee.org/document/6900227 +@ex.jit_class +class CPSO_S(ex.Algorithm): + def __init__( + self, + lb, # lower bound of problem + ub, # upper bound of problem + pop_size, # population size for one swarm of a single dimension + inertia_weight, # w + pbest_coefficient, # c_pbest + gbest_coefficient, # c_gbest + ): + self.dim = lb.shape[0] + self.lb = lb + self.ub = ub + self.pop_size = pop_size + self.w = inertia_weight + self.c_pbest = pbest_coefficient + self.c_gbest = gbest_coefficient + + def setup(self, key): + state_key, init_pop_key, init_v_key = jax.random.split(key, 3) + ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size)) + lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size)) + length = ub - lb + _population = jax.random.uniform( + init_pop_key, shape=(self.dim, self.pop_size) + ) + _population = _population * length + lb + + context_vector = _population[:, 0] # b + broadcast_context_vector = jnp.broadcast_to(context_vector[:, None], (self.dim, self.pop_size)) + cond = jnp.broadcast_to(jnp.arange(self.dim)[:,None] == jnp.arange(self.dim), shape=(self.dim, self.dim))[:, :, None] + population = jnp.where(cond, _population[None, :], broadcast_context_vector[None, :]) + population = jnp.transpose(population, axes=(0,2,1)) + + velocity = jax.random.uniform(init_v_key, shape=(self.dim, self.pop_size)) + velocity = velocity * length * 2 - length + + return ex.State( + # _population/velocity: shape:(dim, pop_size) + # _population has dim different swarms, each swarm has pop_size particles + # each particle in a swarm only records one single number, which is the position of this particle in this dimension + _population=_population, + velocity=velocity, + + # population: shape:(dim, pop_size, dim) + # population has dim different swarms, each swarm has pop_size particles + # population is the combination of _population and context_vector + # it is used to calculate the fitness of each particle + population=population, + + # pbest: like other algorithms, pbest is the best position of each particle + pbest_position=_population, # shape:(dim, pop_size) + pbest_fitness=jnp.full((self.dim, self.pop_size,), jnp.inf), # shape:(dim, pop_size) + + # gbest: !!! gbest is the best position of the swarm, not the whole population !!! + # Because each swarm only focuses on one dimension, so the gbest in cpso_s should only record one single number, + # But for the convenience of coding, we still use a array with shape(dim, dim) to represent the gbest position. + # which represents the best position of this swarm. + # therefore, the shape of gbest position is (dim, dim) and gbest fitness is (dim,) + gbest_position=_population[:, 0], # shape:(dim,) + gbest_fitness=jnp.full((self.dim,), jnp.inf), # shape:(dim,) + + # in fact, gbest_position is the same as context_vector + # but we still use gbest_position to represent the best position of one swarm + context_vector=context_vector, + key=state_key, + ) + + def ask(self, state): + return state.population, state + + # fitness: shape:(dim, pop_size) + def tell(self, state, fitness): + state_key, rand_key_gbest, rand_key_pbest = jax.random.split(state.key, num=3) + + # ----------------- Update pbest ----------------- + compare = state.pbest_fitness > fitness + pbest_position = jnp.where( + compare, state._population, state.pbest_position + ) + pbest_fitness = jnp.minimum(state.pbest_fitness, fitness) + + # ----------------- Update gbest ----------------- + gbest_fitness = jnp.amin(pbest_fitness, axis=1) + gbest_index = jnp.argmin(pbest_fitness, axis=1) + gbest_position = pbest_position[jnp.arange(pbest_position.shape[0]), gbest_index] + + # ------------------------------------------------------ + + rand_pbest = jax.random.uniform(rand_key_pbest, shape=(self.dim, self.pop_size)) + rand_gbest = jax.random.uniform(rand_key_gbest, shape=(self.dim, self.pop_size)) + velocity = ( + self.w * state.velocity + + self.c_pbest * rand_pbest * (pbest_position - state._population) + + self.c_gbest * rand_gbest * (jnp.broadcast_to(gbest_position[:, None], shape=(self.dim, self.pop_size)) - state._population) + ) + _population = state._population + velocity + ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size)) + lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size)) + _population = jnp.clip(_population, lb, ub) + + # ----------------- Update population ----------------- + context_vector = gbest_position + broadcast_context_vector = jnp.broadcast_to(context_vector[:,None], (self.dim, self.pop_size)) + cond = jnp.broadcast_to(jnp.arange(self.dim)[:,None] == jnp.arange(self.dim), shape=(self.dim, self.dim))[:, :, None] + population = jnp.where(cond, _population[None, :], broadcast_context_vector[None, :]) + population = jnp.transpose(population, axes=(0,2,1)) + + return ex.State( + _population=_population, + velocity=velocity, + population=population, + pbest_position=pbest_position, + pbest_fitness=pbest_fitness, + gbest_position=gbest_position, + gbest_fitness=gbest_fitness, + context_vector=context_vector, + key=state_key, + )