Skip to content

Commit

Permalink
Add docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Mar 15, 2024
1 parent 1412c4a commit ea774bb
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
1 change: 0 additions & 1 deletion optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,6 @@ def should_parallelize_layer_predicate_func(layer):
if checkpoint_dir is not None:
cls.load_model_checkpoint(model, checkpoint_dir)

# model._original_parameter_names_to_gqa_qkv_names = original_parameter_names_to_gqa_qkv_names
model._gqa_qkv_metadata = gqa_qkv_metadata

return model
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ def _transform(

tp_size = get_tensor_model_parallel_size()

weight_map = getattr(model, "_weight_map", {})
weight_map = getattr(model, "_weight_map", None)
config = model.config
normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

Expand Down
58 changes: 54 additions & 4 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def __post_init__(self):


class FakeProj(torch.nn.Module):
"""
Dummy layer that replaces a Linear projection by gathering the result from its associated merged
QGAQKVColumnParallelLinear.
"""

def __init__(
self,
fully_qualified_name: str,
Expand Down Expand Up @@ -178,6 +183,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:


class OptimumGQAQKVColumnParallelLinear(GQAQKVColumnParallelLinear):
"""
Same as GQAQKVColumnParallelLinear with the needed metadata for `optimum-neuron`.
"""

def __init__(
self,
query_proj_name: str,
Expand Down Expand Up @@ -238,6 +247,10 @@ def get_parameter_names_mapping(
def get_parameter_names_mapping_after_gqa_qkv_replacement(
model: torch.nn.Module, reversed: bool = False
) -> Dict[str, str]:
"""
Returns the mapping between the original projection names and their names after replacing them with
GQAQKVColumnParallelLinear.
"""
from neuronx_distributed.pipeline import NxDPPModel

mapping = {}
Expand All @@ -254,6 +267,10 @@ def get_parameter_names_mapping_after_gqa_qkv_replacement(

@requires_neuronx_distributed
def get_output_projection_qualified_names_after_qga_qkv_replacement(model: torch.nn.Module) -> Set[str]:
"""
Returns the names of the output projections inside the attention layer, these are needed when using
GQAQKVColumnParallelLinear.
"""
from neuronx_distributed.pipeline import NxDPPModel

qualified_names = set()
Expand Down Expand Up @@ -430,12 +447,14 @@ def embedding_to_parallel_embedding(


def get_linear_weight_info(
weight_map: Dict[str, Union[Path, str]],
weight_map: Optional[Dict[str, Union[Path, str]]],
linear_layer_qualified_name: str,
device: Optional[torch.device] = None,
fail_if_not_found: bool = True,
) -> Tuple[Optional[WeightInformation], Optional[WeightInformation]]:
linear_layer_weight_qualified_name = f"{linear_layer_qualified_name}.weight"
if weight_map is None:
weight_map = {}
if linear_layer_weight_qualified_name not in weight_map:
if fail_if_not_found:
raise ValueError(
Expand Down Expand Up @@ -470,6 +489,11 @@ def get_linear_weight_info(
def create_kv_proj_local_weight_from_regular_weight(
weight_data: torch.Tensor, kv_size_multiplier: int, output_size_per_partition: int
) -> torch.Tensor:
"""
Creates the local version of the key or value projections weight for the given TP rank when using
GQAQKVColumnParallelLinear.
"""
assert not isinstance(weight_data, torch.nn.Parameter)
from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
Expand All @@ -485,6 +509,9 @@ def create_kv_proj_local_weight_from_regular_weight(
def compute_query_indicies_for_rank(
tp_size: int, tp_rank: int, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int
):
"""
Computes the permutation for the query weight wheun using GQAQKVColumnParallelLinear.
"""
num_attention_heads_per_rank = num_attention_heads // tp_size
num_key_value_heads_per_rank = (num_key_value_heads * kv_size_multiplier) // tp_size
query_group_size = num_attention_heads // num_key_value_heads
Expand Down Expand Up @@ -526,7 +553,12 @@ def create_query_or_output_projection_local_weight_from_regular_weight(
kv_size_multiplier: int,
query_or_output_proj: Union[Literal["query"], Literal["output"]],
) -> torch.Tensor:
"""
Creates the local version of the query or output projections weight for the given TP rank when using
GQAQKVColumnParallelLinear.
"""
assert query_or_output_proj in ["query", "output"]
assert not isinstance(weight_data, torch.nn.Parameter)

from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -565,7 +597,12 @@ def create_local_bias_from_regular_bias(
query_or_key_value_bias: Union[Literal["query"], Literal["key_value"]],
gather_output: bool,
) -> torch.Tensor:
"""
Creates the local version of the query, key and value projections bias for the given TP rank when using
GQAQKVColumnParallelLinear.
"""
assert query_or_key_value_bias in ["query", "key_value"]
assert not isinstance(bias_weigth_data, torch.nn.Parameter)
from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
Expand Down Expand Up @@ -703,7 +740,7 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear(


def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear(
output_projection: "layers.ColumnParallelLinear",
output_projection: "layers.RowParallelLinear",
num_attention_heads: int,
num_key_value_heads: int,
kv_size_multiplier: int,
Expand All @@ -714,9 +751,10 @@ def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_l
try_from_original_layer: bool = False,
):
weight = output_projection.weight
bias = output_projection.bias
with torch.no_grad():
weight_data = None
if not was_already_initialized_during_parallelization(weight):
weight_data = None
if try_from_checkpoint and linear_layer_weight_info is not None:
weight_data = load_tensor_for_weight(linear_layer_weight_info)
elif (
Expand All @@ -733,7 +771,19 @@ def maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_l
mark_parameter_init_status_during_parallelization(weight, True)
else:
mark_parameter_init_status_during_parallelization(weight, False)
# TODO: bias
if bias is not None and not was_already_initialized_during_parallelization(bias):
bias_weight_data = None
if linear_layer_bias_weight_info is not None:
bias_weight_data = load_tensor_for_weight(linear_layer_bias_weight_info)
elif original_output_projection is not None and original_output_projection.bias.device != torch.device(
"meta"
):
bias_weight_data = original_output_projection.bias.data
if bias_weight_data is not None:
output_projection.bias.copy_(bias_weight_data)
mark_parameter_init_status_during_parallelization(output_projection.bias, True)
else:
mark_parameter_init_status_during_parallelization(output_projection.bias, False)


@requires_neuronx_distributed
Expand Down

0 comments on commit ea774bb

Please sign in to comment.