Skip to content

Commit

Permalink
add Algorithm:SL-PSO-US and SL-PSO-US
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoyuWang141 committed Jul 16, 2023
1 parent 007121a commit 8ab1abb
Show file tree
Hide file tree
Showing 6 changed files with 243 additions and 1 deletion.
32 changes: 32 additions & 0 deletions examples/algorithm/so/pso_variants/eg_sl_pso_gs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from evox import algorithms, problems, pipelines, monitors
import jax
import jax.numpy as jnp

algorithm = algorithms.so.pso_vatients.SL_PSO_GS(
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=100,
epsilon=0.1,
theta=0.1,
)

problem = problems.classic.Ackley()

monitor = monitors.FitnessMonitor()

# create a pipeline

pipeline = pipelines.StdPipeline(
algorithm=algorithm,
problem=problem,
fitness_transform=monitor.update,
)

# init the pipeline
key = jax.random.PRNGKey(42)
state = pipeline.init(key)

# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())
32 changes: 32 additions & 0 deletions examples/algorithm/so/pso_variants/eg_sl_pso_us.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from evox import algorithms, problems, pipelines, monitors
import jax
import jax.numpy as jnp

algorithm = algorithms.so.pso_vatients.SL_PSO_US(
lb=jnp.full(shape=(10,), fill_value=-32),
ub=jnp.full(shape=(10,), fill_value=32),
pop_size=100,
epsilon=0.1,
_lambda=0.1,
)

problem = problems.classic.Ackley()

monitor = monitors.FitnessMonitor()

# create a pipeline

pipeline = pipelines.StdPipeline(
algorithm=algorithm,
problem=problem,
fitness_transform=monitor.update,
)

# init the pipeline
key = jax.random.PRNGKey(42)
state = pipeline.init(key)

# run the pipeline for 100 steps
for i in range(100):
state = pipeline.step(state)
print(monitor.get_min_fitness())
3 changes: 2 additions & 1 deletion src/evox/algorithms/so/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .de import DE
from .cma_es import CMAES, SepCMAES
from .nes import xNES, SeparableNES
from .open_es import OpenES
from .open_es import OpenES
from .pso_vatients import *
2 changes: 2 additions & 0 deletions src/evox/algorithms/so/pso_vatients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .sl_pso_us import SL_PSO_US
from .sl_pso_gs import SL_PSO_GS
88 changes: 88 additions & 0 deletions src/evox/algorithms/so/pso_vatients/sl_pso_gs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import jax
import jax.numpy as jnp

import evox as ex
from evox.utils import *

