Skip to content

Commit

Permalink
Support gpt_bigcode in bettertransformer (#1252)
Browse files Browse the repository at this point in the history
support gpt_bigcode in bettertransformer
  • Loading branch information
fxmarty authored Aug 4, 2023
1 parent e1be6e8 commit 393113f
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 13 deletions.
4 changes: 4 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BlenderbotAttentionLayerBetterTransformer,
CodegenAttentionLayerBetterTransformer,
GPT2AttentionLayerBetterTransformer,
GPTBigCodeAttentionLayerBetterTransformer,
GPTJAttentionLayerBetterTransformer,
GPTNeoAttentionLayerBetterTransformer,
GPTNeoXAttentionLayerBetterTransformer,
Expand Down Expand Up @@ -68,6 +69,7 @@ class BetterTransformerManager:
"ernie": {"ErnieLayer": BertLayerBetterTransformer},
"fsmt": {"EncoderLayer": FSMTEncoderLayerBetterTransformer},
"gpt2": {"GPT2Attention": GPT2AttentionLayerBetterTransformer},
"gpt_bigcode": {"GPTBigCodeAttention": GPTBigCodeAttentionLayerBetterTransformer},
"gptj": {"GPTJAttention": GPTJAttentionLayerBetterTransformer},
"gpt_neo": {"GPTNeoSelfAttention": GPTNeoAttentionLayerBetterTransformer},
"gpt_neox": {"GPTNeoXAttention": GPTNeoXAttentionLayerBetterTransformer},
Expand Down Expand Up @@ -130,6 +132,7 @@ class BetterTransformerManager:
"blenderbot",
"codegen",
"gpt2",
"gpt_bigcode",
"gptj",
"gpt_neo",
"gpt_neox",
Expand All @@ -144,6 +147,7 @@ class BetterTransformerManager:
"blip-2",
"codegen",
"gpt2",
"gpt_bigcode",
"gptj",
"gpt_neo",
"gpt_neox",
Expand Down
136 changes: 136 additions & 0 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv


# TODO (CRITICAL): Layer-wise attention scaling is broken for several archs (see a fix in gpt_bigcode_wrapped_scaled_dot_product).


def raise_on_head_mask(head_mask: Optional[torch.Tensor]):
if head_mask is not None:
raise ValueError(
Expand Down Expand Up @@ -663,3 +666,136 @@ def llama_forward(
attn_weights = None

return attn_output, attn_weights, past_key_value


def gpt_bigcode_wrapped_scaled_dot_product(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
):
raise_on_head_mask(head_mask)

# TODO: remove once PyTorch 2.1 is released with the scale argument to SDPA
if self.scale_attn_weights:
softmax_dtype = torch.float32 if self.attention_softmax_in_fp32 else query.dtype
if self.scale_attention_softmax_in_fp32 and query.dtype != softmax_dtype:
query = query / (self.layer_idx + 1)
else:
query = query / self.head_dim**0.5

# MQA models: (batch_size, query_length, num_heads * head_dim)
# MHA models: (batch_size, num_heads, query_length, head_dim)
query_shape = query.shape
batch_size = query_shape[0]

if self.multi_query:
query_length = query_shape[1]

# NOTE: Maybe there is better than this?
query = query.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
else:
raise NotImplementedError(
"BetterTransformer integration with GPT BigCode without Multi-Query Attention (MQA) has not been implemented. Please open an issue or PR at https://github.com/huggingface/optimum."
)

dropout_p = self.dropout_prob_attn if self.training else 0.0

# I did not find how to avoid these unsqueeze, SDPA complains otherwise.
key = key.unsqueeze(1)
value = value.unsqueeze(1)

if batch_size == 1 or self.training:
if query_length > 1:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
)
else:
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
)
else:
if attention_mask is not None:
mask_value = self._get_mask_value(query.device, query.dtype)

# gpt_bigcode has the bad taste to use a causal mask a
# [batch_size, target_length, 1, source_length] which is different from
# **all** other architectures and not compatible with SDPA.
# We could avoid this transpose by overriding the forward from GPTBigCodeModel,
# but it is probably not worth it.
attention_mask = attention_mask.transpose(1, 2)
attention_mask = torch.where(attention_mask, 0.0, mask_value)

sdpa_result = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
)

if self.multi_query:
# (batch_size, num_heads, seq_len, head_dim) --> (batch_size, seq_len, num_heads, head_dim)
sdpa_result = sdpa_result.transpose(1, 2)

# Reshape is kind of expensive here (as here it does a memory copy)
# but I did not manage to make away without it.
# (batch_size, seq_len, num_heads, head_dim) --> (batch_size, seq_len, num_heads * head_dim)
sdpa_result = sdpa_result.reshape(query_shape)

return sdpa_result


def gpt_bigcode_forward(
self,
hidden_states: torch.Tensor,
layer_past: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if output_attentions is True:
raise ValueError("output_attentions=True can not be supported with BetterTransformer.")

if encoder_hidden_states is not None:
if not hasattr(self, "q_attn") or not self.is_cross_attention:
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
"Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
)

query = self.q_attn(hidden_states)
key_value = self.c_attn(encoder_hidden_states)
attention_mask = encoder_attention_mask
elif self.multi_query:
query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
else:
# Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
# i.e., the memory layout is not the same as GPT2.
# This makes the concatenation with past_key_value more efficient.
query, key_value = (
self.c_attn(hidden_states)
.view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
.transpose(1, 2)
.split((self.head_dim, 2 * self.head_dim), dim=3)
)

