Skip to content

Commit

Permalink
add algorithm cpso_s
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoyuWang141 committed Aug 1, 2023
1 parent aefd81a commit 65f0194
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 6 deletions.
53 changes: 53 additions & 0 deletions examples/algorithm/so/pso_variants/eg_cpso_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from evox import algorithms, problems, pipelines, monitors
import jax
import jax.numpy as jnp
from functools import partial
import evox as ex

algorithm = algorithms.so.pso_vatients.CPSO_S(
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=15,
inertia_weight=0.4,
pbest_coefficient=2.5,
gbest_coefficient=0.8,
)

def _ackley_func(a, b, c, x):
return (
-a * jnp.exp(-b * jnp.sqrt(jnp.mean(x**2)))
- jnp.exp(jnp.mean(jnp.cos(c * x)))
+ a
+ jnp.e
)

@ex.jit_class
class Ackley(ex.Problem):
def __init__(self, a=20, b=0.2, c=2*jnp.pi):
self.a = a
self.b = b
self.c = c

def evaluate(self, state, X):
return jax.vmap(jax.vmap(partial(_ackley_func, self.a, self.b, self.c)))(X), state

problem = Ackley()

monitor = monitors.FitnessMonitor()

# create a pipeline

pipeline = pipelines.StdPipeline(
algorithm=algorithm,
problem=problem,
fitness_transform=monitor.update,
)

# init the pipeline
key = jax.random.PRNGKey(42)
state = pipeline.init(key)

# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())
4 changes: 2 additions & 2 deletions examples/algorithm/so/pso_variants/eg_sl_pso_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=100,
epsilon=0.1,
theta=0.1,
social_influence_factor=0.1,
demonstrator_choice_factor=0.1,
)

problem = problems.classic.Ackley()
Expand Down
4 changes: 2 additions & 2 deletions examples/algorithm/so/pso_variants/eg_sl_pso_us.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=100,
epsilon=0.1,
_lambda=0.1,
social_influence_factor=0.1,
demonstrator_choice_factor=0.1,
)

problem = problems.classic.Ackley()
Expand Down
3 changes: 2 additions & 1 deletion src/evox/algorithms/so/pso_vatients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .sl_pso_us import SL_PSO_US
from .sl_pso_gs import SL_PSO_GS
from .clpso import CLPSO
from .clpso import CLPSO
from .cpso_s import CPSO_S
2 changes: 1 addition & 1 deletion src/evox/algorithms/so/pso_vatients/clpso.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def tell(self, state, fitness):
# ----------------- Update gbest -----------------
gbest_position, gbest_fitness = min_by(
[state.gbest_position[jnp.newaxis, :], state.population],
[state.gbest_position, fitness],
[state.gbest_fitness, fitness],
)
gbest_fitness = jnp.atleast_1d(gbest_fitness)

Expand Down
129 changes: 129 additions & 0 deletions src/evox/algorithms/so/pso_vatients/cpso_s.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import jax
import jax.numpy as jnp

import evox as ex
from evox.utils import *

# SL-PSO: Social Learning PSO
# SL-PSO-GS: Using Gaussian Sampling for Demonstator Choice
# https://ieeexplore.ieee.org/document/6900227
@ex.jit_class
class CPSO_S(ex.Algorithm):
def __init__(
self,
lb, # lower bound of problem
ub, # upper bound of problem
pop_size, # population size for one swarm of a single dimension
inertia_weight, # w
pbest_coefficient, # c_pbest
gbest_coefficient, # c_gbest
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.w = inertia_weight
self.c_pbest = pbest_coefficient
self.c_gbest = gbest_coefficient

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size))
lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size))
length = ub - lb
_population = jax.random.uniform(
init_pop_key, shape=(self.dim, self.pop_size)
)
_population = _population * length + lb

context_vector = _population[:, 0] # b
broadcast_context_vector = jnp.broadcast_to(context_vector[:, None], (self.dim, self.pop_size))
cond = jnp.broadcast_to(jnp.arange(self.dim)[:,None] == jnp.arange(self.dim), shape=(self.dim, self.dim))[:, :, None]
population = jnp.where(cond, _population[None, :], broadcast_context_vector[None, :])
population = jnp.transpose(population, axes=(0,2,1))

