From 3f0ab8d24424b6accca882b8e0da4606028d9bf9 Mon Sep 17 00:00:00 2001
From: lugimzzz <63761690+lugimzzz@users.noreply.github.com>
Date: Thu, 17 Aug 2023 16:26:25 +0800
Subject: [PATCH] fix (#6756)
---
llm/README.md | 97 ++++++++++++-------
llm/bloom/README.md | 2 +-
llm/chatglm/README.md | 2 +-
llm/chatglm2/README.md | 16 +++
.../gptq_argument.json | 4 +-
.../lora_argument.json | 2 +-
llm/{chatglm_v2 => chatglm2}/pt_argument.json | 2 +-
.../ptq_argument.json | 4 +-
.../sft_argument.json | 6 +-
llm/chatglm_v2/README.md | 16 ---
llm/llama/README.md | 2 +-
llm/opt/sft_argument.json | 4 +-
paddlenlp/trainer/training_args.py | 4 +-
13 files changed, 93 insertions(+), 68 deletions(-)
create mode 100644 llm/chatglm2/README.md
rename llm/{chatglm_v2 => chatglm2}/gptq_argument.json (74%)
rename llm/{chatglm_v2 => chatglm2}/lora_argument.json (93%)
rename llm/{chatglm_v2 => chatglm2}/pt_argument.json (93%)
rename llm/{chatglm_v2 => chatglm2}/ptq_argument.json (81%)
rename llm/{chatglm_v2 => chatglm2}/sft_argument.json (85%)
delete mode 100644 llm/chatglm_v2/README.md
diff --git a/llm/README.md b/llm/README.md
index 8dcfffab599e..77797310d66d 100644
--- a/llm/README.md
+++ b/llm/README.md
@@ -4,8 +4,8 @@
| Model | Pretrain | SFT | LoRA | PrefixTuning | Generation | Quantization |
| --- | --- | --- | --- | --- | --- | --- |
| [LLaMA v1/v2](./llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [ChatGLM-6B v1](./chatglm) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ |
-| [ChatGLM-6B v2](./chatglm_v2) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [ChatGLM-6B](./chatglm) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ |
+| [ChatGLM2-6B](./chatglm2) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Bloom](./bloom) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ |
| [GPT-3](./gpt-3) | ✅ | ✅ | ✅ | WIP | ✅ | WIP |
| [OPT](./opt) | WIP | ✅ | ✅ | WIP| ✅ | WIP |
@@ -34,7 +34,7 @@
[LLaMA v1/v2](./llama)、[GPT-3](./gpt-3) 目录中提供了模型预训练的数据准备和训练细节,后续我们将支持更多的模型预训练。
## 3. 精调
-目前精调统一脚本只支持[LLaMA v1/v2](./llama)、[ChatGLM-6B](./chatglm)、[ChatGLM-6B v2](./chatglm_v2)、[Bloom](./bloom)、[OPT](./opt),其他模型精调使用详见对应模型目录。接下来我们将以**ChatGLM-6B v2**为例介绍如何使用统一脚本进行SFT、LoRA、Prefix Tuning。更多LoRA、Prefix Tuning请参见[PEFT文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/peft.md)。
+目前精调统一脚本只支持[LLaMA v1/v2](./llama)、[ChatGLM-6B](./chatglm)、[ChatGLM2-6B](./chatglm2)、[Bloom](./bloom)、[OPT](./opt),其他模型精调使用详见对应模型目录。接下来我们将以**Llama 2**为例介绍如何使用统一脚本进行SFT、LoRA、Prefix Tuning。更多LoRA、Prefix Tuning请参见[PEFT文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/peft.md)。
### 3.1 精调训练数据格式
@@ -59,8 +59,10 @@ SFT(Supervised Fine-Tuning)依托飞桨提出的[4D混合分布式并行](https:
```
# 张量并行分布式训练(常用)
-# 目前OPT不支持张量并行
-python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./chatglm_v2/sft_argument.json
+python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./llama/sft_argument.json
+
+# 目前ChatGLM2、OPT不支持张量并行,默认使用Sharding策略(Paddle 2.5.1支持Sharding Stage2,Sharding Stage3需要使用Paddle develop版本)
+python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./chatglm2/sft_argument.json
# 张量并行&流水线并行分布式训练(目前仅支持Llama)
python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./llama/sft_pp_argument.json
@@ -81,11 +83,11 @@ PaddleNLP LoRA API支持数据并行、张量并行等多种分布式训练策
```
# 单卡训练
-python finetune_generation.py ./chatglm_v2/lora_argument.json
+python finetune_generation.py ./llama/lora_argument.json
-# 张量并行分布式训练
+# 张量并行分布式训练(ChatGLM2、OPT不支持张量并行)
# 将lora_argument.json中tensor_parallel_degree修改为2
-python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./chatglm_v2/lora_argument.json
+python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./llama/lora_argument.json
```
@@ -100,24 +102,27 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./
PaddleNLP Prefix Tuning API支持数据并行、张量并行等多种分布式训练策略,可以通过控制`tensor_parallel_degree` 调整并行训练策略。
```
# 单卡训练
-python finetune_generation.py ./chatglm_v2/pt_argument.json
+python finetune_generation.py ./llama/pt_argument.json
-# 张量并行分布式训练
+# 张量并行分布式训练(ChatGLM2、OPT不支持张量并行)
# 将pt_argument.json中tensor_parallel_degree修改为2
-python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./chatglm_v2/pt_argument.json
+python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./llama/pt_argument.json
```
### 3.5 精调参数介绍
+ 模型参数(ModelArgument)
-**模型参数(ModelArgument):**
-
-- `model_name_or_path`: 预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。
+- `model_name_or_path`: 预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。每个模型**支持模型权重**详见各模型目录。
- `lora`: 是否开启LoRA微调策略,默认为False。
- `lora_path`: LoRA参数和配置路径,对LoRA参数进行初始化,默认为None。
- `lora_rank`: LoRA算法中rank(秩)的值,默认为8。
- `prefix_tuning`: 是否使用Prefix Tuning策略,默认为False。
- `num_prefix_tokens`: Prefix Tuning策略中Prefix Token数量,默认为128。
-**数据参数(DataArgument):**
+
+
+ 数据参数(DataArgument)
+
+
- `dataset_name_or_path`: 本地数据集目录或内置数据集名称,默认为None。
- `task_name`: 用于选择内置数据集中的具体任务,默认为None。
- `src_length`: 模型输入上下文最大长度,默认为1024。
@@ -126,15 +131,18 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./
- `save_generation_output`: 当`eval_with_do_generation`设为True,是否将生成结果保存在`generated_output.json`文件中,默认为False。
- `intokens`:是否使用InToken数据流(减少Padding冗余计算,大幅提升有效Token计算效率),默认为False。当`eval_with_do_generation`设为True,评估过程不支持InToken数据流。
- `intokens_max_length`: InToken数据流模型训练最大长度,默认为2048。
+
+
-**生成参数(GenerateArgument):**
+ 生成参数(GenerateArgument)
注:以下参数仅在`eval_with_do_generation`为True,调用model.generate()时生效。
- `top_k`: “采样”策略中为 top-k 过滤保留的最高概率标记的数量。默认为1,等价于贪心策略。
- `top_p`:“采样”策略中 top-p 过滤的累积概率。默认为1.0,表示不起作用。
+
-**训练参数(TrainingArguments):**
+ 训练参数(TrainingArguments)
以下仅介绍TrainingArguments部分常用参数,详情请参见[TrainingArguments文档](https://paddlenlp.readthedocs.io/zh/latest/trainer.html)。
@@ -162,57 +170,63 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./
- `tensor_parallel_degree`: 此参数tensor_parallel_degree表示将一层transformer结构的份数,该方法对通信开销较大, 建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。默认为-1,表示不启用张量并行。
- `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。
+
+
### 3.6 张量并行参数合并
我们使用张量并行(TP,Tensor Parallelism)训练过程中,为了节省TP参数合并时间往往在中间checkpoint将参数存储为多个TP参数分片,可以使用提供的分片合并参数脚本进行参数合并。
```
python merge_tp_params.py \
- --model_name_or_path ./checkpoints/chatglm_v2_sft_ckpts/checkpoint-100
+ --model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100
```
-**参数:**
+ 脚本参数介绍
- `model_name_or_path`: 必须,本地的TP模型参数路径,默认为None。
- `device`: 运行环境,默认为gpu。
+
### 3.7 LoRA参数合并
为了后续的**压缩**和**静态图推理**方便,我们提供LoRA参数合并脚本,可以将LoRA参数合并到主干模型并保存相应的权重。
```
python merge_lora_params.py \
- --model_name_or_path THUDM/chatglm2-6b \
- --lora_path ./checkpoints/chatglm_v2_lora_ckpts
+ --model_name_or_path meta-llama/Llama-2-7b-chat \
+ --lora_path ./checkpoints/llama_lora_ckpts
```
-**参数:**
+ 脚本参数介绍
+
- `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。
- `lora_path`: LoRA参数和配置路径,对LoRA参数进行初始化,默认为None。
- `merge_model_path`: 必须,合并参数后保存路径,默认为None。
- `device`: 运行环境,默认为gpu。
+
## 4. 动态图推理
```
python predict_generation.py \
- --model_name_or_path THUDM/chatglm2-6b \
+ --model_name_or_path meta-llama/Llama-2-7b-chat \
--batch_size 1 \
--data_file ./data/dev.json \
--dtype "float16"
# 加载LoRA参数
python predict_generation.py \
- --model_name_or_path THUDM/chatglm2-6b \
+ --model_name_or_path meta-llama/Llama-2-7b-chat \
--batch_size 1 \
--data_file ./data/dev.json \
- --lora_path ./checkpoints/chatglm_v2_lora_ckpts
+ --lora_path ./checkpoints/llama_lora_ckpts
# 加载Prefix Tuning参数
python predict_generation.py \
- --model_name_or_path THUDM/chatglm2-6b \
+ --model_name_or_path meta-llama/Llama-2-7b-chat \
--batch_size 1 \
--data_file ./data/dev.json \
- --prefix_path ./checkpoints/chatglm_v2_pt_ckpts
+ --prefix_path ./checkpoints/llama_pt_ckpts
```
-**参数:**
+ 脚本参数介绍
+
- `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。
- `batch_size`: 批处理大小,默认为8。该参数越大,占用显存越高;该参数越小,占用显存越低。
- `src_length`: 模型输入上下文最大长度,默认为1024。
@@ -228,6 +242,8 @@ python predict_generation.py \
- `dtype`: 模型参数dtype,默认为None。如果没有传入`lora_path`、`prefix_path`则必须传入
- `gpt`: 是否使用GPTForCausalLM模型,默认为False。
+
+
## 5. 服务化部署
### 5.1 环境准备
@@ -243,17 +259,21 @@ python predict_generation.py \
```
python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
- --model_name_or_path THUDM/chatglm2-6b \
+ --model_name_or_path meta-llama/Llama-2-7b-chat \
--port 8010 \
--flask_port 8011 \
--src_length 1024 \
--dtype "float16"
```
-**参数:**
-其他参数请参见动态图推理中参数。
+ 脚本参数介绍
+
+
- `port`: Gradio UI 服务端口号,默认8011。
- `flask_port`: Flask服务端口号,默认8010。
+- 其他参数请参见动态图推理中参数。
+
+
## 6. 量化
@@ -271,18 +291,19 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
### 6.3 PTQ量化
```
-python finetune_generation.py ./chatglm_v2/ptq_argument.json
+python finetune_generation.py ./llama/ptq_argument.json
```
### 6.4 GPTQ量化
```
-python finetune_generation.py ./chatglm_v2/gptq_argument.json
+python finetune_generation.py ./llama/gptq_argument.json
```
### 6.5 量化参数介绍
-**生成参数(QuantArgument):**
+ 量化参数(QuantArgument)
+
- `quant_type`: PTQ,QAT量化类型,默认为A8W8。支持A8W8,WINT4,WINT8:A8W8指对激活(输入)进行INT8量化,对模型权重进行INT8量化;WINT4指仅对模型权重进行INT4量化,后续使用WeightOnly进行推理;WINT8指仅对模型权重进行INT8量化,后续使用WeightOnly进行推理。
- `do_ptq`: 是否进行PTQ量化,默认为False。
- `ptq_step`: PTQ量化步数,也即模型前向次数,默认为32。
@@ -299,12 +320,16 @@ python finetune_generation.py ./chatglm_v2/gptq_argument.json
- `smooth_search_piece`: 使用分段搜索功能时,是否搜索分段数量,默认为False。设为True时,`smooth_k_piece`建议设为6,搜索分段数量耗时较长,如需加速Smooth过程建议关闭。
- `do_gptq`: 是否进行GPTQ量化,GPTQ对模型进行WINT4量化,相比于普通PTQ量化精度更高,量化时间较长。默认为False。
- `gptq_step`: GPTQ量化步数,也即模型前向次数,默认为8。
+
-**其他参数:**
+
+ 其他参数
- `per_device_train_batch_size`: 量化前向批大小,默认为8。量化过程只有模型前向,相比于普通训练需要显存较少。
-其他参数详见精调参数介绍。
+- 更多参数详见精调参数介绍。
+
+
## 7. 静态图推理
diff --git a/llm/bloom/README.md b/llm/bloom/README.md
index b7976e2236bb..346d7b826b84 100644
--- a/llm/bloom/README.md
+++ b/llm/bloom/README.md
@@ -1,4 +1,4 @@
-# Bloom
+# BLOOM
## 1.模型介绍
diff --git a/llm/chatglm/README.md b/llm/chatglm/README.md
index a96a1b062e03..c8cfb4f8b28b 100644
--- a/llm/chatglm/README.md
+++ b/llm/chatglm/README.md
@@ -1,4 +1,4 @@
-# ChatGLM
+# ChatGLM-6B
## 1. 模型介绍
diff --git a/llm/chatglm2/README.md b/llm/chatglm2/README.md
new file mode 100644
index 000000000000..9c07a20f9e1a
--- /dev/null
+++ b/llm/chatglm2/README.md
@@ -0,0 +1,16 @@
+# ChatGLM2-6B
+
+## 1. 模型介绍
+
+ChatGLM2-6B 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 引入了[FlashAttention](https://github.com/HazyResearch/flash-attention)和[Multi-Query Attention](https://arxiv.org/abs/1911.02150v1)等新特性。更详细的模型介绍见[ChatGLM2-6B GitHub](https://github.com/THUDM/ChatGLM2-6B)
+
+**支持模型权重:**
+
+| Model |
+|----------------------------------|
+| THUDM/chatglm2-6b |
+
+## 2. 模型协议
+
+
+ChatGLM2-6B 模型的权重的使用需要遵循[License](../../paddlenlp/transformers/chatglm_v2/LICENSE)。
diff --git a/llm/chatglm_v2/gptq_argument.json b/llm/chatglm2/gptq_argument.json
similarity index 74%
rename from llm/chatglm_v2/gptq_argument.json
rename to llm/chatglm2/gptq_argument.json
index e5ccaf776ad5..4c8aff64c9d2 100644
--- a/llm/chatglm_v2/gptq_argument.json
+++ b/llm/chatglm2/gptq_argument.json
@@ -1,5 +1,5 @@
{
- "model_name_or_path": "./checkpoints/chatglm_v2_sft_ckpts",
+ "model_name_or_path": "./checkpoints/chatglm2_sft_ckpts",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -8,7 +8,7 @@
"fp16": true,
"fp16_opt_level": "O2",
"dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/chatglm_v2_gptq_ckpts",
+ "output_dir": "./checkpoints/chatglm2_gptq_ckpts",
"do_eval": true,
"eval_with_do_generation": false,
"do_gptq": true,
diff --git a/llm/chatglm_v2/lora_argument.json b/llm/chatglm2/lora_argument.json
similarity index 93%
rename from llm/chatglm_v2/lora_argument.json
rename to llm/chatglm2/lora_argument.json
index 1d1389a9081b..4af5769909ba 100644
--- a/llm/chatglm_v2/lora_argument.json
+++ b/llm/chatglm2/lora_argument.json
@@ -1,7 +1,7 @@
{
"model_name_or_path": "THUDM/chatglm2-6b",
"dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/chatglm_v2_lora_ckpts",
+ "output_dir": "./checkpoints/chatglm2_lora_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
diff --git a/llm/chatglm_v2/pt_argument.json b/llm/chatglm2/pt_argument.json
similarity index 93%
rename from llm/chatglm_v2/pt_argument.json
rename to llm/chatglm2/pt_argument.json
index 7a95868a1657..60ff7b9c4082 100644
--- a/llm/chatglm_v2/pt_argument.json
+++ b/llm/chatglm2/pt_argument.json
@@ -1,7 +1,7 @@
{
"model_name_or_path": "THUDM/chatglm2-6b",
"dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/chatglm_v2_pt_ckpts",
+ "output_dir": "./checkpoints/chatglm2_pt_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
diff --git a/llm/chatglm_v2/ptq_argument.json b/llm/chatglm2/ptq_argument.json
similarity index 81%
rename from llm/chatglm_v2/ptq_argument.json
rename to llm/chatglm2/ptq_argument.json
index f6b5323ffa24..98a759517fd0 100644
--- a/llm/chatglm_v2/ptq_argument.json
+++ b/llm/chatglm2/ptq_argument.json
@@ -1,5 +1,5 @@
{
- "model_name_or_path": "./checkpoints/chatglm_v2_sft_ckpts",
+ "model_name_or_path": "./checkpoints/chatglm2_sft_ckpts",
"per_device_train_batch_size": 8,
"per_device_eval_batch_size": 8,
"eval_accumulation_steps":16,
@@ -8,7 +8,7 @@
"fp16": true,
"fp16_opt_level": "O2",
"dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/chatglm_v2_ptq_ckpts",
+ "output_dir": "./checkpoints/chatglm2_ptq_ckpts",
"do_eval": true,
"eval_with_do_generation": false,
"do_ptq": true,
diff --git a/llm/chatglm_v2/sft_argument.json b/llm/chatglm2/sft_argument.json
similarity index 85%
rename from llm/chatglm_v2/sft_argument.json
rename to llm/chatglm2/sft_argument.json
index 6cd164c739c8..048a00b871b8 100644
--- a/llm/chatglm_v2/sft_argument.json
+++ b/llm/chatglm2/sft_argument.json
@@ -1,7 +1,7 @@
{
"model_name_or_path": "THUDM/chatglm2-6b",
"dataset_name_or_path": "./data",
- "output_dir": "./checkpoints/chatglm_v2_sft_ckpts",
+ "output_dir": "./checkpoints/chatglm2_sft_ckpts",
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 4,
"per_device_eval_batch_size": 8,
@@ -24,6 +24,6 @@
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
- "tensor_parallel_degree": 4,
- "pipeline_parallel_degree": 1
+ "sharding_parallel_degree": 4,
+ "sharding": "stage3"
}
\ No newline at end of file
diff --git a/llm/chatglm_v2/README.md b/llm/chatglm_v2/README.md
deleted file mode 100644
index 829ccc23d9f3..000000000000
--- a/llm/chatglm_v2/README.md
+++ /dev/null
@@ -1,16 +0,0 @@
-# ChatGLM
-
-## 1. 模型介绍
-
-ChatGLM**2**-6B 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了[FlashAttention](https://github.com/HazyResearch/flash-attention)和[Multi-Query Attention]等新特性。更详细的模型介绍见[ChatGLM2-6B GitHub](https://github.com/THUDM/ChatGLM2-6B)
-
-**支持模型权重:**
-
-| Model |
-|----------------------------------|
-| THUDM/chatglm2-6b |
-
-## 2. 模型协议
-
-
-ChatGLM2-6B 模型的权重的使用需要遵循[License](../../paddlenlp/transformers/chatglm_v2/LICENSE)。
diff --git a/llm/llama/README.md b/llm/llama/README.md
index c4483d77de99..09eb07fafac2 100644
--- a/llm/llama/README.md
+++ b/llm/llama/README.md
@@ -37,7 +37,7 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat")
## 2. 模型协议
-Llama 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)。
+LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)。
Llama2 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/Llama2.LICENSE)。
diff --git a/llm/opt/sft_argument.json b/llm/opt/sft_argument.json
index 10dfdc459d1d..5ae0ec047545 100644
--- a/llm/opt/sft_argument.json
+++ b/llm/opt/sft_argument.json
@@ -24,6 +24,6 @@
"metric_for_best_model": "accuracy",
"recompute": true,
"save_total_limit": 1,
- "tensor_parallel_degree": 1,
- "pipeline_parallel_degree": 1
+ "sharding_parallel_degree": 4,
+ "sharding": "stage2"
}
\ No newline at end of file
diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py
index 39dc07a7512a..459d1102377f 100644
--- a/paddlenlp/trainer/training_args.py
+++ b/paddlenlp/trainer/training_args.py
@@ -809,8 +809,8 @@ def __post_init__(self):
if sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.sharding:
assert self.data_parallel_degree == 1, "sharding stage1 can not coexist with dp for now"
- if ShardingOption.OFFLOAD in self.sharding or ShardingOption.FULL_SHARD in self.sharding:
- warnings.warn("`offload` and `stage3` is not supported NOW!")
+ if ShardingOption.OFFLOAD in self.sharding:
+ warnings.warn("`offload` is not supported NOW!")
if pipeline_parallel_degree > 1:
if ShardingOption.FULL_SHARD in self.sharding or ShardingOption.SHARD_GRAD_OP in self.sharding: