-
Notifications
You must be signed in to change notification settings - Fork 4
/
predict.py
132 lines (108 loc) · 4.3 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
'''
predict.py: script for making predictions
'''
import argparse
import sys
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm, trange
from transformers import *
from modeling import make_tensor_dataset
from modeling import BertForMatres, RobertaForMatres#, ElectraForMatres
from constants import CLASSES, MAX_SEQ_LENGTH, DOC_STRIDE
from load_data import *
UDST_DIR = "udst/all_annotations/"
def get_dummy_data(tokenizer, lm):
sent1 = "Today I went to the store.".split()
sent2 = "I came home.".split()
# load_data.py:IndexedExamplePartial
# - for single-sentence examples sent1&sent2 are the same
# --> pass in SAME LIST (reference/pointer) for both
ex = IndexedExamplePartial(label="BEFORE", # {AFTER, BEFORE, EQUALS, VAGUE}
sent1=sent1,
sent2=sent2,
tags1=None, # none, unless you want to mask timexes
tags2=None,
e1_idx=2, # "went" = sent1[e1_idx]
e2_idx=1, # "came" = sent2[e2_idx]
doc_name=None) # specify doc_name if you need it later
exs = [ex]
# load_data.py:convert_distant_examples_to_features
# - should automatically generate the right model-specific
# input features according to tokenizer type
feats = convert_distant_examples_to_features(examples=exs,
tokenizer=tokenizer,
max_seq_length=MAX_SEQ_LENGTH,
doc_stride=DOC_STRIDE)
data = make_tensor_dataset(feats, model=lm)
return exs, data
def get_data(tokenizer, lm, data):
'''
tokenizer: PreTrainedTokenizer
lm: str, {bert, roberta}
data: str
'''
if data == 'matres_dev':
exs, data = matres_dev_examples(tokenizer,
lm=args.lm)
elif data == 'udst_dev_maj':
exs, data = udst_majority(tokenizer,
lm=lm,
example_dir=UDST_DIR,
split="dev")
elif data == 'udst_test_maj':
exs, data = udst_majority(tokenizer,
lm=lm,
example_dir=UDST_DIR,
split="test")
elif data == 'udst_train':
exs, data = udst(tokenizer,
lm=lm,
split="train",
example_dir=UDST_DIR)
elif data == 'dummy_data':
exs, data = get_dummy_data(tokenizer, lm)
else:
print("model not yet supported, try {udst_train,udst_dev_maj,udst_test_maj}")
raise NotImplementedError
return exs, data
def load_model_from_directory(lm, model_dir):
if lm.startswith('bert'):
model = BertForMatres.from_pretrained(model_dir)
tokenizer = BertTokenizer.from_pretrained(model_dir)
return model, tokenizer
elif lm.startswith('roberta'):
model = RobertaForMatres.from_pretrained(model_dir)
tokenizer = RobertaTokenizer.from_pretrained(model_dir)
else:
print("model not yet supported, try 'bert' model")
raise NotImplementedError
return model, tokenizer
def predict(model, data, device):
data_sampler = SequentialSampler(data)
data_loader = DataLoader(data, sampler=data_sampler, batch_size=20)
model.to(device)
model.eval()
for batch in tqdm(data_loader, desc="Evaluating"):
batch = tuple(t.to(device) for t in batch)
with torch.no_grad():
_, out, hidden = model(*batch)
val, guess_idxs = out.max(1)
for guess in guess_idxs:
print(CLASSES[guess.item()])
def main(lm, model_dir, data):
'''
lm : str,
model_dir: str,
'''
model, tokenizer = load_model_from_directory(lm, model_dir)
exs, data = get_data(tokenizer, lm, data)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
predict(model, data, device)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--lm', help='transformer model type, from {bert,roberta,electra}')
parser.add_argument('--model_dir', help='path to model directory')
parser.add_argument('--data', help='udst_dev_maj,udst_train,udst_test_maj,dummy_data')
args = parser.parse_args()
main(args.lm, args.model_dir, args.data)