-
Notifications
You must be signed in to change notification settings - Fork 239
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from airaria/v0.2.0dev
v0.2.0dev
- Loading branch information
Showing
27 changed files
with
1,280 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[**中文说明**](README_ZH.md) | [**English**](README.md) | ||
|
||
This example demonstrates distilling a [Chinese-ELECTRA-base](https://github.com/ymcui/Chinese-ELECTRA) model on the MSRA NER task with **distributed data-parallel training**(single node, muliti-GPU). | ||
|
||
|
||
* ner_ElectraTrain_dist.sh : trains a treacher model (Chinese-ELECTRA-base) on MSRA NER. | ||
* ner_ElectraDistill_dist.sh : distills the teacher to a ELECTRA-small model. | ||
|
||
|
||
Set the following variables in the shell scripts before running: | ||
|
||
* ELECTRA_DIR_BASE : where Chinese-ELECTRA-base locates, should includ vocab.txt, pytorch_model.bin and config.json. | ||
|
||
* OUTPUT_DIR : this directory stores the logs and the trained model weights. | ||
* DATA_DIR : it includes MSRA NER dataset: | ||
* msra_train_bio.txt | ||
* msra_test_bio.txt | ||
|
||
For distillation: | ||
|
||
* ELECTRA_DIR_SMALL : where the pretrained Chinese-ELECTRA-small weight locates, should include pytorch_model.bin. This is optional. If you don't provide the ELECTRA-small weight, the student model will be initialized randomly. | ||
* student_config_file : the model config file (i.e., config.json) for the student. Usually it should be in $\{ELECTRA_DIR_SMALL\}. | ||
* trained_teacher_model_file : the ELECTRA-base teacher model that has been fine-tuned. | ||
|
||
The scripts have been tested under **PyTorch==1.2, Transformers==2.8**. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
[**中文说明**](README_ZH.md) | [**English**](README.md) | ||
|
||
这个例子展示MSRA NER(中文命名实体识别)任务上,在**分布式数据并行训练**(Distributed Data-Parallel, DDP)模式(single node, muliti-GPU)下的[Chinese-ELECTRA-base](https://github.com/ymcui/Chinese-ELECTRA)模型蒸馏。 | ||
|
||
|
||
* ner_ElectraTrain_dist.sh : 训练教师模型(ELECTRA-base)。 | ||
* ner_ElectraDistill_dist.sh : 将教师模型蒸馏到学生模型(ELECTRA-small)。 | ||
|
||
|
||
运行脚本前,请根据自己的环境设置相应变量: | ||
|
||
* ELECTRA_DIR_BASE : 存放Chinese-ELECTRA-base模型的目录,包含vocab.txt,pytorch_model.bin和config.json。 | ||
|
||
* OUTPUT_DIR : 存放训练好的模型权重文件和日志。 | ||
* DATA_DIR : MSRA NER数据集目录,包含 | ||
* msra_train_bio.txt | ||
* msra_test_bio.txt | ||
|
||
对于蒸馏,需要设置: | ||
|
||
* ELECTRA_DIR_SMALL : Chinese-ELECTRA-small预训练权重所在目录。应包含pytorch_model.bin。 也可不提供预训练权重,则学生模型将随机初始化。 | ||
* student_config_file : 学生模型配置文件,一般文件名为config.json,也位于 $\{ELECTRA_DIR_SMALL\}。 | ||
* trained_teacher_model_file : 在MSRA NER任务上训练好的ELECTRA-base教师模型。 | ||
|
||
该脚本在 **PyTorch==1.2, Transformers==2.8** 下测试通过。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import argparse | ||
|
||
args = None | ||
|
||
def parse(opt=None): | ||
parser = argparse.ArgumentParser() | ||
|
||
## Required parameters | ||
|
||
parser.add_argument("--vocab_file", default=None, type=str, required=True, | ||
help="The vocabulary file that the BERT model was trained on.") | ||
parser.add_argument("--output_dir", default=None, type=str, required=True, | ||
help="The output directory where the model checkpoints will be written.") | ||
|
||
## Other parameters | ||
parser.add_argument("--train_file", default=None, type=str) | ||
parser.add_argument("--predict_file", default=None, type=str) | ||
parser.add_argument("--do_lower_case", action='store_true', | ||
help="Whether to lower case the input text. Should be True for uncased " | ||
"models and False for cased models.") | ||
parser.add_argument("--max_seq_length", default=416, type=int, | ||
help="The maximum total input sequence length after WordPiece tokenization. Sequences " | ||
"longer than this will be truncated, and sequences shorter than this will be padded.") | ||
parser.add_argument("--do_train", default=False, action='store_true', help="Whether to run training.") | ||
parser.add_argument("--do_predict", default=False, action='store_true', help="Whether to run eval on the dev set.") | ||
parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") | ||
parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.") | ||
parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") | ||
parser.add_argument("--num_train_epochs", default=3.0, type=float, | ||
help="Total number of training epochs to perform.") | ||
parser.add_argument("--warmup_proportion", default=0.1, type=float, | ||
help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% " | ||
"of training.") | ||
parser.add_argument("--verbose_logging", default=False, action='store_true', | ||
help="If true, all of the warnings related to data processing will be printed. " | ||
"A number of warnings are expected for a normal SQuAD evaluation.") | ||
parser.add_argument("--no_cuda", | ||
default=False, | ||
action='store_true', | ||
help="Whether not to use CUDA when available") | ||
parser.add_argument('--gradient_accumulation_steps', | ||
type=int, | ||
default=1, | ||
help="Number of updates steps to accumualte before performing a backward/update pass.") | ||
parser.add_argument("--local_rank", | ||
type=int, | ||
default=-1, | ||
help="local_rank for distributed training on gpus") | ||
parser.add_argument('--fp16', | ||
default=False, | ||
action='store_true', | ||
help="Whether to use 16-bit float precisoin instead of 32-bit") | ||
|
||
parser.add_argument('--random_seed',type=int,default=10236797) | ||
parser.add_argument('--load_model_type',type=str,default='bert',choices=['bert','all','none']) | ||
parser.add_argument('--weight_decay_rate',type=float,default=0.01) | ||
parser.add_argument('--do_eval',action='store_true') | ||
parser.add_argument('--PRINT_EVERY',type=int,default=200) | ||
parser.add_argument('--weight',type=float,default=1.0) | ||
parser.add_argument('--ckpt_frequency',type=int,default=2) | ||
|
||
parser.add_argument('--tuned_checkpoint_T',type=str,default=None) | ||
parser.add_argument('--tuned_checkpoint_S',type=str,default=None) | ||
parser.add_argument("--init_checkpoint_S", default=None, type=str) | ||
parser.add_argument("--bert_config_file_T", default=None, type=str, required=True) | ||
parser.add_argument("--bert_config_file_S", default=None, type=str, required=True) | ||
parser.add_argument("--temperature", default=1, type=float, required=False) | ||
parser.add_argument("--teacher_cached",action='store_true') | ||
|
||
parser.add_argument('--schedule',type=str,default='warmup_linear_release') | ||
|
||
parser.add_argument('--no_inputs_mask',action='store_true') | ||
parser.add_argument('--no_logits', action='store_true') | ||
parser.add_argument('--output_encoded_layers' ,default='true',choices=['true','false']) | ||
parser.add_argument('--output_attention_layers',default='true',choices=['true','false']) | ||
parser.add_argument('--matches',nargs='*',type=str) | ||
|
||
parser.add_argument('--lr_decay',default=None,type=float) | ||
parser.add_argument('--official_schedule',default='linear',type=str) | ||
global args | ||
if opt is None: | ||
args = parser.parse_args() | ||
else: | ||
args = parser.parse_args(opt) | ||
|
||
|
||
if __name__ == '__main__': | ||
print (args) | ||
parse(['--SAVE_DIR','test']) | ||
print(args) |
Oops, something went wrong.