From 61b94511b74a97b5b2c022ed18467a16450464e1 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 29 Mar 2024 18:22:16 +0800 Subject: [PATCH 1/2] [Fix] dispatch internlm rote (#530) dispatch internlm rote --- xtuner/model/modules/dispatch/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index ab104a7dc..7da62ac0e 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -305,12 +305,12 @@ def dispatch_modules(model, use_varlen_attn=False): dispatch_internlm2_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm2_rmsnorm_forward(model) - # replace_internlm2_rote(model) + replace_internlm2_rote(model) elif 'internlm' in model_name: dispatch_internlm_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm_rmsnorm_forward(model) - # replace_internlm_rote(model) + replace_internlm_rote(model) elif 'llama' in model_name: dispatch_llama_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: From 0b5708c49948355ae1406fb99ec6764417aee3fa Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 29 Mar 2024 18:22:28 +0800 Subject: [PATCH 2/2] Limit transformers != 4.38 (#531) limit transformers != 4.38 --- requirements/runtime.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index f531754b3..05b38c1b7 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -19,5 +19,8 @@ tiktoken torch<=2.1.2 torchvision<=0.16.2 # Minimum 4.36.0 to support `Cache` data structure used by KV Cache -transformers>=4.36.0 +# Registering a causal mask in `LlamaModel` is not friendly for very large +# `max_position_embeddings`. Refer to +# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923 +transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2 transformers_stream_generator