Skip to content

Commit

Permalink
update xlnet onnx. transformers<=4.37.2
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Sep 25, 2024
1 parent c64fd3b commit 381042e
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/onnx_predict_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
('sports', '米兰客场8战不败国米10年连胜4'),
('sports', '米兰客场8战不败国米10年连胜5'),
]
m.train(data * 10)
# m.train(data * 10)
m.load_model()

samples = ['名师指导托福语法技巧',
Expand Down
69 changes: 69 additions & 0 deletions examples/onnx_xlnet_predict_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
"""
@author:XuMing([email protected])
@description:
"""

import os
import shutil
import sys
import time

import torch

sys.path.append('..')
from pytextclassifier import BertClassifier

if __name__ == '__main__':
m = BertClassifier(output_dir='models/xlnet-chinese-v1', num_classes=2,
model_type='xlnet', model_name='hfl/chinese-xlnet-base', num_epochs=1)
data = [
('education', '名师指导托福语法技巧:名词的复数形式'),
('education', '中国高考成绩海外认可 是“狼来了”吗?'),
('education', '公务员考虑越来越吃香,这是怎么回事?'),
('education', '公务员考虑越来越吃香,这是怎么回事1?'),
('education', '公务员考虑越来越吃香,这是怎么回事2?'),
('education', '公务员考虑越来越吃香,这是怎么回事3?'),
('education', '公务员考虑越来越吃香,这是怎么回事4?'),
('sports', '图文:法网孟菲尔斯苦战进16强 孟菲尔斯怒吼'),
('sports', '四川丹棱举行全国长距登山挑战赛 近万人参与'),
('sports', '米兰客场8战不败国米10年连胜1'),
('sports', '米兰客场8战不败国米10年连胜2'),
('sports', '米兰客场8战不败国米10年连胜3'),
('sports', '米兰客场8战不败国米10年连胜4'),
('sports', '米兰客场8战不败国米10年连胜5'),
]
m.train(data * 1)
m.load_model()

samples = ['名师指导托福语法技巧',
'米兰客场8战不败',
'恒生AH溢指收平 A股对H股折价1.95%'] * 10

start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_bert = end_time - start_time
print(f'Standard xlnet model prediction time: {elapsed_time_bert} seconds')

# convert to onnx, and load onnx model to predict, speed up 10x
save_onnx_dir = 'models/xlnet-chinese-v1/onnx'
m.model.convert_to_onnx(save_onnx_dir)
# copy label_vocab.json to save_onnx_dir
if os.path.exists(m.label_vocab_path):
shutil.copy(m.label_vocab_path, save_onnx_dir)

# Manually delete the model and clear CUDA cache
del m
torch.cuda.empty_cache()

m = BertClassifier(output_dir=save_onnx_dir, num_classes=2, model_type='xlnet', model_name=save_onnx_dir,
args={"onnx": True})
m.load_model()
start_time = time.time()
predict_label_bert, predict_proba_bert = m.predict(samples)
print(f'predict_label_bert size: {len(predict_label_bert)}')
end_time = time.time()
elapsed_time_onnx = end_time - start_time
print(f'ONNX model prediction time: {elapsed_time_onnx} seconds')
6 changes: 3 additions & 3 deletions pytextclassifier/bert_classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,7 @@ def predict(self, to_predict):
to_predict, return_tensors="pt", padding=True, truncation=True
)

if self.args.model_type in ["bert", "xlnet", "albert"]:
if self.args.model_type in ["bert", "albert"]:
for i, (input_ids, attention_mask, token_type_ids) in enumerate(
zip(
model_inputs["input_ids"],
Expand Down Expand Up @@ -2008,7 +2008,7 @@ def convert_to_onnx(self, output_dir=None, set_onnx_arg=True):
tokenizer=self.tokenizer,
output=Path(onnx_model_name),
pipeline_name="sentiment-analysis",
opset=11,
opset=14,
)

self.args.onnx = True
Expand Down Expand Up @@ -2058,7 +2058,7 @@ def _get_inputs_dict(self, batch, no_hf=False):
if self.args.model_type != "distilbert":
inputs["token_type_ids"] = (
batch[2]
if self.args.model_type in ["bert", "xlnet", "albert"]
if self.args.model_type in ["bert", "albert"]
else None
)

Expand Down

0 comments on commit 381042e

Please sign in to comment.