Skip to content

Commit

Permalink
[WIP] sequence parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Sep 20, 2023
1 parent 0b06bed commit 3369874
Show file tree
Hide file tree
Showing 6 changed files with 338 additions and 44 deletions.
15 changes: 10 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
from .parallel_layers import IOSequenceParallelizer, LayerNormSequenceParallelizer, LayerNormType
from .parallel_layers import IOSequenceParallelizer, LayerNormSequenceParallelizer, LayerNormType, SequenceCollectiveOpInfo
from .utils import TENSOR_PARALLEL_SHARDS_DIR_NAME, ParameterMetadata, WeightInformation, load_tensor_for_weight


Expand Down Expand Up @@ -67,9 +67,10 @@ class Parallelizer(ABC):
SEQUENCE_PARALLEL_LAYERNORM_PATTERNS: Optional[List[str]] = None
LAYERNORM_TYPE: LayerNormType = LayerNormType.REGULAR
SCATTER_SEQUENCE_AT_FIRST_LAYER_OF_TYPE: Optional[Type["torch.nn.Module"]] = None
SCATTER_BEFORE_FIRST_LAYER: bool = True
GATHER_SEQUENCE_AT_LAST_LAYER_OF_TYPE: Optional[Type["torch.nn.Module"]] = None
PARALLELIZE_INPUT_OF_FIRST_LAYERNORM: bool = False
GATHER_OUTPUT_OF_LAST_LAYERNORM: bool = False
GATHER_AFTER_LAST_LAYER: bool = True
SEQUENCE_COLLECTIVE_OPS_INFOS: Optional[List[SequenceCollectiveOpInfo]] = None

def __init__(self):
self._validate_required_libaries_are_available()
Expand Down Expand Up @@ -173,8 +174,12 @@ def parallelize(

# 2. Taking care of scattering / gathering on the sequence axis in the model via the IOSequenceParallelizer.
io_sequence_parallelizer = IOSequenceParallelizer(
sequence_scatter_at_first_layer_of_type=cls.SCATTER_SEQUENCE_AT_FIRST_LAYER_OF_TYPE,
sequence_gather_at_last_layer_of_type=cls.GATHER_SEQUENCE_AT_LAST_LAYER_OF_TYPE,
sequence_parallel_enabled,
sequence_collective_op_infos=cls.SEQUENCE_COLLECTIVE_OPS_INFOS,
# scatter_sequence_at_first_layer_of_type=cls.SCATTER_SEQUENCE_AT_FIRST_LAYER_OF_TYPE,
# scatter_before_first_layer=cls.SCATTER_BEFORE_FIRST_LAYER,
# gather_sequence_at_last_layer_of_type=cls.GATHER_SEQUENCE_AT_LAST_LAYER_OF_TYPE,
# gather_after_last_layer=cls.GATHER_AFTER_LAST_LAYER,
)
io_sequence_parallelizer.sequence_parallelize(model)

Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def _split_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.view(new_shape)
if sequence_parallel_enabled:
# [S, B, num_heads, head_dim] -> [B, num_heads, S, head_dim]
return tensor.permute(1, 2, 0, 3).contiguous()
return tensor.permute(1, 2, 0, 3)
return tensor.permute(0, 2, 1, 3)

def _merge_heads(self, tensor, num_heads, attn_head_size):
if sequence_parallel_enabled:
# [B, num_heads, S, head_dim] -> [S, B, num_heads, hidden_dim]
# [B, num_heads, S, head_dim] -> [S, B, num_heads, head_dim]
tensor = tensor.permute(2, 0, 1, 3).contiguous()
else:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
Expand Down
192 changes: 181 additions & 11 deletions optimum/neuron/distributed/encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import TYPE_CHECKING, Dict, Optional

import torch
from transformers import T5ForSequenceClassification
from transformers.models.t5.modeling_t5 import T5ForSequenceClassification, T5Attention, T5LayerNorm

from ...utils import NormalizedConfigManager
from .base import Parallelizer
from .parallel_layers import ParallelCrossEntropy, ParallelEmbedding, ParallelMLP, ParallelSelfAttention
from .parallel_layers import LayerNormType, ParallelCrossEntropy, ParallelEmbedding, ParallelMLP, ParallelSelfAttention, SequenceCollectiveOpInfo
from .utils import linear_to_parallel_linear


Expand All @@ -47,8 +47,8 @@ def transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]] = None,
device: Optional["torch.device"] = None,
sequence_parallel_enabled: bool = False,
) -> "torch.nn.Module":
from neuronx_distributed.parallel_layers.parallel_state import (
get_tensor_model_parallel_rank,
Expand Down Expand Up @@ -77,7 +77,7 @@ def transform(
layer.relative_attention_bias.num_embeddings = num_attention_heads_per_rank
set_tensor_model_parallel_attributes(layer.relative_attention_bias.weight, True, 1, stride=1)

layer = super().transform(model, layer, orig_to_parallel=orig_to_parallel, device=device)
layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)

return layer

Expand All @@ -91,7 +91,7 @@ def transform(
cls,
model: "PreTrainedModel",
layer: "torch.nn.Module",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]] = None,
sequence_parallel_enabled: bool = False,
device: Optional["torch.device"] = None,
) -> "torch.nn.Module":
from transformers.models.t5.modeling_t5 import T5DenseGatedActDense
Expand All @@ -106,7 +106,7 @@ def transform(
cls.FIRST_LINEAR_NAME = f"{orig_first_linear_name}_0"

# This will parallelize both wi_0 and wo.
layer = super().transform(model, layer, orig_to_parallel=orig_to_parallel, device=device)
layer = super().transform(model, layer, sequence_parallel_enabled=sequence_parallel_enabled, device=device)

if isinstance(layer, T5DenseGatedActDense):
# In this case, only wi_1 remains to be parallelized, we do it here.
Expand All @@ -133,7 +133,7 @@ def transform(
gather_output=False,
linear_layer_weight_info=linear_layer_weight_info,
linear_layer_bias_weight_info=linear_layer_bias_weight_info,
orig_to_parallel=orig_to_parallel,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
),
)
Expand All @@ -148,13 +148,181 @@ class T5ParallelCrossEntropy(ParallelCrossEntropy):


