From 59834032c82d39994c13252aea9b00011d1b2457 Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Fri, 22 Mar 2024 12:32:01 +0800
Subject: [PATCH 1/9] [Feature] Support Sequence parallel (#456)
* support sequence
* add configs
* add sp example to custom dataset
* WIP
* add dispatch utils
* delete useless codes
* move xtuner/engine/sequence_parallel to xtuner/parallel/sequence
* fix lint
* fix lint
* add init_dist to xtuner and add trust_remote_code=True to AutoConfig
* add internlm2 custom_dataset sp4 config
* Sequence Parallel doc V1
* Sequence Parallel doc V1
* Sequence Parallel doc V1
* fix bugs in llama_varlen_attn_forward
* rename indexes to position_ids
* add attn_implementation to config
* delete useless codes
* fix lint
* refine default_collate_fn
* refine doc
* refine doc
* refine doc
* delete replace_internlm2_rote
* add repeat_kv_bshd
* fix apply_rotary_pos_emb bug
* add enable_sequence_parallel flag
* refine doc
* assert {'input_ids', 'labels'}.issubset(dataset.column_names)
* refine doc
---
.pre-commit-config.yaml | 2 +
docs/zh_cn/user_guides/sequence_parallel.md | 188 +++++++++
...nlm2_7b_full_finetune_custom_dataset_e1.py | 1 +
...e_custom_dataset_e1_sequence_parallel_4.py | 220 ++++++++++
.../llama2_7b_full_pgbooks_400iters_sp1.py | 196 +++++++++
.../llama2_7b_full_pgbooks_400iters_sp4.py | 196 +++++++++
.../dataset/collate_fns/defalut_collate_fn.py | 32 +-
xtuner/dataset/huggingface.py | 3 +-
xtuner/dataset/intern_repo.py | 8 +-
xtuner/dataset/samplers/intern_repo.py | 7 +-
xtuner/dataset/utils.py | 14 +-
xtuner/engine/_strategy/deepspeed.py | 15 +
.../varlen_attn_args_to_messagehub_hook.py | 2 +-
xtuner/model/modules/dispatch/__init__.py | 10 +-
xtuner/model/modules/dispatch/internlm.py | 11 +-
xtuner/model/modules/dispatch/internlm2.py | 203 +++++++--
xtuner/model/modules/dispatch/llama.py | 394 +++++++++++-------
xtuner/model/modules/dispatch/mistral.py | 15 +-
xtuner/model/modules/dispatch/utils.py | 64 +++
xtuner/model/sft.py | 68 ++-
xtuner/parallel/sequence/__init__.py | 24 ++
xtuner/parallel/sequence/attention.py | 102 +++++
xtuner/parallel/sequence/data_collate.py | 75 ++++
xtuner/parallel/sequence/reduce_loss.py | 17 +
xtuner/parallel/sequence/sampler.py | 38 ++
xtuner/parallel/sequence/setup_distributed.py | 100 +++++
xtuner/tools/train.py | 32 +-
27 files changed, 1804 insertions(+), 233 deletions(-)
create mode 100644 docs/zh_cn/user_guides/sequence_parallel.md
create mode 100644 xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py
create mode 100644 xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py
create mode 100644 xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py
create mode 100644 xtuner/model/modules/dispatch/utils.py
create mode 100644 xtuner/parallel/sequence/__init__.py
create mode 100644 xtuner/parallel/sequence/attention.py
create mode 100644 xtuner/parallel/sequence/data_collate.py
create mode 100644 xtuner/parallel/sequence/reduce_loss.py
create mode 100644 xtuner/parallel/sequence/sampler.py
create mode 100644 xtuner/parallel/sequence/setup_distributed.py
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f2fdba321..acfe43b66 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -12,6 +12,7 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
+ exclude: 'xtuner/parallel/sequence/__init__.py'
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
hooks:
@@ -38,6 +39,7 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
+ exclude: 'docs/zh_cn/user_guides/sequence_parallel.md'
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
diff --git a/docs/zh_cn/user_guides/sequence_parallel.md b/docs/zh_cn/user_guides/sequence_parallel.md
new file mode 100644
index 000000000..ba29d2830
--- /dev/null
+++ b/docs/zh_cn/user_guides/sequence_parallel.md
@@ -0,0 +1,188 @@
+
+
+# 序列并行:训练极长序列大模型的系统优化
+
+
+
+XTuner 中的序列并行设计思路参考了 DeepSpeed 的工作 [DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509),并加以优化,以达到直接基于 transformers 算法库或 Huggingface Hub 上的开源模型训练 1M 以上超长序列的目标。
+
+## 简介
+
+从生成性AI到科研模型,长序列训练正在变得非常重要。
+
+在生成性AI领域,会话式AI、长文档摘要、代码库理解和例如 Sora 这种视频生成任务都需要在空间和时间层面对长上下文进行推理。
+
+对于科学AI来说,长序列同样至关重要,它为更好地理解结构生物学、医疗保健、气候和天气预测以及大分子模拟打开了大门。
+
+然而,尽管序列长度的重要性不断增长,XTuner 现有的显存优化策略(如 zero 系列),却不足以解决大模型、长序列训练问题。
+
+同时,受限于通信效率,现有的许多序列并行方法也不够高效。
+
+另外,现有的序列并行方法普遍存在较多的代码侵入式修改,易用性和维护性都要大打折扣。同时也不满足 XTuner 基于 transformers 算法库或 Huggingface Hub 上的开源模型直接进行训练的要求。
+
+
+
+
+
+
+为了解决上述长序列训练带来的问题,XTuner 采用了一种简单、易用且高效的序列并行算法。由于 Transformer 结构较为规整,除 attention 计算外,其他计算过程中 token 之间不会互相影响(即每个 token 的计算是独立的),这一条件为序列并行提供了有利条件。上图展示了序列并行的核心设计。设由 P 个 GPUs 共同计算一个长度为 N 的长序列,在 Attention 计算的第一阶段,长度为 N / P 的子序列会通过线性层投影为 Query、Key、Value。接下来, QKV Tensor 会在参与序列并行计算的多个 GPUs 之间通过高度优化的 all-to-all 通信算子汇聚,得到序列长度为 N ,但更少注意力头的子序列。注意力计算后,通过另一个 all-to-all 通信算子将其转换为长度为 N / P 的子序列,进行后续计算。
+
+总体而言,XTuner 的序列并行算法具有以下关键特性:
+
+* 支持全量训练**超过百万个token**的序列
+* 支持百 B 级模型训练:XTuner 的序列并行不仅支持长序列训练,还可结合 zero3 显存优化策略训练大尺寸模型
+* 完全通用的序列并行 **API 抽象**
+
+## 使用 XTuner 进行序列并行训练
+
+### Step 1 修改 config 文件
+
+1. 在 config 中修改 `sequence_parallel_size` 字段即可调整 $sequence\\_parallel\\_world\\_size$ 。
+2. 同时若想保证与不使用序列并行的训练效果类似,需要同步增大梯度累积的数值为原来的 $sequence\\_parallel\\_world\\_size$ 倍,因为在使用序列并行训练时, $data\\_parallel\\_world\\_size$ 降为了原来的 $\frac{1}{sequence\\_parallel\\_world\\_size}$。
+3. 替换 DefaultSampler 为支持序列并行的 SequenceParallelSampler。
+
+**注:需要保证所使用的 GPU 总数可以被 `sequence_parallel_size` 整除。**
+
+```diff
++ from xtuner.parallel.sequence import SequenceParallelSampler
+
+- sequence_parallel_size = 1
++ sequence_parallel_size = 4 # take `sequence_parallel_size = 4`` as an example
+
+- accumulative_counts = 1
++ accumulative_counts = 4 # accumulative_counts = accumulative_counts * sequence_parallel_size
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataloader = dict(
+- sampler=dict(type=DefaultSampler, shuffle=True),
++ sampler=dict(type=SequenceParallelSampler, seed=1024, shuffle=True),
+ ...)
+```
+
+另外,若需要进一步拓展模型的长文本处理能力,需要进一步修改 config 中的 `max_position_embeddings` 字段。例如需要将模型的上下文长度拓展为 64K 时,可进行如下修改:
+
+```diff
++ max_position_embeddings = 65536
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+model = dict(
+ type=SupervisedFinetune,
++ max_position_embeddings = max_position_embeddings,
+ ...)
+```
+
+### Step 2 开始训练
+
+需要使用 DeepSpeed 进行训练:
+
+```bash
+(DIST) NPROC_PER_NODE=${GPU_NUM} xtuner train ${CONFIG_PATH} --deepspeed deepspeed_zero2
+(SLURM) srun ${SRUN_ARGS} xtuner train ${CONFIG_PATH} --launcher slurm --deepspeed deepspeed_zero2
+```
+
+- ${CONFIG_PATH} 为 Step 1 中修改得到的 config 文件路径
+- 可根据实际情况选择使用不同的 zero 策略
+
+## 序列并行 API 抽象
+
+为了提升算法的可迁移性,XTuner 中抽象出了序列并行所必须的五个 API 接口:
+- 序列并行分布式环境初始化 (init_sequence_parallel)
+- 适配序列并行的 Data Sampler (SequenceParallelSampler)
+- 数据 Pad 与切分 (pad_for_sequence_parallel, split_for_sequence_parallel)
+- 适配序列并行的 Attention (dispatch_modules)
+- reduce loss 以正确打印训练损失 (reduce_sequence_parallel_loss)
+
+### 序列并行分布式环境初始化
+
+由于序列并行算法会将长序列切分为 $sequence\\_parallel\\_world\\_size$ 块,并将每个子序列分发给对应的 GPU 独立进行计算。因此需要在训练开始前初始化序列并行分布式环境,以指定哪几块 GPU 共同负责一个长序列输入的计算。
+
+一个 $sequence\\_parallel\\_world\\_size = 4$ 的示例如下:
+
+```python
+# We have to initialize the distributed training environment first.
+# Here is an example when training on slurm scheduler
+# from xtuner.parallel.sequence import init_dist
+# init_dist('slurm', 'nccl', init_backend='deepspeed')
+from xtuner.parallel.sequence import init_sequence_parallel
+sequence_parallel_world_size = 4
+init_sequence_parallel(sequence_parallel_world_size)
+```
+
+上述过程在 xtuner/engine/_strategy/deepspeed.py 中实现。
+
+### Data Sampler 适配序列并行
+
+在使用序列并行后,Dataloader 的采样策略需要进一步调整。例如当 $sequence\\_parallel\\_world\\_size = 4$ 时,4 块 GPU 从 Dataloader 拿到的数据需要是完全一样的。
+
+在构建 Dataloader 时搭配 XTuner 中提供的 SequenceParallelSampler 使用即可:
+
+```python
+from xtuner.parallel.sequence import SequenceParallelSampler
+dataloader = DataLoader(
+ train_dataset, sampler=SequenceParallelSampler(train_dataset),
+ **other_dataloader_params)
+```
+
+### 数据 Pad 与切分
+
+由于每条训练数据的长度可能不尽相同,我们需要将数据进行 Pad 以使得序列长度可以被 $sequence\\_parallel\\_world\\_size$ 整除,这样一条长数据才能被均等地分发给不同的 GPU 上。
+
+训练过程中需要被 Pad 的 Tensor 往往有 input_ids, labels, position_ids, attention_mask 四个,pad 的过程可以通过以下方式实现:
+
+```python
+from xtuner.parallel.sequence import pad_for_sequence_parallel
+input_ids, labels, position_ids, attention_mask = pad_for_sequence_parallel(
+ input_ids, labels, position_ids, attention_mask)
+```
+
+如果训练过程用不到 attention_mask,那么可以:
+
+```python
+input_ids, labels, position_ids, _ = pad_for_sequence_parallel(
+ input_ids, labels, position_ids)
+```
+
+Pad 后,我们需要对长序列均等切分:
+
+```python
+from xtuner.parallel.sequence import split_for_sequence_parallel
+# attention mask should not be split
+input_ids, labels, position_ids = split_for_sequence_parallel(
+ input_ids, labels, position_ids)
+```
+
+以上两步在 xtuner/dataset/collate_fns/defalut_collate_fn.py 中实现。
+
+### Attention 适配序列并行
+
+在 Attention 的计算过程中,序列中的不同 token 是不能独立运算的,但不同的 attention head 之间的计算却是独立的。因此,如[第一节](#简介)所述,需要在计算 Attention 前后(即 qkv_proj 后和 o_proj 前)分别插入一个 *all-to-all* 操作。
+
+XTuner 提供了 dispatch_modules 接口以支持修改模型 Attention 的计算方式:
+
+```python
+from xtuner.model.modules import dispatch_modules
+model: AutoModelForCausalLM
+dispatch_modules(model)
+```
+
+上述过程在 xtuner/model/sft.py 中实现。
+
+### Reduce Loss 以正确打印训练损失
+
+这个 API 对于保证训练的正确性不是必须的,但对于观测模型训练状态,打印训练 loss 是非常有用的。
+
+```python
+from xtuner.parallel.sequence import reduce_sequence_parallel_loss
+outputs = llm(input_ids=input_ids, labels=labels, **kwargs)
+num_tokens_per_rank = (labels != -100).sum()
+# Suppose sequence parallel world size equals to 4,
+# losses on rank0, rank1, rank2, rank3 are different.
+loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens_per_rank)
+# After loss reduction, losses on rank0, rank1, rank2, rank3 are the same.
+```
+
+上述过程在 xtuner/model/sft.py 中实现。
diff --git a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py
index 40c9b1692..58863f06b 100644
--- a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py
+++ b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1.py
@@ -42,6 +42,7 @@
# Model
pretrained_model_name_or_path = 'internlm/internlm2-7b'
use_varlen_attn = True
+sequence_parallel_size = 1
# Data
data_files = ['/path/to/json/file.json']
diff --git a/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py
new file mode 100644
index 000000000..aa7e4b014
--- /dev/null
+++ b/xtuner/configs/internlm/internlm2_7b/internlm2_7b_full_finetune_custom_dataset_e1_sequence_parallel_4.py
@@ -0,0 +1,220 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+"""Data format:
+[
+ {
+ "conversation": [
+ {
+ "system": "",
+ "input": "xxx",
+ "output": "xxx"
+ },
+ {
+ "input": "xxx",
+ "output": "xxx"
+ }
+ ]
+ },
+...
+]
+Please refer to https://github.com/InternLM/xtuner/blob/main/docs/en/user_guides/dataset_format.md for details.
+""" # noqa: E501
+from datasets import load_dataset
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR
+from torch.optim import AdamW
+from torch.utils.data import BatchSampler
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.dataset.map_fns import template_map_fn_factory
+from xtuner.dataset.samplers import InternRepoSampler
+from xtuner.engine import (DatasetInfoHook, EvaluateChatHook, ThroughputHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.utils import PROMPT_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'internlm/internlm2-7b'
+use_varlen_attn = True
+sequence_parallel_size = 4
+
+# Data
+data_files = ['/path/to/json/file.json']
+prompt_template = PROMPT_TEMPLATE.internlm2_chat
+max_length = 32768
+pack_to_max_length = True
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+# accumulative_counts = accumulative_counts * sequence_parallel_size
+accumulative_counts = 4
+dataloader_num_workers = 4
+max_epochs = 1
+optim_type = AdamW
+lr = 4e-5
+betas = (0.9, 0.95)
+weight_decay = 0.01
+max_norm = 1 # grad clip
+warm_up_ratio = 0.025
+
+# Save
+save_steps = 500
+save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
+
+# Evaluate the generation performance during the training
+evaluation_freq = 500
+SYSTEM = ''
+evaluation_inputs = [
+ '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
+]
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataset = dict(
+ type=process_hf_dataset,
+ use_varlen_attn=use_varlen_attn,
+ dataset=dict(type=load_dataset, path='json', data_files=data_files),
+ tokenizer=tokenizer,
+ max_length=max_length,
+ dataset_map_fn=None,
+ template_map_fn=dict(
+ type=template_map_fn_factory, template=prompt_template),
+ remove_unused_columns=True,
+ shuffle_before_pack=True,
+ pack_to_max_length=pack_to_max_length)
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=train_dataset,
+ sampler=dict(type=InternRepoSampler, shuffle=True, seed=1024),
+ batch_sampler=dict(type=BatchSampler, drop_last=True, batch_size=1),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic',
+)
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1 / 40,
+ by_epoch=True,
+ begin=0,
+ end=warm_up_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=lr * 0.15,
+ by_epoch=True,
+ begin=warm_up_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(
+ type=DatasetInfoHook, tokenizer=tokenizer,
+ is_intern_repo_dataset=True),
+ dict(
+ type=EvaluateChatHook,
+ tokenizer=tokenizer,
+ every_n_iters=evaluation_freq,
+ evaluation_inputs=evaluation_inputs,
+ system=SYSTEM,
+ prompt_template=prompt_template),
+ dict(type=ThroughputHook)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 100 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+log_processor = dict(
+ by_epoch=False,
+ window_size=1,
+ mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')
diff --git a/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py
new file mode 100644
index 000000000..9c6e4fe05
--- /dev/null
+++ b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp1.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datasets import load_dataset
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from torch.optim import AdamW
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.engine.hooks import (DatasetInfoHook, ThroughputHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.parallel.sequence import SequenceParallelSampler
+from xtuner.utils import PROMPT_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-hf'
+use_varlen_attn = False
+sequence_parallel_size = 1
+
+# Data
+data_path = 'emozilla/pg_books-tokenized-bos-eos-chunked-65536'
+data_files = [
+ 'data/train-00000-of-00136-877a1768c20d5900.parquet',
+ 'data/train-00001-of-00136-70d7d139dca61754.parquet',
+ 'data/train-00002-of-00136-62d53594e098f3d8.parquet',
+ 'data/train-00003-of-00136-8bd300fecc4c720e.parquet',
+ 'data/train-00004-of-00136-2a9456b5f975ae95.parquet',
+ 'data/train-00005-of-00136-ca38cf7907bb7555.parquet',
+ 'data/train-00006-of-00136-1ae2e4c63f3966da.parquet',
+ 'data/train-00007-of-00136-a00cc39a4ee65ab6.parquet',
+]
+prompt_template = PROMPT_TEMPLATE.llama2_chat
+max_length = 65536
+max_position_embeddings = 65536
+pack_to_max_length = False
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+accumulative_counts = 8
+dataloader_num_workers = 0
+max_epochs = 1
+optim_type = AdamW
+lr = 2e-5
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.05
+
+# Save
+save_steps = 500
+save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ max_position_embeddings=max_position_embeddings,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ attn_implementation='flash_attention_2'))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataset = dict(
+ type=process_hf_dataset,
+ dataset=dict(
+ type=load_dataset,
+ path=data_path,
+ data_files=data_files,
+ ignore_verifications=True),
+ do_dataset_tokenization=False,
+ remove_unused_columns=True,
+ pack_to_max_length=pack_to_max_length,
+ use_varlen_attn=use_varlen_attn)
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=train_dataset,
+ sampler=dict(type=SequenceParallelSampler, seed=1024),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1 / 40,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=lr * 0.15,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ dict(type=ThroughputHook)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ save_optimizer=False,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(
+ by_epoch=False,
+ window_size=1,
+ mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')
diff --git a/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py
new file mode 100644
index 000000000..5e87acffe
--- /dev/null
+++ b/xtuner/configs/llama/llama2_7b/llama2_7b_full_pgbooks_400iters_sp4.py
@@ -0,0 +1,196 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+from datasets import load_dataset
+from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
+ LoggerHook, ParamSchedulerHook)
+from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
+from torch.optim import AdamW
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from xtuner.dataset import process_hf_dataset
+from xtuner.dataset.collate_fns import default_collate_fn
+from xtuner.engine.hooks import (DatasetInfoHook, ThroughputHook,
+ VarlenAttnArgsToMessageHubHook)
+from xtuner.engine.runner import TrainLoop
+from xtuner.model import SupervisedFinetune
+from xtuner.parallel.sequence import SequenceParallelSampler
+from xtuner.utils import PROMPT_TEMPLATE
+
+#######################################################################
+# PART 1 Settings #
+#######################################################################
+# Model
+pretrained_model_name_or_path = 'meta-llama/Llama-2-7b-hf'
+use_varlen_attn = False
+sequence_parallel_size = 4
+
+# Data
+data_path = 'emozilla/pg_books-tokenized-bos-eos-chunked-65536'
+data_files = [
+ 'data/train-00000-of-00136-877a1768c20d5900.parquet',
+ 'data/train-00001-of-00136-70d7d139dca61754.parquet',
+ 'data/train-00002-of-00136-62d53594e098f3d8.parquet',
+ 'data/train-00003-of-00136-8bd300fecc4c720e.parquet',
+ 'data/train-00004-of-00136-2a9456b5f975ae95.parquet',
+ 'data/train-00005-of-00136-ca38cf7907bb7555.parquet',
+ 'data/train-00006-of-00136-1ae2e4c63f3966da.parquet',
+ 'data/train-00007-of-00136-a00cc39a4ee65ab6.parquet',
+]
+prompt_template = PROMPT_TEMPLATE.llama2_chat
+max_length = 65536
+max_position_embeddings = 65536
+pack_to_max_length = False
+
+# Scheduler & Optimizer
+batch_size = 1 # per_device
+accumulative_counts = 32
+dataloader_num_workers = 0
+max_epochs = 1
+optim_type = AdamW
+lr = 2e-5
+betas = (0.9, 0.999)
+weight_decay = 0
+max_norm = 1 # grad clip
+warmup_ratio = 0.05
+
+# Save
+save_steps = 500
+save_total_limit = 1 # Maximum checkpoints to keep (-1 means unlimited)
+
+#######################################################################
+# PART 2 Model & Tokenizer #
+#######################################################################
+tokenizer = dict(
+ type=AutoTokenizer.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ padding_side='right')
+
+model = dict(
+ type=SupervisedFinetune,
+ use_varlen_attn=use_varlen_attn,
+ max_position_embeddings=max_position_embeddings,
+ llm=dict(
+ type=AutoModelForCausalLM.from_pretrained,
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
+ trust_remote_code=True,
+ torch_dtype=torch.bfloat16,
+ attn_implementation='flash_attention_2'))
+
+#######################################################################
+# PART 3 Dataset & Dataloader #
+#######################################################################
+train_dataset = dict(
+ type=process_hf_dataset,
+ dataset=dict(
+ type=load_dataset,
+ path=data_path,
+ data_files=data_files,
+ ignore_verifications=True),
+ do_dataset_tokenization=False,
+ remove_unused_columns=True,
+ pack_to_max_length=pack_to_max_length,
+ use_varlen_attn=use_varlen_attn)
+
+train_dataloader = dict(
+ batch_size=batch_size,
+ num_workers=dataloader_num_workers,
+ dataset=train_dataset,
+ sampler=dict(type=SequenceParallelSampler, seed=1024),
+ collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
+
+#######################################################################
+# PART 4 Scheduler & Optimizer #
+#######################################################################
+# optimizer
+optim_wrapper = dict(
+ type=AmpOptimWrapper,
+ optimizer=dict(
+ type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
+ clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
+ accumulative_counts=accumulative_counts,
+ loss_scale='dynamic')
+
+# learning policy
+# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
+param_scheduler = [
+ dict(
+ type=LinearLR,
+ start_factor=1 / 40,
+ by_epoch=True,
+ begin=0,
+ end=warmup_ratio * max_epochs,
+ convert_to_iter_based=True),
+ dict(
+ type=CosineAnnealingLR,
+ eta_min=lr * 0.15,
+ by_epoch=True,
+ begin=warmup_ratio * max_epochs,
+ end=max_epochs,
+ convert_to_iter_based=True)
+]
+
+# train, val, test setting
+train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
+
+#######################################################################
+# PART 5 Runtime #
+#######################################################################
+# Log the dialogue periodically during the training process, optional
+custom_hooks = [
+ dict(type=DatasetInfoHook, tokenizer=tokenizer),
+ dict(type=ThroughputHook)
+]
+
+if use_varlen_attn:
+ custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
+
+# configure default hooks
+default_hooks = dict(
+ # record the time of every iteration.
+ timer=dict(type=IterTimerHook),
+ # print log every 10 iterations.
+ logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=1),
+ # enable the parameter scheduler.
+ param_scheduler=dict(type=ParamSchedulerHook),
+ # save checkpoint per `save_steps`.
+ checkpoint=dict(
+ type=CheckpointHook,
+ save_optimizer=False,
+ by_epoch=False,
+ interval=save_steps,
+ max_keep_ckpts=save_total_limit),
+ # set sampler seed in distributed evrionment.
+ sampler_seed=dict(type=DistSamplerSeedHook),
+)
+
+# configure environment
+env_cfg = dict(
+ # whether to enable cudnn benchmark
+ cudnn_benchmark=False,
+ # set multi process parameters
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ # set distributed parameters
+ dist_cfg=dict(backend='nccl'),
+)
+
+# set visualizer
+visualizer = None
+
+# set log level
+log_level = 'INFO'
+
+# load from which checkpoint
+load_from = None
+
+# whether to resume training from the loaded checkpoint
+resume = False
+
+# Defaults to use random seed and disable `deterministic`
+randomness = dict(seed=None, deterministic=False)
+
+# set log processor
+log_processor = dict(
+ by_epoch=False,
+ window_size=1,
+ mean_pattern=r'.*(loss|time|data_time|grad_norm|tflops).*')
diff --git a/xtuner/dataset/collate_fns/defalut_collate_fn.py b/xtuner/dataset/collate_fns/defalut_collate_fn.py
index 294cd4870..f644df9cf 100644
--- a/xtuner/dataset/collate_fns/defalut_collate_fn.py
+++ b/xtuner/dataset/collate_fns/defalut_collate_fn.py
@@ -4,6 +4,9 @@
import torch
from torch.nn.utils.rnn import pad_sequence
+from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
+ pad_for_sequence_parallel,
+ split_for_sequence_parallel)
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
@@ -11,13 +14,14 @@ def default_collate_fn(instances: Sequence[Dict],
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
return_hf_format: bool = False,
use_varlen_attn: bool = False):
+ seq_parallel_world_size = get_sequence_parallel_world_size()
input_ids, labels = [], []
has_image = any(inst.get('pixel_values') is not None for inst in instances)
if use_varlen_attn:
- cumulative_len, indexes = [], []
+ position_ids, cumulative_len = [], []
assert len(instances) == 1, (
- f'If utilizing local attention, the batch size should be'
+ f'If utilizing varlen attention, the batch size should be'
f' set to 1, but got {len(instances)}')
assert not has_image, 'Currently, it is not configured to '
'accommodate the use of varlen Attention in multimodal training'
@@ -30,7 +34,7 @@ def default_collate_fn(instances: Sequence[Dict],
labels.append(torch.LongTensor(example['labels']))
if use_varlen_attn:
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
- indexes.append(torch.LongTensor(example['indexes']))
+ position_ids.append(torch.LongTensor(example['position_ids']))
if has_image:
pixel_values.append(example['pixel_values'])
@@ -45,21 +49,37 @@ def default_collate_fn(instances: Sequence[Dict],
labels = torch.stack(labels)
if use_varlen_attn:
- indexes = torch.stack(indexes, dim=0)
+ assert input_ids.size(1) % seq_parallel_world_size == 0
+ attention_mask = None
+ position_ids = torch.stack(position_ids, dim=0)
+ else:
+ attention_mask = input_ids.ne(pad_index)
+ position_ids = attention_mask.long().cumsum(-1) - 1
+
+ input_ids, labels, position_ids, attention_mask = \
+ pad_for_sequence_parallel(input_ids, labels, position_ids,
+ attention_mask)
+
+ # attention mask should not be split
+ input_ids, labels, position_ids = split_for_sequence_parallel(
+ input_ids, labels, position_ids)
+
+ if use_varlen_attn:
max_seqlen = (
cumulative_len[0][1:] - # noqa: W504
cumulative_len[0][:-1]).max().item()
data_dict = {
'input_ids': input_ids,
'cumulative_len': cumulative_len,
- 'indexes': indexes,
+ 'position_ids': position_ids,
'labels': labels,
'max_seqlen': max_seqlen
}
else:
data_dict = {
'input_ids': input_ids,
- 'attention_mask': input_ids.ne(pad_index),
+ 'attention_mask': attention_mask,
+ 'position_ids': position_ids,
'labels': labels
}
diff --git a/xtuner/dataset/huggingface.py b/xtuner/dataset/huggingface.py
index 349cce2f6..30f6bc394 100644
--- a/xtuner/dataset/huggingface.py
+++ b/xtuner/dataset/huggingface.py
@@ -199,10 +199,9 @@ def process(dataset,
dataset = tokenize_dataset(dataset, tokenizer, max_length,
with_image_token, input_ids_with_output,
remove_unused_columns, map_num_proc)
- else:
- assert {'input_ids', 'labels'}.issubset(dataset.column_names)
if input_ids_with_output:
+ assert {'input_ids', 'labels'}.issubset(dataset.column_names)
# remove data that does not have the valid labels.
dataset = dataset.filter(
lambda example: any(label >= 0 for label in example['labels']),
diff --git a/xtuner/dataset/intern_repo.py b/xtuner/dataset/intern_repo.py
index b1034cc31..95cd7cf99 100644
--- a/xtuner/dataset/intern_repo.py
+++ b/xtuner/dataset/intern_repo.py
@@ -191,7 +191,7 @@ def mapping(self, pack_idx: int = 0):
def build_pack(self, begin_sample_idx: int, begin_token_id: int,
end_sample_idx: int, end_token_id: int):
- pack, cumulative_len, indexes, labels = [], [0], [], []
+ pack, cumulative_len, position_ids, labels = [], [0], [], []
while begin_sample_idx < end_sample_idx:
sample_idx = self.shuffled_indices[begin_sample_idx]
@@ -202,7 +202,7 @@ def build_pack(self, begin_sample_idx: int, begin_token_id: int,
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
cumulative_len.append(cumulative_len[-1] + len(chunk))
- indexes.extend(list(range(len(chunk))))
+ position_ids.extend(list(range(len(chunk))))
begin_sample_idx = begin_sample_idx + 1
begin_token_id = 0
@@ -215,12 +215,12 @@ def build_pack(self, begin_sample_idx: int, begin_token_id: int,
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
cumulative_len.append(cumulative_len[-1] + len(chunk))
- indexes.extend(list(range(len(chunk))))
+ position_ids.extend(list(range(len(chunk))))
out = {
'input_ids': pack,
'cumulative_len': cumulative_len,
- 'indexes': indexes,
+ 'position_ids': position_ids,
'labels': labels
}
return out
diff --git a/xtuner/dataset/samplers/intern_repo.py b/xtuner/dataset/samplers/intern_repo.py
index 3ca470c21..933719a58 100644
--- a/xtuner/dataset/samplers/intern_repo.py
+++ b/xtuner/dataset/samplers/intern_repo.py
@@ -4,9 +4,11 @@
import numpy as np
from mmengine import print_log
-from mmengine.dist import get_dist_info
from torch.utils.data import Sampler
+from xtuner.parallel.sequence import (get_data_parallel_rank,
+ get_data_parallel_world_size)
+
class InternRepoSampler(Sampler):
@@ -17,7 +19,8 @@ def __init__(self,
if seed is not None and seed != 1024:
warnings.warn('For alignment accuracy, seed in InternRepoSampler'
'must be set to 1024.')
- rank, world_size = get_dist_info()
+ world_size = get_data_parallel_world_size()
+ rank = get_data_parallel_rank()
self.rank = rank
self.world_size = world_size
diff --git a/xtuner/dataset/utils.py b/xtuner/dataset/utils.py
index 470a95303..84336ddb2 100644
--- a/xtuner/dataset/utils.py
+++ b/xtuner/dataset/utils.py
@@ -176,8 +176,8 @@ def get_cumulative_len(self, chunk_num):
return cumulative_len
- def get_indexes(self, cumulative_len):
- indexes = []
+ def get_position_ids(self, cumulative_len):
+ position_ids = []
for cumulative_len_cur in cumulative_len:
index_cur = []
for i in range(len(cumulative_len_cur) - 1):
@@ -185,8 +185,8 @@ def get_indexes(self, cumulative_len):
list(
range(cumulative_len_cur[i + 1] - # noqa: W504
cumulative_len_cur[i])))
- indexes.append(index_cur)
- return indexes
+ position_ids.append(index_cur)
+ return position_ids
def __call__(self, batch):
concatenated_samples = {
@@ -222,7 +222,7 @@ def __call__(self, batch):
if self.use_varlen_attn:
cumulative_len = self.get_cumulative_len(chunk_num)
result['cumulative_len'] = cumulative_len
- result['indexes'] = self.get_indexes(cumulative_len)
+ result['position_ids'] = self.get_position_ids(cumulative_len)
else:
if self.drop_last:
result = {k: [] for k, v in concatenated_samples.items()}
@@ -235,8 +235,8 @@ def __call__(self, batch):
result['cumulative_len'] = [] if self.drop_last else [
self.residual_cumulative_len
]
- result['indexes'] = [] if self.drop_last else self.get_indexes(
- [self.residual_cumulative_len])
+ result['position_ids'] = [] if self.drop_last \
+ else self.get_position_ids([self.residual_cumulative_len])
self.residual_cumulative_len = [0]
return result
diff --git a/xtuner/engine/_strategy/deepspeed.py b/xtuner/engine/_strategy/deepspeed.py
index afa0cc57c..42b7f5590 100644
--- a/xtuner/engine/_strategy/deepspeed.py
+++ b/xtuner/engine/_strategy/deepspeed.py
@@ -1,13 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
from mmengine._strategy import DeepSpeedStrategy as MMEngineDeepSpeedStrategy
from xtuner import DS_CEPH_DIR
+from xtuner.parallel.sequence import init_sequence_parallel
from xtuner.utils.fileio import patch_fileio
class DeepSpeedStrategy(MMEngineDeepSpeedStrategy):
def __init__(self, *args, **kwargs):
+ sequence_parallel_size = kwargs.pop('sequence_parallel_size', 1)
+ self.sequence_parallel_size = sequence_parallel_size
+
super().__init__(*args, **kwargs)
from transformers.integrations.deepspeed import HfDeepSpeedConfig
@@ -53,3 +59,12 @@ def resume(self, *args, **kwargs) -> None:
else:
checkpoint = super().resume(*args, **kwargs)
return checkpoint
+
+ def _setup_distributed( # type: ignore
+ self,
+ launcher: Optional[str] = None,
+ backend: str = 'nccl',
+ **kwargs,
+ ):
+ super()._setup_distributed(launcher, backend, **kwargs)
+ init_sequence_parallel(self.sequence_parallel_size)
diff --git a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
index 9aa7a91bd..f2b23d3fe 100644
--- a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
+++ b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
@@ -11,7 +11,7 @@
class VarlenAttnArgsToMessageHubHook(Hook):
- args = ('cumulative_len', 'indexes', 'max_seqlen')
+ args = ('cumulative_len', 'max_seqlen')
def cast_data(self, data):
if isinstance(data, Mapping):
diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py
index 79ff88b08..6fbe37fb6 100644
--- a/xtuner/model/modules/dispatch/__init__.py
+++ b/xtuner/model/modules/dispatch/__init__.py
@@ -123,7 +123,8 @@ def dispatch_internlm2_attn_forward(model, use_varlen_attn):
print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
- if type(module).__name__ == 'InternLM2Attention':
+ if type(module).__name__ in ('InternLM2Attention',
+ 'InternLM2FlashAttention2'):
if use_varlen_attn:
print_log('dispatch internlm2 varlen attn forward', 'current')
module.forward = types.MethodType(
@@ -188,11 +189,12 @@ def traverse(module):
for name, child in module.named_children():
if type(child).__name__ in (
'InternLM2RotaryEmbedding',
+ 'InternLM2LinearScalingRotaryEmbedding',
'InternLM2DynamicNTKScalingRotaryEmbedding'):
print_log('replace internlm2 rope', 'current')
dim_model = child.inv_freq.shape[0] * 2
child_new = InternLM2RotaryEmbedding(
- dim_model, child.max_seq_len_cached, rotary_base).to(
+ dim_model, child.max_position_embeddings, rotary_base).to(
device=child.inv_freq.device,
dtype=child.inv_freq.dtype)
setattr(module, name, child_new)
@@ -301,12 +303,12 @@ def dispatch_modules(model, use_varlen_attn=False):
dispatch_internlm2_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
dispatch_internlm2_rmsnorm_forward(model)
- replace_internlm2_rote(model)
+ # replace_internlm2_rote(model)
elif 'internlm' in model_name:
dispatch_internlm_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
dispatch_internlm_rmsnorm_forward(model)
- replace_internlm_rote(model)
+ # replace_internlm_rote(model)
elif 'llama' in model_name:
dispatch_llama_attn_forward(model, use_varlen_attn)
if USE_TRITON_KERNEL:
diff --git a/xtuner/model/modules/dispatch/internlm.py b/xtuner/model/modules/dispatch/internlm.py
index 8bee27b67..fd06def33 100644
--- a/xtuner/model/modules/dispatch/internlm.py
+++ b/xtuner/model/modules/dispatch/internlm.py
@@ -154,7 +154,7 @@ def internlm_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- indexes = message_hub.get_info(f'indexes_rank_{rank}')
+ # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
@@ -175,10 +175,11 @@ def internlm_varlen_attn_forward(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
+ query_states = apply_rotary_emb(query_states,
+ cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
+ key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py
index 66aa4f391..a166e8bae 100644
--- a/xtuner/model/modules/dispatch/internlm2.py
+++ b/xtuner/model/modules/dispatch/internlm2.py
@@ -8,13 +8,15 @@
from einops import rearrange
from mmengine import MessageHub
+from xtuner.parallel.sequence import sequence_parallel_wrapper
from .triton_kernels import apply_rotary_emb
+from .utils import upad_qkv
SUPPORT_FLASH2 = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
-
+ from flash_attn.bert_padding import pad_input
SUPPORT_FLASH2 = True
except ImportError:
pass
@@ -28,6 +30,9 @@ def __init__(self,
base=1000000,
device=None):
super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
self.inv_freq = 1.0 / (
base**(torch.arange(0, dim, 2).float().to(device) / dim))
@@ -96,22 +101,115 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
head_dim)
+def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
+ to (batch, seqlen, num_attention_heads, head_dim)"""
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, :,
+ None, :].expand(batch, slen,
+ num_key_value_heads, n_rep,
+ head_dim)
+ return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
+ head_dim)
+
+
+@sequence_parallel_wrapper
+def flash_attn_wo_mask(query_states,
+ key_states,
+ value_states,
+ causal,
+ dropout_rate=0.0):
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout_rate, causal=causal)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def flash_attn_w_mask(
+ query_states, # bs, q_len, nhead, h_dim
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ dropout_rate=0.0):
+ batch_size, q_len = query_states.shape[:2]
+ query_states, key_states, value_states, indices_q, \
+ cu_seq_lens, max_seq_lens = upad_qkv(
+ query_states, key_states, value_states, attention_mask, q_len)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout_rate,
+ causal=causal,
+ )
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
+ return attn_output
+
+
+def flash_attn1_pytorch(query_states, key_states, value_states, *args,
+ **kwargs):
+ # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = F.scaled_dot_product_attention(query_states, key_states,
+ value_states, *args, **kwargs)
+ attn_output = attn_output.transpose(1, 2)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
+ max_seqlen):
+ q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
+ 0, 1), value_states.flatten(0, 1)
+ cumulative_len = torch.cat(cumulative_len, dim=0)
+ attn_output = flash_attn_varlen_func(
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ cumulative_len,
+ cumulative_len,
+ max_seqlen,
+ max_seqlen,
+ 0,
+ return_attn_probs=False,
+ causal=True,
+ )
+ attn_output = attn_output.unsqueeze(0)
+ return attn_output
+
+
def internlm2_attn_forward(
self,
hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
+):
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37'
'Please make sure use `attention_mask` instead.`')
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop('padding_mask')
+
+ output_attentions = False
+
bsz, q_len, _ = hidden_states.size()
qkv_states = self.wqkv(hidden_states)
@@ -135,7 +233,10 @@ def internlm2_attn_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ # This modification is necessary for sequential parallel
+ assert position_ids is not None and (position_ids.max() + 1) >= kv_seq_len
+ cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
@@ -146,29 +247,49 @@ def internlm2_attn_forward(
past_key_value = (key_states, value_states) if use_cache else None
+ # repeat kv for sequence parallel
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # flash attn 2 need (bs, seq_len, nhead, h_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
if SUPPORT_FLASH2:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
- attn_output = attn_output.contiguous()
+ causal = self.is_causal and q_len != 1
+
+ if attention_mask is not None:
+ attn_output = flash_attn_w_mask(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ training=self.training)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal,
+ training=self.training)
else:
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
# use flash attention implemented by pytorch
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = flash_attn1_pytorch(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ training=self.training)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-
attn_output = self.wo(attn_output)
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
def internlm2_varlen_attn_forward(
@@ -188,7 +309,7 @@ def internlm2_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- indexes = message_hub.get_info(f'indexes_rank_{rank}')
+ # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
@@ -216,10 +337,11 @@ def internlm2_varlen_attn_forward(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
+ query_states = apply_rotary_emb(query_states,
+ cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
+ key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -238,26 +360,21 @@ def internlm2_varlen_attn_forward(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
assert SUPPORT_FLASH2
if is_training:
- q_unpad, k_unpad, v_unpad = query_states.flatten(
- 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
- attn_output = flash_attn_varlen_func(
- q_unpad,
- k_unpad,
- v_unpad,
- cumulative_len,
- cumulative_len,
- max_seqlen,
- max_seqlen,
- 0,
- return_attn_probs=False,
- causal=True,
- )
+ attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ cumulative_len, max_seqlen)
else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=False)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py
index 94ddedaec..27b1f33d6 100644
--- a/xtuner/model/modules/dispatch/llama.py
+++ b/xtuner/model/modules/dispatch/llama.py
@@ -6,14 +6,18 @@
import torch.distributed as dist
import torch.nn.functional as F
from mmengine import MessageHub
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+from transformers.utils import logging
+from xtuner.parallel.sequence import sequence_parallel_wrapper
from .triton_kernels import apply_rotary_emb
+from .utils import upad_qkv
SUPPORT_FLASH2 = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
-
+ from flash_attn.bert_padding import pad_input
SUPPORT_FLASH2 = True
except ImportError:
pass
@@ -26,6 +30,9 @@ class Cache:
pass
+logger = logging.get_logger(__name__)
+
+
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
@@ -33,18 +40,6 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
- # The first two dimensions of cos and sin are always 1,
- # so we can `squeeze` them.
- cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
- sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
- cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""This is the equivalent of torch.repeat_interleave(x, dim=1,
repeats=n_rep).
@@ -63,6 +58,95 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
head_dim)
+def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
+ to (batch, seqlen, num_attention_heads, head_dim)"""
+ batch, slen, num_key_value_heads, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, :,
+ None, :].expand(batch, slen,
+ num_key_value_heads, n_rep,
+ head_dim)
+ return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep,
+ head_dim)
+
+
+@sequence_parallel_wrapper
+def flash_attn_wo_mask(query_states,
+ key_states,
+ value_states,
+ causal,
+ dropout_rate=0.0):
+ attn_output = flash_attn_func(
+ query_states, key_states, value_states, dropout_rate, causal=causal)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def flash_attn_w_mask(
+ query_states, # bs, q_len, nhead, h_dim
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ dropout_rate=0.0):
+ batch_size, q_len = query_states.shape[:2]
+ query_states, key_states, value_states, indices_q, \
+ cu_seq_lens, max_seq_lens = upad_qkv(
+ query_states, key_states, value_states, attention_mask, q_len)
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+ attn_output_unpad = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ dropout_p=dropout_rate,
+ causal=causal,
+ )
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len)
+ return attn_output
+
+
+def flash_attn1_pytorch(query_states, key_states, value_states, *args,
+ **kwargs):
+ # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ attn_output = F.scaled_dot_product_attention(query_states, key_states,
+ value_states, *args, **kwargs)
+ attn_output = attn_output.transpose(1, 2)
+ return attn_output
+
+
+@sequence_parallel_wrapper
+def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
+ max_seqlen):
+ q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
+ 0, 1), value_states.flatten(0, 1)
+ cumulative_len = torch.cat(cumulative_len, dim=0)
+ attn_output = flash_attn_varlen_func(
+ q_unpad,
+ k_unpad,
+ v_unpad,
+ cumulative_len,
+ cumulative_len,
+ max_seqlen,
+ max_seqlen,
+ 0,
+ return_attn_probs=False,
+ causal=True,
+ )
+ attn_output = attn_output.unsqueeze(0)
+ return attn_output
+
+
def llama_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
@@ -136,21 +220,41 @@ def llama_attn_forward_legacy(
past_key_value = (key_states, value_states) if use_cache else None
+ # repeat kv for sequence parallel
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # flash attn 2 need (bs, seq_len, nhead, h_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
if SUPPORT_FLASH2:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
- attn_output = attn_output.contiguous()
+ causal = self.is_causal and q_len != 1
+
+ if attention_mask is not None:
+ attn_output = flash_attn_w_mask(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ training=self.training)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal,
+ training=self.training)
else:
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
# use flash attention implemented by pytorch
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = flash_attn1_pytorch(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=attention_mask,
+ training=self.training)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -183,46 +287,26 @@ def llama_attn_forward(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
- # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
+ # LlamaFlashAttention2 attention does not support output_attentions
if 'padding_mask' in kwargs:
warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in '
- 'v4.37. Please make sure use `attention_mask` instead.`')
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37'
+ ' Please make sure use `attention_mask` instead.`')
- bsz, q_len, _ = hidden_states.size()
+ # overwrite attention_mask with padding_mask
+ attention_mask = kwargs.pop('padding_mask')
- if self.config.pretraining_tp > 1:
- key_value_slicing = (self.num_key_value_heads *
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
+ output_attentions = False
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
+ bsz, q_len, _ = hidden_states.size()
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
@@ -232,16 +316,15 @@ def llama_attn_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
- if self.layer_idx is None:
- raise ValueError(
- 'The cache structure has changed since version v4.36. '
- f'If you are using {self.__class__.__name__} '
- 'for auto-regressive decoding with k/v caching, '
- 'please make sure to initialize the attention class '
- 'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+
+ if self.training:
+ assert position_ids is not None
+ cos, sin = self.rotary_emb(
+ value_states, seq_len=position_ids.max() + 1)
+ else:
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
@@ -250,40 +333,78 @@ def llama_attn_forward(
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
- if SUPPORT_FLASH2:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
- attn_output = attn_output.contiguous()
+ # TODO: These transpose are quite inefficient but Flash Attention
+ # requires the layout [batch_size, sequence_length, num_heads, head_dim].
+ # We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons, therefore the input hidden states gets silently
+ # casted in float32. Hence, we need cast them back in the correct dtype
+ # just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not
+ # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ # Handle the case where the model is quantized
+ if hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f'The input hidden states seems to be silently casted in float32, '
+ f'this might be related to the fact you have upcasted embedding '
+ f'or layer norm layers in float32. We will cast back the input in'
+ f' {target_dtype}.')
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # flash attn
+ if not self._flash_attn_uses_top_left_mask:
+ causal = self.is_causal
else:
- # repeat k/v heads if n_kv_heads < n_heads
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
- # use flash attention implemented by pytorch
- attn_output = F.scaled_dot_product_attention(
- query_states, key_states, value_states, attn_mask=attention_mask)
- attn_output = attn_output.transpose(1, 2).contiguous()
+ # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm
+ # is bumped to 2.1. For details, please see the comment in
+ # LlamaFlashAttention2 __init__.
+ causal = self.is_causal and q_len != 1
+
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
+ if attention_mask is not None:
+ attn_output = flash_attn_w_mask(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ dropout_rate,
+ training=self.training)
+ else:
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal,
+ dropout_rate,
+ training=self.training)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
+ if not output_attentions:
+ attn_weights = None
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
def llama_varlen_attn_forward_legacy(
@@ -302,7 +423,7 @@ def llama_varlen_attn_forward_legacy(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- indexes = message_hub.get_info(f'indexes_rank_{rank}')
+ # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
@@ -357,10 +478,11 @@ def llama_varlen_attn_forward_legacy(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
+ query_states = apply_rotary_emb(query_states,
+ cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
+ key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -379,26 +501,21 @@ def llama_varlen_attn_forward_legacy(
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
assert SUPPORT_FLASH2
if is_training:
- q_unpad, k_unpad, v_unpad = query_states.flatten(
- 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
- attn_output = flash_attn_varlen_func(
- q_unpad,
- k_unpad,
- v_unpad,
- cumulative_len,
- cumulative_len,
- max_seqlen,
- max_seqlen,
- 0,
- return_attn_probs=False,
- causal=True,
- )
+ attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ cumulative_len, max_seqlen)
else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=False)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
@@ -436,7 +553,7 @@ def llama_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- indexes = message_hub.get_info(f'indexes_rank_{rank}')
+ # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
@@ -499,10 +616,11 @@ def llama_varlen_attn_forward(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
+ query_states = apply_rotary_emb(query_states,
+ cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
+ key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -516,33 +634,25 @@ def llama_varlen_attn_forward(
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
- key_states = repeat_kv(key_states, self.num_key_value_groups)
- value_states = repeat_kv(value_states, self.num_key_value_groups)
-
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
+ # repeat kv for sequence parallel
+ key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
+ value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+
assert SUPPORT_FLASH2
if is_training:
- q_unpad, k_unpad, v_unpad = query_states.flatten(
- 0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
- attn_output = flash_attn_varlen_func(
- q_unpad,
- k_unpad,
- v_unpad,
- cumulative_len,
- cumulative_len,
- max_seqlen,
- max_seqlen,
- 0,
- return_attn_probs=False,
- causal=True,
- )
+ attn_output = varlen_flash_attn(query_states, key_states, value_states,
+ cumulative_len, max_seqlen)
else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, causal=True)
+ attn_output = flash_attn_wo_mask(
+ query_states,
+ key_states,
+ value_states,
+ causal=True,
+ training=False)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
diff --git a/xtuner/model/modules/dispatch/mistral.py b/xtuner/model/modules/dispatch/mistral.py
index 8c65cbec6..92245230c 100644
--- a/xtuner/model/modules/dispatch/mistral.py
+++ b/xtuner/model/modules/dispatch/mistral.py
@@ -84,13 +84,9 @@ def mistral_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- indexes = message_hub.get_info(f'indexes_rank_{rank}')
+ # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
- # cumulative_len = message_hub.get_info(f'cumulative_len')
- # indexes = message_hub.get_info(f'indexes')
- # max_seqlen = message_hub.get_info(f'max_seqlen')
-
assert is_training == (cumulative_len is not None) == (
past_key_value is None)
@@ -136,10 +132,11 @@ def mistral_varlen_attn_forward(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[indexes].squeeze(0),
- sin[indexes].squeeze(0))
+ query_states = apply_rotary_emb(query_states,
+ cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
+ key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
+ sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
diff --git a/xtuner/model/modules/dispatch/utils.py b/xtuner/model/modules/dispatch/utils.py
new file mode 100644
index 000000000..4cfa26cd1
--- /dev/null
+++ b/xtuner/model/modules/dispatch/utils.py
@@ -0,0 +1,64 @@
+import torch
+import torch.nn.functional as F
+
+try:
+ from flash_attn.bert_padding import index_first_axis, unpad_input
+except ImportError:
+ pass
+
+
+def _get_unpad_data(attention_mask):
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def upad_qkv(query_layer, key_layer, value_layer, attention_mask,
+ query_length):
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(
+ attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads,
+ head_dim), indices_k)
+ if query_length == kv_seq_len:
+ # Different from the origin version as sequence parallel change
+ # the number of attention heads.
+ query_layer = index_first_axis(
+ query_layer.reshape(batch_size * kv_seq_len, -1, head_dim),
+ indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = \
+ unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py
index c433ff1e9..7aa0ec63c 100644
--- a/xtuner/model/sft.py
+++ b/xtuner/model/sft.py
@@ -1,18 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import math
from collections import OrderedDict
from contextlib import nullcontext
+import torch
from mmengine import print_log
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from mmengine.runner import load_checkpoint
from peft import get_peft_model, prepare_model_for_kbit_training
from torch import nn
-from transformers import PreTrainedModel, PreTrainedTokenizer
+from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
from transformers.integrations import is_deepspeed_zero3_enabled
+from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
+ reduce_sequence_parallel_loss)
from xtuner.registry import BUILDER
from .modules import dispatch_modules
+from .modules.dispatch import SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, make_inputs_require_grad,
traverse_dict)
@@ -69,10 +74,12 @@ def __init__(self,
peft_model=None,
use_activation_checkpointing=True,
use_varlen_attn=False,
- tokenizer=None):
+ tokenizer=None,
+ max_position_embeddings=None):
super().__init__()
with LoadWoInit():
- self.llm = self._build_from_cfg_or_module(llm)
+ self.llm = self._build_from_cfg_or_module(llm,
+ max_position_embeddings)
if tokenizer is not None:
if isinstance(tokenizer, dict):
@@ -137,11 +144,48 @@ def _prepare_for_lora(self,
def init_weights(self):
pass
- def _build_from_cfg_or_module(self, cfg_or_mod):
+ def _prepare_for_long_context_training(self, cfg, max_position_embeddings):
+ pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
+ config = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=True)
+
+ orig_rope_scaling = getattr(config, 'rope_scaling', None)
+ if orig_rope_scaling is None:
+ orig_rope_scaling = {'factor': 1}
+
+ orig_rope_scaling_factor = orig_rope_scaling[
+ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
+ orig_ctx_len = getattr(config, 'max_position_embeddings', None)
+ if orig_ctx_len:
+ orig_ctx_len *= orig_rope_scaling_factor
+ if max_position_embeddings > orig_ctx_len:
+ scaling_factor = float(
+ math.ceil(max_position_embeddings / orig_ctx_len))
+ config.rope_scaling = {
+ 'type': 'linear',
+ 'factor': scaling_factor
+ }
+
+ # hardcode for internlm2
+ config.attn_implementation = 'flash_attention_2'
+
+ cfg.config = config
+ return cfg
+
+ def _build_from_cfg_or_module(self,
+ cfg_or_mod,
+ max_position_embeddings=None):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
+ if SUPPORT_FLASH2:
+ cfg_or_mod.torch_dtype = torch.bfloat16 \
+ if torch.cuda.is_bf16_supported() else torch.float16
+ cfg_or_mod.attn_implementation = 'flash_attention_2'
+ if max_position_embeddings is not None:
+ cfg_or_mod = self._prepare_for_long_context_training(
+ cfg_or_mod, max_position_embeddings)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
@@ -168,10 +212,20 @@ def predict(self, data, data_samples=None):
logits_dict = [{'logits': logits} for logits in outputs.logits]
return logits_dict
- def compute_loss(self, data, data_samples=None):
+ def compute_sequence_parallel_loss(self, data):
outputs = self.llm(**data)
- loss_dict = {'loss': outputs.loss}
- return loss_dict
+ labels = data['labels']
+ num_tokens = (labels != -100).sum()
+ loss = reduce_sequence_parallel_loss(outputs.loss, num_tokens)
+ return {'loss': loss}
+
+ def compute_loss(self, data, data_samples=None):
+ if get_sequence_parallel_world_size() > 1:
+ return self.compute_sequence_parallel_loss(data)
+ else:
+ outputs = self.llm(**data)
+ loss_dict = {'loss': outputs.loss}
+ return loss_dict
def state_dict(self, *args, **kwargs):
state_dict = super().state_dict(*args, **kwargs)
diff --git a/xtuner/parallel/sequence/__init__.py b/xtuner/parallel/sequence/__init__.py
new file mode 100644
index 000000000..a50921336
--- /dev/null
+++ b/xtuner/parallel/sequence/__init__.py
@@ -0,0 +1,24 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmengine.dist import init_dist
+
+from .attention import sequence_parallel_wrapper
+from .data_collate import (pad_for_sequence_parallel,
+ split_for_sequence_parallel)
+from .reduce_loss import reduce_sequence_parallel_loss
+from .sampler import SequenceParallelSampler
+from .setup_distributed import (get_data_parallel_group,
+ get_data_parallel_rank,
+ get_data_parallel_world_size,
+ get_sequence_parallel_group,
+ get_sequence_parallel_rank,
+ get_sequence_parallel_world_size,
+ init_sequence_parallel)
+
+__all__ = [
+ 'sequence_parallel_wrapper', 'pad_for_sequence_parallel',
+ 'split_for_sequence_parallel', 'SequenceParallelSampler',
+ 'init_sequence_parallel', 'get_sequence_parallel_group',
+ 'get_sequence_parallel_world_size', 'get_sequence_parallel_rank',
+ 'get_data_parallel_group', 'get_data_parallel_world_size',
+ 'get_data_parallel_rank', 'reduce_sequence_parallel_loss', 'init_dist'
+]
diff --git a/xtuner/parallel/sequence/attention.py b/xtuner/parallel/sequence/attention.py
new file mode 100644
index 000000000..b1b1ebcee
--- /dev/null
+++ b/xtuner/parallel/sequence/attention.py
@@ -0,0 +1,102 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Any, Tuple
+
+import torch
+import torch.distributed as dist
+from torch import Tensor
+
+from .setup_distributed import (get_sequence_parallel_group,
+ get_sequence_parallel_world_size)
+
+
+def all_to_all_scatter_nhead(input):
+ # bs, seq, nhead, dim ==>
+ # bs, seq * sp_world_size, nhead / sp_world_size, dim
+ sp_world_size = get_sequence_parallel_world_size()
+ sp_group = get_sequence_parallel_group()
+ bs, seq, nhead, dim = input.shape
+ input_t = input.reshape(bs, seq, sp_world_size, nhead // sp_world_size,
+ dim)
+ input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
+ output = torch.empty_like(input_t)
+ dist.all_to_all_single(output, input_t, group=sp_group)
+ output = output.transpose(0, 1)
+ return output.reshape(bs, seq * sp_world_size, nhead // sp_world_size, dim)
+
+
+def all_to_all_scatter_seq(input):
+ # bs, seq * sp_world_size, nhead / sp_world_size, dim ==>
+ # bs, seq, nhead, dim
+ sp_world_size = get_sequence_parallel_world_size()
+ sp_group = get_sequence_parallel_group()
+ bs, seq, nhead, dim = input.shape
+ input_t = input.reshape(bs, sp_world_size, seq // sp_world_size, nhead,
+ dim)
+ input_t = input_t.transpose(0, 1).contiguous()
+ output = torch.empty_like(input_t)
+ dist.all_to_all_single(output, input_t, group=sp_group)
+ output = output.permute(1, 2, 0, 3, 4)
+ return output.reshape(bs, seq // sp_world_size, nhead * sp_world_size, dim)
+
+
+class _SeqAllToAll(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx: Any, input: Tensor, scatter_seq) -> Tensor:
+ ctx.scatter_seq = scatter_seq
+ ctx.input_shape = input.shape
+ if scatter_seq:
+ return all_to_all_scatter_seq(input)
+ return all_to_all_scatter_nhead(input)
+
+ @staticmethod
+ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
+ grad = _SeqAllToAll.apply(*grad_output, not ctx.scatter_seq)
+ return (grad, None)
+
+
+def pre_process_for_sequence_parallel_attn(query_states, key_states,
+ value_states):
+ sequence_parallel_world_size = get_sequence_parallel_world_size()
+ n_head = query_states.shape[2]
+ assert n_head % sequence_parallel_world_size == 0, \
+ ('The number of attention heads should be divisible by '
+ f'sequence_parallel_world_size. But got n_head = {n_head} and '
+ f'sequence_parallel_world_size = {sequence_parallel_world_size}.')
+
+ # (b, s // sp_world_size, nd, dim) -> (b, s, nd // sp_world_size, dim)
+ query_states = _SeqAllToAll.apply(query_states, False)
+ key_states = _SeqAllToAll.apply(key_states, False)
+ value_states = _SeqAllToAll.apply(value_states, False)
+
+ return query_states, key_states, value_states
+
+
+def post_process_for_sequence_parallel_attn(attn_output):
+ # (b, s, nd // sp_world_size, dim) -> (b, s // sp_world_size, nd, dim)
+ output = _SeqAllToAll.apply(attn_output, True)
+ return output
+
+
+def sequence_parallel_wrapper(local_attn):
+
+ def sequence_parallel_attn(query_states, key_states, value_states, *args,
+ **kwargs):
+ training = kwargs.pop('training', True)
+ enable_sequence_parallel = (
+ dist.is_initialized() and get_sequence_parallel_world_size() > 1
+ and training)
+ if enable_sequence_parallel:
+ query_states, key_states, value_states = \
+ pre_process_for_sequence_parallel_attn(
+ query_states, key_states, value_states)
+
+ out = local_attn(query_states, key_states, value_states, *args,
+ **kwargs)
+
+ if enable_sequence_parallel:
+ out = post_process_for_sequence_parallel_attn(out).contiguous()
+
+ return out
+
+ return sequence_parallel_attn
diff --git a/xtuner/parallel/sequence/data_collate.py b/xtuner/parallel/sequence/data_collate.py
new file mode 100644
index 000000000..f61b481b9
--- /dev/null
+++ b/xtuner/parallel/sequence/data_collate.py
@@ -0,0 +1,75 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
+from .setup_distributed import (get_sequence_parallel_rank,
+ get_sequence_parallel_world_size)
+
+
+def pad_for_sequence_parallel(tokens,
+ labels=None,
+ position_ids=None,
+ attention_mask=None,
+ tokens_pad_index=DEFAULT_PAD_TOKEN_INDEX,
+ labels_pad_index=IGNORE_INDEX,
+ position_ids_pad_index=0,
+ attention_mask_pad_index=0):
+ if labels is not None:
+ assert tokens.shape == labels.shape
+ if position_ids is not None:
+ assert tokens.shape == position_ids.shape
+ if attention_mask is not None:
+ assert tokens.shape == attention_mask.shape
+
+ bs, seq_len = tokens.shape
+ seq_parallel_world_size = get_sequence_parallel_world_size()
+ if seq_len % seq_parallel_world_size == 0:
+ return tokens, labels, position_ids, attention_mask
+
+ pad_num = seq_parallel_world_size - (seq_len % seq_parallel_world_size)
+ pad = torch.full((bs, pad_num),
+ tokens_pad_index,
+ dtype=tokens.dtype,
+ device=tokens.device)
+ tokens = torch.cat([tokens, pad], dim=1)
+
+ if labels is not None:
+ pad = torch.full((bs, pad_num),
+ labels_pad_index,
+ dtype=labels.dtype,
+ device=labels.device)
+ labels = torch.cat([labels, pad], dim=1)
+
+ if position_ids is not None:
+ pad = torch.full((bs, pad_num),
+ position_ids_pad_index,
+ dtype=position_ids.dtype,
+ device=position_ids.device)
+ position_ids = torch.cat([position_ids, pad], dim=1)
+
+ if attention_mask is not None:
+ pad = torch.full((bs, pad_num),
+ attention_mask_pad_index,
+ dtype=attention_mask.dtype,
+ device=attention_mask.device)
+ attention_mask = torch.cat([attention_mask, pad], dim=1)
+
+ return tokens, labels, position_ids, attention_mask
+
+
+def split_for_sequence_parallel(tokens, labels=None, position_ids=None):
+ seq_parallel_world_size = get_sequence_parallel_world_size()
+ seq_parallel_world_rank = get_sequence_parallel_rank()
+ seq_len = tokens.size(1)
+ assert seq_len % seq_parallel_world_size == 0
+ sub_seq_len = seq_len // seq_parallel_world_size
+ sub_seq_start = seq_parallel_world_rank * sub_seq_len
+ sub_seq_end = (seq_parallel_world_rank + 1) * sub_seq_len
+
+ tokens = tokens[:, sub_seq_start:sub_seq_end]
+ if labels is not None:
+ labels = labels[:, sub_seq_start:sub_seq_end]
+ if position_ids is not None:
+ position_ids = position_ids[:, sub_seq_start:sub_seq_end]
+
+ return tokens, labels, position_ids
diff --git a/xtuner/parallel/sequence/reduce_loss.py b/xtuner/parallel/sequence/reduce_loss.py
new file mode 100644
index 000000000..56a8389f4
--- /dev/null
+++ b/xtuner/parallel/sequence/reduce_loss.py
@@ -0,0 +1,17 @@
+import torch
+import torch.distributed as dist
+
+from .setup_distributed import get_sequence_parallel_group
+
+
+def reduce_sequence_parallel_loss(mean_loss, num_tokens_for_loss):
+ sequence_parallel_group = get_sequence_parallel_group()
+ if num_tokens_for_loss == 0:
+ # convert nan to 0 just for logging
+ mean_loss = torch.nan_to_num(mean_loss)
+ loss_sum = mean_loss * num_tokens_for_loss
+ dist.all_reduce(loss_sum, group=sequence_parallel_group)
+ dist.all_reduce(num_tokens_for_loss, group=sequence_parallel_group)
+
+ loss = loss_sum / num_tokens_for_loss
+ return loss
diff --git a/xtuner/parallel/sequence/sampler.py b/xtuner/parallel/sequence/sampler.py
new file mode 100644
index 000000000..69adb7cc9
--- /dev/null
+++ b/xtuner/parallel/sequence/sampler.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import math
+from typing import Optional, Sized
+
+from mmengine.dataset import DefaultSampler
+from mmengine.dist import sync_random_seed
+
+from .setup_distributed import (get_data_parallel_rank,
+ get_data_parallel_world_size)
+
+
+class SequenceParallelSampler(DefaultSampler):
+
+ def __init__(self,
+ dataset: Sized,
+ shuffle: bool = True,
+ seed: Optional[int] = None,
+ round_up: bool = True) -> None:
+ rank = get_data_parallel_rank()
+ world_size = get_data_parallel_world_size()
+ self.rank = rank
+ self.world_size = world_size
+
+ self.dataset = dataset
+ self.shuffle = shuffle
+ if seed is None:
+ seed = sync_random_seed()
+ self.seed = seed
+ self.epoch = 0
+ self.round_up = round_up
+
+ if self.round_up:
+ self.num_samples = math.ceil(len(self.dataset) / world_size)
+ self.total_size = self.num_samples * self.world_size
+ else:
+ self.num_samples = math.ceil(
+ (len(self.dataset) - rank) / world_size)
+ self.total_size = len(self.dataset)
diff --git a/xtuner/parallel/sequence/setup_distributed.py b/xtuner/parallel/sequence/setup_distributed.py
new file mode 100644
index 000000000..ea207bf10
--- /dev/null
+++ b/xtuner/parallel/sequence/setup_distributed.py
@@ -0,0 +1,100 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.distributed as dist
+
+_SEQUENCE_PARALLEL_GROUP = None
+_SEQUENCE_PARALLEL_WORLD_SIZE = None
+_SEQUENCE_PARALLEL_RANK = None
+
+_DATA_PARALLEL_GROUP = None
+_DATA_PARALLEL_WORLD_SIZE = None
+_DATA_PARALLEL_RANK = None
+
+
+def init_sequence_parallel(sequence_parallel_size: int = 1):
+ assert dist.is_initialized()
+ world_size: int = dist.get_world_size()
+
+ # enable_ds_sequence_parallel = sequence_parallel_size > 1
+ # if enable_ds_sequence_parallel:
+ if world_size % sequence_parallel_size != 0:
+ raise RuntimeError(f'world_size ({world_size}) is not divisible by '
+ f'sequence_parallel_size {sequence_parallel_size}')
+
+ num_sequence_parallel_groups: int = world_size // sequence_parallel_size
+
+ rank = dist.get_rank()
+
+ # Build the sequence parallel groups.
+ global _SEQUENCE_PARALLEL_GROUP
+ assert _SEQUENCE_PARALLEL_GROUP is None, \
+ 'sequence parallel group is already initialized'
+ for i in range(num_sequence_parallel_groups):
+ ranks = range(i * sequence_parallel_size,
+ (i + 1) * sequence_parallel_size)
+ group = dist.new_group(ranks)
+ if rank in ranks:
+ _SEQUENCE_PARALLEL_GROUP = group
+
+ global _DATA_PARALLEL_GROUP
+ assert _DATA_PARALLEL_GROUP is None, \
+ 'data parallel group is already initialized'
+ all_data_parallel_group_ranks = []
+ start_rank = 0
+ end_rank = world_size
+ for j in range(sequence_parallel_size):
+ ranks = range(start_rank + j, end_rank, sequence_parallel_size)
+ all_data_parallel_group_ranks.append(list(ranks))
+ group = dist.new_group(ranks)
+ if rank in ranks:
+ _DATA_PARALLEL_GROUP = group
+
+
+def get_sequence_parallel_group():
+ """Get the sequence parallel group the caller rank belongs to."""
+ return _SEQUENCE_PARALLEL_GROUP
+
+
+def get_sequence_parallel_world_size():
+ """Return world size for the sequence parallel group."""
+ global _SEQUENCE_PARALLEL_WORLD_SIZE
+ if _SEQUENCE_PARALLEL_WORLD_SIZE is not None:
+ return _SEQUENCE_PARALLEL_WORLD_SIZE
+ _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
+ group=get_sequence_parallel_group())
+ return _SEQUENCE_PARALLEL_WORLD_SIZE
+
+
+def get_sequence_parallel_rank():
+ """Return my rank for the sequence parallel group."""
+ global _SEQUENCE_PARALLEL_RANK
+ if _SEQUENCE_PARALLEL_RANK is not None:
+ return _SEQUENCE_PARALLEL_RANK
+ _SEQUENCE_PARALLEL_RANK = dist.get_rank(
+ group=get_sequence_parallel_group())
+ return _SEQUENCE_PARALLEL_RANK
+
+
+def get_data_parallel_group():
+ """Get the data parallel group the caller rank belongs to."""
+ assert _DATA_PARALLEL_GROUP is not None, \
+ 'data parallel group is not initialized'
+ return _DATA_PARALLEL_GROUP
+
+
+def get_data_parallel_world_size():
+ """Return world size for the data parallel group."""
+ global _DATA_PARALLEL_WORLD_SIZE
+ if _DATA_PARALLEL_WORLD_SIZE is not None:
+ return _DATA_PARALLEL_WORLD_SIZE
+ _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size(
+ group=get_data_parallel_group())
+ return _DATA_PARALLEL_WORLD_SIZE
+
+
+def get_data_parallel_rank():
+ """Return my rank for the data parallel group."""
+ global _DATA_PARALLEL_RANK
+ if _DATA_PARALLEL_RANK is not None:
+ return _DATA_PARALLEL_RANK
+ _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group())
+ return _DATA_PARALLEL_RANK
diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py
index 922089dae..7acbbf21f 100644
--- a/xtuner/tools/train.py
+++ b/xtuner/tools/train.py
@@ -76,6 +76,31 @@ def register_function(cfg_dict):
register_function(value)
+def check_cfg(cfg):
+ if getattr(cfg, 'use_varlen_attn',
+ False) and cfg.train_dataloader.batch_size > 1:
+ raise NotImplementedError(
+ f'If utilizing varlen attention, the batch size should be'
+ f' set to 1, but got {cfg.train_dataloader.batch_size}')
+
+ if getattr(cfg, 'use_varlen_attn', False) and (not getattr(
+ cfg.train_dataloader.dataset, 'pack_to_max_length', True)):
+ raise AssertionError(
+ 'When using varlen attention, `pack_to_max_length`'
+ 'should be set to True, but got use_varlen_attn = True and '
+ 'pack_to_max_length = False.')
+
+ if getattr(cfg, 'use_varlen_attn', False):
+ sequence_parallel = getattr(cfg, 'sequence_parallel', 1)
+ max_length = getattr(cfg.train_dataloader.dataset, 'max_length', None)
+ if max_length is not None:
+ assert max_length % sequence_parallel == 0, \
+ ('When using varlen attention, `max_length` should be evenly '
+ 'divided by sequence parallel world size, but got '
+ f'max_length = {max_length} and sequence_parallel = '
+ f'{sequence_parallel}')
+
+
def main():
args = parse_args()
@@ -96,6 +121,8 @@ def main():
# change these FunctionType object to str
register_function(cfg._cfg_dict)
+ check_cfg(cfg)
+
if cfg.get('framework', 'mmengine').lower() == 'huggingface':
# set default training_args
if cfg.get('training_args', None) is None:
@@ -277,7 +304,10 @@ def main():
gradient_accumulation_steps=grad_accum,
train_micro_batch_size_per_gpu=train_bs,
gradient_clipping=grad_clip,
- exclude_frozen_parameters=exclude_frozen_parameters)
+ exclude_frozen_parameters=exclude_frozen_parameters,
+ sequence_parallel_size=getattr(cfg,
+ 'sequence_parallel_size',
+ 1))
cfg.__setitem__('strategy', strategy)
optim_wrapper = dict(
type='DeepSpeedOptimWrapper',
From 32e3e5f0581998fd84f30f8a1847554a287c161a Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Mon, 25 Mar 2024 20:37:35 +0800
Subject: [PATCH 2/9] [Bug] Fix bugs in flash_attn1_pytorch (#513)
add @sequence_parallel_wrapper
---
xtuner/model/modules/dispatch/internlm2.py | 1 +
xtuner/model/modules/dispatch/llama.py | 1 +
2 files changed, 2 insertions(+)
diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py
index a166e8bae..b354e6fd5 100644
--- a/xtuner/model/modules/dispatch/internlm2.py
+++ b/xtuner/model/modules/dispatch/internlm2.py
@@ -156,6 +156,7 @@ def flash_attn_w_mask(
return attn_output
+@sequence_parallel_wrapper
def flash_attn1_pytorch(query_states, key_states, value_states, *args,
**kwargs):
# hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py
index 27b1f33d6..4077efb6c 100644
--- a/xtuner/model/modules/dispatch/llama.py
+++ b/xtuner/model/modules/dispatch/llama.py
@@ -113,6 +113,7 @@ def flash_attn_w_mask(
return attn_output
+@sequence_parallel_wrapper
def flash_attn1_pytorch(query_states, key_states, value_states, *args,
**kwargs):
# hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
From 6004335ead98496a2e84f59e1b2d2c3e41168b0d Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Tue, 26 Mar 2024 15:30:03 +0800
Subject: [PATCH 3/9] [Fix] delete cat in varlen attn (#508)
delete cat in varlen attn
---
.../varlen_attn_args_to_messagehub_hook.py | 33 ++++++-------------
xtuner/model/modules/dispatch/internlm2.py | 2 --
xtuner/model/modules/dispatch/llama.py | 3 --
3 files changed, 10 insertions(+), 28 deletions(-)
diff --git a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
index f2b23d3fe..f7a95a09c 100644
--- a/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
+++ b/xtuner/engine/hooks/varlen_attn_args_to_messagehub_hook.py
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Mapping, Optional, Sequence, Union
+from typing import Optional, Union
-import torch
import torch.distributed as dist
from mmengine import MessageHub
from mmengine.hooks import Hook
@@ -11,20 +10,6 @@
class VarlenAttnArgsToMessageHubHook(Hook):
- args = ('cumulative_len', 'max_seqlen')
-
- def cast_data(self, data):
- if isinstance(data, Mapping):
- return {key: self.cast_data(data[key]) for key in data}
- elif isinstance(data, (str, bytes)) or data is None:
- return data
- elif isinstance(data, Sequence):
- return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
- elif isinstance(data, torch.Tensor):
- return data.cuda()
- else:
- return data
-
def before_train_iter(self,
runner,
batch_idx: int,
@@ -35,10 +20,13 @@ def before_train_iter(self,
assert 'data' in data_batch.keys()
data = data_batch['data']
- for arg in self.args:
- assert arg in data
- message_hub.update_info(f'{arg}_rank_{rank}',
- self.cast_data(data.pop(arg)))
+ cumulative_len = data.pop('cumulative_len')
+ assert len(cumulative_len) == 1
+ cumulative_len = cumulative_len[0].cuda()
+ message_hub.update_info(f'cumulative_len_rank_{rank}', cumulative_len)
+
+ max_seqlen = data.pop('max_seqlen')
+ message_hub.update_info(f'max_seqlen_rank_{rank}', max_seqlen)
def after_train_iter(self,
runner,
@@ -47,6 +35,5 @@ def after_train_iter(self,
outputs: Optional[dict] = None) -> None:
rank = dist.get_rank()
message_hub = MessageHub.get_instance('varlen_attn_args')
-
- for arg in self.args:
- message_hub.update_info(f'{arg}_rank_{rank}', None)
+ message_hub.update_info(f'cumulative_len_rank_{rank}', None)
+ message_hub.update_info(f'max_seqlen_rank_{rank}', None)
diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py
index b354e6fd5..9bc77177c 100644
--- a/xtuner/model/modules/dispatch/internlm2.py
+++ b/xtuner/model/modules/dispatch/internlm2.py
@@ -174,7 +174,6 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
max_seqlen):
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
attn_output = flash_attn_varlen_func(
q_unpad,
k_unpad,
@@ -310,7 +309,6 @@ def internlm2_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py
index 4077efb6c..df17e1b49 100644
--- a/xtuner/model/modules/dispatch/llama.py
+++ b/xtuner/model/modules/dispatch/llama.py
@@ -131,7 +131,6 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
max_seqlen):
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
0, 1), value_states.flatten(0, 1)
- cumulative_len = torch.cat(cumulative_len, dim=0)
attn_output = flash_attn_varlen_func(
q_unpad,
k_unpad,
@@ -424,7 +423,6 @@ def llama_varlen_attn_forward_legacy(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
@@ -554,7 +552,6 @@ def llama_varlen_attn_forward(
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
- # position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
assert is_training == (cumulative_len is not None)
From 957abf63338216ef6144989873cc164219799fbe Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Wed, 27 Mar 2024 11:43:14 +0800
Subject: [PATCH 4/9] bump version to 0.1.16 (#520)
---
xtuner/version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/xtuner/version.py b/xtuner/version.py
index 11029f49f..ae73ce92a 100644
--- a/xtuner/version.py
+++ b/xtuner/version.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-__version__ = '0.1.16.dev0'
+__version__ = '0.1.16'
short_version = __version__
From 2c0fa5acc75567a5c56267b5c03f1123717f3fae Mon Sep 17 00:00:00 2001
From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com>
Date: Wed, 27 Mar 2024 11:43:49 +0800
Subject: [PATCH 5/9] [Improve] Add `generation_kwargs` for `EvaluateChatHook`
(#501)
* update
* update
---
xtuner/engine/hooks/evaluate_chat_hook.py | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/xtuner/engine/hooks/evaluate_chat_hook.py b/xtuner/engine/hooks/evaluate_chat_hook.py
index efa1bc69f..8e6a86822 100644
--- a/xtuner/engine/hooks/evaluate_chat_hook.py
+++ b/xtuner/engine/hooks/evaluate_chat_hook.py
@@ -29,7 +29,8 @@ def __init__(self,
every_n_iters=None,
max_new_tokens=600,
stop_word=None,
- stop_words=[]):
+ stop_words=[],
+ generation_kwargs={}):
self.evaluation_inputs = evaluation_inputs
if isinstance(self.evaluation_inputs, str):
self.evaluation_inputs = [self.evaluation_inputs]
@@ -69,8 +70,9 @@ def __init__(self,
if image_processor is not None:
self.image_processor = BUILDER.build(image_processor)
self.stop_criteria = StoppingCriteriaList()
+
# default generation config
- self.gen_config = GenerationConfig(
+ default_generation_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.1,
@@ -79,8 +81,10 @@ def __init__(self,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
if self.tokenizer.pad_token_id is not None else
- self.tokenizer.eos_token_id,
- )
+ self.tokenizer.eos_token_id)
+ default_generation_kwargs.update(generation_kwargs)
+ self.gen_config = GenerationConfig(**default_generation_kwargs)
+
self.stop_criteria = StoppingCriteriaList()
for word in stop_words:
self.stop_criteria.append(
From 62e7d80d9761c66446bcc46180aef936c101becb Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Wed, 27 Mar 2024 19:09:32 +0800
Subject: [PATCH 6/9] [Bugs] Fix bugs when training in non-distributed env
(#522)
fix bugs when training in non-distributed env
---
xtuner/parallel/sequence/data_collate.py | 3 +++
xtuner/parallel/sequence/setup_distributed.py | 26 ++++++++++++++-----
2 files changed, 22 insertions(+), 7 deletions(-)
diff --git a/xtuner/parallel/sequence/data_collate.py b/xtuner/parallel/sequence/data_collate.py
index f61b481b9..15b242d73 100644
--- a/xtuner/parallel/sequence/data_collate.py
+++ b/xtuner/parallel/sequence/data_collate.py
@@ -59,6 +59,9 @@ def pad_for_sequence_parallel(tokens,
def split_for_sequence_parallel(tokens, labels=None, position_ids=None):
seq_parallel_world_size = get_sequence_parallel_world_size()
+ if seq_parallel_world_size == 1:
+ return tokens, labels, position_ids
+
seq_parallel_world_rank = get_sequence_parallel_rank()
seq_len = tokens.size(1)
assert seq_len % seq_parallel_world_size == 0
diff --git a/xtuner/parallel/sequence/setup_distributed.py b/xtuner/parallel/sequence/setup_distributed.py
index ea207bf10..9eb159e66 100644
--- a/xtuner/parallel/sequence/setup_distributed.py
+++ b/xtuner/parallel/sequence/setup_distributed.py
@@ -59,8 +59,11 @@ def get_sequence_parallel_world_size():
global _SEQUENCE_PARALLEL_WORLD_SIZE
if _SEQUENCE_PARALLEL_WORLD_SIZE is not None:
return _SEQUENCE_PARALLEL_WORLD_SIZE
- _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
- group=get_sequence_parallel_group())
+ if not dist.is_initialized():
+ _SEQUENCE_PARALLEL_WORLD_SIZE = 1
+ else:
+ _SEQUENCE_PARALLEL_WORLD_SIZE = dist.get_world_size(
+ group=get_sequence_parallel_group())
return _SEQUENCE_PARALLEL_WORLD_SIZE
@@ -69,8 +72,11 @@ def get_sequence_parallel_rank():
global _SEQUENCE_PARALLEL_RANK
if _SEQUENCE_PARALLEL_RANK is not None:
return _SEQUENCE_PARALLEL_RANK
- _SEQUENCE_PARALLEL_RANK = dist.get_rank(
- group=get_sequence_parallel_group())
+ if not dist.is_initialized():
+ _SEQUENCE_PARALLEL_RANK = 0
+ else:
+ _SEQUENCE_PARALLEL_RANK = dist.get_rank(
+ group=get_sequence_parallel_group())
return _SEQUENCE_PARALLEL_RANK
@@ -86,8 +92,11 @@ def get_data_parallel_world_size():
global _DATA_PARALLEL_WORLD_SIZE
if _DATA_PARALLEL_WORLD_SIZE is not None:
return _DATA_PARALLEL_WORLD_SIZE
- _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size(
- group=get_data_parallel_group())
+ if not dist.is_initialized():
+ _DATA_PARALLEL_WORLD_SIZE = 1
+ else:
+ _DATA_PARALLEL_WORLD_SIZE = dist.get_world_size(
+ group=get_data_parallel_group())
return _DATA_PARALLEL_WORLD_SIZE
@@ -96,5 +105,8 @@ def get_data_parallel_rank():
global _DATA_PARALLEL_RANK
if _DATA_PARALLEL_RANK is not None:
return _DATA_PARALLEL_RANK
- _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group())
+ if not dist.is_initialized():
+ _DATA_PARALLEL_RANK = 0
+ else:
+ _DATA_PARALLEL_RANK = dist.get_rank(group=get_data_parallel_group())
return _DATA_PARALLEL_RANK
From 1dd5cbd367f34257d7f4e53910df0399005e5247 Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Thu, 28 Mar 2024 18:30:26 +0800
Subject: [PATCH 7/9] [Fix] Support transformers>=4.38 and require
transformers>=4.36.0 (#494)
* support transformers>=4.38 and require transformers>=4.36.0
* add flash attn2 to config if SUPPORT_FLASH2
* fix lint
* fix comments
* fix lint
* dispatch config if flash_attn or torch.nn.functional.scaled_dot_product_attention is supported
* do not dispatch attn forward if using scaled_dot_product_attention and require flash_attn installed to use sequence parallel
* dispatch config in llava model if flash_attn or torch.nn.functional.scaled_dot_product_attention is supported
---
requirements/runtime.txt | 6 +-
xtuner/model/llava.py | 66 ++-
xtuner/model/modules/dispatch/__init__.py | 10 +-
xtuner/model/modules/dispatch/internlm2.py | 33 +-
xtuner/model/modules/dispatch/llama.py | 513 +++++++++------------
xtuner/model/sft.py | 66 ++-
xtuner/tools/train.py | 5 +
7 files changed, 338 insertions(+), 361 deletions(-)
diff --git a/requirements/runtime.txt b/requirements/runtime.txt
index 5ccada772..f531754b3 100644
--- a/requirements/runtime.txt
+++ b/requirements/runtime.txt
@@ -18,8 +18,6 @@ tiktoken
# limit pytorch version <= 2.1.2 as there may be some bugs in triton 2.2
torch<=2.1.2
torchvision<=0.16.2
-# Minimum 4.34.0 to support added_tokens_decoder of tokenizer
-# Exclude 4.34.1, 4.35.0, 4.35.1, 4.35.2 to avoid BC-break,
-# see https://github.com/huggingface/transformers/pull/27020, https://github.com/huggingface/transformers/pull/27073
-transformers>=4.34.0,!=4.34.1,!=4.35.0,!=4.35.1,!=4.35.2
+# Minimum 4.36.0 to support `Cache` data structure used by KV Cache
+transformers>=4.36.0
transformers_stream_generator
diff --git a/xtuner/model/llava.py b/xtuner/model/llava.py
index d7e39a804..19b427a75 100644
--- a/xtuner/model/llava.py
+++ b/xtuner/model/llava.py
@@ -1,13 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import math
from collections import OrderedDict
+import torch
import torch.nn as nn
from mmengine.config import Config, ConfigDict
from mmengine.model import BaseModel
from peft import get_peft_model, prepare_model_for_kbit_training
+from transformers import AutoConfig
from xtuner.registry import BUILDER
from .modules import ProjectorConfig, ProjectorModel, dispatch_modules
+from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, guess_load_checkpoint,
make_inputs_require_grad,
@@ -26,11 +30,15 @@ def __init__(self,
projector_depth=2,
llm_lora=None,
visual_encoder_lora=None,
- use_activation_checkpointing=True):
+ use_activation_checkpointing=True,
+ max_position_embeddings=None):
super().__init__()
self.freeze_llm = freeze_llm
self.freeze_visual_encoder = freeze_visual_encoder
with LoadWoInit():
+ if isinstance(llm, dict):
+ llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
+
self.llm = self._build_from_cfg_or_module(llm)
self.visual_encoder = self._build_from_cfg_or_module(
visual_encoder)
@@ -157,6 +165,62 @@ def state_dict(self, *args, **kwargs):
for k, v in state_dict.items() if 'projector.' in k})
return to_return
+ @staticmethod
+ def _prepare_for_long_context_training(cfg, llm_cfg,
+ max_position_embeddings):
+
+ orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
+ if orig_rope_scaling is None:
+ orig_rope_scaling = {'factor': 1}
+
+ orig_rope_scaling_factor = orig_rope_scaling[
+ 'factor'] if 'factor' in orig_rope_scaling.keys() else 1
+ orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
+ if orig_ctx_len:
+ orig_ctx_len *= orig_rope_scaling_factor
+ if max_position_embeddings > orig_ctx_len:
+ scaling_factor = float(
+ math.ceil(max_position_embeddings / orig_ctx_len))
+ llm_cfg.rope_scaling = {
+ 'type': 'linear',
+ 'factor': scaling_factor
+ }
+
+ # hardcode for internlm2
+ llm_cfg.attn_implementation = 'flash_attention_2'
+ cfg.config = llm_cfg
+
+ return cfg, llm_cfg
+
+ @staticmethod
+ def _prepare_for_flash_attn(cfg, llm_cfg):
+ cls_name = type(llm_cfg).__name__
+ SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
+ 'MixtralConfig', 'Qwen2Config',
+ 'Starcoder2Config', 'Starcoder2Config')
+ SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
+ 'Starcoder2Config', 'Starcoder2Config')
+
+ if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
+ cfg.torch_dtype = torch.bfloat16 \
+ if torch.cuda.is_bf16_supported() else torch.float16
+ cfg.attn_implementation = 'flash_attention_2'
+ elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
+ cfg.attn_implementation = 'sdpa'
+
+ return cfg, llm_cfg
+
+ def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
+ pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
+ llm_cfg = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=True)
+ cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
+ if max_position_embeddings is not None:
+ cfg, llm_cfg = self._prepare_for_long_context_training(
+ cfg, llm_cfg, max_position_embeddings)
+ return cfg
+
def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
diff --git a/xtuner/model/modules/dispatch/__init__.py b/xtuner/model/modules/dispatch/__init__.py
index 6fbe37fb6..ab104a7dc 100644
--- a/xtuner/model/modules/dispatch/__init__.py
+++ b/xtuner/model/modules/dispatch/__init__.py
@@ -13,7 +13,7 @@
from .yi import yi_attn_forward
IS_LOW_VERSION_TRANSFORMERS = digit_version(
- transformers.__version__) < digit_version('4.36')
+ transformers.__version__) < digit_version('4.38')
SUPPORT_FLASH1 = digit_version(torch.__version__) >= digit_version('2.0.0')
SUPPORT_FLASH2 = False
@@ -48,7 +48,7 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
if use_varlen_attn:
assert SUPPORT_FLASH2 and SUPPORT_TRITON, \
'flash_attn and triton is required if you want to use varlen_attn.'
- elif not SUPPORT_FLASH:
+ elif not SUPPORT_FLASH2:
return
from .llama import (llama_attn_forward, llama_attn_forward_legacy,
@@ -57,8 +57,10 @@ def dispatch_llama_attn_forward(model, use_varlen_attn):
print_log(NO_ATTN_WEIGHTS_MSG, 'current', logging.WARNING)
for module in model.modules():
- if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2',
- 'LlamaSdpaAttention'):
+ # Do not need to dispatch if
+ # type(module).__name__ == 'LlamaSdpaAttention', as flash_attn is
+ # required when using sequence parallel
+ if type(module).__name__ in ('LlamaAttention', 'LlamaFlashAttention2'):
if use_varlen_attn:
print_log('dispatch llama varlen attn forward', 'current')
if IS_LOW_VERSION_TRANSFORMERS:
diff --git a/xtuner/model/modules/dispatch/internlm2.py b/xtuner/model/modules/dispatch/internlm2.py
index 9bc77177c..93a43229e 100644
--- a/xtuner/model/modules/dispatch/internlm2.py
+++ b/xtuner/model/modules/dispatch/internlm2.py
@@ -156,19 +156,6 @@ def flash_attn_w_mask(
return attn_output
-@sequence_parallel_wrapper
-def flash_attn1_pytorch(query_states, key_states, value_states, *args,
- **kwargs):
- # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = F.scaled_dot_product_attention(query_states, key_states,
- value_states, *args, **kwargs)
- attn_output = attn_output.transpose(1, 2)
- return attn_output
-
-
@sequence_parallel_wrapper
def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
max_seqlen):
@@ -251,12 +238,12 @@ def internlm2_attn_forward(
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- # flash attn 2 need (bs, seq_len, nhead, h_dim)
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
if SUPPORT_FLASH2:
+ # flash attn 2 need (bs, seq_len, nhead, h_dim)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
causal = self.is_causal and q_len != 1
if attention_mask is not None:
@@ -276,12 +263,10 @@ def internlm2_attn_forward(
training=self.training)
else:
# use flash attention implemented by pytorch
- attn_output = flash_attn1_pytorch(
- query_states,
- key_states,
- value_states,
- attn_mask=attention_mask,
- training=self.training)
+ # do not support sequence parallel
+ attn_output = F.scaled_dot_product_attention(
+ query_states, key_states, value_states, attn_mask=attention_mask)
+ attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.wo(attn_output)
diff --git a/xtuner/model/modules/dispatch/llama.py b/xtuner/model/modules/dispatch/llama.py
index df17e1b49..c9febf34f 100644
--- a/xtuner/model/modules/dispatch/llama.py
+++ b/xtuner/model/modules/dispatch/llama.py
@@ -4,10 +4,10 @@
import torch
import torch.distributed as dist
-import torch.nn.functional as F
from mmengine import MessageHub
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
-from transformers.utils import logging
+from transformers.models.llama.modeling_llama import (apply_rotary_pos_emb,
+ repeat_kv)
+from transformers.utils import is_flash_attn_greater_or_equal_2_10
from xtuner.parallel.sequence import sequence_parallel_wrapper
from .triton_kernels import apply_rotary_emb
@@ -30,34 +30,6 @@ class Cache:
pass
-logger = logging.get_logger(__name__)
-
-
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., :x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2:]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """This is the equivalent of torch.repeat_interleave(x, dim=1,
- repeats=n_rep).
-
- The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to
- (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :,
- None, :, :].expand(batch,
- num_key_value_heads,
- n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen,
- head_dim)
-
-
def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""The hidden states go from (batch, seqlen, num_key_value_heads, head_dim)
to (batch, seqlen, num_attention_heads, head_dim)"""
@@ -114,21 +86,12 @@ def flash_attn_w_mask(
@sequence_parallel_wrapper
-def flash_attn1_pytorch(query_states, key_states, value_states, *args,
- **kwargs):
- # hacky: pytorch flash attn need (bs, n_head, seq_len, h_dim)
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- attn_output = F.scaled_dot_product_attention(query_states, key_states,
- value_states, *args, **kwargs)
- attn_output = attn_output.transpose(1, 2)
- return attn_output
-
-
-@sequence_parallel_wrapper
-def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
- max_seqlen):
+def varlen_flash_attn(query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ dropout_rate=0.):
q_unpad, k_unpad, v_unpad = query_states.flatten(0, 1), key_states.flatten(
0, 1), value_states.flatten(0, 1)
attn_output = flash_attn_varlen_func(
@@ -139,7 +102,7 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
cumulative_len,
max_seqlen,
max_seqlen,
- 0,
+ dropout_p=dropout_rate,
return_attn_probs=False,
causal=True,
)
@@ -147,58 +110,29 @@ def varlen_flash_attn(query_states, key_states, value_states, cumulative_len,
return attn_output
-def llama_attn_forward_legacy(
+def llama_attn_forward(
self,
hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
**kwargs,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
- # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
+):
+ # Modified from https://github.com/huggingface/transformers/blob/66ce9593fdb8e340df546ddd0774eb444f17a12c/src/transformers/models/llama/modeling_llama.py#L422 # noqa:E501
+ output_attentions = False
- if 'padding_mask' in kwargs:
- warnings.warn('Passing `padding_mask` is deprecated and will be '
- 'removed in v4.37. Please make sure use '
- '`attention_mask` instead.`')
bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (
- self.num_key_value_heads * # noqa: W504
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
@@ -206,77 +140,90 @@ def llama_attn_forward_legacy(
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
self.head_dim).transpose(1, 2)
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
- cos, sin, position_ids)
+ cos, sin)
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
- past_key_value = (key_states, value_states) if use_cache else None
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models;
+ # cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
- # repeat kv for sequence parallel
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
- # flash attn 2 need (bs, seq_len, nhead, h_dim)
+ assert SUPPORT_FLASH2
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- if SUPPORT_FLASH2:
- causal = self.is_causal and q_len != 1
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons therefore the input hidden states gets silently
+ # casted in float32. Hence, we need cast them back in the correct dtype
+ # just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not
+ # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
- if attention_mask is not None:
- attn_output = flash_attn_w_mask(
- query_states,
- key_states,
- value_states,
- attention_mask,
- causal,
- training=self.training)
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
else:
- attn_output = flash_attn_wo_mask(
- query_states,
- key_states,
- value_states,
- causal,
- training=self.training)
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ if is_flash_attn_greater_or_equal_2_10():
+ causal = self.is_causal
+ else:
+ # TODO: Remove the `q_len != 1` check once Flash Attention for RoCm
+ # is bumped to 2.1. For details, please see the comment in
+ # LlamaFlashAttention2 __init__.
+ causal = self.is_causal and q_len != 1
+
+ if attention_mask is not None:
+ attn_output = flash_attn_w_mask(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ causal,
+ dropout_rate,
+ training=self.training)
else:
- # use flash attention implemented by pytorch
- attn_output = flash_attn1_pytorch(
+ attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
+ causal,
+ dropout_rate,
training=self.training)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
+ if not output_attentions:
+ attn_weights = None
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
-def llama_attn_forward(
+def llama_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -287,16 +234,11 @@ def llama_attn_forward(
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
- # LlamaFlashAttention2 attention does not support output_attentions
+ # Modified from https://github.com/huggingface/transformers/blob/ced9fd86f55ebb6b656c273f6e23f8ba50652f83/src/transformers/models/llama/modeling_llama.py#L331 # noqa:E501
if 'padding_mask' in kwargs:
warnings.warn(
- 'Passing `padding_mask` is deprecated and will be removed in v4.37'
- ' Please make sure use `attention_mask` instead.`')
-
- # overwrite attention_mask with padding_mask
- attention_mask = kwargs.pop('padding_mask')
-
- output_attentions = False
+ 'Passing `padding_mask` is deprecated and will be removed in '
+ 'v4.37. Please make sure use `attention_mask` instead.`')
bsz, q_len, _ = hidden_states.size()
@@ -304,9 +246,6 @@ def llama_attn_forward(
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
- # Flash attention requires the input to have the shape
- # batch_size x seq_length x head_dim x hidden_dim
- # therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads,
self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
@@ -316,11 +255,17 @@ def llama_attn_forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ 'The cache structure has changed since version v4.36. '
+ f'If you are using {self.__class__.__name__} '
+ 'for auto-regressive decoding with k/v caching, '
+ 'please make sure to initialize the attention class '
+ 'with a layer index.')
kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
-
+ assert position_ids is not None
if self.training:
- assert position_ids is not None
cos, sin = self.rotary_emb(
value_states, seq_len=position_ids.max() + 1)
else:
@@ -333,42 +278,38 @@ def llama_attn_forward(
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs)
- # TODO: These transpose are quite inefficient but Flash Attention
- # requires the layout [batch_size, sequence_length, num_heads, head_dim].
- # We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ assert SUPPORT_FLASH2
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
- dropout_rate = self.attention_dropout if self.training else 0.0
-
# In PEFT, usually we cast the layer norms in float32 for training
- # stability reasons, therefore the input hidden states gets silently
+ # stability reasons therefore the input hidden states gets silently
# casted in float32. Hence, we need cast them back in the correct dtype
# just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not
# cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
+
input_dtype = query_states.dtype
if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
- if hasattr(self.config, '_pre_quantization_dtype'):
+ elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
- logger.warning_once(
- f'The input hidden states seems to be silently casted in float32, '
- f'this might be related to the fact you have upcasted embedding '
- f'or layer norm layers in float32. We will cast back the input in'
- f' {target_dtype}.')
-
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
- # flash attn
- if not self._flash_attn_uses_top_left_mask:
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ if is_flash_attn_greater_or_equal_2_10():
causal = self.is_causal
else:
# TODO: Remove the `q_len != 1` check once Flash Attention for RoCm
@@ -376,10 +317,6 @@ def llama_attn_forward(
# LlamaFlashAttention2 __init__.
causal = self.is_causal and q_len != 1
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
-
if attention_mask is not None:
attn_output = flash_attn_w_mask(
query_states,
@@ -401,20 +338,21 @@ def llama_attn_forward(
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
+ # Due to the implementation of the PyTorch version of flash attention,
+ # even when the output_attentions flag is set to True, it is not possible
+ # to return the attn_weights.
+ return attn_output, None, past_key_value
-def llama_varlen_attn_forward_legacy(
+def llama_varlen_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
@@ -432,82 +370,70 @@ def llama_varlen_attn_forward_legacy(
'`attention_mask` instead.`')
bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (
- self.num_key_value_heads * # noqa: W504
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ query_states = query_states.view(bsz, q_len, self.num_heads,
+ self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
+ self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads,
- self.head_dim)
+ self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
+ cos, sin)
+
+ past_key_value = getattr(self, 'past_key_value', past_key_value)
- kv_seq_len = key_states.shape[-3]
if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
+ # sin and cos are specific to RoPE models;
+ # cache_position needed for the static cache
+ cache_kwargs = {
+ 'sin': sin,
+ 'cos': cos,
+ 'cache_position': cache_position
+ }
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, cache_kwargs)
- if is_training:
- cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states,
- cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- else:
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
- cos, sin = self.rotary_emb(value_states, kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids)
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
- if past_key_value is not None:
- # reuse k, v, self_attention
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ dropout_rate = self.attention_dropout if self.training else 0.0
- past_key_value = (key_states, value_states) if use_cache else None
- query_states = query_states.transpose(1, 2)
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons therefore the input hidden states gets silently casted
+ # in float32. Hence, we need cast them back in the correct dtype
+ # just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not
+ # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
- # repeat kv for sequence parallel
- key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
- value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
assert SUPPORT_FLASH2
if is_training:
- attn_output = varlen_flash_attn(query_states, key_states, value_states,
- cumulative_len, max_seqlen)
+ attn_output = varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ dropout_rate=dropout_rate)
else:
attn_output = flash_attn_wo_mask(
query_states,
@@ -517,26 +443,12 @@ def llama_varlen_attn_forward_legacy(
training=False)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
-
- # Due to the implementation of the PyTorch version of flash attention,
- # even when the output_attentions flag is set to True, it is not possible
- # to return the attn_weights.
return attn_output, None, past_key_value
-def llama_varlen_attn_forward(
+def llama_varlen_attn_forward_legacy(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
@@ -561,38 +473,9 @@ def llama_varlen_attn_forward(
'`attention_mask` instead.`')
bsz, q_len, _ = hidden_states.size()
- if self.config.pretraining_tp > 1:
- key_value_slicing = (
- self.num_key_value_heads * # noqa: W504
- self.head_dim) // self.config.pretraining_tp
- query_slices = self.q_proj.weight.split(
- (self.num_heads * self.head_dim) // self.config.pretraining_tp,
- dim=0)
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
-
- query_states = [
- F.linear(hidden_states, query_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- query_states = torch.cat(query_states, dim=-1)
-
- key_states = [
- F.linear(hidden_states, key_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- key_states = torch.cat(key_states, dim=-1)
-
- value_states = [
- F.linear(hidden_states, value_slices[i])
- for i in range(self.config.pretraining_tp)
- ]
- value_states = torch.cat(value_states, dim=-1)
-
- else:
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads,
@@ -614,11 +497,12 @@ def llama_varlen_attn_forward(
if is_training:
cos, sin = self.rotary_emb(value_states, max_seqlen)
- query_states = apply_rotary_emb(query_states,
- cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
- key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
- sin[position_ids].squeeze(0))
+ # position_ids (1, seq_len)
+ # cos, sin (1, seq_len, dim) -> (seq_len, dim)
+ cos = cos[position_ids].squeeze(0)
+ sin = sin[position_ids].squeeze(0)
+ query_states = apply_rotary_emb(query_states, cos, sin)
+ key_states = apply_rotary_emb(key_states, cos, sin)
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@@ -640,31 +524,50 @@ def llama_varlen_attn_forward(
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups)
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups)
+ dropout_rate = self.attention_dropout if self.training else 0.0
+
+ # In PEFT, usually we cast the layer norms in float32 for training
+ # stability reasons therefore the input hidden states gets silently casted
+ # in float32. Hence, we need cast them back in the correct dtype
+ # just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not
+ # cast the LayerNorms in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, '_pre_quantization_dtype'):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
assert SUPPORT_FLASH2
if is_training:
- attn_output = varlen_flash_attn(query_states, key_states, value_states,
- cumulative_len, max_seqlen)
+ attn_output = varlen_flash_attn(
+ query_states,
+ key_states,
+ value_states,
+ cumulative_len,
+ max_seqlen,
+ dropout_rate=dropout_rate)
else:
attn_output = flash_attn_wo_mask(
query_states,
key_states,
value_states,
causal=True,
+ dropout_rate=dropout_rate,
training=False)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- if self.config.pretraining_tp > 1:
- attn_output = attn_output.split(
- self.hidden_size // self.config.pretraining_tp, dim=2)
- o_proj_slices = self.o_proj.weight.split(
- self.hidden_size // self.config.pretraining_tp, dim=1)
- attn_output = sum([
- F.linear(attn_output[i], o_proj_slices[i])
- for i in range(self.config.pretraining_tp)
- ])
- else:
- attn_output = self.o_proj(attn_output)
+ attn_output = self.o_proj(attn_output)
# Due to the implementation of the PyTorch version of flash attention,
# even when the output_attentions flag is set to True, it is not possible
diff --git a/xtuner/model/sft.py b/xtuner/model/sft.py
index 7aa0ec63c..e1a29ab8a 100644
--- a/xtuner/model/sft.py
+++ b/xtuner/model/sft.py
@@ -17,7 +17,7 @@
reduce_sequence_parallel_loss)
from xtuner.registry import BUILDER
from .modules import dispatch_modules
-from .modules.dispatch import SUPPORT_FLASH2
+from .modules.dispatch import SUPPORT_FLASH1, SUPPORT_FLASH2
from .utils import (LoadWoInit, find_all_linear_names,
get_peft_model_state_dict, make_inputs_require_grad,
traverse_dict)
@@ -78,8 +78,9 @@ def __init__(self,
max_position_embeddings=None):
super().__init__()
with LoadWoInit():
- self.llm = self._build_from_cfg_or_module(llm,
- max_position_embeddings)
+ if isinstance(llm, dict):
+ llm = self._dispatch_lm_model_cfg(llm, max_position_embeddings)
+ self.llm = self._build_from_cfg_or_module(llm)
if tokenizer is not None:
if isinstance(tokenizer, dict):
@@ -144,48 +145,67 @@ def _prepare_for_lora(self,
def init_weights(self):
pass
- def _prepare_for_long_context_training(self, cfg, max_position_embeddings):
- pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
- config = AutoConfig.from_pretrained(
- pretrained_model_name_or_path, trust_remote_code=True)
+ @staticmethod
+ def _prepare_for_long_context_training(cfg, llm_cfg,
+ max_position_embeddings):
- orig_rope_scaling = getattr(config, 'rope_scaling', None)
+ orig_rope_scaling = getattr(llm_cfg, 'rope_scaling', None)
if orig_rope_scaling is None:
orig_rope_scaling = {'factor': 1}
orig_rope_scaling_factor = orig_rope_scaling[
'factor'] if 'factor' in orig_rope_scaling.keys() else 1
- orig_ctx_len = getattr(config, 'max_position_embeddings', None)
+ orig_ctx_len = getattr(llm_cfg, 'max_position_embeddings', None)
if orig_ctx_len:
orig_ctx_len *= orig_rope_scaling_factor
if max_position_embeddings > orig_ctx_len:
scaling_factor = float(
math.ceil(max_position_embeddings / orig_ctx_len))
- config.rope_scaling = {
+ llm_cfg.rope_scaling = {
'type': 'linear',
'factor': scaling_factor
}
# hardcode for internlm2
- config.attn_implementation = 'flash_attention_2'
-
- cfg.config = config
+ llm_cfg.attn_implementation = 'flash_attention_2'
+ cfg.config = llm_cfg
+
+ return cfg, llm_cfg
+
+ @staticmethod
+ def _prepare_for_flash_attn(cfg, llm_cfg):
+ cls_name = type(llm_cfg).__name__
+ SUPPORT_SDPA_ATTN = ('LlamaConfig', 'GemmaConfig', 'MistralConfig',
+ 'MixtralConfig', 'Qwen2Config',
+ 'Starcoder2Config', 'Starcoder2Config')
+ SUPPORT_FLASH_ATTN2 = ('InternLM2Config', 'LlamaConfig', 'GemmaConfig',
+ 'MistralConfig', 'MixtralConfig', 'Qwen2Config',
+ 'Starcoder2Config', 'Starcoder2Config')
+
+ if SUPPORT_FLASH2 and cls_name in SUPPORT_FLASH_ATTN2:
+ cfg.torch_dtype = torch.bfloat16 \
+ if torch.cuda.is_bf16_supported() else torch.float16
+ cfg.attn_implementation = 'flash_attention_2'
+ elif SUPPORT_FLASH1 and cls_name in SUPPORT_SDPA_ATTN:
+ cfg.attn_implementation = 'sdpa'
+
+ return cfg, llm_cfg
+
+ def _dispatch_lm_model_cfg(self, cfg, max_position_embeddings=None):
+ pretrained_model_name_or_path = cfg.pretrained_model_name_or_path
+ llm_cfg = AutoConfig.from_pretrained(
+ pretrained_model_name_or_path, trust_remote_code=True)
+ cfg, llm_cfg = self._prepare_for_flash_attn(cfg, llm_cfg)
+ if max_position_embeddings is not None:
+ cfg, llm_cfg = self._prepare_for_long_context_training(
+ cfg, llm_cfg, max_position_embeddings)
return cfg
- def _build_from_cfg_or_module(self,
- cfg_or_mod,
- max_position_embeddings=None):
+ def _build_from_cfg_or_module(self, cfg_or_mod):
if isinstance(cfg_or_mod, nn.Module):
return cfg_or_mod
elif isinstance(cfg_or_mod, dict):
traverse_dict(cfg_or_mod)
- if SUPPORT_FLASH2:
- cfg_or_mod.torch_dtype = torch.bfloat16 \
- if torch.cuda.is_bf16_supported() else torch.float16
- cfg_or_mod.attn_implementation = 'flash_attention_2'
- if max_position_embeddings is not None:
- cfg_or_mod = self._prepare_for_long_context_training(
- cfg_or_mod, max_position_embeddings)
return BUILDER.build(cfg_or_mod)
else:
raise NotImplementedError
diff --git a/xtuner/tools/train.py b/xtuner/tools/train.py
index 7acbbf21f..23e3d2a3f 100644
--- a/xtuner/tools/train.py
+++ b/xtuner/tools/train.py
@@ -19,6 +19,7 @@
from xtuner.configs import cfgs_name_path
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.model.modules import dispatch_modules
+from xtuner.model.modules.dispatch import SUPPORT_FLASH2
from xtuner.model.utils import LoadWoInit, find_all_linear_names, traverse_dict
from xtuner.registry import BUILDER, MAP_FUNC
from xtuner.tools.utils import (auto_dtype_of_deepspeed_config,
@@ -100,6 +101,10 @@ def check_cfg(cfg):
f'max_length = {max_length} and sequence_parallel = '
f'{sequence_parallel}')
+ if getattr(cfg, 'sequence_parallel_size', 1) > 1:
+ assert SUPPORT_FLASH2, ('`flash_attn` is required if you want to use '
+ 'sequence parallel.')
+
def main():
args = parse_args()
From 520ce99965cae6c2badb0b0b4d6eecc45b92a29d Mon Sep 17 00:00:00 2001
From: whcao <41630003+HIT-cwh@users.noreply.github.com>
Date: Fri, 29 Mar 2024 16:13:47 +0800
Subject: [PATCH 8/9] [Fix] Fix throughput hook (#527)
fix throughput hook
---
xtuner/engine/hooks/throughput_hook.py | 102 ++++++++++++++++++++-----
1 file changed, 82 insertions(+), 20 deletions(-)
diff --git a/xtuner/engine/hooks/throughput_hook.py b/xtuner/engine/hooks/throughput_hook.py
index cf31414c0..a07e216fe 100644
--- a/xtuner/engine/hooks/throughput_hook.py
+++ b/xtuner/engine/hooks/throughput_hook.py
@@ -1,11 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import logging
from typing import Optional, Union
import torch
+from mmengine import print_log
from mmengine.hooks import Hook
from mmengine.model.wrappers import is_model_wrapper
from torch.utils._pytree import tree_flatten
+from xtuner.parallel.sequence import get_sequence_parallel_world_size
+
DATA_BATCH = Optional[Union[dict, tuple, list]]
@@ -20,12 +24,39 @@ def __init__(self,
hidden_size=None,
num_layers=None,
vocab_size=None,
- mlp_ratio=None):
+ mlp_ratio=None,
+ is_casual=None):
self.use_activation_checkpointing = use_activation_checkpointing
self.hidden_size = hidden_size
self.num_layers = num_layers
self.vocab_size = vocab_size
self.mlp_ratio = mlp_ratio
+ self.is_casual = is_casual
+
+ @staticmethod
+ def _guess_is_casual_attn(model):
+ for module in model.modules():
+ if hasattr(module, 'is_causal'):
+ return module.is_causal
+ print_log(
+ 'It\'s impossible to speculate whether casual attention was used, '
+ 'and FLOPs will be calculated as `casual = True`.', 'current')
+ return True
+
+ @staticmethod
+ def _get_batch_size_and_sequence_len(data_batch):
+ data_list, _ = tree_flatten(data_batch)
+ for data in data_list:
+ if isinstance(data, torch.Tensor):
+ return data.size(0), data.size(1)
+ raise RuntimeError('No tensor found in the batch')
+
+ @staticmethod
+ def _guess_use_activation_checkpointing(model):
+ for module in model.modules():
+ if hasattr(module, 'gradient_checkpointing'):
+ return module.gradient_checkpointing
+ return False
def before_run(self, runner) -> None:
if is_model_wrapper(runner.model):
@@ -41,20 +72,18 @@ def before_run(self, runner) -> None:
self.mlp_ratio = self.mlp_ratio or (model.config.intermediate_size /
model.config.hidden_size)
self.mlp_ratio *= 1.5 # has gate_proj
- return
+ self.is_casual = self.is_casual if self.is_casual is not None \
+ else self._guess_is_casual_attn(model)
- def _get_batch_size_and_sequence_len(self, data_batch):
- data_list, _ = tree_flatten(data_batch)
- for data in data_list:
- if isinstance(data, torch.Tensor):
- return data.size(0), data.size(1)
- raise RuntimeError('No tensor found in the batch')
+ use_varlen_attn = getattr(model, 'use_varlen_attn', False)
+ if use_varlen_attn:
+ print_log(
+ 'Using variable-length Flash Attention causes an inflation'
+ ' in the FLOPs calculation.',
+ 'current',
+ level=logging.WARNING)
- def _guess_use_activation_checkpointing(self, model):
- for module in model.modules():
- if hasattr(module, 'gradient_checkpointing'):
- return module.gradient_checkpointing
- return False
+ return
def after_train_iter(self,
runner,
@@ -66,17 +95,50 @@ def after_train_iter(self,
batch_size, sequence_len = self._get_batch_size_and_sequence_len(
data_batch)
+ sequence_parallel_size = get_sequence_parallel_world_size()
message_hub = runner.message_hub
iter_time = message_hub.get_scalar('train/time').current()
- flops_per_iteration = (
- (3 + int(self.use_activation_checkpointing)) *
- ((8 + self.mlp_ratio * 4) * batch_size * sequence_len *
- self.hidden_size**2 +
- 4 * batch_size * sequence_len**2 * self.hidden_size)
- ) * self.num_layers + \
- 6 * batch_size * sequence_len * self.hidden_size * self.vocab_size
+ # We consider a language model with 𝑙 transformer layers,
+ # hidden size h, sequence length s, vocabulary size V, and
+ # training batch size B.
+ # A $A_{mxk}$ x $X_{kxn}$ matrix multiplication requires 2𝑚 ×𝑘 ×𝑛 FLOPs
+ # (factor of 2 needed to account for multiplies and adds).
+
+ # Attention Layer:
+ # qkv_proj + o_proj: 8B * s * h^2
+ # attn: 2B * s^2 * h (casual=False) and 2B * s^2 * h / 2 (casual=True)
+
+ # MLP Layer:
+ # up_proj + down_proj + gate_proj: 4B * s * h^2 * mlp_ratio
+ # (In Llama mlp_ratio = intermediate_size / hidden_size * 1.5
+ # (has gate_proj))
+
+ # The backward pass requires double the number of FLOPs since we
+ # need to calculate the gradients with respect to both input and
+ # weight tensors. In addition, we are using activation recomputation,
+ # which requires an additional forward pass before the backward pass.
+
+ # While sequence parallel will affect the FLOPs calculation in attn.
+ # Suppose the sequence length in one GPU is s and the sequence
+ # parallel world size is `sp_size`, which means the total
+ # sequence length in the attention calculation is
+ # `s * sp_size` and the number of attention heads decrease to
+ # `num_heads / sp_size`. Hence, the FLOPs in attn calculation is:
+ # 2B * (s * sp_size)^2 * (h / sp_size) (casual=False) and
+ # 2B * (s * sp_size)^2 * (h / sp_size) / 2 (casual=True)
+
+ flops_qkvo_proj = 8 * batch_size * sequence_len * self.hidden_size**2
+ flops_attn = 4 * batch_size * sequence_len**2 * self.hidden_size * \
+ sequence_parallel_size / (int(self.is_casual) + 1)
+ flops_mlp = 4 * self.mlp_ratio * batch_size * sequence_len * \
+ self.hidden_size**2
+ flops_wo_head = (3 + int(self.use_activation_checkpointing)) * (
+ flops_qkvo_proj + flops_attn + flops_mlp) * self.num_layers
+ flops_head = 3 * 2 * batch_size * sequence_len * self.hidden_size * \
+ self.vocab_size
+ flops_per_iteration = flops_wo_head + flops_head
avg_tflops_per_gpu = flops_per_iteration / 1e12 / (iter_time + 1e-12)
tokens_per_sec_per_gpu = batch_size * sequence_len / (
From 7de581c3702822d2f57a583577d95f955e311b1e Mon Sep 17 00:00:00 2001
From: Juicy <95841578+JianxinDong@users.noreply.github.com>
Date: Fri, 29 Mar 2024 16:25:46 +0800
Subject: [PATCH 9/9] Update README.md (#528)
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 0bfa3a02d..352e7215e 100644
--- a/README.md
+++ b/README.md
@@ -16,6 +16,7 @@
🔍 Explore our models on
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤗%20Huggingface)](https://huggingface.co/xtuner)
[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🤖%20ModelScope)](https://www.modelscope.cn/organization/xtuner)
+[![Static Badge](https://img.shields.io/badge/-gery?style=social&label=🧰%20OpenXLab)](https://openxlab.org.cn/usercenter/xtuner)
English | [简体中文](README_zh-CN.md)