-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
157 lines (139 loc) · 6.32 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
'''
* @author Waldinsamkeit
* @email [email protected]
* @create date 2020-11-16 11:12:26
* @desc
'''
from typing import Any,Dict
import torch
from torch.utils.tensorboard import SummaryWriter
import torch.optim as optim
from dataloader import DataSetDir, DataSet, Dataloader, DataItemSet,select_sampler
from wrapper import ModelWrapper
from model import Embedding_layer, Attention_layer, BinarySynClassifierBaseOnAttention
import config
from config import TrainingConfig,OperateConfig,DataConfig,ModelConfig,generate_register_hparams
from log import logger
from utils import set_random_seed
from args import parser
args = parser.parse_args()
SEED = 2020
def test_clustertask(operateconfig:Dict,dataconfig:Dict, trainingconfig:Dict, modelconfig:Dict):
#set registered hyper parameters
logger.info("Register Hyper Parameter")
hparams = generate_register_hparams(modelconfig,trainingconfig,dataconfig)
dir_path = dataconfig['data_dir_path']
comment = '_' + dir_path.name +'_'+modelconfig['name']+'_'+modelconfig['version']
metric_dict = {}
w = SummaryWriter(comment = comment) if args.p else None
if not dir_path:
raise KeyError
logger.info("Load Embedding Vector")
datasetdir = DataSetDir(dir_path,word_emb_select=dataconfig['word_emb_select'])
# combine model
embedding_layer = Embedding_layer.from_pretrained(datasetdir.embedding_vec)
embedding_layer.freeze_parameters()
attenion_layer = Attention_layer(embedding_layer.dim,modelconfig['attention_hidden_size'])
modelconfig['attention'] = attenion_layer
modelconfig['embedding'] = embedding_layer
model = BinarySynClassifierBaseOnAttention(
config = modelconfig
)
optimizer = optim.Adam(filter(lambda x : x.requires_grad , model.parameters()),lr=trainingconfig['lr'], amsgrad=True)
trainingconfig['optim'] = optimizer
trainingconfig['loss_fn'] = torch.nn.BCELoss()
wrapper = ModelWrapper(model,trainingconfig)
if operateconfig['resume']:
wrapper.load_check_point()
# continue to trainning
if operateconfig['train']:
logger.info("Generate DataLoader")
train_datasetitem = DataItemSet(
dataset=datasetdir.train_dataset,
sampler = select_sampler(dataconfig['sample_strategy']),
negative_sample_size = dataconfig['negative_sample_size']
)
dev_datasetitem = DataItemSet(
dataset=datasetdir.test_dataset,
sampler = select_sampler(dataconfig['test_sample_strategy']),
negative_sample_size = dataconfig['test_negative_sample_size']
)
train_dataloader = Dataloader(
dataitems=train_datasetitem,
word2id=datasetdir.word2id,
batch_size=trainingconfig['batch_size']
)
dev_dataloader = Dataloader(
dataitems=dev_datasetitem,
word2id=datasetdir.word2id,
batch_size=trainingconfig['batch_size']
)
logger.info("Start to Train !! ")
#Plot in Tensorboard
for ix,item in enumerate(wrapper.train(train_dataloader=train_dataloader,dev_dataloader=dev_dataloader)):
ep_loss, t_ac, t_p, t_r, t_f1, v_loss, v_ac, v_p, v_r, v_f1, cluster_unit, b_score = item
if w:
w.add_scalar("Training/Loss", ep_loss ,ix)
w.add_scalar("Training/Accuracy", t_ac, ix )
w.add_scalar("Training/Precision", t_p, ix)
w.add_scalar("Training/Recall", t_r, ix)
w.add_scalar("Training/F1_score", t_f1, ix)
w.add_scalar("Validation/Loss",v_loss, ix)
w.add_scalar("Validation/Accuracy", v_ac, ix)
w.add_scalar("Validation/Precision", v_p, ix)
w.add_scalar("Validation/Recall", v_r, ix)
w.add_scalar("Validation/F1_score", v_f1, ix)
w.add_scalar("Validation/FMI", cluster_unit['FMI'], ix)
w.add_scalar("Validation/ARI", cluster_unit['ARI'], ix)
w.add_scalar("Validation/NMI",cluster_unit['NMI'], ix)
w.add_scalar("Best Score Update", b_score, ix)
if operateconfig['test']:
test_datasetitem = DataItemSet(
dataset=datasetdir.test_dataset,
sampler = select_sampler(dataconfig['test_sample_strategy']),
negative_sample_size = dataconfig['test_negative_sample_size']
)
test_dataloader = Dataloader(
dataitems=test_datasetitem,
word2id=datasetdir.word2id,
batch_size=trainingconfig['batch_size']
)
d = wrapper.test_performance(test_dataloader=test_dataloader)
metric_dict = { **metric_dict, **d}
if operateconfig['predict']:
pred_word_set = wrapper.cluster_predict(
dataset=datasetdir.test_dataset,
word2id=datasetdir.word2id,
outputfile=trainingconfig['result_out_dir'].joinpath(datasetdir.name+'_result.txt')
)
ans = wrapper.evaluate(datasetdir.test_dataset, pred_word_set)
logger.info("{} DataSet Cluster Prediction".format(datasetdir.train_dataset.name))
for name,f in ans:
logger.info("{} : {:.5f}".format(name,f))
if w:
d = {i:j for i,j in ans}
metric_dict = {**metric_dict, **d}
w.add_hparams(hparams, metric_dict = metric_dict)
w.close()
wrapper.save(config.WRAPPER_DIR_PATH.joinpath(datasetdir.name))
def NYT():
DataConfig['data_dir_path'] = config.NYT_DIR_PATH
test_clustertask(OperateConfig,DataConfig,TrainingConfig,ModelConfig)
def PubMed():
DataConfig['data_dir_path'] = config.PubMed_DIR_PATH
test_clustertask(OperateConfig,DataConfig,TrainingConfig,ModelConfig)
def Wiki():
DataConfig['data_dir_path'] = config.Wiki_DIR_PATH
test_clustertask(OperateConfig,DataConfig,TrainingConfig,ModelConfig)
def run():
set_random_seed(seed=SEED)
if args.dataset == 'NYT':
NYT()
elif args.dataset == 'PubMed':
PubMed()
elif args.dataset == 'Wiki':
Wiki()
else:
raise KeyError
if __name__ == '__main__':
run()