Skip to content

Commit

Permalink
[BetterTransformer] Add falcon to BetterTransformer (#1343)
Browse files Browse the repository at this point in the history
* add falcon to BT

* add falcon to BT

* update
  • Loading branch information
younesbelkada authored Sep 4, 2023
1 parent 5663aae commit 659cf02
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 1 deletion.
4 changes: 4 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BlenderbotAttentionLayerBetterTransformer,
BloomAttentionLayerBetterTransformer,
CodegenAttentionLayerBetterTransformer,
FalconAttentionLayerBetterTransformer,
GPT2AttentionLayerBetterTransformer,
GPTBigCodeAttentionLayerBetterTransformer,
GPTJAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -77,6 +78,7 @@ class BetterTransformerManager:
"electra": {"ElectraLayer": BertLayerBetterTransformer},
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"falcon": {"FalconAttention": FalconAttentionLayerBetterTransformer},
"gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer},
"gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer},
"gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer},
Expand Down Expand Up @@ -150,6 +152,7 @@ class BetterTransformerManager:
"opt",
"pegasus",
"t5",
"falcon",
}

NOT_REQUIRES_STRICT_VALIDATION = {
Expand All @@ -166,6 +169,7 @@ class BetterTransformerManager:
"opt",
"pegasus",
"t5",
"falcon",
}

@staticmethod
Expand Down
94 changes: 94 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,97 @@ def bloom_forward(
present = None

return (output_tensor, present)


def falcon_forward(
self,
hidden_states: torch.Tensor,
alibi: Optional[torch.Tensor],
attention_mask: torch.Tensor,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
head_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
output_attentions: bool = False,
):
fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
# 3 x [batch_size, seq_length, num_heads, head_dim]
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)

if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

batch_size, query_length, _, _ = query_layer.shape

query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
key_layer = key_layer.transpose(1, 2).reshape(
batch_size * num_kv_heads,
query_length,
self.head_dim,
)
value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)

past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)

if layer_past is not None:
past_key, past_value = layer_past
# concatenate along seq_length dimension:
# - key: [batch_size * self.num_heads, kv_length, head_dim]
# - value: [batch_size * self.num_heads, kv_length, head_dim]
key_layer = torch.cat((past_key, key_layer), dim=1)
value_layer = torch.cat((past_value, value_layer), dim=1)

if use_cache is True:
present = (key_layer, value_layer)
else:
present = None

attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)

query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)

if alibi is None:
if batch_size == 1 or self.training:
if query_length > 1:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=True
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attn_mask=None, dropout_p=0.0, is_causal=False
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
)

attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
attn_output = attn_output.permute(0, 2, 1, 3)
attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)

output_tensor = self.dense(attn_output)

return output_tensor, present

else:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
alibi = torch.masked_fill(alibi, attention_mask, torch.finfo(alibi.dtype).min)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=alibi,
dropout_p=self.attention_dropout if self.training else 0.0,
)
context_layer = context_layer.transpose(1, 2)

# change view [batch_size, q_length, num_heads * head_dim]
context_layer = self._merge_heads(context_layer)

output_tensor = self.dense(context_layer)

return output_tensor, present
27 changes: 27 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,18 @@
else:
from ...utils.dummy_bettertransformer_objects import BarkSelfAttention

if check_if_transformers_greater("4.32"):
from transformers.models.falcon.modeling_falcon import FalconAttention
else:
from ...utils.dummy_bettertransformer_objects import FalconAttention


from .attention import (
bark_wrapped_scaled_dot_product,
bart_forward,
bloom_forward,
codegen_wrapped_scaled_dot_product,
falcon_forward,
gpt2_wrapped_scaled_dot_product,
gpt_bigcode_forward,
gpt_bigcode_wrapped_scaled_dot_product,
Expand Down Expand Up @@ -236,6 +243,26 @@ def forward(self, *args, **kwargs):
return bloom_forward(self, *args, **kwargs)


class FalconAttentionLayerBetterTransformer(BetterTransformerBaseLayer, FalconAttention, nn.Module):
def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):
super().__init__(config)

with torch.device("meta"):
super(BetterTransformerBaseLayer, self).__init__(config)

self.dropout_prob_attn = config.attention_dropout

self.module_mapping = None
submodules = ["query_key_value", "dense", "attention_dropout", "maybe_rotary"]
for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}

def forward(self, *args, **kwargs):
return falcon_forward(self, *args, **kwargs)


class CodegenAttentionLayerBetterTransformer(BetterTransformerBaseLayer, CodeGenAttention, nn.Module):
_attn = codegen_wrapped_scaled_dot_product

Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
DummyTimestepInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
)
from .modeling_utils import recurse_getattr, recurse_setattr
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ def require_numpy_strictly_lower(version: str, message: str):
"transformers_431",
(lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")),
),
(
"transformers_432",
(lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")),
),
]
)

Expand Down
26 changes: 26 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,29 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
)
for _ in range(self.num_layers)
]


class FalconDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
self.num_kv_heads = 1
head_dim = self.hidden_size // self.num_attention_heads

past_key_shape = (
self.batch_size,
self.num_kv_heads,
self.sequence_length,
head_dim,
)
past_value_shape = (
self.batch_size,
self.num_kv_heads,
self.sequence_length,
head_dim,
)
return [
(
self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
1 change: 1 addition & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class NormalizedConfigManager:
"blenderbot": BartLikeNormalizedTextConfig,
"blenderbot_small": BartLikeNormalizedTextConfig,
"bloom": NormalizedTextConfig.with_args(num_layers="n_layer"),
"falcon": NormalizedTextConfig.with_args(num_layers="num_hidden_layers", num_attention_heads="num_kv_heads"),
"camembert": NormalizedTextConfig,
"codegen": GPT2LikeNormalizedTextConfig,
"cvt": NormalizedVisionConfig,
Expand Down
17 changes: 16 additions & 1 deletion tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,26 @@
from optimum.utils import (
BloomDummyPastKeyValuesGenerator,
DummyPastKeyValuesGenerator,
FalconDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
NormalizedConfigManager,
)
from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
SUPPORTED_ARCH = ["bloom", "codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"]
SUPPORTED_ARCH = [
"bloom",
"codegen",
"gpt2",
"gpt_bigcode",
"gptj",
"gpt_neo",
"gpt_neox",
"falcon",
"llama",
"opt",
]

FULL_GRID = {
"model_type": SUPPORTED_ARCH,
Expand Down Expand Up @@ -133,12 +145,15 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in
pkv_generator_class = GPTBigCodeDummyPastKeyValuesGenerator
elif model_type == "bloom":
pkv_generator_class = BloomDummyPastKeyValuesGenerator
elif model_type == "falcon":
pkv_generator_class = FalconDummyPastKeyValuesGenerator
else:
pkv_generator_class = DummyPastKeyValuesGenerator

pkv_generator = pkv_generator_class(
task="", normalized_config=normalized_config, batch_size=batch_size, sequence_length=seq_length
)

past_key_values = pkv_generator.generate(input_name="past_key_values")

result_vanilla = model(input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"distilbert": "hf-internal-testing/tiny-random-DistilBertModel",
"electra": "hf-internal-testing/tiny-random-ElectraModel",
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
"falcon": "Rocketknight1/tiny-random-falcon-7b",
"fsmt": "hf-internal-testing/tiny-random-FSMTModel",
"gpt2": "hf-internal-testing/tiny-random-GPT2Model",
# NOTE: this tiny model does not use attention_softmax_in_fp32=True (contrary to e.g. starcoder)
Expand Down

0 comments on commit 659cf02

Please sign in to comment.