class T5Parallelizer(Parallelizer):
SEQUENCE_PARALLEL_LAYERNORM_PATTERNS = [
"encoder.block.[0-9]+.layer.[0-9]+.layer_norm",
"encoder.final_layer_norm",
"decoder.block.[0-9]+.layer.[0-9]+.layer_norm",
"decoder.final_layer_norm",
]

LAYERNORM_TYPE = LayerNormType.RMS_NORM
SEQUENCE_COLLECTIVE_OPS_INFOS = [
SequenceCollectiveOpInfo(torch.nn.Embedding, "first", "

]
# SCATTER_SEQUENCE_AT_FIRST_LAYER_OF_TYPE = torch.nn.Embedding
# # Scattering needs to happen before the embeddings computations.
# SCATTER_BEFORE_FIRST_LAYER = False
# GATHER_SEQUENCE_AT_LAST_LAYER_OF_TYPE = T5LayerNorm

@classmethod
def patch_for_sequence_paralelism(cls, model: "PreTrainedModel", sequence_parallel_enabled: bool):
from torch import nn

def sequence_parallel_forward(
self,
hidden_states,
mask=None,
key_value_states=None,
position_bias=None,
past_key_value=None,
layer_head_mask=None,
query_length=None,
use_cache=False,
output_attentions=False,
):
# Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
if sequence_parallel_enabled:
batch_size = hidden_states.shape[1]
else:
batch_size = hidden_states.shape[0]


def shape(states):
"""projection"""
if sequence_parallel_enabled:
return states.view(-1, batch_size, self.n_heads, self.key_value_proj_dim).permute(1, 2, 0, 3)
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

def unshape(states):
"""reshape"""
if sequence_parallel_enabled:
return states.permute(2, 0, 1, 3).view(-1, batch_size, self.inner_dim)
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))

if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
if sequence_parallel_enabled:
hidden_states = torch.cat([past_key_value, hidden_states.transpose(0, 1)], dim=2)
else:
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else:
# cross-attn
hidden_states = past_key_value
return hidden_states

# get query states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)

# get key/value states
key_states = project(
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

real_seq_length = key_states.shape[2]

if past_key_value is not None:
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

# if sequence_parallel_enabled:
# key_length = real_seq_length if key_value_states is None else key_value_states.shape[0]
# else:
print("Key value states shape", key_value_states.shape if key_value_states is not None else None, key_states.shape)
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

# compute scores
scores = torch.matmul(
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

if position_bias is None:
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
)
if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True
else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)

# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

if mask is not None:
print(position_bias.shape, mask.shape)
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)

if self.pruned_heads:
mask = torch.ones(position_bias.shape[1])
mask[list(self.pruned_heads)] = 0
position_bias_masked = position_bias[:, mask.bool()]
else:
position_bias_masked = position_bias

scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores
) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)

# Mask heads if we want to
if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask

attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
attn_output = self.o(attn_output)

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

if output_attentions:
outputs = outputs + (attn_weights,)
return outputs

for module in model.modules():
if isinstance(module, T5Attention):
module.forward = sequence_parallel_forward.__get__(module)


@classmethod
def _parallelize(
cls,
model: "PreTrainedModel",
orig_to_parallel: Optional[Dict[int, "torch.nn.Parameter"]],
device: Optional["torch.device"] = None,
parallelize_embeddings: bool = True,
sequence_parallel_enabled: bool = False,
) -> "PreTrainedModel":
if isinstance(model, T5ForSequenceClassification):
raise NotImplementedError(
Expand All @@ -172,6 +340,7 @@ def _parallelize(
block.layer[0].SelfAttention = T5ParallelSelfAttention.transform(
model,
block.layer[0].SelfAttention,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
)
block.layer[1].DenseReluDense = T5ParallelMLP.transform(
Expand All @@ -181,11 +350,12 @@ def _parallelize(
block.layer[0].SelfAttention = T5ParallelSelfAttention.transform(
model,
block.layer[0].SelfAttention,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
)
block.layer[2].DenseReluDense = T5ParallelMLP.transform(
model, block.layer[2].DenseReluDense, device=device
)
model, block.layer[2].DenseReluDense, sequence_parallel_enabled=sequence_parallel_enabled, device=device
)
if parallelize_embeddings:
model = T5ParallelCrossEntropy.transform(model, model, device=device)
model = T5ParallelCrossEntropy.transform(model, model, sequence_parallel_enabled=sequence_parallel_enabled, device=device)
return model
3 changes: 2 additions & 1 deletion optimum/neuron/distributed/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

from typing import TYPE_CHECKING, Optional

import torch

from ..utils.require_utils import requires_neuronx_distributed
from .base import Parallelizer
from .parallel_layers import ParallelCrossEntropy, ParallelEmbedding, ParallelSelfAttention, ParallelSelfOutput
from .utils import create_sequence_parallel_attention_forward


if TYPE_CHECKING:
import torch
from transformers import PreTrainedModel


Expand Down
Loading

0 comments on commit 3369874

Please sign in to comment.