From 52bee5dba4af5c6b77f920f8ee6d13e4d0313060 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Thu, 7 Mar 2024 19:48:24 +0100 Subject: [PATCH] WIP works for queries --- optimum/neuron/distributed/parallel_layers.py | 2 + optimum/neuron/distributed/utils.py | 95 +++++++++++++++++-- 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index e98e4b8de..c8c92e07f 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -352,6 +352,7 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear( "GQAQKVColumnParallelLinear is not needed." ) + num_attention_heads = getattr(attention_layer, cls.NUM_ATTENTION_HEADS_NAME) query_linear = getattr(attention_layer, cls.QUERIES_NAME) key_linear = getattr(attention_layer, cls.KEYS_NAME) @@ -370,6 +371,7 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear( cls.QUERIES_NAME, cls.KEYS_NAME, cls.VALUES_NAME, + num_attention_heads, num_key_value_heads, hidden_size, [query_in_features, key_value_in_features], diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 2ee8d4689..88b62e5fc 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -203,6 +203,8 @@ def __init__( query_proj_name: str, key_proj_name: str, value_proj_name: str, + output_proj_name: str, + num_attention_heads: int, num_key_value_heads: int, input_size: int, output_sizes: int, @@ -231,8 +233,10 @@ def __init__( self.query_proj_name = query_proj_name self.key_proj_name = key_proj_name self.value_proj_name = value_proj_name + self.output_proj_name = output_proj self._qkv_proj_name_to_proj_name = {"q": query_proj_name, "k": key_proj_name, "v": value_proj_name} + self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads def get_parameter_names_mapping( @@ -464,7 +468,7 @@ def get_linear_weight_info( @requires_neuronx_distributed -def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, kv_size_multiplier: int, output_size_per_partition: int): +def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, kv_size_multiplier: int, output_size_per_partition: int) -> torch.Tensor: from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_size tp_size = get_tensor_model_parallel_size() @@ -477,6 +481,31 @@ def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, split = torch.split(repeated_weight, output_size_per_partition, dim=0) return torch.cat(split[tp_rank::tp_size], dim=0) + +def create_q_or_o_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int, query_or_output_proj: Union[Literal["query"], Literal["output"]]) -> torch.Tensor: + from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_size + + tp_size = get_tensor_model_parallel_size() + tp_rank = get_tensor_model_parallel_rank() + + indices = torch.arange(num_attention_heads // tp_size) + + key_index = (tp_rank % kv_size_multiplier) + + # Detailed computation + num_attention_heads_per_rank = (num_attention_heads // tp_size) + num_ranks_per_group = (num_attention_heads // num_key_value_heads) // num_attention_heads_per_rank + + shift = key_index * num_attention_heads_per_rank * num_ranks_per_group + + queries_indices = indices + shift + + reshaped_weight = weight.view(num_attention_heads, -1, weight.size(-1)) + queries = reshaped_weight[queries_indices] + queries = queries.reshape(-1, weight.size(-1)) + return queries + + @requires_neuronx_distributed def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( layer: OptimumGQAQKVColumnParallelLinear, @@ -505,9 +534,8 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( bias = getattr(layer, f"bias_{proj_name}") row_size, _ = weight.shape - # if proj_name in ["k", "v"]: - # tp_rank = tp_rank // layer.kv_size_multiplier - + num_attention_heads = layer.num_attention_heads + num_key_value_heads = layer.num_key_value_heads kv_size_multiplier = layer.kv_size_multiplier if proj_name in "kv" else 1 with torch.no_grad(): @@ -532,9 +560,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( linear_layer_weight_info, tensor_slices=tensor_slices ) - if proj_name in ["k", "v"]: - pass - weight_data = create_kv_proj_local_weight_from_regular_weight(weight, kv_size_multiplier, weight.size(0)) + if proj_name in "kv": + weight_data = create_kv_proj_local_weight_from_regular_weight(weight, kv_size_multiplier, weight.size(0)) + else: + weight_data = create_q_or_o_proj_local_weight_from_regular_weight(weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query") weight.copy_(weight_data) mark_parameter_init_status_during_parallelization(weight, True) del weight_data @@ -543,12 +572,16 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( # weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, layer.kv_size_multiplier, weight.size(0)) # else: # weight_data = linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :] - weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, kv_size_multiplier, weight.size(0)) + if proj_name in "kv": + weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, kv_size_multiplier, weight.size(0)) + else: + weight_data = create_q_or_o_proj_local_weight_from_regular_weight(linear_layer.weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query") weight.copy_(weight_data) mark_parameter_init_status_during_parallelization(weight, True) else: mark_parameter_init_status_during_parallelization(weight, False) + # TODO: add support for bias. if bias is not None: if not was_already_initialized_during_parallelization(bias): if linear_layer_bias_weight_info is not None: @@ -578,6 +611,32 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( mark_parameter_init_status_during_parallelization(bias, False) +def maybe_load_weights_to_output_proj_when_using_gqa_qkv_column_parallel_linear( + layer: torch.nn.Module, + associated_optimum_gqa_qkv_column_parallel_linear: OptimumGQAQKVColumnParallelLinear, + linear_layer_weight_info: Optional[WeightInformation] = None, + linear_layer_bias_weight_info: Optional[WeightInformation] = None, + linear_layer: Optional["torch.nn.Linear"] = None, +): + + num_attention_heads = associated_optimum_gqa_qkv_column_parallel_linear.num_attention_heads + num_key_value_heads = associated_optimum_gqa_qkv_column_parallel_linear.num_key_value_heads + kv_size_multiplier = associated_optimum_gqa_qkv_column_parallel_linear.kv_size_multiplier if proj_name in "kv" else 1 + with torch.no_grad(): + if not was_already_initialized_during_parallelization(weight): + weight_data = None + if linear_layer_weight_info is not None: + weight_data = load_tensor_for_weight(linear_layer_weight_info) + elif linear_layer.weight.device != torch.device("meta"): + weight_data = create_q_or_o_proj_local_weight_from_regular_weight(linear_layer.weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output") + if weight_data is not None: + weight_data = create_q_or_o_proj_local_weight_from_regular_weight(weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output") + weight.copy_(weight_data) + mark_parameter_init_status_during_parallelization(weight, True) + else: + mark_parameter_init_status_during_parallelization(weight, False) + # TODO: bias + def maybe_load_weights_to_gqa_qkv_column_parallel_linear( model: torch.nn.Module, layer: OptimumGQAQKVColumnParallelLinear, @@ -589,12 +648,10 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( original_to_gqa = layer.get_parameter_names_mapping(named_modules) for orig_name, gqa_name in original_to_gqa.items(): - print(orig_name, gqa_name) linear_layer_qualified_name, _ = orig_name.rsplit(".", maxsplit=1) linear_weight_info, linear_bias_weight_info = get_linear_weight_info( weight_map, linear_layer_qualified_name, fail_if_not_found=False ) - print(linear_weight_info, linear_layer_qualified_name, weight_map) weight_name = gqa_name.split(".")[-1] if try_from_checkpoint and linear_weight_info is not None: maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear( @@ -611,6 +668,24 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( linear_layer=model.get_submodule(orig_layer_name), ) + # If we were able to initiliaze the query projection, then we assume we should initialize the output projection + # as well. + if weight_name == "weight_q" and was_already_initialized_during_parallelization(getattr(layer, weight_name)): + parent_qualified_name = named_modules[layer].rsplit(".", maxsplit=1)[0] + output_projection_qualified_name = f"{parent_qualified_name}.{layer.output_proj_name}" + output_projection = model.get_submodule(output_projection_qualified_name) + linear_weight_info, linear_bias_weight_info = get_linear_weight_info( + weight_map, output_projection_qualified_name, fail_if_not_found=False + ) + if try_from_checkpoint and linear_weight_info is not None: + maybe_load_weights_to_output_proj_when_using_gqa_qkv_column_parallel_linear( + output_projection, + layer, + linear_layer_weight_info=linear_weight_info, + linear_layer_bias_weight_info=linear_bias_weight_info, + ) + + @requires_neuronx_distributed def maybe_load_linear_weight_to_parallel_linear(