diff --git a/examples/text_matching/sentence_transformers/train.py b/examples/text_matching/sentence_transformers/train.py index 17fb795d2c652..e9dc8864b4a15 100644 --- a/examples/text_matching/sentence_transformers/train.py +++ b/examples/text_matching/sentence_transformers/train.py @@ -166,8 +166,7 @@ def do_train(): set_seed(args.seed) - train_ds, dev_ds, test_ds = load_dataset( - "lcqmc", splits=["train", "dev", "test"]) + train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"]) # If you wanna use bert/roberta pretrained model, # pretrained_model = ppnlp.transformers.BertModel.from_pretrained('bert-base-chinese') @@ -205,12 +204,6 @@ def do_train(): batch_size=args.batch_size, batchify_fn=batchify_fn, trans_fn=trans_func) - test_data_loader = create_dataloader( - test_ds, - mode='test', - batch_size=args.batch_size, - batchify_fn=batchify_fn, - trans_fn=trans_func) model = SentenceTransformer(pretrained_model) @@ -274,10 +267,6 @@ def do_train(): paddle.save(model.state_dict(), save_param_path) tokenizer.save_pretrained(save_dir) - if rank == 0: - print('Evaluating on test data.') - evaluate(model, criterion, metric, test_data_loader) - if __name__ == "__main__": do_train()