From 59834032c82d39994c13252aea9b00011d1b2457 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 22 Mar 2024 12:32:01 +0800 Subject: [PATCH 1/9] [Feature] Support Sequence parallel (#456) * support sequence * add configs * add sp example to custom dataset * WIP * add dispatch utils * delete useless codes * move xtuner/engine/sequence_parallel to xtuner/parallel/sequence * fix lint * fix lint * add init_dist to xtuner and add trust_remote_code=True to AutoConfig * add internlm2 custom_dataset sp4 config * Sequence Parallel doc V1 * Sequence Parallel doc V1 * Sequence Parallel doc V1 * fix bugs in llama_varlen_attn_forward * rename indexes to position_ids * add attn_implementation to config * delete useless codes * fix lint * refine default_collate_fn * refine doc * refine doc * refine doc * delete replace_internlm2_rote * add repeat_kv_bshd * fix apply_rotary_pos_emb bug * add enable_sequence_parallel flag * refine doc * assert {'input_ids', 'labels'}.issubset(dataset.column_names) * refine doc --- .pre-commit-config.yaml | 2 + docs/zh_cn/user_guides/sequence_parallel.md | 188 +++++++++ ...nlm2_7b_full_finetune_custom_dataset_e1.py | 1 + ...e_custom_dataset_e1_sequence_parallel_4.py | 220 ++++++++++ .../llama2_7b_full_pgbooks_400iters_sp1.py | 196 +++++++++ .../llama2_7b_full_pgbooks_400iters_sp4.py | 196 +++++++++ .../dataset/collate_fns/defalut_collate_fn.py | 32 +- xtuner/dataset/huggingface.py | 3 +- xtuner/dataset/intern_repo.py | 8 +- xtuner/dataset/samplers/intern_repo.py | 7 +- xtuner/dataset/utils.py | 14 +- xtuner/engine/_strategy/deepspeed.py | 15 + .../varlen_attn_args_to_messagehub_hook.py | 2 +- xtuner/model/modules/dispatch/__init__.py | 10 +- xtuner/model/modules/dispatch/internlm.py | 11 +- xtuner/model/modules/dispatch/internlm2.py | 203 +++++++-- xtuner/model/modules/dispatch/llama.py | 394 +++++++++++------- xtuner/model/modules/dispatch/mistral.py | 15 +- xtuner/model/modules/dispatch/utils.py | 64 +++ xtuner/model/sft.py | 68 ++- xtuner/parallel/sequence/__init__.py | 24 ++ xtuner/parallel/sequence/attention.py | 102 +++++ xtuner/parallel/sequence/data_collate.py | 75 ++++ xtuner/parallel/sequence/reduce_loss.py | 17 + xtuner/parallel/sequence/sampler.py | 38 ++ xtuner/parallel/sequence/setup_distributed.py | 100 +++++ xtuner/tools/train.py | 32 +- 27 files changed, 1804 insertions(+), 233 deletions(-) create mode 100644 docs/zh_cn/user_guides/sequence_parallel.md create mode 100644 xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py create mode 100644 xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py create mode 100644 xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py create mode 100644 xtuner/model/modules/dispatch/utils.py create mode 100644 xtuner/parallel/sequence/__init__.py create mode 100644 xtuner/parallel/sequence/attention.py create mode 100644 xtuner/parallel/sequence/data_collate.py create mode 100644 xtuner/parallel/sequence/reduce_loss.py create mode 100644 xtuner/parallel/sequence/sampler.py create mode 100644 xtuner/parallel/sequence/setup_distributed.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2fdba321..acfe43b66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,6 +12,7 @@ repos: rev: v0.32.0 hooks: - id: yapf + exclude: 'xtuner/parallel/sequence/__init__.py' - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: @@ -38,6 +39,7 @@ repos: - mdformat-openmmlab - mdformat_frontmatter - linkify-it-py + exclude: 'docs/zh_cn/user_guides/sequence_parallel.md' - repo: https://github.com/myint/docformatter rev: v1.3.1 hooks: diff --git a/docs/zh_cn/user_guides/sequence_parallel.md b/docs/zh_cn/user_guides/sequence_parallel.md new file mode 100644 index 000000000..ba29d2830 --- /dev/null +++ b/docs/zh_cn/user_guides/sequence_parallel.md @@ -0,0 +1,188 @@ +
+ +# 序列并行:训练极长序列大模型的系统优化 + +
+ +XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509),并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。 + +## 简介 + +从生成性AI到科研模型,长序列训练正在变得非常重要。 + +在生成性AI领域,会话式AI、长文档摘要、代码库理解和例如 Sora 这种视频生成任务都需要在空间和时间层面对长上下文进行推理。 + +对于科学AI来说,长序列同样至关重要,它为更好地理解结构生物学、医疗保健、气候和天气预测以及大分子模拟打开了大门。 + +然而,尽管序列长度的重要性不断增长,XTuner 现有的显存优化策略(如 zero 系列),却不足以解决大模型、长序列训练问题。 + +同时,受限于通信效率,现有的许多序列并行方法也不够高效。 + +另外,现有的序列并行方法普遍存在较多的代码侵入式修改,易用性和维护性都要大打折扣。同时也不满足 XTuner 基于 transformers 算法库或 Huggingface Hub 上的开源模型直接进行训练的要求。 + +
+ +

