diff --git a/README.md b/README.md index 0bfa3a02d..352e7215e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ 🔍 Explore our models on [![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤗%20Huggingface)](https://huggingface.co/xtuner) [![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤖%20ModelScope)](https://www.modelscope.cn/organization/xtuner) +[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🧰%20OpenXLab)](https://openxlab.org.cn/usercenter/xtuner) English | [įŽ€äŊ“中文](README_zh-CN.md) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5ccada772..f531754b3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -18,8 +18,6 @@ tiktoken # limit pytorch version <= 2.1.2 as there may be some bugs in triton 2.2 torch<=2.1.2 torchvision<=0.16.2 -# Minimum 4.34.0 to support added_tokens_decoder of tokenizer -# Exclude 4.34.1, 4.35.0, 4.35.1, 4.35.2 to avoid BC-break, -# see https://github.com/huggingface/transformers/pull/27020, https://github.com/huggingface/transformers/pull/27073 -transformers>=4.34.0,!=4.34.1,!=4.35.0,!=4.35.1,!=4.35.2 +# Minimum 4.36.0 to support `Cache` data structure used by KV Cache +transformers>=4.36.0 transformers_stream_generator diff --git a/xtuner/engine/hooks/evaluate_chat_hook.py b/xtuner/engine/hooks/evaluate_chat_hook.py index efa1bc69f..8e6a86822 100644 --- a/xtuner/engine/hooks/evaluate_chat_hook.py +++ b/xtuner/engine/hooks/evaluate_chat_hook.py @@ -29,7 +29,8 @@ def __init__(self, every_n_iters=None, max_new_tokens=600, stop_word=None, - stop_words=[]): + stop_words=[], + generation_kwargs={}): self.evaluation_inputs = evaluation_inputs if isinstance(self.evaluation_inputs, str): self.evaluation_inputs = [self.evaluation_inputs] @@ -69,8 +70,9 @@ def __init__(self, if image_processor is not None: self.image_processor = BUILDER.build(image_processor) self.stop_criteria = StoppingCriteriaList() + # default generation config - self.gen_config = GenerationConfig( + default_generation_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, temperature=0.1, @@ -79,8 +81,10 @@ def __init__(self, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else - self.tokenizer.eos_token_id, - ) + self.tokenizer.eos_token_id) + default_generation_kwargs.update(generation_kwargs) + self.gen_config = GenerationConfig(**default_generation_kwargs) + self.stop_criteria = StoppingCriteriaList() for word in stop_words: self.stop_criteria.append( diff --git a/xtuner/engine/hooks/throughput_hook.py b/xtuner/engine/hooks/throughput_hook.py index cf31414c0..a07e216fe 100644 --- a/xtuner/engine/hooks/throughput_hook.py +++ b/xtuner/engine/hooks/throughput_hook.py @@ -1,11 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from typing import Optional, Union import torch +from mmengine import print_log from mmengine.hooks import Hook from mmengine.model.wrappers import is_model_wrapper from torch.utils._pytree import tree_flatten +from xtuner.parallel.sequence import get_sequence_parallel_world_size + DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -20,12 +24,39 @@ def __init__(self, hidden_size=None, num_layers=None, vocab_size=None, - mlp_ratio=None): + mlp_ratio=None, + is_casual=None): self.use_activation_checkpointing = use_activation_checkpointing self.hidden_size = hidden_size self.num_layers = num_layers self.vocab_size = vocab_size self.mlp_ratio = mlp_ratio + self.is_casual = is_casual + + @staticmethod + def _guess_is_casual_attn(model): + for module in model.modules(): + if hasattr(module, 'is_causal'): + return module.is_causal + print_log( + 'It\'s impossible to speculate whether casual attention was used, ' + 'and FLOPs will be calculated as `casual = True`.', 'current') + return True + + @staticmethod + def _get_batch_size_and_sequence_len(data_batch): + data_list, _ = tree_flatten(data_batch) + for data in data_list: + if isinstance(data, torch.Tensor): + return data.size(0), data.size(1) + raise RuntimeError('No tensor found in the batch') + + @staticmethod + def _guess_use_activation_checkpointing(model): + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing'): + return module.gradient_checkpointing + return False def before_run(self, runner) -> None: if is_model_wrapper(runner.model): @@ -41,20 +72,18 @@ def before_run(self, runner) -> None: self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size / model.config.hidden_size) self.mlp_ratio *= 1.5 # has gate_proj - return + self.is_casual = self.is_casual if self.is_casual is not None \ + else self._guess_is_casual_attn(model) - def _get_batch_size_and_sequence_len(self, data_batch): - data_list, _ = tree_flatten(data_batch) - for data in data_list: - if isinstance(data, torch.Tensor): - return data.size(0), data.size(1) - raise RuntimeError('No tensor found in the batch') + use_varlen_attn = getattr(model, 'use_varlen_attn', False) + if use_varlen_attn: + print_log( + 'Using variable-length Flash Attention causes an inflation' + ' in the FLOPs calculation.', + 'current', + level=logging.WARNING) - def _guess_use_activation_checkpointing(self, model): - for module in model.modules(): - if hasattr(module, 'gradient_checkpointing'): - return module.gradient_checkpointing - return False + return def after_train_iter(self, runner, @@ -66,17 +95,50 @@ def after_train_iter(self, batch_size, sequence_len = self._get_batch_size_and_sequence_len( data_batch) + sequence_parallel_size = get_sequence_parallel_world_size() message_hub = runner.message_hub iter_time = message_hub.get_scalar('train/time').current() - flops_per_iteration = ( - (3 + int(self.use_activation_checkpointing)) * - ((8 + self.mlp_ratio * 4) * batch_size * sequence_len * - self.hidden_size**2 + - 4 * batch_size * sequence_len**2 * self.hidden_size) - ) * self.num_layers + \ - 6 * batch_size * sequence_len * self.hidden_size * self.vocab_size + # We consider a language model with 𝑙 transformer layers, + # hidden size h, sequence length s, vocabulary size V, and + # training batch size B. + # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 Ã—đ‘˜ Ã—đ‘› FLOPs + # (factor of 2 needed to account for multiplies and adds). + + # Attention Layer: + # qkv_proj + o_proj: 8B * s * h^2 + # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True) + + # MLP Layer: + # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio + # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5 + # (has gate_proj)) + + # The backward pass requires double the number of FLOPs since we + # need to calculate the gradients with respect to both input and + # weight tensors. In addition, we are using activation recomputation, + # which requires an additional forward pass before the backward pass. + + # While sequence parallel will affect the FLOPs calculation in attn. + # Suppose the sequence length in one GPU is s and the sequence + # parallel world size is `sp_size`, which means the total + # sequence length in the attention calculation is + # `s * sp_size` and the number of attention heads decrease to + # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is: + # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and + # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True) + + flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2 + flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \ + sequence_parallel_size / (int(self.is_casual) + 1) + flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \ + self.hidden_size**2 + flops_wo_head = (3 + int(self.use_activation_checkpointing)) * ( + flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers + flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \ + self.vocab_size + flops_per_iteration = flops_wo_head + flops_head avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12) tokens_per_sec_per_gpu = batch_size * sequence_len / ( diff --git a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py index f2b23d3fe..f7a95a09c 100644 --- a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py +++ b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Mapping, Optional, Sequence, Union +from typing import Optional, Union -import torch import torch.distributed as dist from mmengine import MessageHub from mmengine.hooks import Hook @@ -11,20 +10,6 @@ class VarlenAttnArgsToMessageHubHook(Hook): - args = ('cumulative_len', 'max_seqlen') - - def cast_data(self, data): - if isinstance(data, Mapping): - return {key: self.cast_data(data[key]) for key in data} - elif isinstance(data, (str, bytes)) or data is None: - return data - elif isinstance(data, Sequence): - return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable - elif isinstance(data, torch.Tensor): - return data.cuda() - else: - return data - def before_train_iter(self, runner, batch_idx: int, @@ -35,10 +20,13 @@ def before_train_iter(self, assert 'data' in data_batch.keys() data = data_batch['data'] - for arg in self.args: - assert arg in data - message_hub.update_info(f'{arg}_rank_{rank}', - self.cast_data(data.pop(arg))) + cumulative_len = data.pop('cumulative_len') + assert len(cumulative_len) == 1 + cumulative_len = cumulative_len[0].cuda() + message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len) + + max_seqlen = data.pop('max_seqlen') + message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen) def after_train_iter(self, runner, @@ -47,6 +35,5 @@ def after_train_iter(self, outputs: Optional[dict] = None) -> None: rank = dist.get_rank() message_hub = MessageHub.get_instance('varlen_attn_args') - - for arg in self.args: - message_hub.update_info(f'{arg}_rank_{rank}', None) + message_hub.update_info(f'cumulative_len_rank_{rank}', None) + message_hub.update_info(f'max_seqlen_rank_{rank}', None) diff --git a/xtuner/model/llava.py b/xtuner/model/llava.py index d7e39a804..19b427a75 100644 --- a/xtuner/model/llava.py +++ b/xtuner/model/llava.py @@ -1,13 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from collections import OrderedDict +import torch import torch.nn as nn from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training +from transformers import AutoConfig from xtuner.registry import BUILDER from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, @@ -26,11 +30,15 @@ def __init__(self, projector_depth=2, llm_lora=None, visual_encoder_lora=None, - use_activation_checkpointing=True): + use_activation_checkpointing=True, + max_position_embeddings=None): super().__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = freeze_visual_encoder with LoadWoInit(): + if isinstance(llm, dict): + llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) + self.llm = self._build_from_cfg_or_module(llm) self.visual_encoder = self._build_from_cfg_or_module( visual_encoder) @@ -157,6 +165,62 @@ def state_dict(self, *args, **kwargs): for k, v in state_dict.items() if 'projector.' in k}) return to_return + @staticmethod + def _prepare_for_long_context_training(cfg, llm_cfg, + max_position_embeddings): + + orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) + if orig_rope_scaling is None: + orig_rope_scaling = {'factor': 1} + + orig_rope_scaling_factor = orig_rope_scaling[ + 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 + orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) + if orig_ctx_len: + orig_ctx_len *= orig_rope_scaling_factor + if max_position_embeddings > orig_ctx_len: + scaling_factor = float( + math.ceil(max_position_embeddings / orig_ctx_len)) + llm_cfg.rope_scaling = { + 'type': 'linear', + 'factor': scaling_factor + } + + # hardcode for internlm2 + llm_cfg.attn_implementation = 'flash_attention_2' + cfg.config = llm_cfg + + return cfg, llm_cfg + + @staticmethod + def _prepare_for_flash_attn(cfg, llm_cfg): + cls_name = type(llm_cfg).__name__ + SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', + 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + + if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch.bfloat16 \ + if torch.cuda.is_bf16_supported() else torch.float16 + cfg.attn_implementation = 'flash_attention_2' + elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: + cfg.attn_implementation = 'sdpa' + + return cfg, llm_cfg + + def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): + pretrained_model_name_or_path = cfg.pretrained_model_name_or_path + llm_cfg = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True) + cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) + if max_position_embeddings is not None: + cfg, llm_cfg = self._prepare_for_long_context_training( + cfg, llm_cfg, max_position_embeddings) + return cfg + def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 6fbe37fb6..ab104a7dc 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -13,7 +13,7 @@ from .yi import yi_attn_forward IS_LOW_VERSION_TRANSFORMERS = digit_version( - transformers.__version__) < digit_version('4.36') + transformers.__version__) < digit_version('4.38') SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.0.0') SUPPORT_FLASH2 = False @@ -48,7 +48,7 @@ def dispatch_llama_attn_forward(model, use_varlen_attn): if use_varlen_attn: assert SUPPORT_FLASH2 and SUPPORT_TRITON, \ 'flash_attn and triton is required if you want to use varlen_attn.' - elif not SUPPORT_FLASH: + elif not SUPPORT_FLASH2: return from .llama import (llama_attn_forward, llama_attn_forward_legacy, @@ -57,8 +57,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2', - 'LlamaSdpaAttention'): + # Do not need to dispatch if + # type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is + # required when using sequence parallel + if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'): if use_varlen_attn: print_log('dispatch llama varlen attn forward', 'current') if IS_LOW_VERSION_TRANSFORMERS: diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index 8aa664ad5..93a43229e 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -156,24 +156,11 @@ def flash_attn_w_mask( return attn_output -def flash_attn1_pytorch(query_states, key_states, value_states, *args, - **kwargs): - # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query_states, key_states, - value_states, *args, **kwargs) - attn_output = attn_output.transpose(1, 2) - return attn_output - - @sequence_parallel_wrapper def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - attn_output = flash_attn_varlen_func( q_unpad, k_unpad, @@ -251,12 +238,12 @@ def internlm2_attn_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # flash attn 2 need (bs, seq_len, nhead, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if SUPPORT_FLASH2: + # flash attn 2 need (bs, seq_len, nhead, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + causal = self.is_causal and q_len != 1 if attention_mask is not None: @@ -276,12 +263,10 @@ def internlm2_attn_forward( training=self.training) else: # use flash attention implemented by pytorch - attn_output = flash_attn1_pytorch( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - training=self.training) + # do not support sequence parallel + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) @@ -309,7 +294,6 @@ def internlm2_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index 27b1f33d6..be9290451 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from mmengine import MessageHub -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from transformers.utils import logging +from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, + repeat_kv) +from transformers.utils import is_flash_attn_greater_or_equal_2_10 from xtuner.parallel.sequence import sequence_parallel_wrapper from .triton_kernels import apply_rotary_emb @@ -30,34 +30,6 @@ class Cache: pass -logger = logging.get_logger(__name__) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, - repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, - num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)""" @@ -113,24 +85,15 @@ def flash_attn_w_mask( return attn_output -def flash_attn1_pytorch(query_states, key_states, value_states, *args, - **kwargs): - # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query_states, key_states, - value_states, *args, **kwargs) - attn_output = attn_output.transpose(1, 2) - return attn_output - - @sequence_parallel_wrapper -def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, - max_seqlen): +def varlen_flash_attn(query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=0.): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) attn_output = flash_attn_varlen_func( q_unpad, k_unpad, @@ -139,7 +102,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, cumulative_len, max_seqlen, max_seqlen, - 0, + dropout_p=dropout_rate, return_attn_probs=False, causal=True, ) @@ -147,58 +110,29 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, return attn_output -def llama_attn_forward_legacy( +def llama_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 +): + # Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501 + output_attentions = False - if 'padding_mask' in kwargs: - warnings.warn('Passing `padding_mask` is deprecated and will be ' - 'removed in v4.37. Please make sure use ' - '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -206,77 +140,90 @@ def llama_attn_forward_legacy( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + cos, sin) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = getattr(self, 'past_key_value', past_key_value) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_value is not None: + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) - # repeat kv for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # flash attn 2 need (bs, seq_len, nhead, h_dim) + assert SUPPORT_FLASH2 query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if SUPPORT_FLASH2: - causal = self.is_causal and q_len != 1 + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - if attention_mask is not None: - attn_output = flash_attn_w_mask( - query_states, - key_states, - value_states, - attention_mask, - causal, - training=self.training) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal, - training=self.training) + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + if is_flash_attn_greater_or_equal_2_10(): + causal = self.is_causal else: - # use flash attention implemented by pytorch - attn_output = flash_attn1_pytorch( + # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm + # is bumped to 2.1. For details, please see the comment in + # LlamaFlashAttention2 __init__. + causal = self.is_causal and q_len != 1 + + if attention_mask is not None: + attn_output = flash_attn_w_mask( query_states, key_states, value_states, - attn_mask=attention_mask, + attention_mask, + causal, + dropout_rate, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal, + dropout_rate, training=self.training) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value + return attn_output, attn_weights, past_key_value -def llama_attn_forward( +def llama_attn_forward_legacy( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -318,9 +265,8 @@ def llama_attn_forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + assert position_ids is not None if self.training: - assert position_ids is not None cos, sin = self.rotary_emb( value_states, seq_len=position_ids.max() + 1) else: @@ -333,42 +279,42 @@ def llama_attn_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention - # requires the layout [batch_size, sequence_length, num_heads, head_dim]. - # We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + + assert SUPPORT_FLASH2 query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - dropout_rate = self.attention_dropout if self.training else 0.0 - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons, therefore the input hidden states gets silently + # stability reasons therefore the input hidden states gets silently # casted in float32. Hence, we need cast them back in the correct dtype # just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, '_pre_quantization_dtype'): + elif hasattr(self.config, '_pre_quantization_dtype'): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype - logger.warning_once( - f'The input hidden states seems to be silently casted in float32, ' - f'this might be related to the fact you have upcasted embedding ' - f'or layer norm layers in float32. We will cast back the input in' - f' {target_dtype}.') - query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - # flash attn - if not self._flash_attn_uses_top_left_mask: + dropout_rate = self.attention_dropout if self.training else 0.0 + + if is_flash_attn_greater_or_equal_2_10(): causal = self.is_causal else: # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm @@ -376,10 +322,6 @@ def llama_attn_forward( # LlamaFlashAttention2 __init__. causal = self.is_causal and q_len != 1 - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - if attention_mask is not None: attn_output = flash_attn_w_mask( query_states, @@ -401,20 +343,21 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value -def llama_varlen_attn_forward_legacy( +def llama_varlen_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -423,7 +366,6 @@ def llama_varlen_attn_forward_legacy( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -433,82 +375,70 @@ def llama_varlen_attn_forward_legacy( '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) + self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) + self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + past_key_value = getattr(self, 'past_key_value', past_key_value) - kv_seq_len = key_states.shape[-3] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) - if is_training: - cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, - cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states, kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + dropout_rate = self.attention_dropout if self.training else 0.0 - past_key_value = (key_states, value_states) if use_cache else None - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently casted + # in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) assert SUPPORT_FLASH2 if is_training: - attn_output = varlen_flash_attn(query_states, key_states, value_states, - cumulative_len, max_seqlen) + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=dropout_rate) else: attn_output = flash_attn_wo_mask( query_states, @@ -518,26 +448,12 @@ def llama_varlen_attn_forward_legacy( training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) - - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. return attn_output, None, past_key_value -def llama_varlen_attn_forward( +def llama_varlen_attn_forward_legacy( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -553,7 +469,6 @@ def llama_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -563,38 +478,9 @@ def llama_varlen_attn_forward( '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -616,11 +502,12 @@ def llama_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, - cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) + # position_ids (1, seq_len) + # cos, sin (1, seq_len, dim) -> (seq_len, dim) + cos = cos[position_ids].squeeze(0) + sin = sin[position_ids].squeeze(0) + query_states = apply_rotary_emb(query_states, cos, sin) + key_states = apply_rotary_emb(key_states, cos, sin) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -642,31 +529,50 @@ def llama_varlen_attn_forward( key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently casted + # in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + assert SUPPORT_FLASH2 if is_training: - attn_output = varlen_flash_attn(query_states, key_states, value_states, - cumulative_len, max_seqlen) + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=dropout_rate) else: attn_output = flash_attn_wo_mask( query_states, key_states, value_states, causal=True, + dropout_rate=dropout_rate, training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) # Due to the implementation of the PyTorch version of flash attention, # even when the output_attentions flag is set to True, it is not possible diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py index 5355bce74..4cfa26cd1 100644 --- a/xtuner/model/modules/dispatch/utils.py +++ b/xtuner/model/modules/dispatch/utils.py @@ -25,6 +25,7 @@ def upad_qkv(query_layer, key_layer, value_layer, attention_mask, indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 7aa0ec63c..ac86a8cca 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -17,7 +17,7 @@ reduce_sequence_parallel_loss) from xtuner.registry import BUILDER from .modules import dispatch_modules -from .modules.dispatch import SUPPORT_FLASH2 +from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, make_inputs_require_grad, traverse_dict) @@ -78,8 +78,9 @@ def __init__(self, max_position_embeddings=None): super().__init__() with LoadWoInit(): - self.llm = self._build_from_cfg_or_module(llm, - max_position_embeddings) + if isinstance(llm, dict): + llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) + self.llm = self._build_from_cfg_or_module(llm) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -144,37 +145,63 @@ def _prepare_for_lora(self, def init_weights(self): pass - def _prepare_for_long_context_training(self, cfg, max_position_embeddings): - pretrained_model_name_or_path = cfg.pretrained_model_name_or_path - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=True) + @staticmethod + def _prepare_for_long_context_training(cfg, llm_cfg, + max_position_embeddings): - orig_rope_scaling = getattr(config, 'rope_scaling', None) + orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) if orig_rope_scaling is None: orig_rope_scaling = {'factor': 1} orig_rope_scaling_factor = orig_rope_scaling[ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 - orig_ctx_len = getattr(config, 'max_position_embeddings', None) + orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) if orig_ctx_len: orig_ctx_len *= orig_rope_scaling_factor if max_position_embeddings > orig_ctx_len: scaling_factor = float( math.ceil(max_position_embeddings / orig_ctx_len)) - config.rope_scaling = { + llm_cfg.rope_scaling = { 'type': 'linear', 'factor': scaling_factor } # hardcode for internlm2 - config.attn_implementation = 'flash_attention_2' - - cfg.config = config + llm_cfg.attn_implementation = 'flash_attention_2' + cfg.config = llm_cfg + + return cfg, llm_cfg + + @staticmethod + def _prepare_for_flash_attn(cfg, llm_cfg): + cls_name = type(llm_cfg).__name__ + SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', + 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + + if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch.bfloat16 \ + if torch.cuda.is_bf16_supported() else torch.float16 + cfg.attn_implementation = 'flash_attention_2' + elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: + cfg.attn_implementation = 'sdpa' + + return cfg, llm_cfg + + def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): + pretrained_model_name_or_path = cfg.pretrained_model_name_or_path + llm_cfg = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True) + cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) + if max_position_embeddings is not None: + cfg, llm_cfg = self._prepare_for_long_context_training( + cfg, llm_cfg, max_position_embeddings) return cfg - def _build_from_cfg_or_module(self, - cfg_or_mod, - max_position_embeddings=None): + def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): diff --git a/xtuner/parallel/sequence/data_collate.py b/xtuner/parallel/sequence/data_collate.py index f61b481b9..15b242d73 100644 --- a/xtuner/parallel/sequence/data_collate.py +++ b/xtuner/parallel/sequence/data_collate.py @@ -59,6 +59,9 @@ def pad_for_sequence_parallel(tokens, def split_for_sequence_parallel(tokens, labels=None, position_ids=None): seq_parallel_world_size = get_sequence_parallel_world_size() + if seq_parallel_world_size == 1: + return tokens, labels, position_ids + seq_parallel_world_rank = get_sequence_parallel_rank() seq_len = tokens.size(1) assert seq_len % seq_parallel_world_size == 0 diff --git a/xtuner/parallel/sequence/setup_distributed.py b/xtuner/parallel/sequence/setup_distributed.py index ea207bf10..9eb159e66 100644 --- a/xtuner/parallel/sequence/setup_distributed.py +++ b/xtuner/parallel/sequence/setup_distributed.py @@ -59,8 +59,11 @@ def get_sequence_parallel_world_size(): global _SEQUENCE_PARALLEL_WORLD_SIZE if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: return _SEQUENCE_PARALLEL_WORLD_SIZE - _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( - group=get_sequence_parallel_group()) + if not dist.is_initialized(): + _SEQUENCE_PARALLEL_WORLD_SIZE = 1 + else: + _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_sequence_parallel_group()) return _SEQUENCE_PARALLEL_WORLD_SIZE @@ -69,8 +72,11 @@ def get_sequence_parallel_rank(): global _SEQUENCE_PARALLEL_RANK if _SEQUENCE_PARALLEL_RANK is not None: return _SEQUENCE_PARALLEL_RANK - _SEQUENCE_PARALLEL_RANK = dist.get_rank( - group=get_sequence_parallel_group()) + if not dist.is_initialized(): + _SEQUENCE_PARALLEL_RANK = 0 + else: + _SEQUENCE_PARALLEL_RANK = dist.get_rank( + group=get_sequence_parallel_group()) return _SEQUENCE_PARALLEL_RANK @@ -86,8 +92,11 @@ def get_data_parallel_world_size(): global _DATA_PARALLEL_WORLD_SIZE if _DATA_PARALLEL_WORLD_SIZE is not None: return _DATA_PARALLEL_WORLD_SIZE - _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( - group=get_data_parallel_group()) + if not dist.is_initialized(): + _DATA_PARALLEL_WORLD_SIZE = 1 + else: + _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_data_parallel_group()) return _DATA_PARALLEL_WORLD_SIZE @@ -96,5 +105,8 @@ def get_data_parallel_rank(): global _DATA_PARALLEL_RANK if _DATA_PARALLEL_RANK is not None: return _DATA_PARALLEL_RANK - _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) + if not dist.is_initialized(): + _DATA_PARALLEL_RANK = 0 + else: + _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) return _DATA_PARALLEL_RANK diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 7acbbf21f..23e3d2a3f 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -19,6 +19,7 @@ from xtuner.configs import cfgs_name_path from xtuner.dataset.collate_fns import default_collate_fn from xtuner.model.modules import dispatch_modules +from xtuner.model.modules.dispatch import SUPPORT_FLASH2 from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict from xtuner.registry import BUILDER, MAP_FUNC from xtuner.tools.utils import (auto_dtype_of_deepspeed_config, @@ -100,6 +101,10 @@ def check_cfg(cfg): f'max_length = {max_length} and sequence_parallel = ' f'{sequence_parallel}') + if getattr(cfg, 'sequence_parallel_size', 1) > 1: + assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use ' + 'sequence parallel.') + def main(): args = parse_args() diff --git a/xtuner/version.py b/xtuner/version.py index 11029f49f..ae73ce92a 100644 --- a/xtuner/version.py +++ b/xtuner/version.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '0.1.16.dev0' +__version__ = '0.1.16' short_version = __version__