-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
38 lines (34 loc) · 860 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from dev import *
import pickle
# Make Class Dictionary
test_data = [
'test1.txt',
'test2.txt',
'test3.txt',
'test4.txt',
'test5.txt',
'test6.txt',
'test7.txt',
'test8.txt'
]
Models = pickle.load(open("all_models.bin","rb"))
# predict classes
prob_all = dict()
max_probs = dict()
for ex in test_data:
ex = "./test/" + ex
probs = dict()
for M in Models.values():
probs.update({M.name: M.predict(ex)})
prob_all.update({ex: probs})
probs = pd.Series(
probs.values(), index = probs.keys()
)
probs = probs.sort_values(ascending=False)
max_prob = probs.index[:3].to_list()
print(ex,":", max_prob)
max_probs.update({ex: max_prob})
# Save / Document
print(max_probs)
pickle.dump(max_probs,open("test_max_probs.bin","wb"))
pickle.dump(prob_all,open("test_all_probs.bin","wb"))