From 2b17497dc4ff7a67b94f5423428dec559f72b27e Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:19:23 -0800 Subject: [PATCH] Bump transformers to 4.38.2 (#1018) --- llmfoundry/models/layers/attention.py | 33 +++++++++++++++++++-------- setup.py | 2 +- tests/models/test_model.py | 2 -- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 281f41753a..8fa8e6bc66 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -619,19 +619,34 @@ def forward( value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim) elif rotary_emb_w_meta_info['impl'] == 'hf': - (cos, sin) = rotary_emb(value, seq_len) - if is_transformers_version_gte('4.36'): - query, key = apply_rotary_pos_emb(query, - key, - cos, - sin, - offset_info, + if is_transformers_version_gte('4.38'): + (cos, sin) = rotary_emb(x=value, + position_ids=offset_info, + seq_len=None) + else: + (cos, sin) = rotary_emb(x=value, seq_len=seq_len) + if is_transformers_version_gte('4.38'): + query, key = apply_rotary_pos_emb(q=query, + k=key, + cos=cos, + sin=sin, + position_ids=None, + unsqueeze_dim=2) + elif is_transformers_version_gte('4.36'): + query, key = apply_rotary_pos_emb(q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info, unsqueeze_dim=2) else: query = query.transpose(1, 2) key = key.transpose(1, 2) - query, key = apply_rotary_pos_emb(query, key, cos, sin, - offset_info) + query, key = apply_rotary_pos_emb(q=query, + k=key, + cos=cos, + sin=sin, + position_ids=offset_info) query = query.transpose(1, 2) key = key.transpose(1, 2) diff --git a/setup.py b/setup.py index d11cf02ca2..7534d24503 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ 'mosaicml[libcloud,wandb,oci,gcs]>=0.20.1,<0.21', 'mlflow>=2.10,<3', 'accelerate>=0.25,<0.26', # for HF inference `device_map` - 'transformers>=4.37,<4.38', + 'transformers>=4.38.2,<4.39', 'mosaicml-streaming>=0.7.4,<0.8', 'torch>=2.2.1,<2.3', 'datasets>=2.16,<2.17', diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 5a2d1cd212..4765c4003b 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -1172,12 +1172,10 @@ def test_generate(attention_impl: str, precision: str, pos_emb_config: dict, max_new_tokens=5, use_cache=True) _ = mpt.generate(input_ids=None, - inputs_embeds=None, max_new_tokens=5, use_cache=False, bos_token_id=50256) _ = mpt.generate(input_ids=None, - inputs_embeds=None, max_new_tokens=5, use_cache=True, bos_token_id=50256)