将 tensorflow 版本的预训练 bert model 转化为 pytorch 版本
pip install -r requirement.txt
下载 tensorflow 版本的预训练 bert model : chinese_L-12_H-768_A-12
解压后放在当前目录下:
├── convert_tf_bert_model_to_pytorch
| └── chinese_L-12_H-768_A-12
| | └── bert_config.json
| | └── bert_model.ckpt.data-00000-of-00001
| | └── bert_model.ckpt.index
| | └── ......
运行该代码:
python convert_tf_checkpoint_to_pytorch.py
转换成功的 pytorch 版本的预训练 model 会保存在:
├── convert_tf_bert_model_to_pytorch
| └── chinese_L-12_H-768_A-12
| | └── pytorch_model.bin
| | └── ......
如果地址出现错误,请修改如下代码为自定义地址:
windows版本地址:
args.tf_checkpoint_path = "chinese_L-12_H-768_A-12\\bert_model.ckpt"
args.bert_config_file = "chinese_L-12_H-768_A-12\\bert_config.json"
args.pytorch_dump_path = "chinese_L-12_H-768_A-12\pytorch_model.bin"
linux版本地址:
args.tf_checkpoint_path = "./chinese_L-12_H-768_A-12/bert_model.ckpt"
args.bert_config_file = "./chinese_L-12_H-768_A-12/bert_config.json"
args.pytorch_dump_path = "./chinese_L-12_H-768_A-12/pytorch_model.bin"
config: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json
vocab: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt
model: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin