Skip to content

Commit

Permalink
drop the test evaluation to adapt to the qianyan-lcqmc dataset (Paddl…
Browse files Browse the repository at this point in the history
  • Loading branch information
Steffy-zxf authored Jul 3, 2021
1 parent 0177aa5 commit b17a3bd
Showing 1 changed file with 1 addition and 12 deletions.
13 changes: 1 addition & 12 deletions examples/text_matching/sentence_transformers/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,7 @@ def do_train():

set_seed(args.seed)

train_ds, dev_ds, test_ds = load_dataset(
"lcqmc", splits=["train", "dev", "test"])
train_ds, dev_ds = load_dataset("lcqmc", splits=["train", "dev"])

# If you wanna use bert/roberta pretrained model,
# pretrained_model = ppnlp.transformers.BertModel.from_pretrained('bert-base-chinese')
Expand Down Expand Up @@ -205,12 +204,6 @@ def do_train():
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)
test_data_loader = create_dataloader(
test_ds,
mode='test',
batch_size=args.batch_size,
batchify_fn=batchify_fn,
trans_fn=trans_func)

model = SentenceTransformer(pretrained_model)

Expand Down Expand Up @@ -274,10 +267,6 @@ def do_train():
paddle.save(model.state_dict(), save_param_path)
tokenizer.save_pretrained(save_dir)

if rank == 0:
print('Evaluating on test data.')
evaluate(model, criterion, metric, test_data_loader)


if __name__ == "__main__":
do_train()

0 comments on commit b17a3bd

Please sign in to comment.