Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize the NSGA-III and update IMMOEA for gpjax0.8.2 #123

Merged
merged 9 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/evox/algorithms/mo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@
from .bce_ibea import BCEIBEA
from .lmocso import LMOCSO
from .im_moea import IMMOEA

72 changes: 41 additions & 31 deletions src/evox/algorithms/mo/im_moea.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def setup(self, key):
key, subkey = jax.random.split(key)
self.pop_size = int(jnp.ceil(self.pop_size / self.k) * self.k)
W = UniformSampling(self.k, self.n_objs)()[0]
W = jnp.fliplr(jnp.sort(jnp.fliplr(W), axis=1))
W = jnp.fliplr(
jnp.sort(jnp.fliplr(W), axis=0)
) # unknown reason, but it is the same as the original code.
population = (
jax.random.uniform(subkey, shape=(self.pop_size, self.dim))
* (self.ub - self.lb)
Expand Down Expand Up @@ -108,7 +110,7 @@ def ask(self, state):
fitness = state.fitness
key, key1 = jax.random.split(state.key)
distances = cos_dist(fitness, state.reference_vector)
partition = jnp.argmax(distances, axis=1)
partition = jnp.argmin(distances, axis=1)

def get_sub_pop(i):
mask = partition == i
Expand All @@ -117,11 +119,7 @@ def get_sub_pop(i):

sub_pops = jax.vmap(get_sub_pop, out_axes=0)(jnp.arange(self.k))
next_generation = jnp.vstack(sub_pops)
nan_mask = jnp.isnan(next_generation).sum(axis=1).astype(jnp.bool_)
indices = jnp.arange(next_generation.shape[0])
masked_indices = jnp.where(nan_mask, jnp.inf, indices)
sorted_indices = jnp.sort(masked_indices).astype(jnp.int32)
next_generation = next_generation[sorted_indices, :]
next_generation = jnp.sort(next_generation, axis=0)
next_generation = lax.dynamic_slice(next_generation, (0, 0), population.shape)
next_generation = jnp.clip(next_generation, self.lb, self.ub)
return next_generation, state.update(next_generation=next_generation, key=key)
Expand Down Expand Up @@ -263,7 +261,7 @@ def normal_fun(x_key):
# Determine the indices of the population that will be generated. Unselected individuals' indices are marked with -1 and their value will be jnp.nan.
pop_indices = jnp.arange(N)
final_pop_indices = jnp.where(
pop_indices % jnp.floor(N / n) == 0, pop_indices, -1
pop_indices % jnp.floor(N / self.n_objs) == 0, pop_indices, -1
)

# Randomly classify the dimensions of the population.
Expand All @@ -289,7 +287,7 @@ def gp_body(m, shuffled_indices_group, _keys, new_pop, new_fit):
# random select the group of population to train GP model.
permutation = random.permutation(key_arr[0], pop_indices)
parents = jnp.where(
pop_indices <= jnp.ceil(n / self.n_objs), permutation, -1
pop_indices <= jnp.ceil(N / self.n_objs), permutation, -1
)
_mask = pop_indices == parents
sub_off, sub_fit = random_fill(
Expand All @@ -303,26 +301,31 @@ def gp_body(m, shuffled_indices_group, _keys, new_pop, new_fit):
likelihood = Gaussian(num_datapoints=len(sub_off))

# get the prediction's input of the GP model
inputs = jnp.linspace(fmin[m], fmax[m], N)
inputs = jnp.where(final_pop_indices >= 0, inputs, jnp.inf)[
:, jnp.newaxis
]
inputs = jnp.linspace(fmin[m], fmax[m], N)[:, jnp.newaxis]

_sub_fit = lax.dynamic_slice(sub_fit, (0, m), (sub_fit.shape[0], 1))

# get the offspring of the GP model
def get_off(i, dim_indices, _keys, sub_off, sub_fit):
def get_off(i, dim_indices, _keys, sub_off):
model = GPRegression(likelihood=likelihood, kernel=Linear())
_sub_pop = lax.dynamic_slice(
sub_off, (0, dim_indices[i]), (sub_fit.shape[0], 1)
sub_off, (0, dim_indices[i]), (sub_off.shape[0], 1)
)
model.fit(x=_sub_fit, y=_sub_pop, optimzer=ox.adam(0.001))
_, ymu, ystd = model.predict(inputs)
return ymu + random.normal(_keys[i], shape=ystd.shape) * ystd

res = jax.vmap(
lambda i: get_off(i, dim_indices, _keys, sub_off, sub_fit)
)(jnp.arange(self.l))
return res
res = jax.vmap(lambda i: get_off(i, dim_indices, _keys, sub_off))(
jnp.arange(self.l)
)

def get_offspring(i, sub_off):
index = dim_indices[i]
return sub_off.at[:, index].set(res[i, :])

sub_off = lax.fori_loop(0, self.l, get_offspring, sub_off)
sub_off = jnp.where((final_pop_indices >= 0)[:, None], sub_off, jnp.inf)
return sub_off

x_keys = jax.random.split(x_key, num=self.n_objs)
# sub_off is a list of offspring of each objective. M*L*N matrix
Expand All @@ -332,23 +335,30 @@ def get_off(i, dim_indices, _keys, sub_off, sub_fit):
)
)(arr)

