From 17cf1478234c8b0c82620eb286fe36504ec206c7 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Tue, 16 Apr 2024 14:20:54 +0200 Subject: [PATCH] Update Llama attention --- optimum/neuron/__init__.py | 2 - optimum/neuron/distributed/decoder_models.py | 50 +++++--------------- optimum/neuron/distributed/utils.py | 9 +++- optimum/neuron/utils/training_utils.py | 15 +++--- 4 files changed, 25 insertions(+), 51 deletions(-) diff --git a/optimum/neuron/__init__.py b/optimum/neuron/__init__.py index dca82b384..f2b43ff74 100644 --- a/optimum/neuron/__init__.py +++ b/optimum/neuron/__init__.py @@ -100,7 +100,5 @@ ) -import os - from .utils import is_neuron_available, is_neuronx_available, patch_transformers_for_neuron_sdk from .version import __version__ diff --git a/optimum/neuron/distributed/decoder_models.py b/optimum/neuron/distributed/decoder_models.py index 7ba6784a7..153093b1f 100644 --- a/optimum/neuron/distributed/decoder_models.py +++ b/optimum/neuron/distributed/decoder_models.py @@ -15,7 +15,7 @@ """Classes related to `neuronx-distributed` to perform parallelism.""" import warnings -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Optional, Tuple import torch from transformers.cache_utils import Cache @@ -455,6 +455,7 @@ def attention_forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if "padding_mask" in kwargs: @@ -502,45 +503,29 @@ def attention_forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - "The cache structure has changed since version `transformers v4.36. If you are using " - f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to " - "initialize the attention class with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) - # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -581,19 +566,6 @@ class LlamaPipelineParallelismSpecs(PipelineParallelismSpecs): DEFAULT_INPUT_NAMES = ("input_ids", "attention_mask", "labels") LEAF_MODULE_CLASSES_NAMES = [LlamaRMSNorm] - @classmethod - def get_patching_specs(cls) -> List[Tuple[str, Any]]: - return [] - # leaf_prepare_4d_causal_attention_mask = torch.fx._symbolic_trace._create_wrapped_func( - # _prepare_4d_causal_attention_mask - # ) - # return [ - # ( - # "transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask", - # leaf_prepare_4d_causal_attention_mask, - # ), - # ] - class LlamaParallelizer(Parallelizer): SEQUENCE_PARALLELSIM_SPECS_CLS = LlamaSequenceParallelismSpecs diff --git a/optimum/neuron/distributed/utils.py b/optimum/neuron/distributed/utils.py index 742005a20..7dd4aaaf1 100644 --- a/optimum/neuron/distributed/utils.py +++ b/optimum/neuron/distributed/utils.py @@ -28,6 +28,7 @@ import torch from transformers import PretrainedConfig from transformers.utils import is_peft_available +from transformers.utils.fx import HFTracer from ...utils import logging from ..utils import DynamicPatch, Patcher @@ -44,7 +45,7 @@ if is_neuronx_distributed_available(): from neuronx_distributed.modules.qkv_linear import GQAQKVColumnParallelLinear from neuronx_distributed.parallel_layers import layers - from neuronx_distributed.pipeline.trace import HFTracerWrapper + from neuronx_distributed.pipeline.trace import HFTracerWrapper, NxDTracer else: class GQAQKVColumnParallelLinear(torch.nn.Module): @@ -1361,7 +1362,11 @@ def is_sharded(self): class OptimumNeuronFXTracer(HFTracerWrapper): def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: - return super().is_leaf_module(m, module_qualified_name) or isinstance(m, FakeProj) + return ( + NxDTracer.is_leaf_module(self, m, module_qualified_name) + or HFTracer.is_leaf_module(self, m, module_qualified_name) + or isinstance(m, FakeProj) + ) class SavedModelInTemporaryDirectory: diff --git a/optimum/neuron/utils/training_utils.py b/optimum/neuron/utils/training_utils.py index f20d279fd..ca3baa15e 100644 --- a/optimum/neuron/utils/training_utils.py +++ b/optimum/neuron/utils/training_utils.py @@ -19,7 +19,6 @@ import torch import transformers from accelerate import skip_first_batches as accelerate_skip_first_batches -from packaging import version from transformers import GenerationMixin from transformers.models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, @@ -45,7 +44,7 @@ from ...utils.logging import set_verbosity as set_verbosity_optimum from ..generation import GeneralNeuronGenerationMixin, NeuronGenerationMixin -from . import is_neuronx_distributed_available, is_torch_xla_available +from . import is_neuronx_distributed_available from .require_utils import requires_neuronx_distributed, requires_torch_xla @@ -183,13 +182,13 @@ def patch_transformers_for_neuron_sdk(): """ transformers.utils.logging.set_verbosity = set_verbosity - if version.parse(transformers.__version__) >= version.parse("4.39.3"): - raise RuntimeError("This should be removed since it is not needed.") - elif is_torch_xla_available(): - import sys + # if version.parse(transformers.__version__) >= version.parse("4.39.3"): + # raise RuntimeError("This should be removed since it is not needed.") + # elif is_torch_xla_available(): + # import sys - sys.modules["torch_xla.distributed.spmd"] = object() - sys.modules["torch_xla.runtime"] = object() + # sys.modules["torch_xla.distributed.spmd"] = object() + # sys.modules["torch_xla.runtime"] = object() @requires_torch_xla