diff --git a/optimum/bettertransformer/models/__init__.py b/optimum/bettertransformer/models/__init__.py index 7ef029bbdd..286e90231f 100644 --- a/optimum/bettertransformer/models/__init__.py +++ b/optimum/bettertransformer/models/__init__.py @@ -20,6 +20,7 @@ BlenderbotAttentionLayerBetterTransformer, BloomAttentionLayerBetterTransformer, CodegenAttentionLayerBetterTransformer, + FalconAttentionLayerBetterTransformer, GPT2AttentionLayerBetterTransformer, GPTBigCodeAttentionLayerBetterTransformer, GPTJAttentionLayerBetterTransformer, @@ -77,6 +78,7 @@ class BetterTransformerManager: "electra": {"ElectraLayer": BertLayerBetterTransformer}, "ernie": {"ErnieLayer": BertLayerBetterTransformer}, "fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer}, + "falcon": {"FalconAttention": FalconAttentionLayerBetterTransformer}, "gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer}, "gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer}, "gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer}, @@ -150,6 +152,7 @@ class BetterTransformerManager: "opt", "pegasus", "t5", + "falcon", } NOT_REQUIRES_STRICT_VALIDATION = { @@ -166,6 +169,7 @@ class BetterTransformerManager: "opt", "pegasus", "t5", + "falcon", } @staticmethod diff --git a/optimum/bettertransformer/models/attention.py b/optimum/bettertransformer/models/attention.py index 86ed1cce6e..702aca3257 100644 --- a/optimum/bettertransformer/models/attention.py +++ b/optimum/bettertransformer/models/attention.py @@ -899,3 +899,97 @@ def bloom_forward( present = None return (output_tensor, present) + + +def falcon_forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, +): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + if output_attentions is True: + raise ValueError("output_attentions=True can not be supported with BetterTransformer.") + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape( + batch_size * num_kv_heads, + query_length, + self.head_dim, + ) + value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim) + + past_kv_length = 0 if layer_past is None else layer_past[0].shape[1] + query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length) + + if layer_past is not None: + past_key, past_value = layer_past + # concatenate along seq_length dimension: + # - key: [batch_size * self.num_heads, kv_length, head_dim] + # - value: [batch_size * self.num_heads, kv_length, head_dim] + key_layer = torch.cat((past_key, key_layer), dim=1) + value_layer = torch.cat((past_value, value_layer), dim=1) + + if use_cache is True: + present = (key_layer, value_layer) + else: + present = None + + attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype) + + query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) + key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) + + if alibi is None: + if batch_size == 1 or self.training: + if query_length > 1: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=False + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False + ) + + attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) + attn_output = attn_output.permute(0, 2, 1, 3) + attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + + output_tensor = self.dense(attn_output) + + return output_tensor, present + + else: + alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:]) + alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=alibi, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + context_layer = context_layer.transpose(1, 2) + + # change view [batch_size, q_length, num_heads * head_dim] + context_layer = self._merge_heads(context_layer) + + output_tensor = self.dense(context_layer) + + return output_tensor, present diff --git a/optimum/bettertransformer/models/decoder_models.py b/optimum/bettertransformer/models/decoder_models.py index fc23e1b9b2..47aacd17b6 100644 --- a/optimum/bettertransformer/models/decoder_models.py +++ b/optimum/bettertransformer/models/decoder_models.py @@ -44,11 +44,18 @@ else: from ...utils.dummy_bettertransformer_objects import BarkSelfAttention +if check_if_transformers_greater("4.32"): + from transformers.models.falcon.modeling_falcon import FalconAttention +else: + from ...utils.dummy_bettertransformer_objects import FalconAttention + + from .attention import ( bark_wrapped_scaled_dot_product, bart_forward, bloom_forward, codegen_wrapped_scaled_dot_product, + falcon_forward, gpt2_wrapped_scaled_dot_product, gpt_bigcode_forward, gpt_bigcode_wrapped_scaled_dot_product, @@ -236,6 +243,26 @@ def forward(self, *args, **kwargs): return bloom_forward(self, *args, **kwargs) +class FalconAttentionLayerBetterTransformer(BetterTransformerBaseLayer, FalconAttention, nn.Module): + def __init__(self, layer: "nn.Module", config: "PretrainedConfig"): + super().__init__(config) + + with torch.device("meta"): + super(BetterTransformerBaseLayer, self).__init__(config) + + self.dropout_prob_attn = config.attention_dropout + + self.module_mapping = None + submodules = ["query_key_value", "dense", "attention_dropout", "maybe_rotary"] + for attr in submodules: + setattr(self, attr, getattr(layer, attr)) + + self.original_layers_mapping = {submodule: submodule for submodule in submodules} + + def forward(self, *args, **kwargs): + return falcon_forward(self, *args, **kwargs) + + class CodegenAttentionLayerBetterTransformer(BetterTransformerBaseLayer, CodeGenAttention, nn.Module): _attn = codegen_wrapped_scaled_dot_product diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 36e60e89a2..09a968eeec 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -58,6 +58,7 @@ DummyTimestepInputGenerator, DummyVisionEmbeddingsGenerator, DummyVisionInputGenerator, + FalconDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, ) from .modeling_utils import recurse_getattr, recurse_setattr diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 351477abbd..899f3e504f 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -192,6 +192,10 @@ def require_numpy_strictly_lower(version: str, message: str): "transformers_431", (lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")), ), + ( + "transformers_432", + (lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")), + ), ] ) diff --git a/optimum/utils/input_generators.py b/optimum/utils/input_generators.py index f69bb39848..a11101ccbd 100644 --- a/optimum/utils/input_generators.py +++ b/optimum/utils/input_generators.py @@ -858,3 +858,29 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) for _ in range(self.num_layers) ] + + +class FalconDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + self.num_kv_heads = 1 + head_dim = self.hidden_size // self.num_attention_heads + + past_key_shape = ( + self.batch_size, + self.num_kv_heads, + self.sequence_length, + head_dim, + ) + past_value_shape = ( + self.batch_size, + self.num_kv_heads, + self.sequence_length, + head_dim, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] diff --git a/optimum/utils/normalized_config.py b/optimum/utils/normalized_config.py index e65c3c42d6..30bbec030a 100644 --- a/optimum/utils/normalized_config.py +++ b/optimum/utils/normalized_config.py @@ -211,6 +211,7 @@ class NormalizedConfigManager: "blenderbot": BartLikeNormalizedTextConfig, "blenderbot_small": BartLikeNormalizedTextConfig, "bloom": NormalizedTextConfig.with_args(num_layers="n_layer"), + "falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"), "camembert": NormalizedTextConfig, "codegen": GPT2LikeNormalizedTextConfig, "cvt": NormalizedVisionConfig, diff --git a/tests/bettertransformer/test_decoder.py b/tests/bettertransformer/test_decoder.py index b06ebf5879..c3b9178125 100644 --- a/tests/bettertransformer/test_decoder.py +++ b/tests/bettertransformer/test_decoder.py @@ -25,6 +25,7 @@ from optimum.utils import ( BloomDummyPastKeyValuesGenerator, DummyPastKeyValuesGenerator, + FalconDummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, NormalizedConfigManager, ) @@ -32,7 +33,18 @@ class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase): - SUPPORTED_ARCH = ["bloom", "codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"] + SUPPORTED_ARCH = [ + "bloom", + "codegen", + "gpt2", + "gpt_bigcode", + "gptj", + "gpt_neo", + "gpt_neox", + "falcon", + "llama", + "opt", + ] FULL_GRID = { "model_type": SUPPORTED_ARCH, @@ -133,12 +145,15 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in pkv_generator_class = GPTBigCodeDummyPastKeyValuesGenerator elif model_type == "bloom": pkv_generator_class = BloomDummyPastKeyValuesGenerator + elif model_type == "falcon": + pkv_generator_class = FalconDummyPastKeyValuesGenerator else: pkv_generator_class = DummyPastKeyValuesGenerator pkv_generator = pkv_generator_class( task="", normalized_config=normalized_config, batch_size=batch_size, sequence_length=seq_length ) + past_key_values = pkv_generator.generate(input_name="past_key_values") result_vanilla = model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values) diff --git a/tests/bettertransformer/testing_utils.py b/tests/bettertransformer/testing_utils.py index e5699d0e5b..113c59f63c 100644 --- a/tests/bettertransformer/testing_utils.py +++ b/tests/bettertransformer/testing_utils.py @@ -43,6 +43,7 @@ "distilbert": "hf-internal-testing/tiny-random-DistilBertModel", "electra": "hf-internal-testing/tiny-random-ElectraModel", "ernie": "hf-internal-testing/tiny-random-ErnieModel", + "falcon": "Rocketknight1/tiny-random-falcon-7b", "fsmt": "hf-internal-testing/tiny-random-FSMTModel", "gpt2": "hf-internal-testing/tiny-random-GPT2Model", # NOTE: this tiny model does not use attention_softmax_in_fp32=True (contrary to e.g. starcoder)