Skip to content

Commit

Permalink
Support Llama 2 in BetterTransformer. (#1235)
Browse files Browse the repository at this point in the history
* Support Llama 2 in BetterTransformer.

* add test for Llama with GQA

* fix formating issues

* fix style

---------

Co-authored-by: Félix Marty <[email protected]>
  • Loading branch information
noamwies and fxmarty authored Aug 1, 2023
1 parent c7bd911 commit 637805d
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
49 changes: 43 additions & 6 deletions optimum/bettertransformer/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import torch
from transformers.models.llama.modeling_llama import _expand_mask as _llama_expand_mask
from transformers.models.llama.modeling_llama import _make_causal_mask as _llama_make_causal_mask
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv


def raise_on_head_mask(head_mask: Optional[torch.Tensor]):
Expand Down Expand Up @@ -577,9 +577,35 @@ def llama_forward(

bsz, q_len, _ = hidden_states.size()

query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
torch.nn.functional.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
torch.nn.functional.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
torch.nn.functional.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)

else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)

query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
Expand All @@ -595,6 +621,10 @@ def llama_forward(

past_key_value = (key_states, value_states) if use_cache else None

# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

if bsz == 1 or self.training:
# BEWARE: at this stage, attention_mask is not the same as in transformers llama
if query_states.shape[2] > 1:
Expand All @@ -617,10 +647,17 @@ def llama_forward(
query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum(
[torch.nn.functional.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]
)
else:
attn_output = self.o_proj(attn_output)

if not output_attentions:
attn_weights = None
Expand Down
1 change: 1 addition & 0 deletions tests/bettertransformer/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"hubert": "ybelkada/hubert-tiny-random",
"layoutlm": "hf-internal-testing/tiny-random-LayoutLMModel",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
"llama-gqa": "noamwies/llama-test-gqa-with-better-transformer",
"m2m_100": "hf-internal-testing/tiny-random-nllb",
"marian": "fxmarty/tiny-marian", # the other tiny ones have a too small max_position_embeddings
"markuplm": "hf-internal-testing/tiny-random-MarkupLMModel",
Expand Down

0 comments on commit 637805d

Please sign in to comment.