Skip to content

Commit

Permalink
[WIP]
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Mar 13, 2024
1 parent caae3d4 commit 87aada5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 158 deletions.
76 changes: 37 additions & 39 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
embedding_to_parallel_embedding,
get_linear_weight_info,
linear_to_parallel_linear,
maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear,
maybe_load_weights_to_gqa_qkv_column_parallel_linear,
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear,
)


Expand Down Expand Up @@ -446,7 +446,7 @@ def _transform(

tp_size = get_tensor_model_parallel_size()

weight_map = getattr(model, "_weight_map", None)
weight_map = getattr(model, "_weight_map", {})
config = model.config
normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

Expand Down Expand Up @@ -492,13 +492,12 @@ def _transform(
)
else:
for name in [cls.QUERIES_NAME, cls.KEYS_NAME, cls.VALUES_NAME]:
linear_layer_weight_info, linear_layer_bias_weight_info = None, None
if weight_map is not None:
linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info(
weight_map,
f"{layer_qualified_name}.{name}",
device=device,
)
linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info(
weight_map,
f"{layer_qualified_name}.{name}",
device=device,
fail_if_not_found=False,
)
parallel_linear = linear_to_parallel_linear(
getattr(layer, name),
"column",
Expand All @@ -512,40 +511,39 @@ def _transform(
setattr(layer, name, parallel_linear)

if cls.OUTPUT_PROJECTION_NAME is not None:
linear_layer_weight_info, linear_layer_bias_weight_info = None, None
if weight_map is not None:
linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info(
weight_map,
f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}",
device=device,
)
setattr(
layer,
cls.OUTPUT_PROJECTION_NAME,
linear_to_parallel_linear(
getattr(layer, cls.OUTPUT_PROJECTION_NAME),
"row",
input_is_parallel=True,
linear_layer_weight_info=linear_layer_weight_info,
linear_layer_bias_weight_info=linear_layer_bias_weight_info,
sequence_parallel_enabled=sequence_parallel_enabled,
skip_weight_load=skip_linear_weight_load,
device=device,
),
linear_layer_weight_info, linear_layer_bias_weight_info = get_linear_weight_info(
weight_map,
f"{layer_qualified_name}.{cls.OUTPUT_PROJECTION_NAME}",
device=device,
fail_if_not_found=False,
)

if needs_gqa_qkv_column_parallel_linear:
qga_qkv_layer = getattr(layer, cls.GQA_QKV_PROJ_NAME)
maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
model,
parallel_output_proj = linear_to_parallel_linear(
getattr(layer, cls.OUTPUT_PROJECTION_NAME),
qga_qkv_layer.num_attention_heads,
qga_qkv_layer.num_key_value_heads,
qga_qkv_layer.kv_size_multiplier,
try_from_checkpoint=not skip_linear_weight_load,
try_from_original_layer=not skip_linear_weight_load,
"row",
input_is_parallel=True,
linear_layer_weight_info=linear_layer_weight_info,
linear_layer_bias_weight_info=linear_layer_bias_weight_info,
sequence_parallel_enabled=sequence_parallel_enabled,
skip_weight_load=skip_linear_weight_load,
device=device,
)

if needs_gqa_qkv_column_parallel_linear:
qga_qkv_layer = getattr(layer, cls.GQA_QKV_PROJ_NAME)
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
parallel_output_proj,
qga_qkv_layer.num_attention_heads,
qga_qkv_layer.num_key_value_heads,
qga_qkv_layer.kv_size_multiplier,
original_output_projection=getattr(layer, cls.OUTPUT_PROJECTION_NAME),
linear_layer_weight_info=linear_layer_weight_info,
linear_layer_bias_weight_info=linear_layer_bias_weight_info,
try_from_checkpoint=not skip_linear_weight_load,
try_from_original_layer=not skip_linear_weight_load,
)

setattr(layer, cls.OUTPUT_PROJECTION_NAME, parallel_output_proj)

setattr(
layer,
num_attention_heads_name,
Expand Down
151 changes: 32 additions & 119 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,6 @@ def compute_query_indices_for_rank(
num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size
query_group_size = num_attention_heads // num_key_value_heads
query_group_size_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank
num_ranks_to_complete_group = query_group_size // query_group_size_per_rank

queries_indices = [
torch.arange(num_attention_heads_per_rank // num_key_value_heads_per_rank)
Expand All @@ -505,7 +504,6 @@ def compute_query_indices_for_rank(
keys_indices = torch.repeat_interleave(keys_indices, num_attention_heads_per_rank // num_key_value_heads_per_rank)
keys_indices = torch.chunk(keys_indices, tp_size)

# shift_per_key = torch.arange(0, num_attention_heads, num_key_value_heads)
shift_per_key = torch.arange(0, num_attention_heads, query_group_size)

shift_within_query_group = torch.arange(0, query_group_size, query_group_size_per_rank)
Expand All @@ -523,7 +521,6 @@ def compute_query_indices_for_rank(
indices.append(q_indices + k_shift + group_shift)

indices = torch.cat(indices, dim=0)
print(tp_rank, indices)
return indices


Expand All @@ -546,37 +543,6 @@ def create_query_or_output_projection_local_weight_from_regular_weight(
if query_or_output_proj == "output":
weight_data = weight_data.transpose(0, 1)

# indices = torch.arange(num_attention_heads // tp_size)

# key_index = (tp_rank % kv_size_multiplier)

# # Detailed computation

# shift = key_index * num_attention_heads_per_rank * num_ranks_per_group

# queries_indices = indices + shift
# num_groups_per_rank = num_attention_heads_per_rank // num_key_value_heads_per_rank
# queries_indices = [
# torch.arange(num_key_value_heads_per_rank)
# for _ in range(num_groups_per_rank)
# ]
# # TODO: should we split with num_attention_heads_per_rank // num_key_value_heads_per_rank or num_key_value_heads_per_rank?
# keys_indices = (
# torch.arange(num_key_value_heads)
# .repeat(kv_size_multiplier)
# .chunk(tp_size)
# # .split(num_attention_heads_per_rank // num_key_value_heads_per_rank)
# )
# print(keys_indices)
# final_queries_indices = []
# for ith_key_in_rank, indices in enumerate(queries_indices):
# tp_rank_shift = tp_rank * num_attention_heads_per_rank
# key_index = keys_indices[tp_rank][ith_key_in_rank]
# shift = key_index * num_attention_heads_per_rank * num_ranks_per_group
# final_queries_indices.append(indices + shift + tp_rank_shift)

# indices = torch.cat(final_queries_indices, dim=0)

reshaped_weight = weight_data.view(num_attention_heads, -1, weight_data.size(-1))
indices = compute_query_indices_for_rank(
tp_size, tp_rank, num_attention_heads, num_key_value_heads, kv_size_multiplier
Expand Down Expand Up @@ -627,38 +593,19 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
with torch.no_grad():
if not was_already_initialized_during_parallelization(weight):
if linear_layer_weight_info is not None:
if proj_name in ["k", "v"]:
tensor_slices = None
else:
tensor_slices = (
(tp_rank * row_size, (tp_rank + 1) * row_size),
None,
)
tensor_slices = None
# weight_data = load_tensor_for_weight(
# linear_layer_weight_info,
# tensor_slices=(
# (tp_rank * row_size, (tp_rank + 1) * row_size),
# None,
# ),
# )
weight_data = load_tensor_for_weight(linear_layer_weight_info, tensor_slices=tensor_slices)
weight_data = load_tensor_for_weight(linear_layer_weight_info)
if proj_name in "kv":
weight_data = create_kv_proj_local_weight_from_regular_weight(
weight, kv_size_multiplier, weight.size(0)
weight_data, kv_size_multiplier, weight.size(0)
)
else:
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query"
weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query"
)
weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(weight, True)
del weight_data
elif linear_layer.weight.device != torch.device("meta"):
# if proj_name in ["k", "v"]:
# weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, layer.kv_size_multiplier, weight.size(0))
# else:
# weight_data = linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :]
elif linear_layer is not None and linear_layer.weight.device != torch.device("meta"):
if proj_name in "kv":
weight_data = create_kv_proj_local_weight_from_regular_weight(
linear_layer.weight, kv_size_multiplier, weight.size(0)
Expand Down Expand Up @@ -702,34 +649,6 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
mark_parameter_init_status_during_parallelization(bias, False)


def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
layer: "layers.ColumnParallelLinear",
num_attention_heads: int,
num_key_value_heads: int,
kv_size_multiplier: int,
linear_layer_weight_info: Optional[WeightInformation] = None,
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
linear_layer: Optional["torch.nn.Linear"] = None,
):
weight = layer.weight
with torch.no_grad():
if not was_already_initialized_during_parallelization(weight):
weight_data = None
if linear_layer_weight_info is not None:
weight_data = load_tensor_for_weight(linear_layer_weight_info)
elif linear_layer is not None and linear_layer.weight.device != torch.device("meta"):
weight_data = linear_layer.weight.data
if weight_data is not None:
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output"
)
weight.copy_(weight_data.repeat(1, 32))
mark_parameter_init_status_during_parallelization(weight, True)
else:
mark_parameter_init_status_during_parallelization(weight, False)
# TODO: bias


def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
model: torch.nn.Module,
layer: OptimumGQAQKVColumnParallelLinear,
Expand Down Expand Up @@ -762,44 +681,38 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
)


def maybe_load_weights_from_checkpoint_or_original_layer_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
model: torch.nn.Module,
def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
output_projection: "layers.ColumnParallelLinear",
num_attention_heads: int,
num_key_value_heads: int,
kv_size_multiplier: int,
original_output_projection: Optional[torch.nn.Linear] = None,
linear_layer_weight_info: Optional[WeightInformation] = None,
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
try_from_checkpoint: bool = True,
try_from_original_layer: bool = False,
):

# # If we were able to initiliaze the query projection, then we assume we should initialize the output projection
# # as well.
# if weight_name == "weight_q" and was_already_initialized_during_parallelization(getattr(layer, weight_name)):
# parent_qualified_name = named_modules[layer].rsplit(".", maxsplit=1)[0]
# output_projection_qualified_name = f"{parent_qualified_name}.{layer.output_proj_name}"
# output_projection = model.get_submodule(output_projection_qualified_name)
weight_map = getattr(model, "_weight_map", {})
module_to_name = {v: k for k, v in model.named_modules()}
linear_weight_info, linear_bias_weight_info = get_linear_weight_info(
weight_map, module_to_name[output_projection], fail_if_not_found=False
)
if try_from_checkpoint and linear_weight_info is not None:
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
output_projection,
num_attention_heads,
num_key_value_heads,
kv_size_multiplier,
linear_layer_weight_info=linear_weight_info,
linear_layer_bias_weight_info=linear_bias_weight_info,
)
elif try_from_original_layer:
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
output_projection,
num_attention_heads,
num_key_value_heads,
kv_size_multiplier,
linear_layer=output_projection,
)
weight = output_projection.weight
with torch.no_grad():
if not was_already_initialized_during_parallelization(weight):
weight_data = None
if try_from_checkpoint and linear_layer_weight_info is not None:
weight_data = load_tensor_for_weight(linear_layer_weight_info)
elif (
try_from_original_layer
and original_output_projection is not None
and original_output_projection.weight.device != torch.device("meta")
):
weight_data = original_output_projection.weight.data
if weight_data is not None:
weight_data = create_query_or_output_projection_local_weight_from_regular_weight(
weight_data, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output"
)
weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(weight, True)
else:
mark_parameter_init_status_during_parallelization(weight, False)
# TODO: bias


@requires_neuronx_distributed
Expand Down Expand Up @@ -1004,9 +917,9 @@ def linear_to_parallel_linear(
if embedding_weight_to_tie is not None:
parallel_linear_layer.weight = embedding_weight_to_tie

del linear_layer.weight
if linear_layer.bias is not None:
del linear_layer.bias
# del linear_layer.weight
# if linear_layer.bias is not None:
# del linear_layer.bias

return parallel_linear_layer

Expand Down
17 changes: 17 additions & 0 deletions tests/distributed/test_model_parallelization.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,27 @@ def test_llama_v2_gqa_with_qkv_parallel_collumn_linear(
[22, 23, 30, 31],
],
],
[
8,
32,
4,
8,
[
[0, 8, 16, 24],
[1, 9, 17, 25],
[2, 10, 18, 26],
[3, 11, 19, 27],
[4, 12, 20, 28],
[5, 13, 21, 29],
[6, 14, 22, 30],
[7, 15, 23, 31],
],
],
],
ids=[
"32-heads-4kv-heads-kv-mul-2,one kv head per rank",
"32-heads-4kv-heads-kv-mul-4,multiple kv heads per rank",
"32-heads-4kv-heads-kv-mul-8,all kv heads per rank",
],
)
def test_compute_query_indices_for_rank(
Expand Down

0 comments on commit 87aada5

Please sign in to comment.