# reshape sub_off: M*L*N matrix, to (M*L*N) matrix
off = jnp.vstack(sub_off).T

# reshape sub_off: L*N*M matrix, to (L*N)*M matrix
off = jnp.vstack(sub_off)
# replace the selected dimensions with the offspring.
def get_offspring(i, offspring):
index = shuffled_dim_indices[i]
return offspring.at[:, index].set(off[:, i])

offspring = lax.fori_loop(0, self.l * self.n_objs, get_offspring, new_pop)
off = jnp.sort(off, axis=0)[: self.pop_size, :]

# Convert invalid values to random values
rand_pop = jax.random.uniform(
pop_key, offspring.shape, minval=self.lb, maxval=self.ub
pop_key, off.shape, minval=self.lb, maxval=self.ub
)
invalid = (offspring < self.lb) | (offspring > self.ub)
offspring = jnp.where(invalid, rand_pop, offspring)
return offspring
invalid = (off < self.lb) | (off > self.ub)
valid_sum = jnp.sum(~invalid.all())
num = jnp.sum(mask)

def valid_fun(x):
x = jnp.sort(jnp.where(invalid.all(), jnp.inf, x), axis=0)
return x

def invalid_fun(x):
x = jnp.where(invalid, rand_pop, x)
return x

off = jax.lax.cond(valid_sum > num, valid_fun, invalid_fun, off)
off = jnp.where((jnp.arange(self.pop_size) > num)[:, None], jnp.inf, off)
return off

final_pop = lax.cond(
n >= 2 * self.n_objs, normal_fun, lambda x: population, x_key
Expand Down
196 changes: 101 additions & 95 deletions src/evox/algorithms/mo/nsga3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
sampling,
)
from evox import Algorithm, jit_class, State
from evox.utils import cos_dist