if layer_past is not None:
key_value = torch.cat((layer_past, key_value), dim=-2)
present = key_value if use_cache else None

key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)

# Difference with the transformers implementation: there is no need to transpose the key here,
# as SDPA expects seq_length to be at index -2
attn_output = self._attn(query, key, value, attention_mask, head_mask)

if not self.multi_query:
attn_output = attn_output.transpose(1, 2).reshape(hidden_states.shape)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)

outputs = (attn_output, present)

return outputs
26 changes: 26 additions & 0 deletions optimum/bettertransformer/models/decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotAttention
from transformers.models.codegen.modeling_codegen import CodeGenAttention
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoSelfAttention
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
from transformers.models.gptj.modeling_gptj import GPTJAttention
Expand All @@ -35,6 +36,8 @@
bart_forward,
codegen_wrapped_scaled_dot_product,
gpt2_wrapped_scaled_dot_product,
gpt_bigcode_forward,
gpt_bigcode_wrapped_scaled_dot_product,
gpt_neo_wrapped_scaled_dot_product,
llama_forward,
opt_forward,
Expand Down Expand Up @@ -360,3 +363,26 @@ def __init__(self, layer: "nn.Module", config: "PretrainedConfig"):

def forward(self, *args, **kwargs):
return llama_forward(self, *args, **kwargs)


class GPTBigCodeAttentionLayerBetterTransformer(BetterTransformerBaseLayer, GPTBigCodeAttention):
_attn = gpt_bigcode_wrapped_scaled_dot_product

def __init__(self, layer: nn.Module, config: "PretrainedConfig"):
with torch.device("meta"):
super(BetterTransformerBaseLayer, self).__init__(config)

self.module_mapping = None
submodules = ["c_attn", "c_proj"]

if layer.is_cross_attention:
submodules.append("q_attn")

for attr in submodules:
setattr(self, attr, getattr(layer, attr))

self.original_layers_mapping = {submodule: submodule for submodule in submodules}
self.dropout_prob_attn = config.attn_pdrop

def forward(self, *args, **kwargs):
return gpt_bigcode_forward(self, *args, **kwargs)
11 changes: 1 addition & 10 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
NormalizedConfig,
NormalizedEncoderDecoderConfig,
NormalizedSeq2SeqConfig,
Expand Down Expand Up @@ -269,16 +270,6 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
}


class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_value_shape = (
self.batch_size,
self.sequence_length,
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]


class GPTBigCodeOnnxConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (
GPTBigCodeDummyPastKeyValuesGenerator,
Expand Down
1 change: 1 addition & 0 deletions optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
)
from .modeling_utils import recurse_getattr, recurse_setattr
from .normalized_config import (
Expand Down
10 changes: 10 additions & 0 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,13 @@ def __init__(
def generate(self, input_name: str, framework: str = "pt"):
shape = [self.batch_size, self.max_patches, self.flattened_patch_size]
return self.random_float_tensor(shape, framework=framework)


class GPTBigCodeDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def generate(self, input_name: str, framework: str = "pt"):
past_key_value_shape = (
self.batch_size,
self.sequence_length,
self.hidden_size // self.num_attention_heads * 2,
)
return [self.random_float_tensor(past_key_value_shape, framework=framework) for _ in range(self.num_layers)]
12 changes: 9 additions & 3 deletions tests/bettertransformer/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
from transformers import AutoModelForCausalLM, AutoTokenizer

from optimum.bettertransformer import BetterTransformer
from optimum.utils import DummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils import DummyPastKeyValuesGenerator, GPTBigCodeDummyPastKeyValuesGenerator, NormalizedConfigManager
from optimum.utils.testing_utils import grid_parameters, require_accelerate, require_torch_gpu


class BetterTransformersDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
SUPPORTED_ARCH = ["codegen", "gpt2", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"]
SUPPORTED_ARCH = ["codegen", "gpt2", "gpt_bigcode", "gptj", "gpt_neo", "gpt_neox", "llama", "opt"]

FULL_GRID = {
"model_type": SUPPORTED_ARCH,
Expand Down Expand Up @@ -123,7 +123,13 @@ def test_logits_with_cache(self, test_name: str, model_type: str, batch_size: in
model = AutoModelForCausalLM.from_pretrained(model_id)

normalized_config = NormalizedConfigManager.get_normalized_config_class(model.config.model_type)(model.config)
pkv_generator = DummyPastKeyValuesGenerator(

if model_type == "gpt_bigcode":
pkv_generator_class = GPTBigCodeDummyPastKeyValuesGenerator
else:
pkv_generator_class = DummyPastKeyValuesGenerator

pkv_generator = pkv_generator_class(
task="", normalized_config=normalized_config, batch_size=batch_size, sequence_length=seq_length
)
past_key_values = pkv_generator.generate(input_name="past_key_values")
Expand Down
2 changes: 2 additions & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
"ernie": "hf-internal-testing/tiny-random-ErnieModel",
"fsmt": "hf-internal-testing/tiny-random-FSMTModel",
"gpt2": "hf-internal-testing/tiny-random-GPT2Model",
# NOTE: this tiny model does not use attention_softmax_in_fp32=True (contrary to e.g. starcoder)
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM",
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
Expand Down

0 comments on commit 393113f

Please sign in to comment.