+
+ +为了解决上述长序列训练带来的问题,XTuner 采用了一种简单、易用且高效的序列并行算法。由于 Transformer 结构较为规整,除 attention 计算外,其他计算过程中 token 之间不会互相影响(即每个 token 的计算是独立的),这一条件为序列并行提供了有利条件。上图展示了序列并行的核心设计。设由 P 个 GPUs 共同计算一个长度为 N 的长序列,在 Attention 计算的第一阶段,长度为 N / P 的子序列会通过线性层投影为 Query、Key、Value。接下来, QKV Tensor 会在参与序列并行计算的多个 GPUs 之间通过高度优化的 all-to-all 通信算子汇聚,得到序列长度为 N ,但更少注意力头的子序列。注意力计算后,通过另一个 all-to-all 通信算子将其转换为长度为 N / P 的子序列,进行后续计算。 + +总体而言,XTuner 的序列并行算法具有以下关键特性: + +* 支持全量训练**超过百万个token**的序列 +* 支持百 B 级模型训练:XTuner 的序列并行不仅支持长序列训练,还可结合 zero3 显存优化策略训练大尺寸模型 +* 完全通用的序列并行 **API 抽象** + +## 使用 XTuner 进行序列并行训练 + +### Step 1 修改 config 文件 + +1. 在 config 中修改 `sequence_parallel_size` 字段即可调整 $sequence\\_parallel\\_world\\_size$ 。 +2. 同时若想保证与不使用序列并行的训练效果类似,需要同步增大梯度累积的数值为原来的 $sequence\\_parallel\\_world\\_size$ 倍,因为在使用序列并行训练时, $data\\_parallel\\_world\\_size$ 降为了原来的 $\frac{1}{sequence\\_parallel\\_world\\_size}$。 +3. 替换 DefaultSampler 为支持序列并行的 SequenceParallelSampler。 + +**注:需要保证所使用的 GPU 总数可以被 `sequence_parallel_size` 整除。** + +```diff ++ from xtuner.parallel.sequence import SequenceParallelSampler + +- sequence_parallel_size = 1 ++ sequence_parallel_size = 4 # take `sequence_parallel_size = 4`` as an example + +- accumulative_counts = 1 ++ accumulative_counts = 4 # accumulative_counts = accumulative_counts * sequence_parallel_size + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataloader = dict( +- sampler=dict(type=DefaultSampler, shuffle=True), ++ sampler=dict(type=SequenceParallelSampler, seed=1024, shuffle=True), + ...) +``` + +另外,若需要进一步拓展模型的长文本处理能力,需要进一步修改 config 中的 `max_position_embeddings` 字段。例如需要将模型的上下文长度拓展为 64K 时,可进行如下修改: + +```diff ++ max_position_embeddings = 65536 + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +model = dict( + type=SupervisedFinetune, ++ max_position_embeddings = max_position_embeddings, + ...) +``` + +### Step 2 开始训练 + +需要使用 DeepSpeed 进行训练: + +```bash +(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train ${CONFIG_PATH} --deepspeed deepspeed_zero2 +(SLURM) srun ${SRUN_ARGS} xtuner train ${CONFIG_PATH} --launcher slurm --deepspeed deepspeed_zero2 +``` + +- ${CONFIG_PATH} 为 Step 1 中修改得到的 config 文件路径 +- 可根据实际情况选择使用不同的 zero 策略 + +## 序列并行 API 抽象 + +为了提升算法的可迁移性,XTuner 中抽象出了序列并行所必须的五个 API 接口: +- 序列并行分布式环境初始化 (init_sequence_parallel) +- 适配序列并行的 Data Sampler (SequenceParallelSampler) +- 数据 Pad 与切分 (pad_for_sequence_parallel, split_for_sequence_parallel) +- 适配序列并行的 Attention (dispatch_modules) +- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss) + +### 序列并行分布式环境初始化 + +由于序列并行算法会将长序列切分为 $sequence\\_parallel\\_world\\_size$ 块,并将每个子序列分发给对应的 GPU 独立进行计算。因此需要在训练开始前初始化序列并行分布式环境,以指定哪几块 GPU 共同负责一个长序列输入的计算。 + +一个 $sequence\\_parallel\\_world\\_size = 4$ 的示例如下: + +```python +# We have to initialize the distributed training environment first. +# Here is an example when training on slurm scheduler +# from xtuner.parallel.sequence import init_dist +# init_dist('slurm', 'nccl', init_backend='deepspeed') +from xtuner.parallel.sequence import init_sequence_parallel +sequence_parallel_world_size = 4 +init_sequence_parallel(sequence_parallel_world_size) +``` + +上述过程在 xtuner/engine/_strategy/deepspeed.py 中实现。 + +### Data Sampler 适配序列并行 + +在使用序列并行后,Dataloader 的采样策略需要进一步调整。例如当 $sequence\\_parallel\\_world\\_size = 4$ 时,4 块 GPU 从 Dataloader 拿到的数据需要是完全一样的。 + +在构建 Dataloader 时搭配 XTuner 中提供的 SequenceParallelSampler 使用即可: + +```python +from xtuner.parallel.sequence import SequenceParallelSampler +dataloader = DataLoader( + train_dataset, sampler=SequenceParallelSampler(train_dataset), + **other_dataloader_params) +``` + +### 数据 Pad 与切分 + +由于每条训练数据的长度可能不尽相同,我们需要将数据进行 Pad 以使得序列长度可以被 $sequence\\_parallel\\_world\\_size$ 整除,这样一条长数据才能被均等地分发给不同的 GPU 上。 + +训练过程中需要被 Pad 的 Tensor 往往有 input_ids, labels, position_ids, attention_mask 四个,pad 的过程可以通过以下方式实现: + +```python +from xtuner.parallel.sequence import pad_for_sequence_parallel +input_ids, labels, position_ids, attention_mask = pad_for_sequence_parallel( + input_ids, labels, position_ids, attention_mask) +``` + +如果训练过程用不到 attention_mask,那么可以: + +```python +input_ids, labels, position_ids, _ = pad_for_sequence_parallel( + input_ids, labels, position_ids) +``` + +Pad 后,我们需要对长序列均等切分: + +```python +from xtuner.parallel.sequence import split_for_sequence_parallel +# attention mask should not be split +input_ids, labels, position_ids = split_for_sequence_parallel( + input_ids, labels, position_ids) +``` + +以上两步在 xtuner/dataset/collate_fns/defalut_collate_fn.py 中实现。 + +### Attention 适配序列并行 + +在 Attention 的计算过程中,序列中的不同 token 是不能独立运算的,但不同的 attention head 之间的计算却是独立的。因此,如[第一节](#简介)所述,需要在计算 Attention 前后(即 qkv_proj 后和 o_proj 前)分别插入一个 *all-to-all* 操作。 + +XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计算方式: + +```python +from xtuner.model.modules import dispatch_modules +model: AutoModelForCausalLM +dispatch_modules(model) +``` + +上述过程在 xtuner/model/sft.py 中实现。 + +### Reduce Loss 以正确打印训练损失 + +这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。 + +```python +from xtuner.parallel.sequence import reduce_sequence_parallel_loss +outputs = llm(input_ids=input_ids, labels=labels, **kwargs) +num_tokens_per_rank = (labels != -100).sum() +# Suppose sequence parallel world size equals to 4, +# losses on rank0, rank1, rank2, rank3 are different. +loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank) +# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same. +``` + +上述过程在 xtuner/model/sft.py 中实现。 diff --git a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py index 40c9b1692..58863f06b 100644 --- a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py +++ b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py @@ -42,6 +42,7 @@ # Model pretrained_model_name_or_path = 'internlm/internlm2-7b' use_varlen_attn = True +sequence_parallel_size = 1 # Data data_files = ['/path/to/json/file.json'] diff --git a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py new file mode 100644 index 000000000..aa7e4b014 --- /dev/null +++ b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py @@ -0,0 +1,220 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Data format: +[ + { + "conversation": [ + { + "system": "", + "input": "xxx", + "output": "xxx" + }, + { + "input": "xxx", + "output": "xxx" + } + ] + }, +... +] +Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details. +""" # noqa: E501 +from datasets import load_dataset +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR +from torch.optim import AdamW +from torch.utils.data import BatchSampler +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.dataset.map_fns import template_map_fn_factory +from xtuner.dataset.samplers import InternRepoSampler +from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'internlm/internlm2-7b' +use_varlen_attn = True +sequence_parallel_size = 4 + +# Data +data_files = ['/path/to/json/file.json'] +prompt_template = PROMPT_TEMPLATE.internlm2_chat +max_length = 32768 +pack_to_max_length = True + +# Scheduler & Optimizer +batch_size = 1 # per_device +# accumulative_counts = accumulative_counts * sequence_parallel_size +accumulative_counts = 4 +dataloader_num_workers = 4 +max_epochs = 1 +optim_type = AdamW +lr = 4e-5 +betas = (0.9, 0.95) +weight_decay = 0.01 +max_norm = 1 # grad clip +warm_up_ratio = 0.025 + +# Save +save_steps = 500 +save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) + +# Evaluate the generation performance during the training +evaluation_freq = 500 +SYSTEM = '' +evaluation_inputs = [ + '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' +] + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True)) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + use_varlen_attn=use_varlen_attn, + dataset=dict(type=load_dataset, path='json', data_files=data_files), + tokenizer=tokenizer, + max_length=max_length, + dataset_map_fn=None, + template_map_fn=dict( + type=template_map_fn_factory, template=prompt_template), + remove_unused_columns=True, + shuffle_before_pack=True, + pack_to_max_length=pack_to_max_length) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024), + batch_sampler=dict(type=BatchSampler, drop_last=True, batch_size=1), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic', +) + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1 / 40, + by_epoch=True, + begin=0, + end=warm_up_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=lr * 0.15, + by_epoch=True, + begin=warm_up_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict( + type=DatasetInfoHook, tokenizer=tokenizer, + is_intern_repo_dataset=True), + dict( + type=EvaluateChatHook, + tokenizer=tokenizer, + every_n_iters=evaluation_freq, + evaluation_inputs=evaluation_inputs, + system=SYSTEM, + prompt_template=prompt_template), + dict(type=ThroughputHook) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 100 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +log_processor = dict( + by_epoch=False, + window_size=1, + mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*') diff --git a/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py new file mode 100644 index 000000000..9c6e4fe05 --- /dev/null +++ b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datasets import load_dataset +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.engine.hooks import (DatasetInfoHook, ThroughputHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.parallel.sequence import SequenceParallelSampler +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-hf' +use_varlen_attn = False +sequence_parallel_size = 1 + +# Data +data_path = 'emozilla/pg_books-tokenized-bos-eos-chunked-65536' +data_files = [ + 'data/train-00000-of-00136-877a1768c20d5900.parquet', + 'data/train-00001-of-00136-70d7d139dca61754.parquet', + 'data/train-00002-of-00136-62d53594e098f3d8.parquet', + 'data/train-00003-of-00136-8bd300fecc4c720e.parquet', + 'data/train-00004-of-00136-2a9456b5f975ae95.parquet', + 'data/train-00005-of-00136-ca38cf7907bb7555.parquet', + 'data/train-00006-of-00136-1ae2e4c63f3966da.parquet', + 'data/train-00007-of-00136-a00cc39a4ee65ab6.parquet', +] +prompt_template = PROMPT_TEMPLATE.llama2_chat +max_length = 65536 +max_position_embeddings = 65536 +pack_to_max_length = False + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 8 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-5 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.05 + +# Save +save_steps = 500 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + max_position_embeddings=max_position_embeddings, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict( + type=load_dataset, + path=data_path, + data_files=data_files, + ignore_verifications=True), + do_dataset_tokenization=False, + remove_unused_columns=True, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=SequenceParallelSampler, seed=1024), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1 / 40, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=lr * 0.15, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=ThroughputHook) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict( + by_epoch=False, + window_size=1, + mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*') diff --git a/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py new file mode 100644 index 000000000..5e87acffe --- /dev/null +++ b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py @@ -0,0 +1,196 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from datasets import load_dataset +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) +from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR +from torch.optim import AdamW +from transformers import AutoModelForCausalLM, AutoTokenizer + +from xtuner.dataset import process_hf_dataset +from xtuner.dataset.collate_fns import default_collate_fn +from xtuner.engine.hooks import (DatasetInfoHook, ThroughputHook, + VarlenAttnArgsToMessageHubHook) +from xtuner.engine.runner import TrainLoop +from xtuner.model import SupervisedFinetune +from xtuner.parallel.sequence import SequenceParallelSampler +from xtuner.utils import PROMPT_TEMPLATE + +####################################################################### +# PART 1 Settings # +####################################################################### +# Model +pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-hf' +use_varlen_attn = False +sequence_parallel_size = 4 + +# Data +data_path = 'emozilla/pg_books-tokenized-bos-eos-chunked-65536' +data_files = [ + 'data/train-00000-of-00136-877a1768c20d5900.parquet', + 'data/train-00001-of-00136-70d7d139dca61754.parquet', + 'data/train-00002-of-00136-62d53594e098f3d8.parquet', + 'data/train-00003-of-00136-8bd300fecc4c720e.parquet', + 'data/train-00004-of-00136-2a9456b5f975ae95.parquet', + 'data/train-00005-of-00136-ca38cf7907bb7555.parquet', + 'data/train-00006-of-00136-1ae2e4c63f3966da.parquet', + 'data/train-00007-of-00136-a00cc39a4ee65ab6.parquet', +] +prompt_template = PROMPT_TEMPLATE.llama2_chat +max_length = 65536 +max_position_embeddings = 65536 +pack_to_max_length = False + +# Scheduler & Optimizer +batch_size = 1 # per_device +accumulative_counts = 32 +dataloader_num_workers = 0 +max_epochs = 1 +optim_type = AdamW +lr = 2e-5 +betas = (0.9, 0.999) +weight_decay = 0 +max_norm = 1 # grad clip +warmup_ratio = 0.05 + +# Save +save_steps = 500 +save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited) + +####################################################################### +# PART 2 Model & Tokenizer # +####################################################################### +tokenizer = dict( + type=AutoTokenizer.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + padding_side='right') + +model = dict( + type=SupervisedFinetune, + use_varlen_attn=use_varlen_attn, + max_position_embeddings=max_position_embeddings, + llm=dict( + type=AutoModelForCausalLM.from_pretrained, + pretrained_model_name_or_path=pretrained_model_name_or_path, + trust_remote_code=True, + torch_dtype=torch.bfloat16, + attn_implementation='flash_attention_2')) + +####################################################################### +# PART 3 Dataset & Dataloader # +####################################################################### +train_dataset = dict( + type=process_hf_dataset, + dataset=dict( + type=load_dataset, + path=data_path, + data_files=data_files, + ignore_verifications=True), + do_dataset_tokenization=False, + remove_unused_columns=True, + pack_to_max_length=pack_to_max_length, + use_varlen_attn=use_varlen_attn) + +train_dataloader = dict( + batch_size=batch_size, + num_workers=dataloader_num_workers, + dataset=train_dataset, + sampler=dict(type=SequenceParallelSampler, seed=1024), + collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) + +####################################################################### +# PART 4 Scheduler & Optimizer # +####################################################################### +# optimizer +optim_wrapper = dict( + type=AmpOptimWrapper, + optimizer=dict( + type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), + clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), + accumulative_counts=accumulative_counts, + loss_scale='dynamic') + +# learning policy +# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 +param_scheduler = [ + dict( + type=LinearLR, + start_factor=1 / 40, + by_epoch=True, + begin=0, + end=warmup_ratio * max_epochs, + convert_to_iter_based=True), + dict( + type=CosineAnnealingLR, + eta_min=lr * 0.15, + by_epoch=True, + begin=warmup_ratio * max_epochs, + end=max_epochs, + convert_to_iter_based=True) +] + +# train, val, test setting +train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) + +####################################################################### +# PART 5 Runtime # +####################################################################### +# Log the dialogue periodically during the training process, optional +custom_hooks = [ + dict(type=DatasetInfoHook, tokenizer=tokenizer), + dict(type=ThroughputHook) +] + +if use_varlen_attn: + custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] + +# configure default hooks +default_hooks = dict( + # record the time of every iteration. + timer=dict(type=IterTimerHook), + # print log every 10 iterations. + logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1), + # enable the parameter scheduler. + param_scheduler=dict(type=ParamSchedulerHook), + # save checkpoint per `save_steps`. + checkpoint=dict( + type=CheckpointHook, + save_optimizer=False, + by_epoch=False, + interval=save_steps, + max_keep_ckpts=save_total_limit), + # set sampler seed in distributed evrionment. + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set visualizer +visualizer = None + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) + +# set log processor +log_processor = dict( + by_epoch=False, + window_size=1, + mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*') diff --git a/xtuner/dataset/collate_fns/defalut_collate_fn.py b/xtuner/dataset/collate_fns/defalut_collate_fn.py index 294cd4870..f644df9cf 100644 --- a/xtuner/dataset/collate_fns/defalut_collate_fn.py +++ b/xtuner/dataset/collate_fns/defalut_collate_fn.py @@ -4,6 +4,9 @@ import torch from torch.nn.utils.rnn import pad_sequence +from xtuner.parallel.sequence import (get_sequence_parallel_world_size, + pad_for_sequence_parallel, + split_for_sequence_parallel) from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX @@ -11,13 +14,14 @@ def default_collate_fn(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX, return_hf_format: bool = False, use_varlen_attn: bool = False): + seq_parallel_world_size = get_sequence_parallel_world_size() input_ids, labels = [], [] has_image = any(inst.get('pixel_values') is not None for inst in instances) if use_varlen_attn: - cumulative_len, indexes = [], [] + position_ids, cumulative_len = [], [] assert len(instances) == 1, ( - f'If utilizing local attention, the batch size should be' + f'If utilizing varlen attention, the batch size should be' f' set to 1, but got {len(instances)}') assert not has_image, 'Currently, it is not configured to ' 'accommodate the use of varlen Attention in multimodal training' @@ -30,7 +34,7 @@ def default_collate_fn(instances: Sequence[Dict], labels.append(torch.LongTensor(example['labels'])) if use_varlen_attn: cumulative_len.append(torch.IntTensor(example['cumulative_len'])) - indexes.append(torch.LongTensor(example['indexes'])) + position_ids.append(torch.LongTensor(example['position_ids'])) if has_image: pixel_values.append(example['pixel_values']) @@ -45,21 +49,37 @@ def default_collate_fn(instances: Sequence[Dict], labels = torch.stack(labels) if use_varlen_attn: - indexes = torch.stack(indexes, dim=0) + assert input_ids.size(1) % seq_parallel_world_size == 0 + attention_mask = None + position_ids = torch.stack(position_ids, dim=0) + else: + attention_mask = input_ids.ne(pad_index) + position_ids = attention_mask.long().cumsum(-1) - 1 + + input_ids, labels, position_ids, attention_mask = \ + pad_for_sequence_parallel(input_ids, labels, position_ids, + attention_mask) + + # attention mask should not be split + input_ids, labels, position_ids = split_for_sequence_parallel( + input_ids, labels, position_ids) + + if use_varlen_attn: max_seqlen = ( cumulative_len[0][1:] - # noqa: W504 cumulative_len[0][:-1]).max().item() data_dict = { 'input_ids': input_ids, 'cumulative_len': cumulative_len, - 'indexes': indexes, + 'position_ids': position_ids, 'labels': labels, 'max_seqlen': max_seqlen } else: data_dict = { 'input_ids': input_ids, - 'attention_mask': input_ids.ne(pad_index), + 'attention_mask': attention_mask, + 'position_ids': position_ids, 'labels': labels } diff --git a/xtuner/dataset/huggingface.py b/xtuner/dataset/huggingface.py index 349cce2f6..30f6bc394 100644 --- a/xtuner/dataset/huggingface.py +++ b/xtuner/dataset/huggingface.py @@ -199,10 +199,9 @@ def process(dataset, dataset = tokenize_dataset(dataset, tokenizer, max_length, with_image_token, input_ids_with_output, remove_unused_columns, map_num_proc) - else: - assert {'input_ids', 'labels'}.issubset(dataset.column_names) if input_ids_with_output: + assert {'input_ids', 'labels'}.issubset(dataset.column_names) # remove data that does not have the valid labels. dataset = dataset.filter( lambda example: any(label >= 0 for label in example['labels']), diff --git a/xtuner/dataset/intern_repo.py b/xtuner/dataset/intern_repo.py index b1034cc31..95cd7cf99 100644 --- a/xtuner/dataset/intern_repo.py +++ b/xtuner/dataset/intern_repo.py @@ -191,7 +191,7 @@ def mapping(self, pack_idx: int = 0): def build_pack(self, begin_sample_idx: int, begin_token_id: int, end_sample_idx: int, end_token_id: int): - pack, cumulative_len, indexes, labels = [], [0], [], [] + pack, cumulative_len, position_ids, labels = [], [0], [], [] while begin_sample_idx < end_sample_idx: sample_idx = self.shuffled_indices[begin_sample_idx] @@ -202,7 +202,7 @@ def build_pack(self, begin_sample_idx: int, begin_token_id: int, assert len(_labels) == len(chunk), (_labels, chunk) labels.extend(_labels) cumulative_len.append(cumulative_len[-1] + len(chunk)) - indexes.extend(list(range(len(chunk)))) + position_ids.extend(list(range(len(chunk)))) begin_sample_idx = begin_sample_idx + 1 begin_token_id = 0 @@ -215,12 +215,12 @@ def build_pack(self, begin_sample_idx: int, begin_token_id: int, assert len(_labels) == len(chunk), (_labels, chunk) labels.extend(_labels) cumulative_len.append(cumulative_len[-1] + len(chunk)) - indexes.extend(list(range(len(chunk)))) + position_ids.extend(list(range(len(chunk)))) out = { 'input_ids': pack, 'cumulative_len': cumulative_len, - 'indexes': indexes, + 'position_ids': position_ids, 'labels': labels } return out diff --git a/xtuner/dataset/samplers/intern_repo.py b/xtuner/dataset/samplers/intern_repo.py index 3ca470c21..933719a58 100644 --- a/xtuner/dataset/samplers/intern_repo.py +++ b/xtuner/dataset/samplers/intern_repo.py @@ -4,9 +4,11 @@ import numpy as np from mmengine import print_log -from mmengine.dist import get_dist_info from torch.utils.data import Sampler +from xtuner.parallel.sequence import (get_data_parallel_rank, + get_data_parallel_world_size) + class InternRepoSampler(Sampler): @@ -17,7 +19,8 @@ def __init__(self, if seed is not None and seed != 1024: warnings.warn('For alignment accuracy, seed in InternRepoSampler' 'must be set to 1024.') - rank, world_size = get_dist_info() + world_size = get_data_parallel_world_size() + rank = get_data_parallel_rank() self.rank = rank self.world_size = world_size diff --git a/xtuner/dataset/utils.py b/xtuner/dataset/utils.py index 470a95303..84336ddb2 100644 --- a/xtuner/dataset/utils.py +++ b/xtuner/dataset/utils.py @@ -176,8 +176,8 @@ def get_cumulative_len(self, chunk_num): return cumulative_len - def get_indexes(self, cumulative_len): - indexes = [] + def get_position_ids(self, cumulative_len): + position_ids = [] for cumulative_len_cur in cumulative_len: index_cur = [] for i in range(len(cumulative_len_cur) - 1): @@ -185,8 +185,8 @@ def get_indexes(self, cumulative_len): list( range(cumulative_len_cur[i + 1] - # noqa: W504 cumulative_len_cur[i]))) - indexes.append(index_cur) - return indexes + position_ids.append(index_cur) + return position_ids def __call__(self, batch): concatenated_samples = { @@ -222,7 +222,7 @@ def __call__(self, batch): if self.use_varlen_attn: cumulative_len = self.get_cumulative_len(chunk_num) result['cumulative_len'] = cumulative_len - result['indexes'] = self.get_indexes(cumulative_len) + result['position_ids'] = self.get_position_ids(cumulative_len) else: if self.drop_last: result = {k: [] for k, v in concatenated_samples.items()} @@ -235,8 +235,8 @@ def __call__(self, batch): result['cumulative_len'] = [] if self.drop_last else [ self.residual_cumulative_len ] - result['indexes'] = [] if self.drop_last else self.get_indexes( - [self.residual_cumulative_len]) + result['position_ids'] = [] if self.drop_last \ + else self.get_position_ids([self.residual_cumulative_len]) self.residual_cumulative_len = [0] return result diff --git a/xtuner/engine/_strategy/deepspeed.py b/xtuner/engine/_strategy/deepspeed.py index afa0cc57c..42b7f5590 100644 --- a/xtuner/engine/_strategy/deepspeed.py +++ b/xtuner/engine/_strategy/deepspeed.py @@ -1,13 +1,19 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy from xtuner import DS_CEPH_DIR +from xtuner.parallel.sequence import init_sequence_parallel from xtuner.utils.fileio import patch_fileio class DeepSpeedStrategy(MMEngineDeepSpeedStrategy): def __init__(self, *args, **kwargs): + sequence_parallel_size = kwargs.pop('sequence_parallel_size', 1) + self.sequence_parallel_size = sequence_parallel_size + super().__init__(*args, **kwargs) from transformers.integrations.deepspeed import HfDeepSpeedConfig @@ -53,3 +59,12 @@ def resume(self, *args, **kwargs) -> None: else: checkpoint = super().resume(*args, **kwargs) return checkpoint + + def _setup_distributed( # type: ignore + self, + launcher: Optional[str] = None, + backend: str = 'nccl', + **kwargs, + ): + super()._setup_distributed(launcher, backend, **kwargs) + init_sequence_parallel(self.sequence_parallel_size) diff --git a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py index 9aa7a91bd..f2b23d3fe 100644 --- a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py +++ b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py @@ -11,7 +11,7 @@ class VarlenAttnArgsToMessageHubHook(Hook): - args = ('cumulative_len', 'indexes', 'max_seqlen') + args = ('cumulative_len', 'max_seqlen') def cast_data(self, data): if isinstance(data, Mapping): diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 79ff88b08..6fbe37fb6 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -123,7 +123,8 @@ def dispatch_internlm2_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ == 'InternLM2Attention': + if type(module).__name__ in ('InternLM2Attention', + 'InternLM2FlashAttention2'): if use_varlen_attn: print_log('dispatch internlm2 varlen attn forward', 'current') module.forward = types.MethodType( @@ -188,11 +189,12 @@ def traverse(module): for name, child in module.named_children(): if type(child).__name__ in ( 'InternLM2RotaryEmbedding', + 'InternLM2LinearScalingRotaryEmbedding', 'InternLM2DynamicNTKScalingRotaryEmbedding'): print_log('replace internlm2 rope', 'current') dim_model = child.inv_freq.shape[0] * 2 child_new = InternLM2RotaryEmbedding( - dim_model, child.max_seq_len_cached, rotary_base).to( + dim_model, child.max_position_embeddings, rotary_base).to( device=child.inv_freq.device, dtype=child.inv_freq.dtype) setattr(module, name, child_new) @@ -301,12 +303,12 @@ def dispatch_modules(model, use_varlen_attn=False): dispatch_internlm2_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm2_rmsnorm_forward(model) - replace_internlm2_rote(model) + # replace_internlm2_rote(model) elif 'internlm' in model_name: dispatch_internlm_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: dispatch_internlm_rmsnorm_forward(model) - replace_internlm_rote(model) + # replace_internlm_rote(model) elif 'llama' in model_name: dispatch_llama_attn_forward(model, use_varlen_attn) if USE_TRITON_KERNEL: diff --git a/xtuner/model/modules/dispatch/internlm.py b/xtuner/model/modules/dispatch/internlm.py index 8bee27b67..fd06def33 100644 --- a/xtuner/model/modules/dispatch/internlm.py +++ b/xtuner/model/modules/dispatch/internlm.py @@ -154,7 +154,7 @@ def internlm_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - indexes = message_hub.get_info(f'indexes_rank_{rank}') + # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -175,10 +175,11 @@ def internlm_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) + query_states = apply_rotary_emb(query_states, + cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) + key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index 66aa4f391..a166e8bae 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -8,13 +8,15 @@ from einops import rearrange from mmengine import MessageHub +from xtuner.parallel.sequence import sequence_parallel_wrapper from .triton_kernels import apply_rotary_emb +from .utils import upad_qkv SUPPORT_FLASH2 = False try: from flash_attn import flash_attn_func, flash_attn_varlen_func - + from flash_attn.bert_padding import pad_input SUPPORT_FLASH2 = True except ImportError: pass @@ -28,6 +30,9 @@ def __init__(self, base=1000000, device=None): super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base self.inv_freq = 1.0 / ( base**(torch.arange(0, dim, 2).float().to(device) / dim)) @@ -96,22 +101,115 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim) +def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) + to (batch, seqlen, num_attention_heads, head_dim)""" + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, + None, :].expand(batch, slen, + num_key_value_heads, n_rep, + head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, + head_dim) + + +@sequence_parallel_wrapper +def flash_attn_wo_mask(query_states, + key_states, + value_states, + causal, + dropout_rate=0.0): + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout_rate, causal=causal) + return attn_output + + +@sequence_parallel_wrapper +def flash_attn_w_mask( + query_states, # bs, q_len, nhead, h_dim + key_states, + value_states, + attention_mask, + causal, + dropout_rate=0.0): + batch_size, q_len = query_states.shape[:2] + query_states, key_states, value_states, indices_q, \ + cu_seq_lens, max_seq_lens = upad_qkv( + query_states, key_states, value_states, attention_mask, q_len) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout_rate, + causal=causal, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + return attn_output + + +def flash_attn1_pytorch(query_states, key_states, value_states, *args, + **kwargs): + # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = F.scaled_dot_product_attention(query_states, key_states, + value_states, *args, **kwargs) + attn_output = attn_output.transpose(1, 2) + return attn_output + + +@sequence_parallel_wrapper +def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, + max_seqlen): + q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( + 0, 1), value_states.flatten(0, 1) + cumulative_len = torch.cat(cumulative_len, dim=0) + attn_output = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cumulative_len, + cumulative_len, + max_seqlen, + max_seqlen, + 0, + return_attn_probs=False, + causal=True, + ) + attn_output = attn_output.unsqueeze(0) + return attn_output + + def internlm2_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: +): if 'padding_mask' in kwargs: warnings.warn( 'Passing `padding_mask` is deprecated and will be removed in v4.37' 'Please make sure use `attention_mask` instead.`') + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') + + output_attentions = False + bsz, q_len, _ = hidden_states.size() qkv_states = self.wqkv(hidden_states) @@ -135,7 +233,10 @@ def internlm2_attn_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # This modification is necessary for sequential parallel + assert position_ids is not None and (position_ids.max() + 1) >= kv_seq_len + cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -146,29 +247,49 @@ def internlm2_attn_forward( past_key_value = (key_states, value_states) if use_cache else None + # repeat kv for sequence parallel + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # flash attn 2 need (bs, seq_len, nhead, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if SUPPORT_FLASH2: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) - attn_output = attn_output.contiguous() + causal = self.is_causal and q_len != 1 + + if attention_mask is not None: + attn_output = flash_attn_w_mask( + query_states, + key_states, + value_states, + attention_mask, + causal, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal, + training=self.training) else: - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) # use flash attention implemented by pytorch - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = flash_attn1_pytorch( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + training=self.training) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - attn_output = self.wo(attn_output) - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value def internlm2_varlen_attn_forward( @@ -188,7 +309,7 @@ def internlm2_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - indexes = message_hub.get_info(f'indexes_rank_{rank}') + # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -216,10 +337,11 @@ def internlm2_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) + query_states = apply_rotary_emb(query_states, + cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) + key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -238,26 +360,21 @@ def internlm2_varlen_attn_forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + assert SUPPORT_FLASH2 if is_training: - q_unpad, k_unpad, v_unpad = query_states.flatten( - 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) - attn_output = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cumulative_len, - cumulative_len, - max_seqlen, - max_seqlen, - 0, - return_attn_probs=False, - causal=True, - ) + attn_output = varlen_flash_attn(query_states, key_states, value_states, + cumulative_len, max_seqlen) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index 94ddedaec..27b1f33d6 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -6,14 +6,18 @@ import torch.distributed as dist import torch.nn.functional as F from mmengine import MessageHub +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.utils import logging +from xtuner.parallel.sequence import sequence_parallel_wrapper from .triton_kernels import apply_rotary_emb +from .utils import upad_qkv SUPPORT_FLASH2 = False try: from flash_attn import flash_attn_func, flash_attn_varlen_func - + from flash_attn.bert_padding import pad_input SUPPORT_FLASH2 = True except ImportError: pass @@ -26,6 +30,9 @@ class Cache: pass +logger = logging.get_logger(__name__) + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., :x.shape[-1] // 2] @@ -33,18 +40,6 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, - # so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -63,6 +58,95 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: head_dim) +def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) + to (batch, seqlen, num_attention_heads, head_dim)""" + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, :, + None, :].expand(batch, slen, + num_key_value_heads, n_rep, + head_dim) + return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, + head_dim) + + +@sequence_parallel_wrapper +def flash_attn_wo_mask(query_states, + key_states, + value_states, + causal, + dropout_rate=0.0): + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout_rate, causal=causal) + return attn_output + + +@sequence_parallel_wrapper +def flash_attn_w_mask( + query_states, # bs, q_len, nhead, h_dim + key_states, + value_states, + attention_mask, + causal, + dropout_rate=0.0): + batch_size, q_len = query_states.shape[:2] + query_states, key_states, value_states, indices_q, \ + cu_seq_lens, max_seq_lens = upad_qkv( + query_states, key_states, value_states, attention_mask, q_len) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout_rate, + causal=causal, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) + return attn_output + + +def flash_attn1_pytorch(query_states, key_states, value_states, *args, + **kwargs): + # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + attn_output = F.scaled_dot_product_attention(query_states, key_states, + value_states, *args, **kwargs) + attn_output = attn_output.transpose(1, 2) + return attn_output + + +@sequence_parallel_wrapper +def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, + max_seqlen): + q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( + 0, 1), value_states.flatten(0, 1) + cumulative_len = torch.cat(cumulative_len, dim=0) + attn_output = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cumulative_len, + cumulative_len, + max_seqlen, + max_seqlen, + 0, + return_attn_probs=False, + causal=True, + ) + attn_output = attn_output.unsqueeze(0) + return attn_output + + def llama_attn_forward_legacy( self, hidden_states: torch.Tensor, @@ -136,21 +220,41 @@ def llama_attn_forward_legacy( past_key_value = (key_states, value_states) if use_cache else None + # repeat kv for sequence parallel + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # flash attn 2 need (bs, seq_len, nhead, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + if SUPPORT_FLASH2: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) - attn_output = attn_output.contiguous() + causal = self.is_causal and q_len != 1 + + if attention_mask is not None: + attn_output = flash_attn_w_mask( + query_states, + key_states, + value_states, + attention_mask, + causal, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal, + training=self.training) else: - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) # use flash attention implemented by pytorch - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = flash_attn1_pytorch( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + training=self.training) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -183,46 +287,26 @@ def llama_attn_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 + # LlamaFlashAttention2 attention does not support output_attentions if 'padding_mask' in kwargs: warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in ' - 'v4.37. Please make sure use `attention_mask` instead.`') + 'Passing `padding_mask` is deprecated and will be removed in v4.37' + ' Please make sure use `attention_mask` instead.`') - bsz, q_len, _ = hidden_states.size() + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop('padding_mask') - if self.config.pretraining_tp > 1: - key_value_slicing = (self.num_key_value_heads * - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + output_attentions = False - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) + bsz, q_len, _ = hidden_states.size() - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -232,16 +316,15 @@ def llama_attn_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - 'The cache structure has changed since version v4.36. ' - f'If you are using {self.__class__.__name__} ' - 'for auto-regressive decoding with k/v caching, ' - 'please make sure to initialize the attention class ' - 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + if self.training: + assert position_ids is not None + cos, sin = self.rotary_emb( + value_states, seq_len=position_ids.max() + 1) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -250,40 +333,78 @@ def llama_attn_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - if SUPPORT_FLASH2: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) - attn_output = attn_output.contiguous() + # TODO: These transpose are quite inefficient but Flash Attention + # requires the layout [batch_size, sequence_length, num_heads, head_dim]. + # We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons, therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f'The input hidden states seems to be silently casted in float32, ' + f'this might be related to the fact you have upcasted embedding ' + f'or layer norm layers in float32. We will cast back the input in' + f' {target_dtype}.') + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # flash attn + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal else: - # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - # use flash attention implemented by pytorch - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attn_mask=attention_mask) - attn_output = attn_output.transpose(1, 2).contiguous() + # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm + # is bumped to 2.1. For details, please see the comment in + # LlamaFlashAttention2 __init__. + causal = self.is_causal and q_len != 1 + + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + + if attention_mask is not None: + attn_output = flash_attn_w_mask( + query_states, + key_states, + value_states, + attention_mask, + causal, + dropout_rate, + training=self.training) + else: + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal, + dropout_rate, + training=self.training) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value + return attn_output, attn_weights, past_key_value def llama_varlen_attn_forward_legacy( @@ -302,7 +423,7 @@ def llama_varlen_attn_forward_legacy( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - indexes = message_hub.get_info(f'indexes_rank_{rank}') + # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -357,10 +478,11 @@ def llama_varlen_attn_forward_legacy( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) + query_states = apply_rotary_emb(query_states, + cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) + key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -379,26 +501,21 @@ def llama_varlen_attn_forward_legacy( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + assert SUPPORT_FLASH2 if is_training: - q_unpad, k_unpad, v_unpad = query_states.flatten( - 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) - attn_output = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cumulative_len, - cumulative_len, - max_seqlen, - max_seqlen, - 0, - return_attn_probs=False, - causal=True, - ) + attn_output = varlen_flash_attn(query_states, key_states, value_states, + cumulative_len, max_seqlen) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -436,7 +553,7 @@ def llama_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - indexes = message_hub.get_info(f'indexes_rank_{rank}') + # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -499,10 +616,11 @@ def llama_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) + query_states = apply_rotary_emb(query_states, + cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) + key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -516,33 +634,25 @@ def llama_varlen_attn_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) + # repeat kv for sequence parallel + key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) + value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + assert SUPPORT_FLASH2 if is_training: - q_unpad, k_unpad, v_unpad = query_states.flatten( - 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) - attn_output = flash_attn_varlen_func( - q_unpad, - k_unpad, - v_unpad, - cumulative_len, - cumulative_len, - max_seqlen, - max_seqlen, - 0, - return_attn_probs=False, - causal=True, - ) + attn_output = varlen_flash_attn(query_states, key_states, value_states, + cumulative_len, max_seqlen) else: - attn_output = flash_attn_func( - query_states, key_states, value_states, causal=True) + attn_output = flash_attn_wo_mask( + query_states, + key_states, + value_states, + causal=True, + training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) diff --git a/xtuner/model/modules/dispatch/mistral.py b/xtuner/model/modules/dispatch/mistral.py index 8c65cbec6..92245230c 100644 --- a/xtuner/model/modules/dispatch/mistral.py +++ b/xtuner/model/modules/dispatch/mistral.py @@ -84,13 +84,9 @@ def mistral_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - indexes = message_hub.get_info(f'indexes_rank_{rank}') + # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') - # cumulative_len = message_hub.get_info(f'cumulative_len') - # indexes = message_hub.get_info(f'indexes') - # max_seqlen = message_hub.get_info(f'max_seqlen') - assert is_training == (cumulative_len is not None) == ( past_key_value is None) @@ -136,10 +132,11 @@ def mistral_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0), - sin[indexes].squeeze(0)) + query_states = apply_rotary_emb(query_states, + cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) + key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), + sin[position_ids].squeeze(0)) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py new file mode 100644 index 000000000..4cfa26cd1 --- /dev/null +++ b/xtuner/model/modules/dispatch/utils.py @@ -0,0 +1,64 @@ +import torch +import torch.nn.functional as F + +try: + from flash_attn.bert_padding import index_first_axis, unpad_input +except ImportError: + pass + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def upad_qkv(query_layer, key_layer, value_layer, attention_mask, + query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, + head_dim), indices_k) + if query_length == kv_seq_len: + # Different from the origin version as sequence parallel change + # the number of attention heads. + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), + indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \ + unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index c433ff1e9..7aa0ec63c 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -1,18 +1,23 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from collections import OrderedDict from contextlib import nullcontext +import torch from mmengine import print_log from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from mmengine.runner import load_checkpoint from peft import get_peft_model, prepare_model_for_kbit_training from torch import nn -from transformers import PreTrainedModel, PreTrainedTokenizer +from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer from transformers.integrations import is_deepspeed_zero3_enabled +from xtuner.parallel.sequence import (get_sequence_parallel_world_size, + reduce_sequence_parallel_loss) from xtuner.registry import BUILDER from .modules import dispatch_modules +from .modules.dispatch import SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, make_inputs_require_grad, traverse_dict) @@ -69,10 +74,12 @@ def __init__(self, peft_model=None, use_activation_checkpointing=True, use_varlen_attn=False, - tokenizer=None): + tokenizer=None, + max_position_embeddings=None): super().__init__() with LoadWoInit(): - self.llm = self._build_from_cfg_or_module(llm) + self.llm = self._build_from_cfg_or_module(llm, + max_position_embeddings) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -137,11 +144,48 @@ def _prepare_for_lora(self, def init_weights(self): pass - def _build_from_cfg_or_module(self, cfg_or_mod): + def _prepare_for_long_context_training(self, cfg, max_position_embeddings): + pretrained_model_name_or_path = cfg.pretrained_model_name_or_path + config = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True) + + orig_rope_scaling = getattr(config, 'rope_scaling', None) + if orig_rope_scaling is None: + orig_rope_scaling = {'factor': 1} + + orig_rope_scaling_factor = orig_rope_scaling[ + 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 + orig_ctx_len = getattr(config, 'max_position_embeddings', None) + if orig_ctx_len: + orig_ctx_len *= orig_rope_scaling_factor + if max_position_embeddings > orig_ctx_len: + scaling_factor = float( + math.ceil(max_position_embeddings / orig_ctx_len)) + config.rope_scaling = { + 'type': 'linear', + 'factor': scaling_factor + } + + # hardcode for internlm2 + config.attn_implementation = 'flash_attention_2' + + cfg.config = config + return cfg + + def _build_from_cfg_or_module(self, + cfg_or_mod, + max_position_embeddings=None): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) + if SUPPORT_FLASH2: + cfg_or_mod.torch_dtype = torch.bfloat16 \ + if torch.cuda.is_bf16_supported() else torch.float16 + cfg_or_mod.attn_implementation = 'flash_attention_2' + if max_position_embeddings is not None: + cfg_or_mod = self._prepare_for_long_context_training( + cfg_or_mod, max_position_embeddings) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError @@ -168,10 +212,20 @@ def predict(self, data, data_samples=None): logits_dict = [{'logits': logits} for logits in outputs.logits] return logits_dict - def compute_loss(self, data, data_samples=None): + def compute_sequence_parallel_loss(self, data): outputs = self.llm(**data) - loss_dict = {'loss': outputs.loss} - return loss_dict + labels = data['labels'] + num_tokens = (labels != -100).sum() + loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens) + return {'loss': loss} + + def compute_loss(self, data, data_samples=None): + if get_sequence_parallel_world_size() > 1: + return self.compute_sequence_parallel_loss(data) + else: + outputs = self.llm(**data) + loss_dict = {'loss': outputs.loss} + return loss_dict def state_dict(self, *args, **kwargs): state_dict = super().state_dict(*args, **kwargs) diff --git a/xtuner/parallel/sequence/__init__.py b/xtuner/parallel/sequence/__init__.py new file mode 100644 index 000000000..a50921336 --- /dev/null +++ b/xtuner/parallel/sequence/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dist import init_dist + +from .attention import sequence_parallel_wrapper +from .data_collate import (pad_for_sequence_parallel, + split_for_sequence_parallel) +from .reduce_loss import reduce_sequence_parallel_loss +from .sampler import SequenceParallelSampler +from .setup_distributed import (get_data_parallel_group, + get_data_parallel_rank, + get_data_parallel_world_size, + get_sequence_parallel_group, + get_sequence_parallel_rank, + get_sequence_parallel_world_size, + init_sequence_parallel) + +__all__ = [ + 'sequence_parallel_wrapper', 'pad_for_sequence_parallel', + 'split_for_sequence_parallel', 'SequenceParallelSampler', + 'init_sequence_parallel', 'get_sequence_parallel_group', + 'get_sequence_parallel_world_size', 'get_sequence_parallel_rank', + 'get_data_parallel_group', 'get_data_parallel_world_size', + 'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist' +] diff --git a/xtuner/parallel/sequence/attention.py b/xtuner/parallel/sequence/attention.py new file mode 100644 index 000000000..b1b1ebcee --- /dev/null +++ b/xtuner/parallel/sequence/attention.py @@ -0,0 +1,102 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Tuple + +import torch +import torch.distributed as dist +from torch import Tensor + +from .setup_distributed import (get_sequence_parallel_group, + get_sequence_parallel_world_size) + + +def all_to_all_scatter_nhead(input): + # bs, seq, nhead, dim ==> + # bs, seq * sp_world_size, nhead / sp_world_size, dim + sp_world_size = get_sequence_parallel_world_size() + sp_group = get_sequence_parallel_group() + bs, seq, nhead, dim = input.shape + input_t = input.reshape(bs, seq, sp_world_size, nhead // sp_world_size, + dim) + input_t = input_t.permute(2, 0, 1, 3, 4).contiguous() + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=sp_group) + output = output.transpose(0, 1) + return output.reshape(bs, seq * sp_world_size, nhead // sp_world_size, dim) + + +def all_to_all_scatter_seq(input): + # bs, seq * sp_world_size, nhead / sp_world_size, dim ==> + # bs, seq, nhead, dim + sp_world_size = get_sequence_parallel_world_size() + sp_group = get_sequence_parallel_group() + bs, seq, nhead, dim = input.shape + input_t = input.reshape(bs, sp_world_size, seq // sp_world_size, nhead, + dim) + input_t = input_t.transpose(0, 1).contiguous() + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=sp_group) + output = output.permute(1, 2, 0, 3, 4) + return output.reshape(bs, seq // sp_world_size, nhead * sp_world_size, dim) + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward(ctx: Any, input: Tensor, scatter_seq) -> Tensor: + ctx.scatter_seq = scatter_seq + ctx.input_shape = input.shape + if scatter_seq: + return all_to_all_scatter_seq(input) + return all_to_all_scatter_nhead(input) + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: + grad = _SeqAllToAll.apply(*grad_output, not ctx.scatter_seq) + return (grad, None) + + +def pre_process_for_sequence_parallel_attn(query_states, key_states, + value_states): + sequence_parallel_world_size = get_sequence_parallel_world_size() + n_head = query_states.shape[2] + assert n_head % sequence_parallel_world_size == 0, \ + ('The number of attention heads should be divisible by ' + f'sequence_parallel_world_size. But got n_head = {n_head} and ' + f'sequence_parallel_world_size = {sequence_parallel_world_size}.') + + # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim) + query_states = _SeqAllToAll.apply(query_states, False) + key_states = _SeqAllToAll.apply(key_states, False) + value_states = _SeqAllToAll.apply(value_states, False) + + return query_states, key_states, value_states + + +def post_process_for_sequence_parallel_attn(attn_output): + # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim) + output = _SeqAllToAll.apply(attn_output, True) + return output + + +def sequence_parallel_wrapper(local_attn): + + def sequence_parallel_attn(query_states, key_states, value_states, *args, + **kwargs): + training = kwargs.pop('training', True) + enable_sequence_parallel = ( + dist.is_initialized() and get_sequence_parallel_world_size() > 1 + and training) + if enable_sequence_parallel: + query_states, key_states, value_states = \ + pre_process_for_sequence_parallel_attn( + query_states, key_states, value_states) + + out = local_attn(query_states, key_states, value_states, *args, + **kwargs) + + if enable_sequence_parallel: + out = post_process_for_sequence_parallel_attn(out).contiguous() + + return out + + return sequence_parallel_attn diff --git a/xtuner/parallel/sequence/data_collate.py b/xtuner/parallel/sequence/data_collate.py new file mode 100644 index 000000000..f61b481b9 --- /dev/null +++ b/xtuner/parallel/sequence/data_collate.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX +from .setup_distributed import (get_sequence_parallel_rank, + get_sequence_parallel_world_size) + + +def pad_for_sequence_parallel(tokens, + labels=None, + position_ids=None, + attention_mask=None, + tokens_pad_index=DEFAULT_PAD_TOKEN_INDEX, + labels_pad_index=IGNORE_INDEX, + position_ids_pad_index=0, + attention_mask_pad_index=0): + if labels is not None: + assert tokens.shape == labels.shape + if position_ids is not None: + assert tokens.shape == position_ids.shape + if attention_mask is not None: + assert tokens.shape == attention_mask.shape + + bs, seq_len = tokens.shape + seq_parallel_world_size = get_sequence_parallel_world_size() + if seq_len % seq_parallel_world_size == 0: + return tokens, labels, position_ids, attention_mask + + pad_num = seq_parallel_world_size - (seq_len % seq_parallel_world_size) + pad = torch.full((bs, pad_num), + tokens_pad_index, + dtype=tokens.dtype, + device=tokens.device) + tokens = torch.cat([tokens, pad], dim=1) + + if labels is not None: + pad = torch.full((bs, pad_num), + labels_pad_index, + dtype=labels.dtype, + device=labels.device) + labels = torch.cat([labels, pad], dim=1) + + if position_ids is not None: + pad = torch.full((bs, pad_num), + position_ids_pad_index, + dtype=position_ids.dtype, + device=position_ids.device) + position_ids = torch.cat([position_ids, pad], dim=1) + + if attention_mask is not None: + pad = torch.full((bs, pad_num), + attention_mask_pad_index, + dtype=attention_mask.dtype, + device=attention_mask.device) + attention_mask = torch.cat([attention_mask, pad], dim=1) + + return tokens, labels, position_ids, attention_mask + + +def split_for_sequence_parallel(tokens, labels=None, position_ids=None): + seq_parallel_world_size = get_sequence_parallel_world_size() + seq_parallel_world_rank = get_sequence_parallel_rank() + seq_len = tokens.size(1) + assert seq_len % seq_parallel_world_size == 0 + sub_seq_len = seq_len // seq_parallel_world_size + sub_seq_start = seq_parallel_world_rank * sub_seq_len + sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len + + tokens = tokens[:, sub_seq_start:sub_seq_end] + if labels is not None: + labels = labels[:, sub_seq_start:sub_seq_end] + if position_ids is not None: + position_ids = position_ids[:, sub_seq_start:sub_seq_end] + + return tokens, labels, position_ids diff --git a/xtuner/parallel/sequence/reduce_loss.py b/xtuner/parallel/sequence/reduce_loss.py new file mode 100644 index 000000000..56a8389f4 --- /dev/null +++ b/xtuner/parallel/sequence/reduce_loss.py @@ -0,0 +1,17 @@ +import torch +import torch.distributed as dist + +from .setup_distributed import get_sequence_parallel_group + + +def reduce_sequence_parallel_loss(mean_loss, num_tokens_for_loss): + sequence_parallel_group = get_sequence_parallel_group() + if num_tokens_for_loss == 0: + # convert nan to 0 just for logging + mean_loss = torch.nan_to_num(mean_loss) + loss_sum = mean_loss * num_tokens_for_loss + dist.all_reduce(loss_sum, group=sequence_parallel_group) + dist.all_reduce(num_tokens_for_loss, group=sequence_parallel_group) + + loss = loss_sum / num_tokens_for_loss + return loss diff --git a/xtuner/parallel/sequence/sampler.py b/xtuner/parallel/sequence/sampler.py new file mode 100644 index 000000000..69adb7cc9 --- /dev/null +++ b/xtuner/parallel/sequence/sampler.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +from typing import Optional, Sized + +from mmengine.dataset import DefaultSampler +from mmengine.dist import sync_random_seed + +from .setup_distributed import (get_data_parallel_rank, + get_data_parallel_world_size) + + +class SequenceParallelSampler(DefaultSampler): + + def __init__(self, + dataset: Sized, + shuffle: bool = True, + seed: Optional[int] = None, + round_up: bool = True) -> None: + rank = get_data_parallel_rank() + world_size = get_data_parallel_world_size() + self.rank = rank + self.world_size = world_size + + self.dataset = dataset + self.shuffle = shuffle + if seed is None: + seed = sync_random_seed() + self.seed = seed + self.epoch = 0 + self.round_up = round_up + + if self.round_up: + self.num_samples = math.ceil(len(self.dataset) / world_size) + self.total_size = self.num_samples * self.world_size + else: + self.num_samples = math.ceil( + (len(self.dataset) - rank) / world_size) + self.total_size = len(self.dataset) diff --git a/xtuner/parallel/sequence/setup_distributed.py b/xtuner/parallel/sequence/setup_distributed.py new file mode 100644 index 000000000..ea207bf10 --- /dev/null +++ b/xtuner/parallel/sequence/setup_distributed.py @@ -0,0 +1,100 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.distributed as dist + +_SEQUENCE_PARALLEL_GROUP = None +_SEQUENCE_PARALLEL_WORLD_SIZE = None +_SEQUENCE_PARALLEL_RANK = None + +_DATA_PARALLEL_GROUP = None +_DATA_PARALLEL_WORLD_SIZE = None +_DATA_PARALLEL_RANK = None + + +def init_sequence_parallel(sequence_parallel_size: int = 1): + assert dist.is_initialized() + world_size: int = dist.get_world_size() + + # enable_ds_sequence_parallel = sequence_parallel_size > 1 + # if enable_ds_sequence_parallel: + if world_size % sequence_parallel_size != 0: + raise RuntimeError(f'world_size ({world_size}) is not divisible by ' + f'sequence_parallel_size {sequence_parallel_size}') + + num_sequence_parallel_groups: int = world_size // sequence_parallel_size + + rank = dist.get_rank() + + # Build the sequence parallel groups. + global _SEQUENCE_PARALLEL_GROUP + assert _SEQUENCE_PARALLEL_GROUP is None, \ + 'sequence parallel group is already initialized' + for i in range(num_sequence_parallel_groups): + ranks = range(i * sequence_parallel_size, + (i + 1) * sequence_parallel_size) + group = dist.new_group(ranks) + if rank in ranks: + _SEQUENCE_PARALLEL_GROUP = group + + global _DATA_PARALLEL_GROUP + assert _DATA_PARALLEL_GROUP is None, \ + 'data parallel group is already initialized' + all_data_parallel_group_ranks = [] + start_rank = 0 + end_rank = world_size + for j in range(sequence_parallel_size): + ranks = range(start_rank + j, end_rank, sequence_parallel_size) + all_data_parallel_group_ranks.append(list(ranks)) + group = dist.new_group(ranks) + if rank in ranks: + _DATA_PARALLEL_GROUP = group + + +def get_sequence_parallel_group(): + """Get the sequence parallel group the caller rank belongs to.""" + return _SEQUENCE_PARALLEL_GROUP + + +def get_sequence_parallel_world_size(): + """Return world size for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_WORLD_SIZE + if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: + return _SEQUENCE_PARALLEL_WORLD_SIZE + _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_sequence_parallel_group()) + return _SEQUENCE_PARALLEL_WORLD_SIZE + + +def get_sequence_parallel_rank(): + """Return my rank for the sequence parallel group.""" + global _SEQUENCE_PARALLEL_RANK + if _SEQUENCE_PARALLEL_RANK is not None: + return _SEQUENCE_PARALLEL_RANK + _SEQUENCE_PARALLEL_RANK = dist.get_rank( + group=get_sequence_parallel_group()) + return _SEQUENCE_PARALLEL_RANK + + +def get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert _DATA_PARALLEL_GROUP is not None, \ + 'data parallel group is not initialized' + return _DATA_PARALLEL_GROUP + + +def get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + global _DATA_PARALLEL_WORLD_SIZE + if _DATA_PARALLEL_WORLD_SIZE is not None: + return _DATA_PARALLEL_WORLD_SIZE + _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_data_parallel_group()) + return _DATA_PARALLEL_WORLD_SIZE + + +def get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + global _DATA_PARALLEL_RANK + if _DATA_PARALLEL_RANK is not None: + return _DATA_PARALLEL_RANK + _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) + return _DATA_PARALLEL_RANK diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 922089dae..7acbbf21f 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -76,6 +76,31 @@ def register_function(cfg_dict): register_function(value) +def check_cfg(cfg): + if getattr(cfg, 'use_varlen_attn', + False) and cfg.train_dataloader.batch_size > 1: + raise NotImplementedError( + f'If utilizing varlen attention, the batch size should be' + f' set to 1, but got {cfg.train_dataloader.batch_size}') + + if getattr(cfg, 'use_varlen_attn', False) and (not getattr( + cfg.train_dataloader.dataset, 'pack_to_max_length', True)): + raise AssertionError( + 'When using varlen attention, `pack_to_max_length`' + 'should be set to True, but got use_varlen_attn = True and ' + 'pack_to_max_length = False.') + + if getattr(cfg, 'use_varlen_attn', False): + sequence_parallel = getattr(cfg, 'sequence_parallel', 1) + max_length = getattr(cfg.train_dataloader.dataset, 'max_length', None) + if max_length is not None: + assert max_length % sequence_parallel == 0, \ + ('When using varlen attention, `max_length` should be evenly ' + 'divided by sequence parallel world size, but got ' + f'max_length = {max_length} and sequence_parallel = ' + f'{sequence_parallel}') + + def main(): args = parse_args() @@ -96,6 +121,8 @@ def main(): # change these FunctionType object to str register_function(cfg._cfg_dict) + check_cfg(cfg) + if cfg.get('framework', 'mmengine').lower() == 'huggingface': # set default training_args if cfg.get('training_args', None) is None: @@ -277,7 +304,10 @@ def main(): gradient_accumulation_steps=grad_accum, train_micro_batch_size_per_gpu=train_bs, gradient_clipping=grad_clip, - exclude_frozen_parameters=exclude_frozen_parameters) + exclude_frozen_parameters=exclude_frozen_parameters, + sequence_parallel_size=getattr(cfg, + 'sequence_parallel_size', + 1)) cfg.__setitem__('strategy', strategy) optim_wrapper = dict( type='DeepSpeedOptimWrapper', From 32e3e5f0581998fd84f30f8a1847554a287c161a Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Mon, 25 Mar 2024 20:37:35 +0800 Subject: [PATCH 2/9] [Bug] Fix bugs in flash_attn1_pytorch (#513) add @sequence_parallel_wrapper --- xtuner/model/modules/dispatch/internlm2.py | 1 + xtuner/model/modules/dispatch/llama.py | 1 + 2 files changed, 2 insertions(+) diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index a166e8bae..b354e6fd5 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -156,6 +156,7 @@ def flash_attn_w_mask( return attn_output +@sequence_parallel_wrapper def flash_attn1_pytorch(query_states, key_states, value_states, *args, **kwargs): # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index 27b1f33d6..4077efb6c 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -113,6 +113,7 @@ def flash_attn_w_mask( return attn_output +@sequence_parallel_wrapper def flash_attn1_pytorch(query_states, key_states, value_states, *args, **kwargs): # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) From 6004335ead98496a2e84f59e1b2d2c3e41168b0d Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Tue, 26 Mar 2024 15:30:03 +0800 Subject: [PATCH 3/9] [Fix] delete cat in varlen attn (#508) delete cat in varlen attn --- .../varlen_attn_args_to_messagehub_hook.py | 33 ++++++------------- xtuner/model/modules/dispatch/internlm2.py | 2 -- xtuner/model/modules/dispatch/llama.py | 3 -- 3 files changed, 10 insertions(+), 28 deletions(-) diff --git a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py index f2b23d3fe..f7a95a09c 100644 --- a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py +++ b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Mapping, Optional, Sequence, Union +from typing import Optional, Union -import torch import torch.distributed as dist from mmengine import MessageHub from mmengine.hooks import Hook @@ -11,20 +10,6 @@ class VarlenAttnArgsToMessageHubHook(Hook): - args = ('cumulative_len', 'max_seqlen') - - def cast_data(self, data): - if isinstance(data, Mapping): - return {key: self.cast_data(data[key]) for key in data} - elif isinstance(data, (str, bytes)) or data is None: - return data - elif isinstance(data, Sequence): - return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable - elif isinstance(data, torch.Tensor): - return data.cuda() - else: - return data - def before_train_iter(self, runner, batch_idx: int, @@ -35,10 +20,13 @@ def before_train_iter(self, assert 'data' in data_batch.keys() data = data_batch['data'] - for arg in self.args: - assert arg in data - message_hub.update_info(f'{arg}_rank_{rank}', - self.cast_data(data.pop(arg))) + cumulative_len = data.pop('cumulative_len') + assert len(cumulative_len) == 1 + cumulative_len = cumulative_len[0].cuda() + message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len) + + max_seqlen = data.pop('max_seqlen') + message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen) def after_train_iter(self, runner, @@ -47,6 +35,5 @@ def after_train_iter(self, outputs: Optional[dict] = None) -> None: rank = dist.get_rank() message_hub = MessageHub.get_instance('varlen_attn_args') - - for arg in self.args: - message_hub.update_info(f'{arg}_rank_{rank}', None) + message_hub.update_info(f'cumulative_len_rank_{rank}', None) + message_hub.update_info(f'max_seqlen_rank_{rank}', None) diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index b354e6fd5..9bc77177c 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -174,7 +174,6 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) attn_output = flash_attn_varlen_func( q_unpad, k_unpad, @@ -310,7 +309,6 @@ def internlm2_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index 4077efb6c..df17e1b49 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -131,7 +131,6 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) - cumulative_len = torch.cat(cumulative_len, dim=0) attn_output = flash_attn_varlen_func( q_unpad, k_unpad, @@ -424,7 +423,6 @@ def llama_varlen_attn_forward_legacy( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) @@ -554,7 +552,6 @@ def llama_varlen_attn_forward( message_hub = MessageHub.get_instance('varlen_attn_args') rank = dist.get_rank() cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') - # position_ids = message_hub.get_info(f'position_ids_rank_{rank}') max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') assert is_training == (cumulative_len is not None) From 957abf63338216ef6144989873cc164219799fbe Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:43:14 +0800 Subject: [PATCH 4/9] bump version to 0.1.16 (#520) --- xtuner/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xtuner/version.py b/xtuner/version.py index 11029f49f..ae73ce92a 100644 --- a/xtuner/version.py +++ b/xtuner/version.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -__version__ = '0.1.16.dev0' +__version__ = '0.1.16' short_version = __version__ From 2c0fa5acc75567a5c56267b5c03f1123717f3fae Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Wed, 27 Mar 2024 11:43:49 +0800 Subject: [PATCH 5/9] [Improve] Add `generation_kwargs` for `EvaluateChatHook` (#501) * update * update --- xtuner/engine/hooks/evaluate_chat_hook.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/xtuner/engine/hooks/evaluate_chat_hook.py b/xtuner/engine/hooks/evaluate_chat_hook.py index efa1bc69f..8e6a86822 100644 --- a/xtuner/engine/hooks/evaluate_chat_hook.py +++ b/xtuner/engine/hooks/evaluate_chat_hook.py @@ -29,7 +29,8 @@ def __init__(self, every_n_iters=None, max_new_tokens=600, stop_word=None, - stop_words=[]): + stop_words=[], + generation_kwargs={}): self.evaluation_inputs = evaluation_inputs if isinstance(self.evaluation_inputs, str): self.evaluation_inputs = [self.evaluation_inputs] @@ -69,8 +70,9 @@ def __init__(self, if image_processor is not None: self.image_processor = BUILDER.build(image_processor) self.stop_criteria = StoppingCriteriaList() + # default generation config - self.gen_config = GenerationConfig( + default_generation_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, temperature=0.1, @@ -79,8 +81,10 @@ def __init__(self, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else - self.tokenizer.eos_token_id, - ) + self.tokenizer.eos_token_id) + default_generation_kwargs.update(generation_kwargs) + self.gen_config = GenerationConfig(**default_generation_kwargs) + self.stop_criteria = StoppingCriteriaList() for word in stop_words: self.stop_criteria.append( From 62e7d80d9761c66446bcc46180aef936c101becb Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Wed, 27 Mar 2024 19:09:32 +0800 Subject: [PATCH 6/9] [Bugs] Fix bugs when training in non-distributed env (#522) fix bugs when training in non-distributed env --- xtuner/parallel/sequence/data_collate.py | 3 +++ xtuner/parallel/sequence/setup_distributed.py | 26 ++++++++++++++----- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/xtuner/parallel/sequence/data_collate.py b/xtuner/parallel/sequence/data_collate.py index f61b481b9..15b242d73 100644 --- a/xtuner/parallel/sequence/data_collate.py +++ b/xtuner/parallel/sequence/data_collate.py @@ -59,6 +59,9 @@ def pad_for_sequence_parallel(tokens, def split_for_sequence_parallel(tokens, labels=None, position_ids=None): seq_parallel_world_size = get_sequence_parallel_world_size() + if seq_parallel_world_size == 1: + return tokens, labels, position_ids + seq_parallel_world_rank = get_sequence_parallel_rank() seq_len = tokens.size(1) assert seq_len % seq_parallel_world_size == 0 diff --git a/xtuner/parallel/sequence/setup_distributed.py b/xtuner/parallel/sequence/setup_distributed.py index ea207bf10..9eb159e66 100644 --- a/xtuner/parallel/sequence/setup_distributed.py +++ b/xtuner/parallel/sequence/setup_distributed.py @@ -59,8 +59,11 @@ def get_sequence_parallel_world_size(): global _SEQUENCE_PARALLEL_WORLD_SIZE if _SEQUENCE_PARALLEL_WORLD_SIZE is not None: return _SEQUENCE_PARALLEL_WORLD_SIZE - _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( - group=get_sequence_parallel_group()) + if not dist.is_initialized(): + _SEQUENCE_PARALLEL_WORLD_SIZE = 1 + else: + _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_sequence_parallel_group()) return _SEQUENCE_PARALLEL_WORLD_SIZE @@ -69,8 +72,11 @@ def get_sequence_parallel_rank(): global _SEQUENCE_PARALLEL_RANK if _SEQUENCE_PARALLEL_RANK is not None: return _SEQUENCE_PARALLEL_RANK - _SEQUENCE_PARALLEL_RANK = dist.get_rank( - group=get_sequence_parallel_group()) + if not dist.is_initialized(): + _SEQUENCE_PARALLEL_RANK = 0 + else: + _SEQUENCE_PARALLEL_RANK = dist.get_rank( + group=get_sequence_parallel_group()) return _SEQUENCE_PARALLEL_RANK @@ -86,8 +92,11 @@ def get_data_parallel_world_size(): global _DATA_PARALLEL_WORLD_SIZE if _DATA_PARALLEL_WORLD_SIZE is not None: return _DATA_PARALLEL_WORLD_SIZE - _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( - group=get_data_parallel_group()) + if not dist.is_initialized(): + _DATA_PARALLEL_WORLD_SIZE = 1 + else: + _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size( + group=get_data_parallel_group()) return _DATA_PARALLEL_WORLD_SIZE @@ -96,5 +105,8 @@ def get_data_parallel_rank(): global _DATA_PARALLEL_RANK if _DATA_PARALLEL_RANK is not None: return _DATA_PARALLEL_RANK - _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) + if not dist.is_initialized(): + _DATA_PARALLEL_RANK = 0 + else: + _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group()) return _DATA_PARALLEL_RANK From 1dd5cbd367f34257d7f4e53910df0399005e5247 Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Thu, 28 Mar 2024 18:30:26 +0800 Subject: [PATCH 7/9] [Fix] Support transformers>=4.38 and require transformers>=4.36.0 (#494) * support transformers>=4.38 and require transformers>=4.36.0 * add flash attn2 to config if SUPPORT_FLASH2 * fix lint * fix comments * fix lint * dispatch config if flash_attn or torch.nn.functional.scaled_dot_product_attention is supported * do not dispatch attn forward if using scaled_dot_product_attention and require flash_attn installed to use sequence parallel * dispatch config in llava model if flash_attn or torch.nn.functional.scaled_dot_product_attention is supported --- requirements/runtime.txt | 6 +- xtuner/model/llava.py | 66 ++- xtuner/model/modules/dispatch/__init__.py | 10 +- xtuner/model/modules/dispatch/internlm2.py | 33 +- xtuner/model/modules/dispatch/llama.py | 513 +++++++++------------ xtuner/model/sft.py | 66 ++- xtuner/tools/train.py | 5 + 7 files changed, 338 insertions(+), 361 deletions(-) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 5ccada772..f531754b3 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -18,8 +18,6 @@ tiktoken # limit pytorch version <= 2.1.2 as there may be some bugs in triton 2.2 torch<=2.1.2 torchvision<=0.16.2 -# Minimum 4.34.0 to support added_tokens_decoder of tokenizer -# Exclude 4.34.1, 4.35.0, 4.35.1, 4.35.2 to avoid BC-break, -# see https://github.com/huggingface/transformers/pull/27020, https://github.com/huggingface/transformers/pull/27073 -transformers>=4.34.0,!=4.34.1,!=4.35.0,!=4.35.1,!=4.35.2 +# Minimum 4.36.0 to support `Cache` data structure used by KV Cache +transformers>=4.36.0 transformers_stream_generator diff --git a/xtuner/model/llava.py b/xtuner/model/llava.py index d7e39a804..19b427a75 100644 --- a/xtuner/model/llava.py +++ b/xtuner/model/llava.py @@ -1,13 +1,17 @@ # Copyright (c) OpenMMLab. All rights reserved. +import math from collections import OrderedDict +import torch import torch.nn as nn from mmengine.config import Config, ConfigDict from mmengine.model import BaseModel from peft import get_peft_model, prepare_model_for_kbit_training +from transformers import AutoConfig from xtuner.registry import BUILDER from .modules import ProjectorConfig, ProjectorModel, dispatch_modules +from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, guess_load_checkpoint, make_inputs_require_grad, @@ -26,11 +30,15 @@ def __init__(self, projector_depth=2, llm_lora=None, visual_encoder_lora=None, - use_activation_checkpointing=True): + use_activation_checkpointing=True, + max_position_embeddings=None): super().__init__() self.freeze_llm = freeze_llm self.freeze_visual_encoder = freeze_visual_encoder with LoadWoInit(): + if isinstance(llm, dict): + llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) + self.llm = self._build_from_cfg_or_module(llm) self.visual_encoder = self._build_from_cfg_or_module( visual_encoder) @@ -157,6 +165,62 @@ def state_dict(self, *args, **kwargs): for k, v in state_dict.items() if 'projector.' in k}) return to_return + @staticmethod + def _prepare_for_long_context_training(cfg, llm_cfg, + max_position_embeddings): + + orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) + if orig_rope_scaling is None: + orig_rope_scaling = {'factor': 1} + + orig_rope_scaling_factor = orig_rope_scaling[ + 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 + orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) + if orig_ctx_len: + orig_ctx_len *= orig_rope_scaling_factor + if max_position_embeddings > orig_ctx_len: + scaling_factor = float( + math.ceil(max_position_embeddings / orig_ctx_len)) + llm_cfg.rope_scaling = { + 'type': 'linear', + 'factor': scaling_factor + } + + # hardcode for internlm2 + llm_cfg.attn_implementation = 'flash_attention_2' + cfg.config = llm_cfg + + return cfg, llm_cfg + + @staticmethod + def _prepare_for_flash_attn(cfg, llm_cfg): + cls_name = type(llm_cfg).__name__ + SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', + 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + + if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch.bfloat16 \ + if torch.cuda.is_bf16_supported() else torch.float16 + cfg.attn_implementation = 'flash_attention_2' + elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: + cfg.attn_implementation = 'sdpa' + + return cfg, llm_cfg + + def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): + pretrained_model_name_or_path = cfg.pretrained_model_name_or_path + llm_cfg = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True) + cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) + if max_position_embeddings is not None: + cfg, llm_cfg = self._prepare_for_long_context_training( + cfg, llm_cfg, max_position_embeddings) + return cfg + def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py index 6fbe37fb6..ab104a7dc 100644 --- a/xtuner/model/modules/dispatch/__init__.py +++ b/xtuner/model/modules/dispatch/__init__.py @@ -13,7 +13,7 @@ from .yi import yi_attn_forward IS_LOW_VERSION_TRANSFORMERS = digit_version( - transformers.__version__) < digit_version('4.36') + transformers.__version__) < digit_version('4.38') SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.0.0') SUPPORT_FLASH2 = False @@ -48,7 +48,7 @@ def dispatch_llama_attn_forward(model, use_varlen_attn): if use_varlen_attn: assert SUPPORT_FLASH2 and SUPPORT_TRITON, \ 'flash_attn and triton is required if you want to use varlen_attn.' - elif not SUPPORT_FLASH: + elif not SUPPORT_FLASH2: return from .llama import (llama_attn_forward, llama_attn_forward_legacy, @@ -57,8 +57,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn): print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING) for module in model.modules(): - if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2', - 'LlamaSdpaAttention'): + # Do not need to dispatch if + # type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is + # required when using sequence parallel + if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'): if use_varlen_attn: print_log('dispatch llama varlen attn forward', 'current') if IS_LOW_VERSION_TRANSFORMERS: diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py index 9bc77177c..93a43229e 100644 --- a/xtuner/model/modules/dispatch/internlm2.py +++ b/xtuner/model/modules/dispatch/internlm2.py @@ -156,19 +156,6 @@ def flash_attn_w_mask( return attn_output -@sequence_parallel_wrapper -def flash_attn1_pytorch(query_states, key_states, value_states, *args, - **kwargs): - # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query_states, key_states, - value_states, *args, **kwargs) - attn_output = attn_output.transpose(1, 2) - return attn_output - - @sequence_parallel_wrapper def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, max_seqlen): @@ -251,12 +238,12 @@ def internlm2_attn_forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # flash attn 2 need (bs, seq_len, nhead, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - if SUPPORT_FLASH2: + # flash attn 2 need (bs, seq_len, nhead, h_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + causal = self.is_causal and q_len != 1 if attention_mask is not None: @@ -276,12 +263,10 @@ def internlm2_attn_forward( training=self.training) else: # use flash attention implemented by pytorch - attn_output = flash_attn1_pytorch( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - training=self.training) + # do not support sequence parallel + attn_output = F.scaled_dot_product_attention( + query_states, key_states, value_states, attn_mask=attention_mask) + attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.wo(attn_output) diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py index df17e1b49..c9febf34f 100644 --- a/xtuner/model/modules/dispatch/llama.py +++ b/xtuner/model/modules/dispatch/llama.py @@ -4,10 +4,10 @@ import torch import torch.distributed as dist -import torch.nn.functional as F from mmengine import MessageHub -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb -from transformers.utils import logging +from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb, + repeat_kv) +from transformers.utils import is_flash_attn_greater_or_equal_2_10 from xtuner.parallel.sequence import sequence_parallel_wrapper from .triton_kernels import apply_rotary_emb @@ -30,34 +30,6 @@ class Cache: pass -logger = logging.get_logger(__name__) - - -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., :x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] - return torch.cat((-x2, x1), dim=-1) - - -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """This is the equivalent of torch.repeat_interleave(x, dim=1, - repeats=n_rep). - - The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to - (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape - if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, - None, :, :].expand(batch, - num_key_value_heads, - n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, - head_dim) - - def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)""" @@ -114,21 +86,12 @@ def flash_attn_w_mask( @sequence_parallel_wrapper -def flash_attn1_pytorch(query_states, key_states, value_states, *args, - **kwargs): - # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim) - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - attn_output = F.scaled_dot_product_attention(query_states, key_states, - value_states, *args, **kwargs) - attn_output = attn_output.transpose(1, 2) - return attn_output - - -@sequence_parallel_wrapper -def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, - max_seqlen): +def varlen_flash_attn(query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=0.): q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten( 0, 1), value_states.flatten(0, 1) attn_output = flash_attn_varlen_func( @@ -139,7 +102,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, cumulative_len, max_seqlen, max_seqlen, - 0, + dropout_p=dropout_rate, return_attn_probs=False, causal=True, ) @@ -147,58 +110,29 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len, return attn_output -def llama_attn_forward_legacy( +def llama_attn_forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, -) -> Tuple[torch.Tensor, Optional[torch.Tensor], - Optional[Tuple[torch.Tensor]]]: - # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 +): + # Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501 + output_attentions = False - if 'padding_mask' in kwargs: - warnings.warn('Passing `padding_mask` is deprecated and will be ' - 'removed in v4.37. Please make sure use ' - '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -206,77 +140,90 @@ def llama_attn_forward_legacy( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids) + cos, sin) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + past_key_value = getattr(self, 'past_key_value', past_key_value) - past_key_value = (key_states, value_states) if use_cache else None + if past_key_value is not None: + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) - # repeat kv for sequence parallel key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - # flash attn 2 need (bs, seq_len, nhead, h_dim) + assert SUPPORT_FLASH2 query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - if SUPPORT_FLASH2: - causal = self.is_causal and q_len != 1 + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently + # casted in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - if attention_mask is not None: - attn_output = flash_attn_w_mask( - query_states, - key_states, - value_states, - attention_mask, - causal, - training=self.training) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype else: - attn_output = flash_attn_wo_mask( - query_states, - key_states, - value_states, - causal, - training=self.training) + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + if is_flash_attn_greater_or_equal_2_10(): + causal = self.is_causal + else: + # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm + # is bumped to 2.1. For details, please see the comment in + # LlamaFlashAttention2 __init__. + causal = self.is_causal and q_len != 1 + + if attention_mask is not None: + attn_output = flash_attn_w_mask( + query_states, + key_states, + value_states, + attention_mask, + causal, + dropout_rate, + training=self.training) else: - # use flash attention implemented by pytorch - attn_output = flash_attn1_pytorch( + attn_output = flash_attn_wo_mask( query_states, key_states, value_states, - attn_mask=attention_mask, + causal, + dropout_rate, training=self.training) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) + if not output_attentions: + attn_weights = None - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. - return attn_output, None, past_key_value + return attn_output, attn_weights, past_key_value -def llama_attn_forward( +def llama_attn_forward_legacy( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -287,16 +234,11 @@ def llama_attn_forward( **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # LlamaFlashAttention2 attention does not support output_attentions + # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501 if 'padding_mask' in kwargs: warnings.warn( - 'Passing `padding_mask` is deprecated and will be removed in v4.37' - ' Please make sure use `attention_mask` instead.`') - - # overwrite attention_mask with padding_mask - attention_mask = kwargs.pop('padding_mask') - - output_attentions = False + 'Passing `padding_mask` is deprecated and will be removed in ' + 'v4.37. Please make sure use `attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() @@ -304,9 +246,6 @@ def llama_attn_forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - # Flash attention requires the input to have the shape - # batch_size x seq_length x head_dim x hidden_dim - # therefore we just need to keep the original shape query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -316,11 +255,17 @@ def llama_attn_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + 'The cache structure has changed since version v4.36. ' + f'If you are using {self.__class__.__name__} ' + 'for auto-regressive decoding with k/v caching, ' + 'please make sure to initialize the attention class ' + 'with a layer index.') kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - + assert position_ids is not None if self.training: - assert position_ids is not None cos, sin = self.rotary_emb( value_states, seq_len=position_ids.max() + 1) else: @@ -333,42 +278,38 @@ def llama_attn_forward( key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) - # TODO: These transpose are quite inefficient but Flash Attention - # requires the layout [batch_size, sequence_length, num_heads, head_dim]. - # We would need to refactor the KV cache - # to be able to avoid many of these transpose/reshape/view. + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + assert SUPPORT_FLASH2 query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - dropout_rate = self.attention_dropout if self.training else 0.0 - # In PEFT, usually we cast the layer norms in float32 for training - # stability reasons, therefore the input hidden states gets silently + # stability reasons therefore the input hidden states gets silently # casted in float32. Hence, we need cast them back in the correct dtype # just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized - if hasattr(self.config, '_pre_quantization_dtype'): + elif hasattr(self.config, '_pre_quantization_dtype'): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype - logger.warning_once( - f'The input hidden states seems to be silently casted in float32, ' - f'this might be related to the fact you have upcasted embedding ' - f'or layer norm layers in float32. We will cast back the input in' - f' {target_dtype}.') - query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - # flash attn - if not self._flash_attn_uses_top_left_mask: + dropout_rate = self.attention_dropout if self.training else 0.0 + + if is_flash_attn_greater_or_equal_2_10(): causal = self.is_causal else: # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm @@ -376,10 +317,6 @@ def llama_attn_forward( # LlamaFlashAttention2 __init__. causal = self.is_causal and q_len != 1 - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) - if attention_mask is not None: attn_output = flash_attn_w_mask( query_states, @@ -401,20 +338,21 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + # Due to the implementation of the PyTorch version of flash attention, + # even when the output_attentions flag is set to True, it is not possible + # to return the attn_weights. + return attn_output, None, past_key_value -def llama_varlen_attn_forward_legacy( +def llama_varlen_attn_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -432,82 +370,70 @@ def llama_varlen_attn_forward_legacy( '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + query_states = query_states.view(bsz, q_len, self.num_heads, + self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) + self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, - self.head_dim) + self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin) + + past_key_value = getattr(self, 'past_key_value', past_key_value) - kv_seq_len = key_states.shape[-3] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + # sin and cos are specific to RoPE models; + # cache_position needed for the static cache + cache_kwargs = { + 'sin': sin, + 'cos': cos, + 'cache_position': cache_position + } + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs) - if is_training: - cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, - cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - else: - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - cos, sin = self.rotary_emb(value_states, kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) - if past_key_value is not None: - # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + dropout_rate = self.attention_dropout if self.training else 0.0 - past_key_value = (key_states, value_states) if use_cache else None - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently casted + # in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) - # repeat kv for sequence parallel - key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) - value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) assert SUPPORT_FLASH2 if is_training: - attn_output = varlen_flash_attn(query_states, key_states, value_states, - cumulative_len, max_seqlen) + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=dropout_rate) else: attn_output = flash_attn_wo_mask( query_states, @@ -517,26 +443,12 @@ def llama_varlen_attn_forward_legacy( training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) - - # Due to the implementation of the PyTorch version of flash attention, - # even when the output_attentions flag is set to True, it is not possible - # to return the attn_weights. return attn_output, None, past_key_value -def llama_varlen_attn_forward( +def llama_varlen_attn_forward_legacy( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, @@ -561,38 +473,9 @@ def llama_varlen_attn_forward( '`attention_mask` instead.`') bsz, q_len, _ = hidden_states.size() - if self.config.pretraining_tp > 1: - key_value_slicing = ( - self.num_key_value_heads * # noqa: W504 - self.head_dim) // self.config.pretraining_tp - query_slices = self.q_proj.weight.split( - (self.num_heads * self.head_dim) // self.config.pretraining_tp, - dim=0) - key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) - value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) - - query_states = [ - F.linear(hidden_states, query_slices[i]) - for i in range(self.config.pretraining_tp) - ] - query_states = torch.cat(query_states, dim=-1) - - key_states = [ - F.linear(hidden_states, key_slices[i]) - for i in range(self.config.pretraining_tp) - ] - key_states = torch.cat(key_states, dim=-1) - - value_states = [ - F.linear(hidden_states, value_slices[i]) - for i in range(self.config.pretraining_tp) - ] - value_states = torch.cat(value_states, dim=-1) - - else: - query_states = self.q_proj(hidden_states) - key_states = self.k_proj(hidden_states) - value_states = self.v_proj(hidden_states) + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, @@ -614,11 +497,12 @@ def llama_varlen_attn_forward( if is_training: cos, sin = self.rotary_emb(value_states, max_seqlen) - query_states = apply_rotary_emb(query_states, - cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) - key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), - sin[position_ids].squeeze(0)) + # position_ids (1, seq_len) + # cos, sin (1, seq_len, dim) -> (seq_len, dim) + cos = cos[position_ids].squeeze(0) + sin = sin[position_ids].squeeze(0) + query_states = apply_rotary_emb(query_states, cos, sin) + key_states = apply_rotary_emb(key_states, cos, sin) else: query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -640,31 +524,50 @@ def llama_varlen_attn_forward( key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training + # stability reasons therefore the input hidden states gets silently casted + # in float32. Hence, we need cast them back in the correct dtype + # just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not + # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, '_pre_quantization_dtype'): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + assert SUPPORT_FLASH2 if is_training: - attn_output = varlen_flash_attn(query_states, key_states, value_states, - cumulative_len, max_seqlen) + attn_output = varlen_flash_attn( + query_states, + key_states, + value_states, + cumulative_len, + max_seqlen, + dropout_rate=dropout_rate) else: attn_output = flash_attn_wo_mask( query_states, key_states, value_states, causal=True, + dropout_rate=dropout_rate, training=False) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) - if self.config.pretraining_tp > 1: - attn_output = attn_output.split( - self.hidden_size // self.config.pretraining_tp, dim=2) - o_proj_slices = self.o_proj.weight.split( - self.hidden_size // self.config.pretraining_tp, dim=1) - attn_output = sum([ - F.linear(attn_output[i], o_proj_slices[i]) - for i in range(self.config.pretraining_tp) - ]) - else: - attn_output = self.o_proj(attn_output) + attn_output = self.o_proj(attn_output) # Due to the implementation of the PyTorch version of flash attention, # even when the output_attentions flag is set to True, it is not possible diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py index 7aa0ec63c..e1a29ab8a 100644 --- a/xtuner/model/sft.py +++ b/xtuner/model/sft.py @@ -17,7 +17,7 @@ reduce_sequence_parallel_loss) from xtuner.registry import BUILDER from .modules import dispatch_modules -from .modules.dispatch import SUPPORT_FLASH2 +from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2 from .utils import (LoadWoInit, find_all_linear_names, get_peft_model_state_dict, make_inputs_require_grad, traverse_dict) @@ -78,8 +78,9 @@ def __init__(self, max_position_embeddings=None): super().__init__() with LoadWoInit(): - self.llm = self._build_from_cfg_or_module(llm, - max_position_embeddings) + if isinstance(llm, dict): + llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings) + self.llm = self._build_from_cfg_or_module(llm) if tokenizer is not None: if isinstance(tokenizer, dict): @@ -144,48 +145,67 @@ def _prepare_for_lora(self, def init_weights(self): pass - def _prepare_for_long_context_training(self, cfg, max_position_embeddings): - pretrained_model_name_or_path = cfg.pretrained_model_name_or_path - config = AutoConfig.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=True) + @staticmethod + def _prepare_for_long_context_training(cfg, llm_cfg, + max_position_embeddings): - orig_rope_scaling = getattr(config, 'rope_scaling', None) + orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None) if orig_rope_scaling is None: orig_rope_scaling = {'factor': 1} orig_rope_scaling_factor = orig_rope_scaling[ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1 - orig_ctx_len = getattr(config, 'max_position_embeddings', None) + orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None) if orig_ctx_len: orig_ctx_len *= orig_rope_scaling_factor if max_position_embeddings > orig_ctx_len: scaling_factor = float( math.ceil(max_position_embeddings / orig_ctx_len)) - config.rope_scaling = { + llm_cfg.rope_scaling = { 'type': 'linear', 'factor': scaling_factor } # hardcode for internlm2 - config.attn_implementation = 'flash_attention_2' - - cfg.config = config + llm_cfg.attn_implementation = 'flash_attention_2' + cfg.config = llm_cfg + + return cfg, llm_cfg + + @staticmethod + def _prepare_for_flash_attn(cfg, llm_cfg): + cls_name = type(llm_cfg).__name__ + SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig', + 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig', + 'MistralConfig', 'MixtralConfig', 'Qwen2Config', + 'Starcoder2Config', 'Starcoder2Config') + + if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2: + cfg.torch_dtype = torch.bfloat16 \ + if torch.cuda.is_bf16_supported() else torch.float16 + cfg.attn_implementation = 'flash_attention_2' + elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN: + cfg.attn_implementation = 'sdpa' + + return cfg, llm_cfg + + def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None): + pretrained_model_name_or_path = cfg.pretrained_model_name_or_path + llm_cfg = AutoConfig.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=True) + cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg) + if max_position_embeddings is not None: + cfg, llm_cfg = self._prepare_for_long_context_training( + cfg, llm_cfg, max_position_embeddings) return cfg - def _build_from_cfg_or_module(self, - cfg_or_mod, - max_position_embeddings=None): + def _build_from_cfg_or_module(self, cfg_or_mod): if isinstance(cfg_or_mod, nn.Module): return cfg_or_mod elif isinstance(cfg_or_mod, dict): traverse_dict(cfg_or_mod) - if SUPPORT_FLASH2: - cfg_or_mod.torch_dtype = torch.bfloat16 \ - if torch.cuda.is_bf16_supported() else torch.float16 - cfg_or_mod.attn_implementation = 'flash_attention_2' - if max_position_embeddings is not None: - cfg_or_mod = self._prepare_for_long_context_training( - cfg_or_mod, max_position_embeddings) return BUILDER.build(cfg_or_mod) else: raise NotImplementedError diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py index 7acbbf21f..23e3d2a3f 100644 --- a/xtuner/tools/train.py +++ b/xtuner/tools/train.py @@ -19,6 +19,7 @@ from xtuner.configs import cfgs_name_path from xtuner.dataset.collate_fns import default_collate_fn from xtuner.model.modules import dispatch_modules +from xtuner.model.modules.dispatch import SUPPORT_FLASH2 from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict from xtuner.registry import BUILDER, MAP_FUNC from xtuner.tools.utils import (auto_dtype_of_deepspeed_config, @@ -100,6 +101,10 @@ def check_cfg(cfg): f'max_length = {max_length} and sequence_parallel = ' f'{sequence_parallel}') + if getattr(cfg, 'sequence_parallel_size', 1) > 1: + assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use ' + 'sequence parallel.') + def main(): args = parse_args() From 520ce99965cae6c2badb0b0b4d6eecc45b92a29d Mon Sep 17 00:00:00 2001 From: whcao <41630003+HIT-cwh@users.noreply.github.com> Date: Fri, 29 Mar 2024 16:13:47 +0800 Subject: [PATCH 8/9] [Fix] Fix throughput hook (#527) fix throughput hook --- xtuner/engine/hooks/throughput_hook.py | 102 ++++++++++++++++++++----- 1 file changed, 82 insertions(+), 20 deletions(-) diff --git a/xtuner/engine/hooks/throughput_hook.py b/xtuner/engine/hooks/throughput_hook.py index cf31414c0..a07e216fe 100644 --- a/xtuner/engine/hooks/throughput_hook.py +++ b/xtuner/engine/hooks/throughput_hook.py @@ -1,11 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import logging from typing import Optional, Union import torch +from mmengine import print_log from mmengine.hooks import Hook from mmengine.model.wrappers import is_model_wrapper from torch.utils._pytree import tree_flatten +from xtuner.parallel.sequence import get_sequence_parallel_world_size + DATA_BATCH = Optional[Union[dict, tuple, list]] @@ -20,12 +24,39 @@ def __init__(self, hidden_size=None, num_layers=None, vocab_size=None, - mlp_ratio=None): + mlp_ratio=None, + is_casual=None): self.use_activation_checkpointing = use_activation_checkpointing self.hidden_size = hidden_size self.num_layers = num_layers self.vocab_size = vocab_size self.mlp_ratio = mlp_ratio + self.is_casual = is_casual + + @staticmethod + def _guess_is_casual_attn(model): + for module in model.modules(): + if hasattr(module, 'is_causal'): + return module.is_causal + print_log( + 'It\'s impossible to speculate whether casual attention was used, ' + 'and FLOPs will be calculated as `casual = True`.', 'current') + return True + + @staticmethod + def _get_batch_size_and_sequence_len(data_batch): + data_list, _ = tree_flatten(data_batch) + for data in data_list: + if isinstance(data, torch.Tensor): + return data.size(0), data.size(1) + raise RuntimeError('No tensor found in the batch') + + @staticmethod + def _guess_use_activation_checkpointing(model): + for module in model.modules(): + if hasattr(module, 'gradient_checkpointing'): + return module.gradient_checkpointing + return False def before_run(self, runner) -> None: if is_model_wrapper(runner.model): @@ -41,20 +72,18 @@ def before_run(self, runner) -> None: self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size / model.config.hidden_size) self.mlp_ratio *= 1.5 # has gate_proj - return + self.is_casual = self.is_casual if self.is_casual is not None \ + else self._guess_is_casual_attn(model) - def _get_batch_size_and_sequence_len(self, data_batch): - data_list, _ = tree_flatten(data_batch) - for data in data_list: - if isinstance(data, torch.Tensor): - return data.size(0), data.size(1) - raise RuntimeError('No tensor found in the batch') + use_varlen_attn = getattr(model, 'use_varlen_attn', False) + if use_varlen_attn: + print_log( + 'Using variable-length Flash Attention causes an inflation' + ' in the FLOPs calculation.', + 'current', + level=logging.WARNING) - def _guess_use_activation_checkpointing(self, model): - for module in model.modules(): - if hasattr(module, 'gradient_checkpointing'): - return module.gradient_checkpointing - return False + return def after_train_iter(self, runner, @@ -66,17 +95,50 @@ def after_train_iter(self, batch_size, sequence_len = self._get_batch_size_and_sequence_len( data_batch) + sequence_parallel_size = get_sequence_parallel_world_size() message_hub = runner.message_hub iter_time = message_hub.get_scalar('train/time').current() - flops_per_iteration = ( - (3 + int(self.use_activation_checkpointing)) * - ((8 + self.mlp_ratio * 4) * batch_size * sequence_len * - self.hidden_size**2 + - 4 * batch_size * sequence_len**2 * self.hidden_size) - ) * self.num_layers + \ - 6 * batch_size * sequence_len * self.hidden_size * self.vocab_size + # We consider a language model with 𝑙 transformer layers, + # hidden size h, sequence length s, vocabulary size V, and + # training batch size B. + # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs + # (factor of 2 needed to account for multiplies and adds). + + # Attention Layer: + # qkv_proj + o_proj: 8B * s * h^2 + # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True) + + # MLP Layer: + # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio + # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5 + # (has gate_proj)) + + # The backward pass requires double the number of FLOPs since we + # need to calculate the gradients with respect to both input and + # weight tensors. In addition, we are using activation recomputation, + # which requires an additional forward pass before the backward pass. + + # While sequence parallel will affect the FLOPs calculation in attn. + # Suppose the sequence length in one GPU is s and the sequence + # parallel world size is `sp_size`, which means the total + # sequence length in the attention calculation is + # `s * sp_size` and the number of attention heads decrease to + # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is: + # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and + # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True) + + flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2 + flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \ + sequence_parallel_size / (int(self.is_casual) + 1) + flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \ + self.hidden_size**2 + flops_wo_head = (3 + int(self.use_activation_checkpointing)) * ( + flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers + flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \ + self.vocab_size + flops_per_iteration = flops_wo_head + flops_head avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12) tokens_per_sec_per_gpu = batch_size * sequence_len / ( From 7de581c3702822d2f57a583577d95f955e311b1e Mon Sep 17 00:00:00 2001 From: Juicy <95841578+JianxinDong@users.noreply.github.com> Date: Fri, 29 Mar 2024 16:25:46 +0800 Subject: [PATCH 9/9] Update README.md (#528) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0bfa3a02d..352e7215e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,7 @@ 🔍 Explore our models on [![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤗%20Huggingface)](https://huggingface.co/xtuner) [![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤖%20ModelScope)](https://www.modelscope.cn/organization/xtuner) +[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🧰%20OpenXLab)](https://openxlab.org.cn/usercenter/xtuner) English | [简体中文](README_zh-CN.md)