@jit_class
class NSGA3(Algorithm):
"""NSGA-III algorithm

Expand Down Expand Up @@ -51,11 +51,11 @@ def __init__(
self.crossover = crossover_op

if self.selection is None:
self.selection = selection.UniformRand(0.5)
self.selection = selection.UniformRand(1)
if self.mutation is None:
self.mutation = mutation.Gaussian()
self.mutation = mutation.Polynomial((self.lb, self.ub))
if self.crossover is None:
self.crossover = crossover.UniformRand()
self.crossover = crossover.SimulatedBinary()

self.sampling = sampling.UniformSampling(self.pop_size, self.n_objs)

Expand All @@ -67,7 +67,7 @@ def setup(self, key):
+ self.lb
)
self.ref = self.sampling(subkey)[0]
self.ref = self.ref / jnp.linalg.norm(self.ref, axis=1)[:, None]
# self.pop_size = len(self.ref)
return State(
population=population,
fitness=jnp.zeros((self.pop_size, self.n_objs)),
Expand All @@ -83,123 +83,129 @@ def init_tell(self, state, fitness):
return state

def ask(self, state):
key, sel_key1, mut_key, sel_key2, x_key = jax.random.split(state.key, 5)
selected = self.selection(sel_key1, state.population)
mutated = self.mutation(mut_key, selected)
selected = self.selection(sel_key2, state.population)
crossovered = self.crossover(x_key, selected)

next_generation = jnp.clip(
jnp.concatenate([mutated, crossovered], axis=0), self.lb, self.ub
)
key, mut_key, x_key = jax.random.split(state.key, 3)
crossovered = self.crossover(x_key, state.population)
next_generation = self.mutation(mut_key, crossovered)
next_generation = jnp.clip(next_generation, self.lb, self.ub)

return next_generation, state.update(next_generation=next_generation, key=key)

def tell(self, state, fitness):
merged_pop = jnp.concatenate([state.population, state.next_generation], axis=0)
merged_fitness = jnp.concatenate([state.fitness, fitness], axis=0)

rank = non_dominated_sort(merged_fitness)
order = jnp.argsort(rank)
rank = rank[order]
ranked_pop = merged_pop[order]
ranked_fitness = merged_fitness[order]
last_rank = rank[self.pop_size]
last_rank = rank[order[self.pop_size]]
ranked_fitness = jnp.where(
jnp.repeat((rank <= last_rank)[:, None], self.n_objs, axis=1),
ranked_fitness,
(rank <= last_rank)[:, None],
merged_fitness,
jnp.nan,
)

# Normalize
ideal = jnp.nanmin(ranked_fitness, axis=0)
offset_fitness = ranked_fitness - ideal
weight = jnp.eye(self.n_objs, self.n_objs) + 1e-6
weighted = (
jnp.repeat(offset_fitness, self.n_objs, axis=0).reshape(
len(offset_fitness), self.n_objs, self.n_objs
)
/ weight
)
asf = jnp.nanmax(weighted, axis=2)
ex_idx = jnp.argmin(asf, axis=0)
extreme = offset_fitness[ex_idx]
ideal_points = jnp.nanmin(ranked_fitness, axis=0)
ranked_fitness = ranked_fitness - ideal_points
weight = jnp.eye(self.n_objs) + 1e-6

def extreme_point(val):
extreme = val[0]
plane = jnp.linalg.solve(extreme, jnp.ones(self.n_objs))
def get_extreme(i):
return jnp.nanargmin(jnp.nanmax(ranked_fitness / weight[i], axis=1))

extreme_ind = jax.vmap(get_extreme)(jnp.arange(self.n_objs))
extreme = ranked_fitness[extreme_ind]

def get_intercept(val):
# Calculate the intercepts of the hyperplane constructed by the extreme points
_extreme = val[0]
plane = jnp.linalg.solve(_extreme, jnp.ones(self.n_objs))
intercept = 1 / plane
return intercept

def worst_point(val):
return jnp.nanmax(ranked_fitness, axis=0)
def worst_intercept(val):
_ranked_fitness = val[1]
return jnp.nanmax(_ranked_fitness, axis=0)

nadir_point = jax.lax.cond(
jnp.linalg.matrix_rank(extreme) == self.n_objs,
extreme_point,
worst_point,
(extreme, offset_fitness),
get_intercept,
worst_intercept,
(extreme, ranked_fitness),
)
normalized_fitness = offset_fitness / nadir_point

# Associate
def perpendicular_distance(x, y):
proj_len = x @ y.T
proj_vec = proj_len.reshape(proj_len.size, 1) * jnp.tile(y, (len(x), 1))
prep_vec = jnp.repeat(x, len(y), axis=0) - proj_vec
dist = jnp.reshape(jnp.linalg.norm(prep_vec, axis=1), (len(x), len(y)))
return dist

dist = perpendicular_distance(ranked_fitness, self.ref)
pi = jnp.nanargmin(dist, axis=1)
d = dist[jnp.arange(len(normalized_fitness)), pi]
normalized_fitness = ranked_fitness / nadir_point
cos_distance = cos_dist(normalized_fitness, self.ref)
dist = jnp.linalg.norm(normalized_fitness, axis=-1, keepdims=True) * jnp.sqrt(
1 - cos_distance**2
)
# Associate each solution with its nearest reference point
group_id = jnp.nanargmin(dist, axis=1)
group_id = jnp.where(group_id == -1, len(self.ref), group_id)
group_dist = jnp.nanmin(dist, axis=1)
rho = jnp.bincount(
jnp.where(rank < last_rank, group_id, len(self.ref)), length=len(self.ref)
)
rho_last = jnp.bincount(
jnp.where(rank == last_rank, group_id, len(self.ref)), length=len(self.ref)
)
group_id = jnp.where(rank == last_rank, group_id, jnp.inf)
group_dist = jnp.where(rank == last_rank, group_dist, jnp.inf)
selected_number = jnp.sum(rho)
rho = jnp.where(rho_last == 0, jnp.inf, rho)
keys = jax.random.split(state.key, self.pop_size + 1)

# Niche
def niche_loop(val):
def nope(val):
idx, i, rho, j = val
rho = rho.at[j].set(self.pop_size)
return idx, i, rho, j

def have(val):
def zero(val):
idx, i, rho, j = val
idx = idx.at[i].set(jnp.nanargmin(jnp.where(pi == j, d, jnp.nan)))
rho = rho.at[j].add(1)
return idx, i + 1, rho, j

def already(val):
idx, i, rho, j = val
key = jax.random.PRNGKey(i * j)
temp = jax.random.randint(
key, (1, len(ranked_pop)), 0, self.pop_size
def select_loop(vals):
selected_number, rank, group_id, rho, rho_last = vals
group = jnp.argmin(rho)
candidates = jnp.where(group_id == group, group_dist, jnp.inf)

def get_rand_candidate(candidates):
order = jnp.sort(
jnp.where(
jnp.isinf(candidates), jnp.inf, jnp.arange(candidates.size)
)
temp = temp + (pi == j) * self.pop_size
idx = idx.at[i].set(jnp.argmax(temp))
rho = rho.at[j].add(1)
return idx, i + 1, rho, j

return jax.lax.cond(rho[val[3]], already, zero, val)

idx, i, rho = val
j = jnp.argmin(rho)
idx, i, rho, j = jax.lax.cond(
jnp.sum(pi == j), have, nope, (idx, i, rho, j)
)
rand_index = jax.random.randint(
keys[selected_number], (), 0, rho_last[group]
)
return order[rand_index].astype(jnp.int32)

def get_min_candidate(candidates):
return jnp.argmin(candidates)

candidate = jax.lax.cond(
(rho[group] == 0) | (rho_last[group] == 1),
get_min_candidate,
get_rand_candidate,
candidates,
)
return idx, i, rho

survivor_idx = jnp.arange(self.pop_size)
rho = jnp.bincount(
jnp.where(rank < last_rank, pi, len(self.ref)), length=len(self.ref)
)
pi = jnp.where(rank == last_rank, pi, -1)
d = jnp.where(rank == last_rank, d, jnp.nan)
survivor_idx, _, _ = jax.lax.while_loop(
lambda val: val[1] < self.pop_size,
niche_loop,
(survivor_idx, jnp.sum(rho), rho),
rank = rank.at[candidate].set(last_rank - 1)
group_id = group_id.at[candidate].set(jnp.nan)
rho_last = rho_last.at[group].set(rho_last[group] - 1)

def update_(vals):
idx, matrix = vals
return matrix.at[idx].set(jnp.inf)

def add_(vals):
idx, matrix = vals
return matrix.at[idx].set(matrix[idx] + 1)

rho = jax.lax.cond(rho_last[group] == 0, update_, add_, (group, rho))
selected_number += 1
return selected_number, rank, group_id, rho, rho_last

selected_number, rank, group_id, rho, rho_last = jax.lax.while_loop(
lambda val: jnp.nansum(val[0]) < self.pop_size,
select_loop,
(selected_number, rank, group_id, rho, rho_last),
)

selected_idx = jnp.sort(
jnp.where(rank < last_rank, jnp.arange(ranked_fitness.shape[0]), jnp.inf)
)[: self.pop_size].astype(jnp.int32)
state = state.update(
population=ranked_pop[survivor_idx], fitness=ranked_fitness[survivor_idx]
population=merged_pop[selected_idx],
fitness=merged_fitness[selected_idx],
key=keys[0],
)
return state
Loading
Loading