Skip to content

Commit

Permalink
add algorithm clpso
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoyuWang141 committed Jul 24, 2023
1 parent 73f7be0 commit aefd81a
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 16 deletions.
4 changes: 1 addition & 3 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,4 @@ tests/mujoco_envs/*.gif
tests/mujoco_envs/*test*.py
tests/mujoco_envs/*learn*.py

tests/gym_env_test.py

why/
tests/gym_env_test.py
3 changes: 2 additions & 1 deletion src/evox/algorithms/so/pso_vatients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .sl_pso_us import SL_PSO_US
from .sl_pso_gs import SL_PSO_GS
from .sl_pso_gs import SL_PSO_GS
from .clpso import CLPSO
98 changes: 98 additions & 0 deletions src/evox/algorithms/so/pso_vatients/clpso.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import jax
import jax.numpy as jnp

import evox as ex
from evox.utils import *

# SL-PSO: Social Learning PSO
# SL-PSO-GS: Using Gaussian Sampling for Demonstator Choice
# https://ieeexplore.ieee.org/document/6900227
@ex.jit_class
class CLPSO(ex.Algorithm):
def __init__(
self,
lb, # lower bound of problem
ub, # upper bound of problem
pop_size, # population size
inertia_weight, # w
const_coefficient, # c
learning_probability, # P_c. shape:(pop_size,). It can be different for each particle
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.w = inertia_weight
self.c = const_coefficient
self.P_c = learning_probability

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
length = self.ub - self.lb
population = jax.random.uniform(
init_pop_key, shape=(self.pop_size, self.dim)
)
population = population * length + self.lb
velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
velocity = velocity * length * 2 - length

return ex.State(
population=population,
velocity=velocity,
pbest_position=population,
pbest_fitness=jnp.full((self.pop_size,), jnp.inf),
gbest_position=population[0],
gbest_fitness=jnp.array([jnp.inf]),
key=state_key,
)

def ask(self, state):
return state.population, state

def tell(self, state, fitness):
key, random_coefficient_key, rand1_key, rand2_key, rand_key = jax.random.split(state.key, num=5)

random_coefficient = jax.random.uniform(random_coefficient_key, shape=(self.pop_size, self.dim))

# ----------------- Update pbest -----------------
compare = state.pbest_fitness > fitness
pbest_position = jnp.where(
compare[:, jnp.newaxis], state.population, state.pbest_position
)
pbest_fitness = jnp.minimum(state.pbest_fitness, fitness)

# ----------------- Update gbest -----------------
gbest_position, gbest_fitness = min_by(
[state.gbest_position[jnp.newaxis, :], state.population],
[state.gbest_position, fitness],
)
gbest_fitness = jnp.atleast_1d(gbest_fitness)

# ------------------ Choose pbest ----------------------

rand1_index = jnp.floor(jax.random.uniform(rand1_key, shape=(self.pop_size,), minval=0, maxval=self.pop_size)).astype(int)
rand2_index = jnp.floor(jax.random.uniform(rand2_key, shape=(self.pop_size,), minval=0, maxval=self.pop_size)).astype(int)
learning_index = jnp.where(pbest_fitness[rand1_index] < pbest_fitness[rand2_index], rand1_index, rand2_index)
learning_pbest = state.pbest_position[learning_index, :]
rand_possibility = jax.random.uniform(rand_key, shape=(self.pop_size,))
rand_possibility = jnp.broadcast_to(rand_possibility[:, jnp.newaxis], shape=(self.pop_size, self.dim))
P_c = jnp.broadcast_to(self.P_c[:, jnp.newaxis], shape=(self.pop_size, self.dim))
pbest = jnp.where(rand_possibility < P_c, learning_pbest, state.pbest_position)

# ------------------------------------------------------

velocity = (
self.w * state.velocity
+ self.c * random_coefficient * (pbest - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)
return ex.State(
population=population,
velocity=velocity,
pbest_position=pbest_position,
pbest_fitness=pbest_fitness,
gbest_position=gbest_position,
gbest_fitness=gbest_fitness,
key=key,
)
12 changes: 6 additions & 6 deletions src/evox/algorithms/so/pso_vatients/sl_pso_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def __init__(
lb, # lower bound of problem
ub, # upper bound of problem
pop_size,
epsilon,
theta,
social_influence_factor, # epsilon
demonstrator_choice_factor, # theta
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.epsilon = epsilon
self.theta = theta
self.social_influence_factor = social_influence_factor
self.demonstrator_choice_factor = demonstrator_choice_factor

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -61,7 +61,7 @@ def tell(self, state, fitness):
# ----------------- Demonstator Choice -----------------
# sort from largest fitness to smallest fitness (worst to best)
ranked_population = state.population[jnp.argsort(-fitness)]
sigma = self.theta * (self.pop_size - (jnp.arange(self.pop_size) + 1))
sigma = self.demonstrator_choice_factor * (self.pop_size - (jnp.arange(self.pop_size) + 1))
standard_normal_distribution = jax.random.normal(demonstrator_choice_key, shape=(self.pop_size,))
# normal distribution (shape=(self.pop_size,)) means
# each individual choose a demonstrator by normal distribution
Expand All @@ -75,7 +75,7 @@ def tell(self, state, fitness):
velocity = (
r1 * state.velocity
+ r2 * (X_k - state.population)
+ r3 * self.epsilon * (X_avg - state.population)
+ r3 * self.social_influence_factor * (X_avg - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)
Expand Down
12 changes: 6 additions & 6 deletions src/evox/algorithms/so/pso_vatients/sl_pso_us.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def __init__(
lb, # lower bound of problem
ub, # upper bound of problem
pop_size,
epsilon,
_lambda,
social_influence_factor, # epsilon
demonstrator_choice_factor, # lambda
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.epsilon = epsilon
self._lambda = _lambda
self.social_influence_factor = social_influence_factor
self.demonstrator_choice_factor = demonstrator_choice_factor

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
Expand Down Expand Up @@ -62,7 +62,7 @@ def tell(self, state, fitness):
# sort from largest fitness to smallest fitness (worst to best)
ranked_population = state.population[jnp.argsort(-fitness)]
# demonstator choice: q to pop_size
q = jnp.clip(self.pop_size - jnp.ceil(self._lambda * (self.pop_size - (jnp.arange(self.pop_size) + 1) - 1)), a_min=1, a_max=self.pop_size)
q = jnp.clip(self.pop_size - jnp.ceil(self.demonstrator_choice_factor * (self.pop_size - (jnp.arange(self.pop_size) + 1) - 1)), a_min=1, a_max=self.pop_size)
# uniform distribution (shape: (pop_size,)) means
# each individual choose a demonstator by uniform distribution in the range of q to pop_size
uniform_distribution = jax.random.uniform(demonstrator_choice_key, (self.pop_size,), minval=q, maxval=self.pop_size + 1)
Expand All @@ -74,7 +74,7 @@ def tell(self, state, fitness):
velocity = (
r1 * state.velocity
+ r2 * (X_k - state.population)
+ r3 * self.epsilon * (X_avg - state.population)
+ r3 * self.social_influence_factor * (X_avg - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)
Expand Down

0 comments on commit aefd81a

Please sign in to comment.