-
Notifications
You must be signed in to change notification settings - Fork 0
/
pipeline.py
executable file
·212 lines (186 loc) · 8.61 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import argparse
import pickle
import os
import deepwings.utils as utils
import deepwings.method_cnn.cnn_prediction as cnnp
from deepwings.method_features_extraction import classification as clf
from deepwings.method_cnn import cnn_training as cnnt
from deepwings.method_features_extraction import features_extractor as fe
DESCRIPTION = """
Predict bee species from images of their wings.
Two different methods:
- Convolutional Neural Network (based on DenseNet121)
- Features extraction + classifier (ANN)
"""
def main():
parser = argparse.ArgumentParser(description=DESCRIPTION)
parser.add_argument('-l', '--list_categories',
action='store_true',
help='Display list of genera/species complying with\
min_images')
parser.add_argument('-s', '--sort',
action='store_true',
help='Create subfolders train/test with the\
different genera/species')
parser.add_argument('-c', '--category',
type=str,
help="Category to classify: 'genus' or 'species'",
default='species')
parser.add_argument('-m', '--min_images',
type=int,
help='Minimum number of images per genus/species',
default=20)
parser.add_argument('-rs', '--random_seed',
type=int,
help='Random seed used for train/test split',
default=1234)
parser.add_argument('-restart', '--restart',
action='store_true',
help='Restart feature extraction from beginning,\
erasing the current csv files')
parser.add_argument('-pred', '--model_prediction',
type=str,
help="Choose a model : 'cnn' or 'ann'")
parser.add_argument('-e', '--extraction',
type=str,
help="Run feature extraction process for 'train' or\
'pred'")
parser.add_argument('-fd', '--n_fourier_descriptors',
type=int,
help='Number of Fourier descriptors used for each\
cell',
default=15)
parser.add_argument('-p', '--plot',
action='store_true',
help="If True, plots figures in valid_images/ or\
invalid_images/")
parser.add_argument('-t', '--train',
type=str,
help="Choose a model : only 'ann' or 'cnn'",
required=False)
parser.add_argument('-raw_train', '--path_raw_training',
type=str,
help='Input path for raw image used for training',
default='training/raw_images/')
parser.add_argument('-ts', '--test_size',
type=float,
help='The ratio of dataset used for testing',
required=False,
default=0.3)
parser.add_argument('-fp', '--folder_to_predict',
type=str,
help="Path to folder of images to predict",
default='prediction/raw_images')
parser.add_argument('-cnn', '--name_cnn',
type=str,
help="Name of the CNN model",
default="DenseNet121")
parser.add_argument('-raw_pred', '--raw_images_prediction',
type=str,
help="Path to the folder to predict",
default='prediction/raw_images')
parser.add_argument('-pann', '--path_ann',
type=str,
help="Path to the ANN model")
parser.add_argument('--n_epochs',
type=int,
help='Number of epochs for CNN training',
default=20)
parser.add_argument('-bs_train', '--batch_size_train',
type=int,
help='Batch size for CNN training',
default=20)
parser.add_argument('-bs_test', '--batch_size_test',
type=int,
help='Batch size for CNN validation',
default=20)
parser.add_argument('--steps_epoch',
type=int,
help='Steps per epoch for CNN training',
default=100)
args = parser.parse_args()
if args.category not in ['genus', 'species']:
print("ERROR: category must be 'genus' or 'species'")
return
if args.extraction not in [None, 'train', 'pred']:
print("ERROR: extraction must be 'pred' or 'train'")
return
if args.train not in [None, 'cnn', 'ann', 'random_forest']:
print("ERROR: wrong classifier for training")
return
if args.model_prediction not in [None, 'cnn', 'ann', 'random_forest']:
print("ERROR: wrong classifier for prediction")
return
pipeline_process = []
if args.list_categories:
pipeline_process += ['list_categories']
if args.sort:
pipeline_process += ['sort']
if args.extraction == 'pred':
pipeline_process += ['extraction_pred']
if args.extraction == 'train':
pipeline_process += ['extraction_training']
if args.train:
pipeline_process += [f'train_{args.train}']
if args.model_prediction:
pipeline_process += [f'pred_{args.model_prediction}']
if len(pipeline_process) == 0:
print("No argument entered, type 'python pipeline.py -h' for"
" further information")
for step in pipeline_process:
if step == 'list_categories':
sorter = utils.Sorter(args.path_raw_training, args.category,
args.min_images)
sorter.filter_categories(verbose=True)
if step == 'sort':
sorter = utils.Sorter(args.path_raw_training, args.category,
args.min_images, args.test_size,
args.random_seed)
sorter.filter_categories(verbose=True)
sorter.train_test_split(verbose=True)
sorter.create_subfolders('train')
sorter.create_subfolders('test')
sorter.pickle_train_test()
elif step == 'extraction_training':
dict_info = pickle.load(open('training/info_train_test.p', 'rb'))
selected_images = dict_info['train'] + dict_info['test']
paths_images = []
for img_name in selected_images:
path_img = os.path.join(args.path_raw_training, img_name)
paths_images.append(path_img)
fe.extract_pictures(paths_images=paths_images,
plot=args.plot,
n_descriptors=args.n_fourier_descriptors,
continue_csv=not(args.restart))
elif step == 'extraction_pred': # Features extraction
paths_images = []
for image_name in os.listdir(args.folder_to_predict):
path_img = os.path.join(args.folder_to_predict, image_name)
paths_images.append(path_img)
fe.extract_pictures(paths_images=paths_images,
plot=args.plot,
n_descriptors=args.n_fourier_descriptors,
continue_csv=False)
elif step == 'train_ann':
clf.train_ann(args.category)
elif step == 'train_random_forest':
clf.train_rf(args.category)
elif step == 'train_cnn':
model = cnnt.build_model()
cnnt.train_model(model,
epochs=args.n_epochs,
bs_train=args.batch_size_train,
bs_test=args.batch_size_test,
steps_per_epoch=args.steps_epoch)
elif step == 'pred_ann':
clf.predict_ann(category=args.category,
path_raw=args.raw_images_prediction,
path_model=args.path_ann)
elif step == 'pred_cnn':
cnnp.cnn_pred(model_name=args.name_cnn,
path_raw=args.raw_images_prediction)
elif step == 'pred_random_forest':
clf.predict_rf(category=args.category,
path_raw=args.raw_images_prediction)
if __name__ == "__main__":
main()