velocity = jax.random.uniform(init_v_key, shape=(self.dim, self.pop_size))
velocity = velocity * length * 2 - length

return ex.State(
# _population/velocity: shape:(dim, pop_size)
# _population has dim different swarms, each swarm has pop_size particles
# each particle in a swarm only records one single number, which is the position of this particle in this dimension
_population=_population,
velocity=velocity,

# population: shape:(dim, pop_size, dim)
# population has dim different swarms, each swarm has pop_size particles
# population is the combination of _population and context_vector
# it is used to calculate the fitness of each particle
population=population,

# pbest: like other algorithms, pbest is the best position of each particle
pbest_position=_population, # shape:(dim, pop_size)
pbest_fitness=jnp.full((self.dim, self.pop_size,), jnp.inf), # shape:(dim, pop_size)

# gbest: !!! gbest is the best position of the swarm, not the whole population !!!
# Because each swarm only focuses on one dimension, so the gbest in cpso_s should only record one single number,
# But for the convenience of coding, we still use a array with shape(dim, dim) to represent the gbest position.
# which represents the best position of this swarm.
# therefore, the shape of gbest position is (dim, dim) and gbest fitness is (dim,)
gbest_position=_population[:, 0], # shape:(dim,)
gbest_fitness=jnp.full((self.dim,), jnp.inf), # shape:(dim,)

# in fact, gbest_position is the same as context_vector
# but we still use gbest_position to represent the best position of one swarm
context_vector=context_vector,
key=state_key,
)

def ask(self, state):
return state.population, state

# fitness: shape:(dim, pop_size)
def tell(self, state, fitness):
state_key, rand_key_gbest, rand_key_pbest = jax.random.split(state.key, num=3)

# ----------------- Update pbest -----------------
compare = state.pbest_fitness > fitness
pbest_position = jnp.where(
compare, state._population, state.pbest_position
)
pbest_fitness = jnp.minimum(state.pbest_fitness, fitness)

# ----------------- Update gbest -----------------
gbest_fitness = jnp.amin(pbest_fitness, axis=1)
gbest_index = jnp.argmin(pbest_fitness, axis=1)
gbest_position = pbest_position[jnp.arange(pbest_position.shape[0]), gbest_index]

# ------------------------------------------------------

rand_pbest = jax.random.uniform(rand_key_pbest, shape=(self.dim, self.pop_size))
rand_gbest = jax.random.uniform(rand_key_gbest, shape=(self.dim, self.pop_size))
velocity = (
self.w * state.velocity
+ self.c_pbest * rand_pbest * (pbest_position - state._population)
+ self.c_gbest * rand_gbest * (jnp.broadcast_to(gbest_position[:, None], shape=(self.dim, self.pop_size)) - state._population)
)
_population = state._population + velocity
ub = jnp.broadcast_to(self.ub[:, None], shape=(self.dim, self.pop_size))
lb = jnp.broadcast_to(self.lb[:, None], shape=(self.dim, self.pop_size))
_population = jnp.clip(_population, lb, ub)

# ----------------- Update population -----------------
context_vector = gbest_position
broadcast_context_vector = jnp.broadcast_to(context_vector[:,None], (self.dim, self.pop_size))
cond = jnp.broadcast_to(jnp.arange(self.dim)[:,None] == jnp.arange(self.dim), shape=(self.dim, self.dim))[:, :, None]
population = jnp.where(cond, _population[None, :], broadcast_context_vector[None, :])
population = jnp.transpose(population, axes=(0,2,1))

return ex.State(
_population=_population,
velocity=velocity,
population=population,
pbest_position=pbest_position,
pbest_fitness=pbest_fitness,
gbest_position=gbest_position,
gbest_fitness=gbest_fitness,
context_vector=context_vector,
key=state_key,
)

0 comments on commit 65f0194

Please sign in to comment.