Skip to content

Commit

Permalink
WIP works for queries
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Mar 7, 2024
1 parent 83f27af commit 52bee5d
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 10 deletions.
2 changes: 2 additions & 0 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear(
"GQAQKVColumnParallelLinear is not needed."
)

num_attention_heads = getattr(attention_layer, cls.NUM_ATTENTION_HEADS_NAME)
query_linear = getattr(attention_layer, cls.QUERIES_NAME)
key_linear = getattr(attention_layer, cls.KEYS_NAME)

Expand All @@ -370,6 +371,7 @@ def replace_qkv_by_gqa_qkv_column_parallel_linear(
cls.QUERIES_NAME,
cls.KEYS_NAME,
cls.VALUES_NAME,
num_attention_heads,
num_key_value_heads,
hidden_size,
[query_in_features, key_value_in_features],
Expand Down
95 changes: 85 additions & 10 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __init__(
query_proj_name: str,
key_proj_name: str,
value_proj_name: str,
output_proj_name: str,
num_attention_heads: int,
num_key_value_heads: int,
input_size: int,
output_sizes: int,
Expand Down Expand Up @@ -231,8 +233,10 @@ def __init__(
self.query_proj_name = query_proj_name
self.key_proj_name = key_proj_name
self.value_proj_name = value_proj_name
self.output_proj_name = output_proj

self._qkv_proj_name_to_proj_name = {"q": query_proj_name, "k": key_proj_name, "v": value_proj_name}
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads

def get_parameter_names_mapping(
Expand Down Expand Up @@ -464,7 +468,7 @@ def get_linear_weight_info(


@requires_neuronx_distributed
def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, kv_size_multiplier: int, output_size_per_partition: int):
def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, kv_size_multiplier: int, output_size_per_partition: int) -> torch.Tensor:
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_size

tp_size = get_tensor_model_parallel_size()
Expand All @@ -477,6 +481,31 @@ def create_kv_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter,
split = torch.split(repeated_weight, output_size_per_partition, dim=0)
return torch.cat(split[tp_rank::tp_size], dim=0)


def create_q_or_o_proj_local_weight_from_regular_weight(weight: torch.nn.Parameter, num_attention_heads: int, num_key_value_heads: int, kv_size_multiplier: int, query_or_output_proj: Union[Literal["query"], Literal["output"]]) -> torch.Tensor:
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_size

tp_size = get_tensor_model_parallel_size()
tp_rank = get_tensor_model_parallel_rank()

indices = torch.arange(num_attention_heads // tp_size)

key_index = (tp_rank % kv_size_multiplier)

# Detailed computation
num_attention_heads_per_rank = (num_attention_heads // tp_size)
num_ranks_per_group = (num_attention_heads // num_key_value_heads) // num_attention_heads_per_rank

shift = key_index * num_attention_heads_per_rank * num_ranks_per_group

queries_indices = indices + shift

reshaped_weight = weight.view(num_attention_heads, -1, weight.size(-1))
queries = reshaped_weight[queries_indices]
queries = queries.reshape(-1, weight.size(-1))
return queries


@requires_neuronx_distributed
def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
layer: OptimumGQAQKVColumnParallelLinear,
Expand Down Expand Up @@ -505,9 +534,8 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
bias = getattr(layer, f"bias_{proj_name}")
row_size, _ = weight.shape

# if proj_name in ["k", "v"]:
# tp_rank = tp_rank // layer.kv_size_multiplier

num_attention_heads = layer.num_attention_heads
num_key_value_heads = layer.num_key_value_heads
kv_size_multiplier = layer.kv_size_multiplier if proj_name in "kv" else 1

with torch.no_grad():
Expand All @@ -532,9 +560,10 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
linear_layer_weight_info,
tensor_slices=tensor_slices
)
if proj_name in ["k", "v"]:
pass
weight_data = create_kv_proj_local_weight_from_regular_weight(weight, kv_size_multiplier, weight.size(0))
if proj_name in "kv":
weight_data = create_kv_proj_local_weight_from_regular_weight(weight, kv_size_multiplier, weight.size(0))
else:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query")
weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(weight, True)
del weight_data
Expand All @@ -543,12 +572,16 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
# weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, layer.kv_size_multiplier, weight.size(0))
# else:
# weight_data = linear_layer.weight[tp_rank * row_size : (tp_rank + 1) * row_size, :]
weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, kv_size_multiplier, weight.size(0))
if proj_name in "kv":
weight_data = create_kv_proj_local_weight_from_regular_weight(linear_layer.weight, kv_size_multiplier, weight.size(0))
else:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(linear_layer.weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "query")
weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(weight, True)
else:
mark_parameter_init_status_during_parallelization(weight, False)

# TODO: add support for bias.
if bias is not None:
if not was_already_initialized_during_parallelization(bias):
if linear_layer_bias_weight_info is not None:
Expand Down Expand Up @@ -578,6 +611,32 @@ def maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
mark_parameter_init_status_during_parallelization(bias, False)


def maybe_load_weights_to_output_proj_when_using_gqa_qkv_column_parallel_linear(
layer: torch.nn.Module,
associated_optimum_gqa_qkv_column_parallel_linear: OptimumGQAQKVColumnParallelLinear,
linear_layer_weight_info: Optional[WeightInformation] = None,
linear_layer_bias_weight_info: Optional[WeightInformation] = None,
linear_layer: Optional["torch.nn.Linear"] = None,
):

num_attention_heads = associated_optimum_gqa_qkv_column_parallel_linear.num_attention_heads
num_key_value_heads = associated_optimum_gqa_qkv_column_parallel_linear.num_key_value_heads
kv_size_multiplier = associated_optimum_gqa_qkv_column_parallel_linear.kv_size_multiplier if proj_name in "kv" else 1
with torch.no_grad():
if not was_already_initialized_during_parallelization(weight):
weight_data = None
if linear_layer_weight_info is not None:
weight_data = load_tensor_for_weight(linear_layer_weight_info)
elif linear_layer.weight.device != torch.device("meta"):
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(linear_layer.weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output")
if weight_data is not None:
weight_data = create_q_or_o_proj_local_weight_from_regular_weight(weight, num_attention_heads, num_key_value_heads, kv_size_multiplier, "output")
weight.copy_(weight_data)
mark_parameter_init_status_during_parallelization(weight, True)
else:
mark_parameter_init_status_during_parallelization(weight, False)
# TODO: bias

def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
model: torch.nn.Module,
layer: OptimumGQAQKVColumnParallelLinear,
Expand All @@ -589,12 +648,10 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
original_to_gqa = layer.get_parameter_names_mapping(named_modules)

for orig_name, gqa_name in original_to_gqa.items():
print(orig_name, gqa_name)
linear_layer_qualified_name, _ = orig_name.rsplit(".", maxsplit=1)
linear_weight_info, linear_bias_weight_info = get_linear_weight_info(
weight_map, linear_layer_qualified_name, fail_if_not_found=False
)
print(linear_weight_info, linear_layer_qualified_name, weight_map)
weight_name = gqa_name.split(".")[-1]
if try_from_checkpoint and linear_weight_info is not None:
maybe_load_linear_weight_to_gqa_qkv_column_parallel_linear(
Expand All @@ -611,6 +668,24 @@ def maybe_load_weights_to_gqa_qkv_column_parallel_linear(
linear_layer=model.get_submodule(orig_layer_name),
)

# If we were able to initiliaze the query projection, then we assume we should initialize the output projection
# as well.
if weight_name == "weight_q" and was_already_initialized_during_parallelization(getattr(layer, weight_name)):
parent_qualified_name = named_modules[layer].rsplit(".", maxsplit=1)[0]
output_projection_qualified_name = f"{parent_qualified_name}.{layer.output_proj_name}"
output_projection = model.get_submodule(output_projection_qualified_name)
linear_weight_info, linear_bias_weight_info = get_linear_weight_info(
weight_map, output_projection_qualified_name, fail_if_not_found=False
)
if try_from_checkpoint and linear_weight_info is not None:
maybe_load_weights_to_output_proj_when_using_gqa_qkv_column_parallel_linear(
output_projection,
layer,
linear_layer_weight_info=linear_weight_info,
linear_layer_bias_weight_info=linear_bias_weight_info,
)



@requires_neuronx_distributed
def maybe_load_linear_weight_to_parallel_linear(
Expand Down

0 comments on commit 52bee5d

Please sign in to comment.