-
Notifications
You must be signed in to change notification settings - Fork 341
/
demo.py
128 lines (110 loc) · 4.77 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
#
# Please assign the DATA_FOLDER before running this scripts, the data, pre-trained model, fine-tuned model will be
# downloaded automatically to DATA_FOLDER
import os
import sys
import logging
from functools import partial
from demo_utils import download_model_folder
import argparse
import subprocess as sp
PROJECT_FOLDER = os.path.dirname(os.path.realpath(__file__))
PYTHON_EXE = 'python'
MODEL_FOLDER = os.path.join(PROJECT_FOLDER, 'models')
DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data')
print(f'PROJECT_FOLDER = {PROJECT_FOLDER}')
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='dummy',
help='choose from dummy, small and full')
dargs = parser.parse_args()
assert dargs.data == 'dummy' or dargs.data == 'small' or dargs.data == 'full' , \
'The specified data option is not support!'
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO
)
logger = logging.getLogger(__name__)
if os.path.exists(MODEL_FOLDER):
print(f'Found existing models folder at {MODEL_FOLDER}, skip creating a new one!')
os.makedirs(MODEL_FOLDER, exist_ok=True)
else:
os.makedirs(MODEL_FOLDER)
#########################################################################
# Download Model
#########################################################################
logger.info('Downloading models...')
download_model = partial(download_model_folder, DATA_FOLDER=MODEL_FOLDER)
# model size: could be one of 'small' (GPT2 with 117M), 'medium'(345M) or 'large' (1542M)
# dataset: one of 'multiref' or 'dstc'
# from_scratch: True : load model trained from scratch or False: load model trained from fine-tuning the GPT-2
target_folder = download_model(model_size='small', dataset='multiref', from_scratch=False)
logger.info('Done!\n')
#########################################################################
# Prepare Data
#########################################################################
logger.info('Downloading and Extracting Data...')
if dargs.data == 'dummy':
cmd = 'bash prepare4db.sh'
ret = sp.run(cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=DATA_FOLDER)
elif dargs.data == 'small':
myCmd = os.popen('cd reddit_extractor; SIZE=small make -j 8; cd ..').read()
elif dargs.data == 'full':
myCmd = os.popen('cd reddit_extractor; SIZE=full make -j 8; cd ..').read()
else:
raise ValueError('you need to implement your own data type, or use either dummy, small, or full')
logger.info('Preparing Data...')
data_path = os.path.join(DATA_FOLDER, 'train.tsv')
MAX_LEN = 128
data_db = f'{data_path[:-4]}.{MAX_LEN}len.db'
if os.path.isdir(data_db):
print(f'{data_db} exists, skip prepro.py')
else:
cmd = ['prepro.py', '--corpus', data_path, '--max_seq_len', f'{MAX_LEN}']
cmd = ' '.join(cmd) #% {'CODE_ROOT': CODE_ROOT}
print(cmd)
ret = sp.run([PYTHON_EXE] + cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
if ret.returncode != 0:
print(f'error occurred, {ret.stdout}')
sys.exit(ret.returncode)
logger.info('Done!\n')
#########################################################################
# Train !
#########################################################################
logger.info('Generating training CMD!')
logger.info('If there is any problem, please copy (modify) and run command below')
logger.info('#########################################################################')
train_cmd = 'LSP_train.py'
args = [
'--model_name_or_path', target_folder,
'--init_checkpoint', os.path.join(target_folder, 'pytorch_model.bin'),
'--train_input_file', data_db , # file from last step
'--eval_input_file', './data/dummy_data.tsv', # dummy test data
'--output_dir', os.path.join(MODEL_FOLDER, 'output_model'),
'--seed', '42',
'--max_seq_length', '128',
'--train_batch_size', '512',
'--gradient_accumulation_steps', '8',
'--eval_batch_size', '64',
'--learning_rate', '1e-5',
'--num_optim_steps', '10000',
'--valid_step', '5000',
'--warmup_steps', '4000',
'--normalize_data', 'true',
'--fp16', 'true',
'--lr_schedule', 'noam',
'--loss_scale', '0.0',
'--no_token_id', 'true',
'--pbar', 'true'
]
arg = ' '.join(args)
train_cmd = train_cmd + ' ' + arg
print(PYTHON_EXE + ' ' +train_cmd)
logger.info('#########################################################################')
with open('./output.log', 'wb') as f:
process = sp.Popen([PYTHON_EXE] + train_cmd.split(' '), stdout=sp.PIPE, stderr=sp.STDOUT, cwd=PROJECT_FOLDER)
for line in iter(process.stdout.readline, b''):
sys.stdout.write(line.decode(sys.stdout.encoding))
f.write(line)
logger.info('Done!\n')