# SL-PSO: Social Learning PSO
# SL-PSO-GS: Using Gaussian Sampling for Demonstator Choice
# https://ieeexplore.ieee.org/document/6900227
@ex.jit_class
class SL_PSO_GS(ex.Algorithm):
def __init__(
self,
lb, # lower bound of problem
ub, # upper bound of problem
pop_size,
epsilon,
theta,
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.epsilon = epsilon
self.theta = theta

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
length = self.ub - self.lb
population = jax.random.uniform(
init_pop_key, shape=(self.pop_size, self.dim)
)
population = population * length + self.lb
velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
velocity = velocity * length * 2 - length

return ex.State(
population=population,
velocity=velocity,
global_best_location=population[0],
global_best_fitness=jnp.array([jnp.inf]),
key=state_key,
)

def ask(self, state):
return state.population, state

def tell(self, state, fitness):
key, r1_key, r2_key, r3_key, demonstrator_choice_key = jax.random.split(state.key, num=5)

r1 = jax.random.uniform(r1_key, shape=(self.pop_size, self.dim))
r2 = jax.random.uniform(r2_key, shape=(self.pop_size, self.dim))
r3 = jax.random.uniform(r3_key, shape=(self.pop_size, self.dim))

global_best_location, global_best_fitness = min_by(
[state.global_best_location[jnp.newaxis, :], state.population],
[state.global_best_fitness, fitness],
)
global_best_fitness = jnp.atleast_1d(global_best_fitness)

# ----------------- Demonstator Choice -----------------
# sort from largest fitness to smallest fitness (worst to best)
ranked_population = state.population[jnp.argsort(-fitness)]
sigma = self.theta * (self.pop_size - (jnp.arange(self.pop_size) + 1))
standard_normal_distribution = jax.random.normal(demonstrator_choice_key, shape=(self.pop_size,))
# normal distribution (shape=(self.pop_size,)) means
# each individual choose a demonstrator by normal distribution
# with mean = pop_size and std = sigma
normal_distribution = sigma * (-jnp.abs(standard_normal_distribution)) + self.pop_size
index_k = jnp.floor(jnp.clip(normal_distribution, 1, self.pop_size)).astype(int) - 1
X_k = ranked_population[index_k]
# ------------------------------------------------------

X_avg = jnp.mean(state.population, axis=0)
velocity = (
r1 * state.velocity
+ r2 * (X_k - state.population)
+ r3 * self.epsilon * (X_avg - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)
return ex.State(
population=population,
velocity=velocity,
global_best_location=global_best_location,
global_best_fitness=global_best_fitness,
key=key,
)
87 changes: 87 additions & 0 deletions src/evox/algorithms/so/pso_vatients/sl_pso_us.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import jax
import jax.numpy as jnp

import evox as ex
from evox.utils import *

# SL-PSO: Social Learning PSO
# SL-PSO-US: Using Uniform Sampling for Demonstator Choice
# https://ieeexplore.ieee.org/document/6900227
@ex.jit_class
class SL_PSO_US(ex.Algorithm):
def __init__(
self,
lb, # lower bound of problem
ub, # upper bound of problem
pop_size,
epsilon,
_lambda,
):
self.dim = lb.shape[0]
self.lb = lb
self.ub = ub
self.pop_size = pop_size
self.epsilon = epsilon
self._lambda = _lambda

def setup(self, key):
state_key, init_pop_key, init_v_key = jax.random.split(key, 3)
length = self.ub - self.lb
population = jax.random.uniform(
init_pop_key, shape=(self.pop_size, self.dim)
)
population = population * length + self.lb
velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))
velocity = velocity * length * 2 - length

return ex.State(
population=population,
velocity=velocity,
global_best_location=population[0],
global_best_fitness=jnp.array([jnp.inf]),
key=state_key,
)

def ask(self, state):
return state.population, state

def tell(self, state, fitness):
key, r1_key, r2_key, r3_key, demonstrator_choice_key = jax.random.split(state.key, num=5)

r1 = jax.random.uniform(r1_key, shape=(self.pop_size, self.dim))
r2 = jax.random.uniform(r2_key, shape=(self.pop_size, self.dim))
r3 = jax.random.uniform(r3_key, shape=(self.pop_size, self.dim))

global_best_location, global_best_fitness = min_by(
[state.global_best_location[jnp.newaxis, :], state.population],
[state.global_best_fitness, fitness],
)
global_best_fitness = jnp.atleast_1d(global_best_fitness)

# ----------------- Demonstator Choice -----------------
# sort from largest fitness to smallest fitness (worst to best)
ranked_population = state.population[jnp.argsort(-fitness)]
# demonstator choice: q to pop_size
q = jnp.clip(self.pop_size - jnp.ceil(self._lambda * (self.pop_size - (jnp.arange(self.pop_size) + 1) - 1)), a_min=1, a_max=self.pop_size)
# uniform distribution (shape: (pop_size,)) means
# each individual choose a demonstator by uniform distribution in the range of q to pop_size
uniform_distribution = jax.random.uniform(demonstrator_choice_key, (self.pop_size,), minval=q, maxval=self.pop_size + 1)
index_k = jnp.floor(uniform_distribution).astype(int) - 1
X_k = ranked_population[index_k]
# ------------------------------------------------------

X_avg = jnp.mean(state.population, axis=0)
velocity = (
r1 * state.velocity
+ r2 * (X_k - state.population)
+ r3 * self.epsilon * (X_avg - state.population)
)
population = state.population + velocity
population = jnp.clip(population, self.lb, self.ub)
return ex.State(
population=population,
velocity=velocity,
global_best_location=global_best_location,
global_best_fitness=global_best_fitness,
key=key,
)

0 comments on commit 8ab1abb

Please sign in to comment.