diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 6c8f16f057..9dfa57844d 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -15,6 +15,9 @@ from typing import Optional, Tuple import torch +import torch.nn.functional as F + +from ...utils import check_if_transformers_greater # TODO (CRITICAL): Layer-wise attention scaling is broken for several archs. @@ -23,7 +26,7 @@ def raise_on_head_mask(head_mask: Optional[torch.Tensor]): if head_mask is not None: raise ValueError( - "layer_head_mask different than None is unsupported for now with BetterTransformer, please" + "layer_head_mask (or head_mask) different than None is unsupported for now with BetterTransformer, please" "open a PR or an issue at https://github.com/huggingface/optimum." ) @@ -534,88 +537,159 @@ def bart_forward( return attn_output, None, past_key_value -# Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward -def bloom_forward( - self, - hidden_states: torch.Tensor, - residual: torch.Tensor, - alibi: torch.Tensor, - attention_mask: torch.Tensor, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - **kwargs, -): - raise_on_head_mask(head_mask) +if check_if_transformers_greater("4.44"): + from transformers.cache_utils import Cache + from transformers.models.bloom.modeling_bloom import dropout_add + + # Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward + def bloom_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + raise_on_head_mask(head_mask) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + + batch_size, q_length, _ = hidden_states.shape + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) + # 3 x [batch_size, num_heads, seq_length, head_dim] + query_layer, key_layer, value_layer = self._reshape(fused_qkv) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + + if attention_mask is not None: # no matter the length, we just slice it + kv_length = cache_position[-1] + 1 # cache position is 0-indexed while length should start from 1 + causal_mask = attention_mask[:, :, :, :kv_length] + alibi = torch.masked_fill(alibi, causal_mask.bool(), torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.dropout_prob_attn if self.training else 0.0, + ) - if output_attentions is True: - raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(batch_size, q_length, self.hidden_size) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + F.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) - fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + output_tensor = dropout_add(output_tensor, residual, self.hidden_dropout, self.training) - # 3 x [batch_size, seq_length, num_heads, head_dim] - (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + outputs = (output_tensor, layer_past) - batch_size, q_length, _, _ = query_layer.shape + return outputs - # Permute to [batch_size, num_heads, seq_length, head_dim] - query_layer = query_layer.transpose(1, 2) +else: + # Adapted from transformers.models.bloom.modeling_bloom.BloomAttention.forward + def bloom_forward( + self, + hidden_states: torch.Tensor, + residual: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + **kwargs, + ): + raise_on_head_mask(head_mask) - if layer_past is not None: - past_key, past_value = layer_past - past_key = past_key.transpose(1, 2) + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") - key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + # [batch_size, seq_length, 3 x hidden_size] + fused_qkv = self.query_key_value(hidden_states) - # concatenate along seq_length dimension - key_layer = torch.cat((past_key, key_layer), dim=1) - value_layer = torch.cat((past_value, value_layer), dim=1) + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) - # untangle batch_size from self.num_heads - key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:]) - value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:]) - else: - key_layer = key_layer.transpose(1, 2) - value_layer = value_layer.transpose(1, 2) - - alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) - alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) - - context_layer = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=alibi, - dropout_p=self.dropout_prob_attn if self.training else 0.0, - ) + batch_size, q_length, _, _ = query_layer.shape - # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] - context_layer = context_layer.transpose(1, 2) - context_layer = context_layer.reshape(*context_layer.shape[:2], -1) - - # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 - if self.pretraining_tp > 1 and self.slow_but_exact: - slices = self.hidden_size / self.pretraining_tp - output_tensor = torch.zeros_like(context_layer) - for i in range(self.pretraining_tp): - output_tensor = output_tensor + torch.nn.functional.linear( - context_layer[:, :, int(i * slices) : int((i + 1) * slices)], - self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], - ) - else: - output_tensor = self.dense(context_layer) + # Permute to [batch_size, num_heads, seq_length, head_dim] + query_layer = query_layer.transpose(1, 2) + + if layer_past is not None: + past_key, past_value = layer_past + past_key = past_key.transpose(1, 2) - output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training) - output_tensor = residual + output_tensor + key_layer = key_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim) - if use_cache is True: - present = ( - key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2), - value_layer.reshape(-1, *value_layer.shape[2:]), + # concatenate along seq_length dimension + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + # untangle batch_size from self.num_heads + key_layer = key_layer.reshape(batch_size, self.num_heads, *key_layer.shape[1:]) + value_layer = value_layer.reshape(batch_size, self.num_heads, *value_layer.shape[1:]) + else: + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.dropout_prob_attn if self.training else 0.0, ) - else: - present = None - return (output_tensor, present) + # Transform [batch_size, num_heads, seq_length, head_dim] to [batch_size, seq_length, num_heads * head_dim] + context_layer = context_layer.transpose(1, 2) + context_layer = context_layer.reshape(*context_layer.shape[:2], -1) + + # aggregate results across tp ranks. See here: https://github.com/pytorch/pytorch/issues/76232 + if self.pretraining_tp > 1 and self.slow_but_exact: + slices = self.hidden_size / self.pretraining_tp + output_tensor = torch.zeros_like(context_layer) + for i in range(self.pretraining_tp): + output_tensor = output_tensor + torch.nn.functional.linear( + context_layer[:, :, int(i * slices) : int((i + 1) * slices)], + self.dense.weight[:, int(i * slices) : int((i + 1) * slices)], + ) + else: + output_tensor = self.dense(context_layer) + + output_tensor = torch.nn.functional.dropout(output_tensor, p=self.hidden_dropout, training=self.training) + output_tensor = residual + output_tensor + + if use_cache is True: + present = ( + key_layer.reshape(-1, *key_layer.shape[2:]).transpose(1, 2), + value_layer.reshape(-1, *value_layer.shape[2:]), + ) + else: + present = None + + return (output_tensor, present) diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index 4bcc057373..b64b7f5a1e 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -216,6 +216,8 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): self.dropout_prob_attn = config.attention_dropout self.module_mapping = None + self.layer_idx = getattr(layer, "layer_idx", None) + submodules = ["query_key_value", "dense", "attention_dropout"] for attr in submodules: setattr(self, attr, getattr(layer, attr)) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 3e11c7e614..d4b15b2968 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -338,27 +338,31 @@ class BloomOnnxConfig(TextDecoderOnnxConfig): ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES DUMMY_PKV_GENERATOR_CLASS = BloomDummyPastKeyValuesGenerator NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_layers="n_layer", num_attention_heads="n_head") + DEFAULT_ONNX_OPSET = 14 # Bloom uses aten::triu that requires opset>=14, and F.scaled_dot_product_attention def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): - if direction not in ["inputs", "outputs"]: - raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - - if direction == "inputs": - decoder_sequence_name = "past_sequence_length" - name = "past_key_values" + if check_if_transformers_greater("4.44"): + super().add_past_key_values(inputs_or_outputs, direction) else: - decoder_sequence_name = "past_sequence_length + 1" - name = "present" + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') - for i in range(self._normalized_config.num_layers): - inputs_or_outputs[f"{name}.{i}.key"] = { - 0: "batch_size x num_heads", - 2: decoder_sequence_name, - } - inputs_or_outputs[f"{name}.{i}.value"] = { - 0: "batch_size x num_heads", - 1: decoder_sequence_name, - } + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = { + 0: "batch_size x num_heads", + 2: decoder_sequence_name, + } + inputs_or_outputs[f"{name}.{i}.value"] = { + 0: "batch_size x num_heads", + 1: decoder_sequence_name, + } class GPTBigCodeOnnxConfig(TextDecoderWithPositionIdsOnnxConfig): diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 6a0dcbba2f..f6d4b7e20a 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -336,8 +336,7 @@ def prepare_past_key_values( dtype = constructor.float16 if self.use_fp16 else constructor.float32 # TODO: find a way to better handle this controlflow, this is EXTREMELY UGLY. - # "1" is the dummy sequence length - if self.model_type == "bloom": + if self.__class__.__name__ == "ORTBloomForCausalLM": shape_value = (batch_size * num_attention_heads, 0, embed_size_per_head) shape_key = (batch_size * num_attention_heads, embed_size_per_head, 0) key = constructor.zeros(shape_key, dtype=dtype) @@ -354,9 +353,9 @@ def prepare_past_key_values( for name, value in zip(self.key_value_output_names, past_key_values): shape = [*value.shape] index = 1 if "value" in name else 2 - shape[index] += sequence_length pkv_output_shape[name] = shape + elif self.model_type == "gpt_bigcode": # GPT BigCode uses muti-query attention, and has the specificity of putting both key and value in the same cache tensor. shape_key_and_value = (batch_size, 0, embed_size_per_head * 2) @@ -371,9 +370,9 @@ def prepare_past_key_values( shape = [*value.shape] shape[1] += sequence_length pkv_output_shape[name] = shape + else: num_key_value_heads = self.num_key_value_heads if self.model_type == "falcon" else num_attention_heads - shape = (batch_size, num_key_value_heads, 0, embed_size_per_head) key_or_value = constructor.zeros(shape, dtype=dtype) @@ -534,9 +533,9 @@ def _from_pretrained( # Since https://github.com/huggingface/optimum/pull/871/ # changed axis notation/naming during export, we need to update the dims - for dim in input_dims.keys(): - if "past" in dim and input_dims[dim][2] == "past_sequence_length + sequence_length": - input_dims[dim][2] = "past_sequence_length" + for input_name in input_dims.keys(): + if "past" in input_name and input_dims[input_name][2] == "past_sequence_length + sequence_length": + input_dims[input_name][2] = "past_sequence_length" override_dims = True if override_dims: @@ -559,6 +558,12 @@ def _from_pretrained( size_threshold=0, ) + # Since transformers 4.44, the bloom model has been updated to use the standard cache format + use_old_bloom_modeling = not check_if_transformers_greater("4.44") + for input_name in input_dims.keys(): + if input_dims[input_name][0] == "batch_size x num_heads": + use_old_bloom_modeling = True + del onnx_model model = ORTModel.load_model( @@ -568,7 +573,7 @@ def _from_pretrained( provider_options=provider_options, ) - if config.model_type == "bloom": + if config.model_type == "bloom" and use_old_bloom_modeling: init_cls = ORTBloomForCausalLM elif config.model_type == "falcon": init_cls = ORTFalconForCausalLM diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index 36913f652a..dac14a3811 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -22,6 +22,7 @@ import numpy as np from transformers.utils import is_tf_available, is_torch_available +from ..utils import check_if_transformers_greater from .normalized_config import ( NormalizedConfig, NormalizedEncoderDecoderConfig, @@ -1026,23 +1027,26 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int class BloomDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): - past_key_shape = ( - self.batch_size * self.num_attention_heads, - self.hidden_size // self.num_attention_heads, - self.sequence_length, - ) - past_value_shape = ( - self.batch_size * self.num_attention_heads, - self.sequence_length, - self.hidden_size // self.num_attention_heads, - ) - return [ - ( - self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), - self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + if check_if_transformers_greater("4.44"): + return super().generate(input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype) + else: + past_key_shape = ( + self.batch_size * self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + self.sequence_length, ) - for _ in range(self.num_layers) - ] + past_value_shape = ( + self.batch_size * self.num_attention_heads, + self.sequence_length, + self.hidden_size // self.num_attention_heads, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] class MultiQueryPastKeyValuesGenerator(DummyPastKeyValuesGenerator): diff --git a/optimum/utils/preprocessing/token_classification.py b/optimum/utils/preprocessing/token_classification.py index 1c59aa2285..64a0bf2da8 100644 --- a/optimum/utils/preprocessing/token_classification.py +++ b/optimum/utils/preprocessing/token_classification.py @@ -28,7 +28,7 @@ class TokenClassificationProcessing(TaskProcessor): ACCEPTED_PREPROCESSOR_CLASSES = (PreTrainedTokenizerBase,) - DEFAULT_DATASET_ARGS = "conll2003" + DEFAULT_DATASET_ARGS = {"path": "conll2003", "trust_remote_code": True} DEFAUL_DATASET_DATA_KEYS = {"primary": "tokens"} ALLOWED_DATA_KEY_NAMES = {"primary"} DEFAULT_REF_KEYS = ["ner_tags", "pos_tags", "chunk_tags"] diff --git a/setup.py b/setup.py index 2e8c9489a8..3ac4315321 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ REQUIRED_PKGS = [ "coloredlogs", "sympy", - "transformers[sentencepiece]>=4.29.0,<4.44.0", + "transformers[sentencepiece]>=4.29,<4.45.0", "torch>=1.11", "packaging", "numpy<2.0", # transformers requires numpy<2.0 https://github.com/huggingface/transformers/pull/31569 @@ -24,6 +24,7 @@ ] # TODO: unpin pytest once https://github.com/huggingface/transformers/pull/29154 is merged & released +# pytest>=8.0.0 also fails with the transformers version pinned for exporters-tf TESTS_REQUIRE = [ "accelerate", "pytest<=8.0.0", @@ -72,7 +73,7 @@ "timm", "h5py", "numpy<1.24.0", - "transformers[sentencepiece]>=4.26.0,<4.38.0", + "transformers[sentencepiece]>=4.26,<4.38", ], "diffusers": ["diffusers"], "intel": "optimum-intel>=1.18.0", @@ -80,9 +81,9 @@ "nncf": "optimum-intel[nncf]>=1.18.0", "neural-compressor": "optimum-intel[neural-compressor]>=1.18.0", "ipex": "optimum-intel[ipex]>=1.18.0", - "habana": ["optimum-habana", "transformers >= 4.43.0, < 4.44.0"], - "neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], - "neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers >= 4.36.2, < 4.42.0"], + "habana": ["optimum-habana", "transformers>=4.43.0,<4.44.0"], + "neuron": ["optimum-neuron[neuron]>=0.0.20", "transformers>=4.36.2,<4.42.0"], + "neuronx": ["optimum-neuron[neuronx]>=0.0.20", "transformers>=4.36.2,<4.42.0"], "graphcore": "optimum-graphcore", "furiosa": "optimum-furiosa", "amd": "optimum-amd", diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index 42340d3b3a..bab8f376fc 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -23,7 +23,6 @@ from optimum.bettertransformer import BetterTransformer from optimum.utils import ( - BloomDummyPastKeyValuesGenerator, DummyPastKeyValuesGenerator, NormalizedConfigManager, ) @@ -136,10 +135,7 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config) - if model_type == "bloom": - pkv_generator_class = BloomDummyPastKeyValuesGenerator - else: - pkv_generator_class = DummyPastKeyValuesGenerator + pkv_generator_class = DummyPastKeyValuesGenerator pkv_generator = pkv_generator_class( task="", normalized_config=normalized_config, batch_size=batch_size, sequence_length=seq_length