Skip to content

Commit

Permalink
Update Llama attention
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Apr 16, 2024
1 parent 973a9a6 commit 17cf147
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 51 deletions.
2 changes: 0 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,5 @@
)


import os

from .utils import is_neuron_available, is_neuronx_available, patch_transformers_for_neuron_sdk
from .version import __version__
50 changes: 11 additions & 39 deletions optimum/neuron/distributed/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Classes related to `neuronx-distributed` to perform parallelism."""

import warnings
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Optional, Tuple

import torch
from transformers.cache_utils import Cache
Expand Down Expand Up @@ -455,6 +455,7 @@ def attention_forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if "padding_mask" in kwargs:
Expand Down Expand Up @@ -502,45 +503,29 @@ def attention_forward(
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
"The cache structure has changed since version `transformers v4.36. If you are using "
f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to "
"initialize the attention class with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_key_value = getattr(self, "past_key_value", past_key_value)
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)

if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand Down Expand Up @@ -581,19 +566,6 @@ class LlamaPipelineParallelismSpecs(PipelineParallelismSpecs):
DEFAULT_INPUT_NAMES = ("input_ids", "attention_mask", "labels")
LEAF_MODULE_CLASSES_NAMES = [LlamaRMSNorm]

@classmethod
def get_patching_specs(cls) -> List[Tuple[str, Any]]:
return []
# leaf_prepare_4d_causal_attention_mask = torch.fx._symbolic_trace._create_wrapped_func(
# _prepare_4d_causal_attention_mask
# )
# return [
# (
# "transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask",
# leaf_prepare_4d_causal_attention_mask,
# ),
# ]


class LlamaParallelizer(Parallelizer):
SEQUENCE_PARALLELSIM_SPECS_CLS = LlamaSequenceParallelismSpecs
Expand Down
9 changes: 7 additions & 2 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
from transformers import PretrainedConfig
from transformers.utils import is_peft_available
from transformers.utils.fx import HFTracer

from ...utils import logging
from ..utils import DynamicPatch, Patcher
Expand All @@ -44,7 +45,7 @@
if is_neuronx_distributed_available():
from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear
from neuronx_distributed.parallel_layers import layers
from neuronx_distributed.pipeline.trace import HFTracerWrapper
from neuronx_distributed.pipeline.trace import HFTracerWrapper, NxDTracer
else:

class GQAQKVColumnParallelLinear(torch.nn.Module):
Expand Down Expand Up @@ -1361,7 +1362,11 @@ def is_sharded(self):

class OptimumNeuronFXTracer(HFTracerWrapper):
def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
return super().is_leaf_module(m, module_qualified_name) or isinstance(m, FakeProj)
return (
NxDTracer.is_leaf_module(self, m, module_qualified_name)
or HFTracer.is_leaf_module(self, m, module_qualified_name)
or isinstance(m, FakeProj)
)


class SavedModelInTemporaryDirectory:
Expand Down
15 changes: 7 additions & 8 deletions optimum/neuron/utils/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import transformers
from accelerate import skip_first_batches as accelerate_skip_first_batches
from packaging import version
from transformers import GenerationMixin
from transformers.models.auto.modeling_auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
Expand All @@ -45,7 +44,7 @@

from ...utils.logging import set_verbosity as set_verbosity_optimum
from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin
from . import is_neuronx_distributed_available, is_torch_xla_available
from . import is_neuronx_distributed_available
from .require_utils import requires_neuronx_distributed, requires_torch_xla


Expand Down Expand Up @@ -183,13 +182,13 @@ def patch_transformers_for_neuron_sdk():
"""
transformers.utils.logging.set_verbosity = set_verbosity

if version.parse(transformers.__version__) >= version.parse("4.39.3"):
raise RuntimeError("This should be removed since it is not needed.")
elif is_torch_xla_available():
import sys
# if version.parse(transformers.__version__) >= version.parse("4.39.3"):
# raise RuntimeError("This should be removed since it is not needed.")
# elif is_torch_xla_available():
# import sys

sys.modules["torch_xla.distributed.spmd"] = object()
sys.modules["torch_xla.runtime"] = object()
# sys.modules["torch_xla.distributed.spmd"] = object()
# sys.modules["torch_xla.runtime"] = object()


@requires_torch_xla
Expand Down

0 comments on commit 17cf147

Please sign in to comment.