diff --git a/requirements/runtime.txt b/requirements/runtime.txt index f531754b3..05b38c1b7 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -19,5 +19,8 @@ tiktoken torch<=2.1.2 torchvision<=0.16.2 # Minimum 4.36.0 to support `Cache` data structure used by KV Cache -transformers>=4.36.0 +# Registering a causal mask in `LlamaModel` is not friendly for very large +# `max_position_embeddings`. Refer to +# https://github.com/huggingface/transformers/blob/v4.38.0/src/transformers/models/llama/modeling_llama.py#L921-L923 +transformers>=4.36.0,!=4.38.0,!=4.38.1,!=4.38.2 transformers_stream_generator diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index ab104a7dc..7da62ac0e 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -305,12 +305,12 @@ def dispatch_modules(model, use_varlen_attn=False): dispatch_internlm2_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm2_rmsnorm_forward(model) - # replace_internlm2_rote(model) + replace_internlm2_rote(model) elif 'internlm' in model_name: dispatch_internlm_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm_rmsnorm_forward(model) - # replace_internlm_rote(model) + replace_internlm_rote(model) elif 'llama' in model_name: dispatch_llama_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index be9290451..c9febf34f 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -234,16 +234,11 @@ def llama_attn_forward_legacy( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions + # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 if 'padding_mask' in kwargs: warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37' - ' Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - - output_attentions = False + 'Passing `padding_mask` is deprecated and will be removed in ' + 'v4.37. Please make sure use `attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() @@ -251,9 +246,6 @@ def llama_attn_forward_legacy( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -263,6 +255,13 @@ def llama_attn_forward_legacy( kv_seq_len = key_states.shape[-2] if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + 'The cache structure has changed since version v4.36. ' + f'If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, ' + 'please make sure to initialize the attention class ' + 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) assert position_ids is not None @@ -282,10 +281,6 @@ def llama_attn_forward_legacy( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - assert SUPPORT_FLASH2 query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index ac86a8cca..a2a88b4cc 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -206,13 +206,6 @@ def _build_from_cfg_or_module(self, cfg_or_mod): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) - if SUPPORT_FLASH2: - cfg_or_mod.torch_dtype = torch.bfloat16 \ - if torch.cuda.is_bf16_supported() else torch.float16 - cfg_or_mod.attn_implementation = 'flash_attention_2' - if max_position_embeddings is not None: - cfg_or_mod = self._prepare_for_long_context_training( - cfg_or_mod, max_position_embeddings) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError @@ -265,4 +258,4 @@ def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: - return getattr(self.llm, name) + return getattr(self.llm, name) \ No newline at end of file