Skip to content

Commit

Permalink
Cleanup optimum/neuron/distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 8, 2024
1 parent 39c8372 commit 73aa74f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 119 deletions.
14 changes: 1 addition & 13 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@
logger = logging.get_logger()


class SavedModelInTemporaryDirectory:
def __init__(self, model: "PreTrainedModel"):
self.tmpdir = TemporaryDirectory()
self.model = model

def __enter__(self):
self.model.save_pretrained(self.tmpdir.name)
return self.tmpdir.name

def __exit__(self, *exc):
self.tmpdir.cleanup()


class SequenceParallelismSpecs:
SEQUENCE_PARALLEL_LAYERNORM_PATTERNS: Optional[List[str]] = None
LAYERNORM_TYPE: LayerNormType = LayerNormType.REGULAR
Expand Down Expand Up @@ -280,6 +267,7 @@ def _parallelize(
Returns:
`PreTrainedModel`: The parallelized model.
"""
pass

@classmethod
@requires_neuronx_distributed
Expand Down
120 changes: 14 additions & 106 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Set, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -125,27 +126,6 @@ def __post_init__(self):
self.qualified_name = self.qualified_name[len(prefix) :]


@dataclass
class GroupedQueryAttentionInfo:
"""
Describes the information about Grouped Query Attention.
Attributes:
- num_attention_heads (`int`) -- The number of query heads in the layer.
- num_key_value_heads (`int`) -- The number of key value heads in the layer.
"""

num_attention_heads: int
num_key_value_heads: int

def __post_init__(self):
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError(
f"The number of key value heads ({self.num_key_value_heads}) does not divide the number of query heads"
f"({self.num_attention_heads})"
)


class FakeProj(torch.nn.Module):
"""
Dummy layer that replaces a Linear projection by gathering the result from its associated merged
Expand Down Expand Up @@ -995,91 +975,6 @@ def linear_to_parallel_linear(
return parallel_linear_layer


@requires_neuronx_distributed
def gqa_key_value_slicing_when_tp_size_greater_than_num_key_value_heads(
gqa_info: GroupedQueryAttentionInfo,
linear_layer: "torch.nn.Linear",
linear_layer_weight_info: Optional[WeightInformation] = None,
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
device: Optional["torch.device"] = None,
) -> "torch.nn.Linear":
"""
Helper function that splits key and value projections when performing Grouped Query Attention with the TP size is
smaller than the number of key value heads.
Args:
gqa_info (`GroupedQueryAttentionInfo`):
The dataclass containing the information related to Grouped Query Attention.
linear_layer (`torch.nn.Linear`):
The linear layer to split.
linear_layer_weight_info (`Optional[torch.nn.Linear]`, defaults to `None`):
Information about which checkpoint file the linear layer weights are stored in.
linear_layer_bias_weight_info (`Optional[WeightInformation]`, defaults to `None`):
Information about which checkpoint file the linear layer bias is stored in.
device (`Optional[torch.device]`, defaults to `None`):
The device where the new split layer should be put.
Returns:
`torch.nn.Linear`: The split linear layer.
"""
from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_size,
)

tp_size = get_tensor_model_parallel_size()
tp_rank = get_tensor_model_parallel_rank()
if tp_size < gqa_info.num_key_value_heads:
raise ValueError(
f"This function can only be used in the case where the TP size ({tp_size}) is smalled than thue number of "
f"key value heads ({gqa_info.num_key_value_heads})."
)
num_key_value_heads_x_head_dim, hidden_size = linear_layer.weight.shape
head_dim = num_key_value_heads_x_head_dim // gqa_info.num_key_value_heads
if device is None:
device = linear_layer.weight.device
sliced_linear_layer = torch.nn.Linear(
hidden_size, head_dim, device=device, dtype=linear_layer.weight.dtype, bias=linear_layer.bias is not None
)
key_value_head_index = gqa_info.num_key_value_heads * tp_rank // tp_size
with torch.no_grad():
if linear_layer_weight_info is not None:
weight_data = load_tensor_for_weight(
linear_layer_weight_info,
tensor_slices=(
(key_value_head_index * head_dim, (key_value_head_index + 1) * head_dim),
None,
),
)
sliced_linear_layer.weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, True)

elif linear_layer.weight.device != torch.device("meta"):
sliced_linear_layer.weight.copy_(
linear_layer.weight[key_value_head_index * head_dim : (key_value_head_index + 1) * head_dim, :]
)
mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, True)
else:
mark_parameter_init_status_during_parallelization(sliced_linear_layer.weight, False)

if linear_layer.bias is not None:
if linear_layer_bias_weight_info is not None:
bias_weight_data = load_tensor_for_weight(
linear_layer_bias_weight_info,
tensor_slices=((key_value_head_index * head_dim, (key_value_head_index + 1) * head_dim),),
)
sliced_linear_layer.bias.copy_(bias_weight_data)
mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, True)
elif sliced_linear_layer.bias.device != torch.device("meta"):
sliced_linear_layer.bias.copy_(
linear_layer.bias[key_value_head_index * head_dim : (key_value_head_index + 1) * head_dim]
)
mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, True)
else:
mark_parameter_init_status_during_parallelization(sliced_linear_layer.bias, False)
return sliced_linear_layer


@requires_neuronx_distributed
def delete_tensor_model_parallel_attributes(tensor: torch.Tensor):
from neuronx_distributed.parallel_layers.utils import _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
Expand Down Expand Up @@ -1496,3 +1391,16 @@ def is_sharded(self):
class OptimumNeuronFXTracer(HFTracerWrapper):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return super().is_leaf_module(m, module_qualified_name) or isinstance(m, FakeProj)


class SavedModelInTemporaryDirectory:
def __init__(self, model: "PreTrainedModel"):
self.tmpdir = TemporaryDirectory()
self.model = model

def __enter__(self):
self.model.save_pretrained(self.tmpdir.name)
return self.tmpdir.name

def __exit__(self, *exc):
self.tmpdir.cleanup()

0 comments on commit 73aa74f

Please sign in to comment.