diff --git a/optimum/neuron/distributed/base.py b/optimum/neuron/distributed/base.py index fc29d6b57..576e6e989 100644 --- a/optimum/neuron/distributed/base.py +++ b/optimum/neuron/distributed/base.py @@ -705,7 +705,6 @@ def should_parallelize_layer_predicate_func(layer): if checkpoint_dir is not None: cls.load_model_checkpoint(model, checkpoint_dir) - # model._original_parameter_names_to_gqa_qkv_names = original_parameter_names_to_gqa_qkv_names model._gqa_qkv_metadata = gqa_qkv_metadata return model diff --git a/optimum/neuron/distributed/parallel_layers.py b/optimum/neuron/distributed/parallel_layers.py index e226b1876..c0f97bc5d 100644 --- a/optimum/neuron/distributed/parallel_layers.py +++ b/optimum/neuron/distributed/parallel_layers.py @@ -447,7 +447,7 @@ def _transform( tp_size = get_tensor_model_parallel_size() - weight_map = getattr(model, "_weight_map", {}) + weight_map = getattr(model, "_weight_map", None) config = model.config normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config) diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index c25b95b48..54088ff67 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -147,6 +147,11 @@ def __post_init__(self): class FakeProj(torch.nn.Module): + """ + Dummy layer that replaces a Linear projection by gathering the result from its associated merged + QGAQKVColumnParallelLinear. + """ + def __init__( self, fully_qualified_name: str, @@ -178,6 +183,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class OptimumGQAQKVColumnParallelLinear(GQAQKVColumnParallelLinear): + """ + Same as GQAQKVColumnParallelLinear with the needed metadata for `optimum-neuron`. + """ + def __init__( self, query_proj_name: str, @@ -238,6 +247,10 @@ def get_parameter_names_mapping( def get_parameter_names_mapping_after_gqa_qkv_replacement( model: torch.nn.Module, reversed: bool = False ) -> Dict[str, str]: + """ + Returns the mapping between the original projection names and their names after replacing them with + GQAQKVColumnParallelLinear. + """ from neuronx_distributed.pipeline import NxDPPModel mapping = {} @@ -254,6 +267,10 @@ def get_parameter_names_mapping_after_gqa_qkv_replacement( @requires_neuronx_distributed def get_output_projection_qualified_names_after_qga_qkv_replacement(model: torch.nn.Module) -> Set[str]: + """ + Returns the names of the output projections inside the attention layer, these are needed when using + GQAQKVColumnParallelLinear. + """ from neuronx_distributed.pipeline import NxDPPModel qualified_names = set() @@ -430,12 +447,14 @@ def embedding_to_parallel_embedding( def get_linear_weight_info( - weight_map: Dict[str, Union[Path, str]], + weight_map: Optional[Dict[str, Union[Path, str]]], linear_layer_qualified_name: str, device: Optional[torch.device] = None, fail_if_not_found: bool = True, ) -> Tuple[Optional[WeightInformation], Optional[WeightInformation]]: linear_layer_weight_qualified_name = f"{linear_layer_qualified_name}.weight" + if weight_map is None: + weight_map = {} if linear_layer_weight_qualified_name not in weight_map: if fail_if_not_found: raise ValueError( @@ -470,6 +489,11 @@ def get_linear_weight_info( def create_kv_proj_local_weight_from_regular_weight( weight_data: torch.Tensor, kv_size_multiplier: int, output_size_per_partition: int ) -> torch.Tensor: + """ + Creates the local version of the key or value projections weight for the given TP rank when using + GQAQKVColumnParallelLinear. + """ + assert not isinstance(weight_data, torch.nn.Parameter) from neuronx_distributed.parallel_layers.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_size, @@ -485,6 +509,9 @@ def create_kv_proj_local_weight_from_regular_weight( def compute_query_indicies_for_rank( tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int ): + """ + Computes the permutation for the query weight wheun using GQAQKVColumnParallelLinear. + """ num_attention_heads_per_rank = num_attention_heads // tp_size num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size query_group_size = num_attention_heads // num_key_value_heads @@ -526,7 +553,12 @@ def create_query_or_output_projection_local_weight_from_regular_weight( kv_size_multiplier: int, query_or_output_proj: Union[Literal["query"], Literal["output"]], ) -> torch.Tensor: + """ + Creates the local version of the query or output projections weight for the given TP rank when using + GQAQKVColumnParallelLinear. + """ assert query_or_output_proj in ["query", "output"] + assert not isinstance(weight_data, torch.nn.Parameter) from neuronx_distributed.parallel_layers.parallel_state import ( get_tensor_model_parallel_rank, @@ -565,7 +597,12 @@ def create_local_bias_from_regular_bias( query_or_key_value_bias: Union[Literal["query"], Literal["key_value"]], gather_output: bool, ) -> torch.Tensor: + """ + Creates the local version of the query, key and value projections bias for the given TP rank when using + GQAQKVColumnParallelLinear. + """ assert query_or_key_value_bias in ["query", "key_value"] + assert not isinstance(bias_weigth_data, torch.nn.Parameter) from neuronx_distributed.parallel_layers.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_size, @@ -703,7 +740,7 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear( def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear( - output_projection: "layers.ColumnParallelLinear", + output_projection: "layers.RowParallelLinear", num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int, @@ -714,9 +751,10 @@ def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_l try_from_original_layer: bool = False, ): weight = output_projection.weight + bias = output_projection.bias with torch.no_grad(): - weight_data = None if not was_already_initialized_during_parallelization(weight): + weight_data = None if try_from_checkpoint and linear_layer_weight_info is not None: weight_data = load_tensor_for_weight(linear_layer_weight_info) elif ( @@ -733,7 +771,19 @@ def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_l mark_parameter_init_status_during_parallelization(weight, True) else: mark_parameter_init_status_during_parallelization(weight, False) - # TODO: bias + if bias is not None and not was_already_initialized_during_parallelization(bias): + bias_weight_data = None + if linear_layer_bias_weight_info is not None: + bias_weight_data = load_tensor_for_weight(linear_layer_bias_weight_info) + elif original_output_projection is not None and original_output_projection.bias.device != torch.device( + "meta" + ): + bias_weight_data = original_output_projection.bias.data + if bias_weight_data is not None: + output_projection.bias.copy_(bias_weight_data) + mark_parameter_init_status_during_parallelization(output_projection.bias, True) + else: + mark_parameter_init_status_during_parallelization(output_projection.bias, False) @requires_neuronx_distributed