From 3f0ab8d24424b6accca882b8e0da4606028d9bf9 Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Thu, 17 Aug 2023 16:26:25 +0800 Subject: [PATCH] fix (#6756) --- llm/README.md | 97 ++++++++++++------- llm/bloom/README.md | 2 +- llm/chatglm/README.md | 2 +- llm/chatglm2/README.md | 16 +++ .../gptq_argument.json | 4 +- .../lora_argument.json | 2 +- llm/{chatglm_v2 => chatglm2}/pt_argument.json | 2 +- .../ptq_argument.json | 4 +- .../sft_argument.json | 6 +- llm/chatglm_v2/README.md | 16 --- llm/llama/README.md | 2 +- llm/opt/sft_argument.json | 4 +- paddlenlp/trainer/training_args.py | 4 +- 13 files changed, 93 insertions(+), 68 deletions(-) create mode 100644 llm/chatglm2/README.md rename llm/{chatglm_v2 => chatglm2}/gptq_argument.json (74%) rename llm/{chatglm_v2 => chatglm2}/lora_argument.json (93%) rename llm/{chatglm_v2 => chatglm2}/pt_argument.json (93%) rename llm/{chatglm_v2 => chatglm2}/ptq_argument.json (81%) rename llm/{chatglm_v2 => chatglm2}/sft_argument.json (85%) delete mode 100644 llm/chatglm_v2/README.md diff --git a/llm/README.md b/llm/README.md index 8dcfffab599e..77797310d66d 100644 --- a/llm/README.md +++ b/llm/README.md @@ -4,8 +4,8 @@ | Model | Pretrain | SFT | LoRA | PrefixTuning | Generation | Quantization | | --- | --- | --- | --- | --- | --- | --- | | [LLaMA v1/v2](./llama) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | -| [ChatGLM-6B v1](./chatglm) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ | -| [ChatGLM-6B v2](./chatglm_v2) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ | +| [ChatGLM-6B](./chatglm) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ | +| [ChatGLM2-6B](./chatglm2) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ | | [Bloom](./bloom) | N/A | ✅ | ✅ | ✅ | ✅ | ✅ | | [GPT-3](./gpt-3) | ✅ | ✅ | ✅ | WIP | ✅ | WIP | | [OPT](./opt) | WIP | ✅ | ✅ | WIP| ✅ | WIP | @@ -34,7 +34,7 @@ [LLaMA v1/v2](./llama)、[GPT-3](./gpt-3) 目录中提供了模型预训练的数据准备和训练细节,后续我们将支持更多的模型预训练。 ## 3. 精调 -目前精调统一脚本只支持[LLaMA v1/v2](./llama)、[ChatGLM-6B](./chatglm)、[ChatGLM-6B v2](./chatglm_v2)、[Bloom](./bloom)、[OPT](./opt),其他模型精调使用详见对应模型目录。接下来我们将以**ChatGLM-6B v2**为例介绍如何使用统一脚本进行SFT、LoRA、Prefix Tuning。更多LoRA、Prefix Tuning请参见[PEFT文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/peft.md)。 +目前精调统一脚本只支持[LLaMA v1/v2](./llama)、[ChatGLM-6B](./chatglm)、[ChatGLM2-6B](./chatglm2)、[Bloom](./bloom)、[OPT](./opt),其他模型精调使用详见对应模型目录。接下来我们将以**Llama 2**为例介绍如何使用统一脚本进行SFT、LoRA、Prefix Tuning。更多LoRA、Prefix Tuning请参见[PEFT文档](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/docs/peft.md)。 ### 3.1 精调训练数据格式 @@ -59,8 +59,10 @@ SFT(Supervised Fine-Tuning)依托飞桨提出的[4D混合分布式并行](https: ``` # 张量并行分布式训练(常用) -# 目前OPT不支持张量并行 -python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./chatglm_v2/sft_argument.json +python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./llama/sft_argument.json + +# 目前ChatGLM2、OPT不支持张量并行,默认使用Sharding策略(Paddle 2.5.1支持Sharding Stage2,Sharding Stage3需要使用Paddle develop版本) +python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./chatglm2/sft_argument.json # 张量并行&流水线并行分布式训练(目前仅支持Llama) python -u -m paddle.distributed.launch --gpus "0,1,2,3" finetune_generation.py ./llama/sft_pp_argument.json @@ -81,11 +83,11 @@ PaddleNLP LoRA API支持数据并行、张量并行等多种分布式训练策 ``` # 单卡训练 -python finetune_generation.py ./chatglm_v2/lora_argument.json +python finetune_generation.py ./llama/lora_argument.json -# 张量并行分布式训练 +# 张量并行分布式训练(ChatGLM2、OPT不支持张量并行) # 将lora_argument.json中tensor_parallel_degree修改为2 -python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./chatglm_v2/lora_argument.json +python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./llama/lora_argument.json ``` @@ -100,24 +102,27 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./ PaddleNLP Prefix Tuning API支持数据并行、张量并行等多种分布式训练策略,可以通过控制`tensor_parallel_degree` 调整并行训练策略。 ``` # 单卡训练 -python finetune_generation.py ./chatglm_v2/pt_argument.json +python finetune_generation.py ./llama/pt_argument.json -# 张量并行分布式训练 +# 张量并行分布式训练(ChatGLM2、OPT不支持张量并行) # 将pt_argument.json中tensor_parallel_degree修改为2 -python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./chatglm_v2/pt_argument.json +python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./llama/pt_argument.json ``` ### 3.5 精调参数介绍 +
  模型参数(ModelArgument)
-**模型参数(ModelArgument):** - -- `model_name_or_path`: 预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。 +- `model_name_or_path`: 预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。每个模型**支持模型权重**详见各模型目录。 - `lora`: 是否开启LoRA微调策略,默认为False。 - `lora_path`: LoRA参数和配置路径,对LoRA参数进行初始化,默认为None。 - `lora_rank`: LoRA算法中rank(秩)的值,默认为8。 - `prefix_tuning`: 是否使用Prefix Tuning策略,默认为False。 - `num_prefix_tokens`: Prefix Tuning策略中Prefix Token数量,默认为128。 -**数据参数(DataArgument):** +
+ +
  数据参数(DataArgument)
+ + - `dataset_name_or_path`: 本地数据集目录或内置数据集名称,默认为None。 - `task_name`: 用于选择内置数据集中的具体任务,默认为None。 - `src_length`: 模型输入上下文最大长度,默认为1024。 @@ -126,15 +131,18 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./ - `save_generation_output`: 当`eval_with_do_generation`设为True,是否将生成结果保存在`generated_output.json`文件中,默认为False。 - `intokens`:是否使用InToken数据流(减少Padding冗余计算,大幅提升有效Token计算效率),默认为False。当`eval_with_do_generation`设为True,评估过程不支持InToken数据流。 - `intokens_max_length`: InToken数据流模型训练最大长度,默认为2048。 +
+ -**生成参数(GenerateArgument):** +
  生成参数(GenerateArgument)
注:以下参数仅在`eval_with_do_generation`为True,调用model.generate()时生效。 - `top_k`: “采样”策略中为 top-k 过滤保留的最高概率标记的数量。默认为1,等价于贪心策略。 - `top_p`:“采样”策略中 top-p 过滤的累积概率。默认为1.0,表示不起作用。 +
-**训练参数(TrainingArguments):** +
  训练参数(TrainingArguments)
以下仅介绍TrainingArguments部分常用参数,详情请参见[TrainingArguments文档](https://paddlenlp.readthedocs.io/zh/latest/trainer.html)。 @@ -162,57 +170,63 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./ - `tensor_parallel_degree`: 此参数tensor_parallel_degree表示将一层transformer结构的份数,该方法对通信开销较大, 建议 tensor_parallel_degree<=8, 尽量使用机器内部通信。默认为-1,表示不启用张量并行。 - `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。 +
+ ### 3.6 张量并行参数合并 我们使用张量并行(TP,Tensor Parallelism)训练过程中,为了节省TP参数合并时间往往在中间checkpoint将参数存储为多个TP参数分片,可以使用提供的分片合并参数脚本进行参数合并。 ``` python merge_tp_params.py \ - --model_name_or_path ./checkpoints/chatglm_v2_sft_ckpts/checkpoint-100 + --model_name_or_path ./checkpoints/llama_sft_ckpts/checkpoint-100 ``` -**参数:** +
  脚本参数介绍
- `model_name_or_path`: 必须,本地的TP模型参数路径,默认为None。 - `device`: 运行环境,默认为gpu。 +
### 3.7 LoRA参数合并 为了后续的**压缩**和**静态图推理**方便,我们提供LoRA参数合并脚本,可以将LoRA参数合并到主干模型并保存相应的权重。 ``` python merge_lora_params.py \ - --model_name_or_path THUDM/chatglm2-6b \ - --lora_path ./checkpoints/chatglm_v2_lora_ckpts + --model_name_or_path meta-llama/Llama-2-7b-chat \ + --lora_path ./checkpoints/llama_lora_ckpts ``` -**参数:** +
  脚本参数介绍
+ - `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。 - `lora_path`: LoRA参数和配置路径,对LoRA参数进行初始化,默认为None。 - `merge_model_path`: 必须,合并参数后保存路径,默认为None。 - `device`: 运行环境,默认为gpu。 +
## 4. 动态图推理 ``` python predict_generation.py \ - --model_name_or_path THUDM/chatglm2-6b \ + --model_name_or_path meta-llama/Llama-2-7b-chat \ --batch_size 1 \ --data_file ./data/dev.json \ --dtype "float16" # 加载LoRA参数 python predict_generation.py \ - --model_name_or_path THUDM/chatglm2-6b \ + --model_name_or_path meta-llama/Llama-2-7b-chat \ --batch_size 1 \ --data_file ./data/dev.json \ - --lora_path ./checkpoints/chatglm_v2_lora_ckpts + --lora_path ./checkpoints/llama_lora_ckpts # 加载Prefix Tuning参数 python predict_generation.py \ - --model_name_or_path THUDM/chatglm2-6b \ + --model_name_or_path meta-llama/Llama-2-7b-chat \ --batch_size 1 \ --data_file ./data/dev.json \ - --prefix_path ./checkpoints/chatglm_v2_pt_ckpts + --prefix_path ./checkpoints/llama_pt_ckpts ``` -**参数:** +
  脚本参数介绍
+ - `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。 - `batch_size`: 批处理大小,默认为8。该参数越大,占用显存越高;该参数越小,占用显存越低。 - `src_length`: 模型输入上下文最大长度,默认为1024。 @@ -228,6 +242,8 @@ python predict_generation.py \ - `dtype`: 模型参数dtype,默认为None。如果没有传入`lora_path`、`prefix_path`则必须传入 - `gpt`: 是否使用GPTForCausalLM模型,默认为False。 +
+ ## 5. 服务化部署 ### 5.1 环境准备 @@ -243,17 +259,21 @@ python predict_generation.py \ ``` python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \ - --model_name_or_path THUDM/chatglm2-6b \ + --model_name_or_path meta-llama/Llama-2-7b-chat \ --port 8010 \ --flask_port 8011 \ --src_length 1024 \ --dtype "float16" ``` -**参数:** -其他参数请参见动态图推理中参数。 +
  脚本参数介绍
+ + - `port`: Gradio UI 服务端口号,默认8011。 - `flask_port`: Flask服务端口号,默认8010。 +- 其他参数请参见动态图推理中参数。 + +
## 6. 量化 @@ -271,18 +291,19 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \ ### 6.3 PTQ量化 ``` -python finetune_generation.py ./chatglm_v2/ptq_argument.json +python finetune_generation.py ./llama/ptq_argument.json ``` ### 6.4 GPTQ量化 ``` -python finetune_generation.py ./chatglm_v2/gptq_argument.json +python finetune_generation.py ./llama/gptq_argument.json ``` ### 6.5 量化参数介绍 -**生成参数(QuantArgument):** +
  量化参数(QuantArgument)
+ - `quant_type`: PTQ,QAT量化类型,默认为A8W8。支持A8W8,WINT4,WINT8:A8W8指对激活(输入)进行INT8量化,对模型权重进行INT8量化;WINT4指仅对模型权重进行INT4量化,后续使用WeightOnly进行推理;WINT8指仅对模型权重进行INT8量化,后续使用WeightOnly进行推理。 - `do_ptq`: 是否进行PTQ量化,默认为False。 - `ptq_step`: PTQ量化步数,也即模型前向次数,默认为32。 @@ -299,12 +320,16 @@ python finetune_generation.py ./chatglm_v2/gptq_argument.json - `smooth_search_piece`: 使用分段搜索功能时,是否搜索分段数量,默认为False。设为True时,`smooth_k_piece`建议设为6,搜索分段数量耗时较长,如需加速Smooth过程建议关闭。 - `do_gptq`: 是否进行GPTQ量化,GPTQ对模型进行WINT4量化,相比于普通PTQ量化精度更高,量化时间较长。默认为False。 - `gptq_step`: GPTQ量化步数,也即模型前向次数,默认为8。 +
-**其他参数:** + +
  其他参数
- `per_device_train_batch_size`: 量化前向批大小,默认为8。量化过程只有模型前向,相比于普通训练需要显存较少。 -其他参数详见精调参数介绍。 +- 更多参数详见精调参数介绍。 + +
## 7. 静态图推理 diff --git a/llm/bloom/README.md b/llm/bloom/README.md index b7976e2236bb..346d7b826b84 100644 --- a/llm/bloom/README.md +++ b/llm/bloom/README.md @@ -1,4 +1,4 @@ -# Bloom +# BLOOM ## 1.模型介绍 diff --git a/llm/chatglm/README.md b/llm/chatglm/README.md index a96a1b062e03..c8cfb4f8b28b 100644 --- a/llm/chatglm/README.md +++ b/llm/chatglm/README.md @@ -1,4 +1,4 @@ -# ChatGLM +# ChatGLM-6B ## 1. 模型介绍 diff --git a/llm/chatglm2/README.md b/llm/chatglm2/README.md new file mode 100644 index 000000000000..9c07a20f9e1a --- /dev/null +++ b/llm/chatglm2/README.md @@ -0,0 +1,16 @@ +# ChatGLM2-6B + +## 1. 模型介绍 + +ChatGLM2-6B 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM2-6B 引入了[FlashAttention](https://github.com/HazyResearch/flash-attention)和[Multi-Query Attention](https://arxiv.org/abs/1911.02150v1)等新特性。更详细的模型介绍见[ChatGLM2-6B GitHub](https://github.com/THUDM/ChatGLM2-6B) + +**支持模型权重:** + +| Model | +|----------------------------------| +| THUDM/chatglm2-6b | + +## 2. 模型协议 + + +ChatGLM2-6B 模型的权重的使用需要遵循[License](../../paddlenlp/transformers/chatglm_v2/LICENSE)。 diff --git a/llm/chatglm_v2/gptq_argument.json b/llm/chatglm2/gptq_argument.json similarity index 74% rename from llm/chatglm_v2/gptq_argument.json rename to llm/chatglm2/gptq_argument.json index e5ccaf776ad5..4c8aff64c9d2 100644 --- a/llm/chatglm_v2/gptq_argument.json +++ b/llm/chatglm2/gptq_argument.json @@ -1,5 +1,5 @@ { - "model_name_or_path": "./checkpoints/chatglm_v2_sft_ckpts", + "model_name_or_path": "./checkpoints/chatglm2_sft_ckpts", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -8,7 +8,7 @@ "fp16": true, "fp16_opt_level": "O2", "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/chatglm_v2_gptq_ckpts", + "output_dir": "./checkpoints/chatglm2_gptq_ckpts", "do_eval": true, "eval_with_do_generation": false, "do_gptq": true, diff --git a/llm/chatglm_v2/lora_argument.json b/llm/chatglm2/lora_argument.json similarity index 93% rename from llm/chatglm_v2/lora_argument.json rename to llm/chatglm2/lora_argument.json index 1d1389a9081b..4af5769909ba 100644 --- a/llm/chatglm_v2/lora_argument.json +++ b/llm/chatglm2/lora_argument.json @@ -1,7 +1,7 @@ { "model_name_or_path": "THUDM/chatglm2-6b", "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/chatglm_v2_lora_ckpts", + "output_dir": "./checkpoints/chatglm2_lora_ckpts", "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "per_device_eval_batch_size": 8, diff --git a/llm/chatglm_v2/pt_argument.json b/llm/chatglm2/pt_argument.json similarity index 93% rename from llm/chatglm_v2/pt_argument.json rename to llm/chatglm2/pt_argument.json index 7a95868a1657..60ff7b9c4082 100644 --- a/llm/chatglm_v2/pt_argument.json +++ b/llm/chatglm2/pt_argument.json @@ -1,7 +1,7 @@ { "model_name_or_path": "THUDM/chatglm2-6b", "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/chatglm_v2_pt_ckpts", + "output_dir": "./checkpoints/chatglm2_pt_ckpts", "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "per_device_eval_batch_size": 8, diff --git a/llm/chatglm_v2/ptq_argument.json b/llm/chatglm2/ptq_argument.json similarity index 81% rename from llm/chatglm_v2/ptq_argument.json rename to llm/chatglm2/ptq_argument.json index f6b5323ffa24..98a759517fd0 100644 --- a/llm/chatglm_v2/ptq_argument.json +++ b/llm/chatglm2/ptq_argument.json @@ -1,5 +1,5 @@ { - "model_name_or_path": "./checkpoints/chatglm_v2_sft_ckpts", + "model_name_or_path": "./checkpoints/chatglm2_sft_ckpts", "per_device_train_batch_size": 8, "per_device_eval_batch_size": 8, "eval_accumulation_steps":16, @@ -8,7 +8,7 @@ "fp16": true, "fp16_opt_level": "O2", "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/chatglm_v2_ptq_ckpts", + "output_dir": "./checkpoints/chatglm2_ptq_ckpts", "do_eval": true, "eval_with_do_generation": false, "do_ptq": true, diff --git a/llm/chatglm_v2/sft_argument.json b/llm/chatglm2/sft_argument.json similarity index 85% rename from llm/chatglm_v2/sft_argument.json rename to llm/chatglm2/sft_argument.json index 6cd164c739c8..048a00b871b8 100644 --- a/llm/chatglm_v2/sft_argument.json +++ b/llm/chatglm2/sft_argument.json @@ -1,7 +1,7 @@ { "model_name_or_path": "THUDM/chatglm2-6b", "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/chatglm_v2_sft_ckpts", + "output_dir": "./checkpoints/chatglm2_sft_ckpts", "per_device_train_batch_size": 4, "gradient_accumulation_steps": 4, "per_device_eval_batch_size": 8, @@ -24,6 +24,6 @@ "metric_for_best_model": "accuracy", "recompute": true, "save_total_limit": 1, - "tensor_parallel_degree": 4, - "pipeline_parallel_degree": 1 + "sharding_parallel_degree": 4, + "sharding": "stage3" } \ No newline at end of file diff --git a/llm/chatglm_v2/README.md b/llm/chatglm_v2/README.md deleted file mode 100644 index 829ccc23d9f3..000000000000 --- a/llm/chatglm_v2/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# ChatGLM - -## 1. 模型介绍 - -ChatGLM**2**-6B 是开源中英双语对话模型 [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) 的第二代版本,在保留了初代模型对话流畅、部署门槛较低等众多优秀特性的基础之上,ChatGLM**2**-6B 引入了[FlashAttention](https://github.com/HazyResearch/flash-attention)和[Multi-Query Attention]等新特性。更详细的模型介绍见[ChatGLM2-6B GitHub](https://github.com/THUDM/ChatGLM2-6B) - -**支持模型权重:** - -| Model | -|----------------------------------| -| THUDM/chatglm2-6b | - -## 2. 模型协议 - - -ChatGLM2-6B 模型的权重的使用需要遵循[License](../../paddlenlp/transformers/chatglm_v2/LICENSE)。 diff --git a/llm/llama/README.md b/llm/llama/README.md index c4483d77de99..09eb07fafac2 100644 --- a/llm/llama/README.md +++ b/llm/llama/README.md @@ -37,7 +37,7 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat") ## 2. 模型协议 -Llama 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)。 +LLaMA 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/LICENSE)。 Llama2 模型的权重的使用则需要遵循[License](../../paddlenlp/transformers/llama/Llama2.LICENSE)。 diff --git a/llm/opt/sft_argument.json b/llm/opt/sft_argument.json index 10dfdc459d1d..5ae0ec047545 100644 --- a/llm/opt/sft_argument.json +++ b/llm/opt/sft_argument.json @@ -24,6 +24,6 @@ "metric_for_best_model": "accuracy", "recompute": true, "save_total_limit": 1, - "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 1 + "sharding_parallel_degree": 4, + "sharding": "stage2" } \ No newline at end of file diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 39dc07a7512a..459d1102377f 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -809,8 +809,8 @@ def __post_init__(self): if sharding_parallel_degree > 1 and ShardingOption.SHARD_OP in self.sharding: assert self.data_parallel_degree == 1, "sharding stage1 can not coexist with dp for now" - if ShardingOption.OFFLOAD in self.sharding or ShardingOption.FULL_SHARD in self.sharding: - warnings.warn("`offload` and `stage3` is not supported NOW!") + if ShardingOption.OFFLOAD in self.sharding: + warnings.warn("`offload` is not supported NOW!") if pipeline_parallel_degree > 1: if ShardingOption.FULL_SHARD in self.sharding or ShardingOption.SHARD_GRAD_OP in self.sharding: