Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Mar 12, 2024
1 parent 50af237 commit caae3d4
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 31 deletions.
6 changes: 5 additions & 1 deletion optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
embedding_to_parallel_embedding,
get_linear_weight_info,
linear_to_parallel_linear,
maybe_load_weights_to_gqa_qkv_column_parallel_linear,
maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear,
maybe_load_weights_to_gqa_qkv_column_parallel_linear,
)


Expand Down Expand Up @@ -535,9 +535,13 @@ def _transform(
)

if needs_gqa_qkv_column_parallel_linear:
qga_qkv_layer = getattr(layer, cls.GQA_QKV_PROJ_NAME)
maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
model,
getattr(layer, cls.OUTPUT_PROJECTION_NAME),
qga_qkv_layer.num_attention_heads,
qga_qkv_layer.num_key_value_heads,
qga_qkv_layer.kv_size_multiplier,
try_from_checkpoint=not skip_linear_weight_load,
try_from_original_layer=not skip_linear_weight_load,
)
Expand Down
107 changes: 77 additions & 30 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,48 @@ def create_kv_proj_local_weight_from_regular_weight(
return torch.cat(split[tp_rank::tp_size], dim=0)


def create_q_or_o_proj_local_weight_from_regular_weight(
def compute_query_indices_for_rank(
tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int
):
num_attention_heads_per_rank = num_attention_heads // tp_size
num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size
query_group_size = num_attention_heads // num_key_value_heads
query_group_size_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank
num_ranks_to_complete_group = query_group_size // query_group_size_per_rank

queries_indices = [
torch.arange(num_attention_heads_per_rank // num_key_value_heads_per_rank)
for _ in range(num_key_value_heads_per_rank)
]

keys_indices = torch.arange(num_key_value_heads).repeat(kv_size_multiplier)
keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank)
keys_indices = torch.chunk(keys_indices, tp_size)

# shift_per_key = torch.arange(0, num_attention_heads, num_key_value_heads)
shift_per_key = torch.arange(0, num_attention_heads, query_group_size)

shift_within_query_group = torch.arange(0, query_group_size, query_group_size_per_rank)
shift_within_query_group = torch.repeat_interleave(
shift_within_query_group, num_attention_heads_per_rank * num_key_value_heads_per_rank
)
shift_within_query_group = torch.chunk(shift_within_query_group, tp_size)

indices = []
for idx, q_indices in enumerate(queries_indices):
s = slice(idx * num_key_value_heads_per_rank, (idx + 1) * num_key_value_heads_per_rank)
k_indices = keys_indices[tp_rank][s]
k_shift = shift_per_key[k_indices]
group_shift = shift_within_query_group[tp_rank][s]
indices.append(q_indices + k_shift + group_shift)

indices = torch.cat(indices, dim=0)
print(tp_rank, indices)
return indices


@requires_neuronx_distributed
def create_query_or_output_projection_local_weight_from_regular_weight(
weight_data: torch.Tensor,
num_attention_heads: int,
num_key_value_heads: int,
Expand All @@ -510,34 +551,37 @@ def create_q_or_o_proj_local_weight_from_regular_weight(
# key_index = (tp_rank % kv_size_multiplier)

# # Detailed computation
num_attention_heads_per_rank = num_attention_heads // tp_size
num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size
num_ranks_per_group = (num_attention_heads // num_key_value_heads) // num_attention_heads_per_rank

# shift = key_index * num_attention_heads_per_rank * num_ranks_per_group

# queries_indices = indices + shift
queries_indices = [
torch.arange(num_key_value_heads_per_rank)
for _ in range(num_attention_heads_per_rank // num_key_value_heads_per_rank)
]
# TODO: should we split with num_attention_heads_per_rank // num_key_value_heads_per_rank or num_key_value_heads_per_rank?
keys_indices = (
torch.arange(num_key_value_heads)
.repeat(kv_size_multiplier)
.split(num_attention_heads_per_rank // num_key_value_heads_per_rank)
)
final_queries_indices = []
for ith_key_in_rank, indices in enumerate(queries_indices):
tp_rank_shift = tp_rank * num_attention_heads_per_rank
key_index = keys_indices[tp_rank][ith_key_in_rank]
shift = key_index * num_attention_heads_per_rank * num_ranks_per_group
final_queries_indices.append(indices + shift + tp_rank_shift)

indices = torch.cat(final_queries_indices, dim=0)
# num_groups_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank
# queries_indices = [
# torch.arange(num_key_value_heads_per_rank)
# for _ in range(num_groups_per_rank)
# ]
# # TODO: should we split with num_attention_heads_per_rank // num_key_value_heads_per_rank or num_key_value_heads_per_rank?
# keys_indices = (
# torch.arange(num_key_value_heads)
# .repeat(kv_size_multiplier)
# .chunk(tp_size)
# # .split(num_attention_heads_per_rank // num_key_value_heads_per_rank)
# )
# print(keys_indices)
# final_queries_indices = []
# for ith_key_in_rank, indices in enumerate(queries_indices):
# tp_rank_shift = tp_rank * num_attention_heads_per_rank
# key_index = keys_indices[tp_rank][ith_key_in_rank]
# shift = key_index * num_attention_heads_per_rank * num_ranks_per_group
# final_queries_indices.append(indices + shift + tp_rank_shift)

# indices = torch.cat(final_queries_indices, dim=0)

reshaped_weight = weight_data.view(num_attention_heads, -1, weight_data.size(-1))
queries = reshaped_weight[queries_indices]
indices = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
queries = reshaped_weight[indices]
queries = queries.reshape(-1, weight_data.size(-1))

if query_or_output_proj == "output":
Expand Down Expand Up @@ -604,7 +648,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
weight, kv_size_multiplier, weight.size(0)
)
else:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query"
)
weight.copy_(weight_data)
Expand All @@ -620,7 +664,7 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
linear_layer.weight, kv_size_multiplier, weight.size(0)
)
else:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
linear_layer.weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query"
)
weight.copy_(weight_data)
Expand Down Expand Up @@ -676,7 +720,7 @@ def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_l
elif linear_layer is not None and linear_layer.weight.device != torch.device("meta"):
weight_data = linear_layer.weight.data
if weight_data is not None:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output"
)
weight.copy_(weight_data.repeat(1, 32))
Expand Down Expand Up @@ -717,16 +761,20 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
linear_layer=model.get_submodule(orig_layer_name),
)


def maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
model: torch.nn.Module,
output_projection: "layers.ColumnParallelLinear",
num_attention_heads: int,
num_key_value_heads: int,
kv_size_multiplier: int,
try_from_checkpoint: bool = True,
try_from_original_layer: bool = False,
):

# # If we were able to initiliaze the query projection, then we assume we should initialize the output projection
# # as well.
# if weight_name == "weight_q" and was_already_initialized_during_parallelization(getattr(layer, weight_name)):
# # If we were able to initiliaze the query projection, then we assume we should initialize the output projection
# # as well.
# if weight_name == "weight_q" and was_already_initialized_during_parallelization(getattr(layer, weight_name)):
# parent_qualified_name = named_modules[layer].rsplit(".", maxsplit=1)[0]
# output_projection_qualified_name = f"{parent_qualified_name}.{layer.output_proj_name}"
# output_projection = model.get_submodule(output_projection_qualified_name)
Expand All @@ -746,7 +794,6 @@ def maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_wh
)
elif try_from_original_layer:
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
output_projection,
output_projection,
num_attention_heads,
num_key_value_heads,
Expand Down
56 changes: 56 additions & 0 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import optimum
from optimum.neuron.accelerate.accelerator import NeuronAccelerator
from optimum.neuron.distributed.parallelizers_manager import ParallelizersManager
from optimum.neuron.distributed.utils import compute_query_indices_for_rank
from optimum.neuron.utils.cache_utils import (
get_num_neuron_cores,
)
Expand Down Expand Up @@ -544,3 +545,58 @@ def test_llama_v2_gqa_with_qkv_parallel_collumn_linear(
sequence_parallel_enabled,
parallelize_embeddings,
)


@pytest.mark.parametrize(
"tp_size,num_attention_heads,num_key_value_heads,kv_size_multiplier,ground_truth",
[
[
8,
32,
4,
2,
[
[0, 1, 2, 3],
[8, 9, 10, 11],
[16, 17, 18, 19],
[24, 25, 26, 27],
[4, 5, 6, 7],
[12, 13, 14, 15],
[20, 21, 22, 23],
[28, 29, 30, 31],
],
],
[
8,
32,
4,
4,
[
[0, 1, 8, 9],
[16, 17, 24, 25],
[2, 3, 10, 11],
[18, 19, 26, 27],
[4, 5, 12, 13],
[20, 21, 28, 29],
[6, 7, 14, 15],
[22, 23, 30, 31],
],
],
],
ids=[
"32-heads-4kv-heads-kv-mul-2,one kv head per rank",
"32-heads-4kv-heads-kv-mul-4,multiple kv heads per rank",
],
)
def test_compute_query_indices_for_rank(
tp_size, num_attention_heads, num_key_value_heads, kv_size_multiplier, ground_truth
):
for tp_rank in range(tp_size):
expected = torch.tensor(ground_truth[tp_rank])
computed = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
)
print(f"TP rank = {tp_rank}")
print(f"Expected {expected}")
print(f"Computed {computed}")
torch.testing.assert_close(expected, computed)

0 comments on commit caae3d4

Please sign in to comment.