-
Notifications
You must be signed in to change notification settings - Fork 0
/
voc.py
310 lines (238 loc) · 10.6 KB
/
voc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import os
import random
import numpy as np
import cv2
class VOCDataset(Dataset):
def __init__(self, is_train, image_dir, label_txt, image_size=448, grid_size=7, num_bboxes=2, num_classes=20):
self.is_train = is_train
self.image_size = image_size
self.S = grid_size
self.B = num_bboxes
self.C = num_classes
mean_rgb = [122.67891434, 116.66876762, 104.00698793]
self.mean = np.array(mean_rgb, dtype=np.float32)
self.to_tensor = transforms.ToTensor()
if isinstance(label_txt, list) or isinstance(label_txt, tuple):
# cat multiple list files together.
# This is useful for VOC2007/VOC2012 combination.
tmp_file = '/tmp/label.txt'
os.system('cat %s > %s' % (' '.join(label_txt), tmp_file))
label_txt = tmp_file
self.paths, self.boxes, self.labels = [], [], []
with open(label_txt) as f:
lines = f.readlines()
for line in lines:
splitted = line.strip().split()
fname = splitted[0]
path = os.path.join(image_dir, fname)
self.paths.append(path)
num_boxes = (len(splitted) - 1) // 5
box, label = [], []
for i in range(num_boxes):
x1 = float(splitted[5*i + 1])
y1 = float(splitted[5*i + 2])
x2 = float(splitted[5*i + 3])
y2 = float(splitted[5*i + 4])
c = int(splitted[5*i + 5])
box.append([x1, y1, x2, y2])
label.append(c)
self.boxes.append(torch.Tensor(box))
self.labels.append(torch.LongTensor(label))
self.num_samples = len(self.paths)
def __getitem__(self, idx):
path = self.paths[idx]
img = cv2.imread(path)
boxes = self.boxes[idx].clone() # [n, 4]
labels = self.labels[idx].clone() # [n,]
if self.is_train:
img, boxes = self.random_flip(img, boxes)
img, boxes = self.random_scale(img, boxes)
img = self.random_blur(img)
img = self.random_brightness(img)
img = self.random_hue(img)
img = self.random_saturation(img)
img, boxes, labels = self.random_shift(img, boxes, labels)
img, boxes, labels = self.random_crop(img, boxes, labels)
# For debug.
debug_dir = 'tmp/'
os.makedirs(debug_dir, exist_ok=True)
img_show = img.copy()
box_show = boxes.numpy().reshape(-1)
n = len(box_show) // 4
for b in range(n):
pt1 = (int(box_show[4*b + 0]), int(box_show[4*b + 1]))
pt2 = (int(box_show[4*b + 2]), int(box_show[4*b + 3]))
cv2.rectangle(img_show, pt1=pt1, pt2=pt2, color=(0,255,0), thickness=1)
cv2.imwrite(os.path.join(debug_dir, 'test_{}.jpg'.format(idx)), img_show)
h, w, _ = img.shape
boxes /= torch.Tensor([[w, h, w, h]]).expand_as(boxes) # normalize (x1, y1, x2, y2) w.r.t. image width/height.
target = self.encode(boxes, labels) # [S, S, 5 x B + C]
img = cv2.resize(img, dsize=(self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # assuming the model is pretrained with RGB images.
img = (img - self.mean) / 255.0 # normalize from -1.0 to 1.0.
img = self.to_tensor(img)
return img, target
def __len__(self):
return self.num_samples
def encode(self, boxes, labels):
""" Encode box coordinates and class labels as one target tensor.
Args:
boxes: (tensor) [[x1, y1, x2, y2]_obj1, ...], normalized from 0.0 to 1.0 w.r.t. image width/height.
labels: (tensor) [c_obj1, c_obj2, ...]
Returns:
An encoded tensor sized [S, S, 5 x B + C], 5=(x, y, w, h, conf)
"""
S, B, C = self.S, self.B, self.C
N = 5 * B + C
target = torch.zeros(S, S, N)
cell_size = 1.0 / float(S)
boxes_wh = boxes[:, 2:] - boxes[:, :2] # width and height for each box, [n, 2]
boxes_xy = (boxes[:, 2:] + boxes[:, :2]) / 2.0 # center x & y for each box, [n, 2]
for b in range(boxes.size(0)):
xy, wh, label = boxes_xy[b], boxes_wh[b], int(labels[b])
ij = (xy / cell_size).ceil() - 1.0
i, j = int(ij[0]), int(ij[1]) # y & x index which represents its location on the grid.
x0y0 = ij * cell_size # x & y of the cell left-top corner.
xy_normalized = (xy - x0y0) / cell_size # x & y of the box on the cell, normalized from 0.0 to 1.0.
# TBM, remove redundant dimensions from target tensor.
# To remove these, loss implementation also has to be modified.
for k in range(B):
s = 5 * k
target[j, i, C+s :C+s+2] = xy_normalized
target[j, i, C+s+2:C+s+4] = wh
target[j, i, C+s+4 ] = 1.0
target[j, i, label] = 1.0
return target
def random_flip(self, img, boxes):
if random.random() < 0.5:
return img, boxes
h, w, _ = img.shape
img = np.fliplr(img)
x1, x2 = boxes[:, 0], boxes[:, 2]
x1_new = w - x2
x2_new = w - x1
boxes[:, 0], boxes[:, 2] = x1_new, x2_new
return img, boxes
def random_scale(self, img, boxes):
if random.random() < 0.5:
return img, boxes
scale = random.uniform(0.8, 1.2)
h, w, _ = img.shape
img = cv2.resize(img, dsize=(int(w * scale), h), interpolation=cv2.INTER_LINEAR)
scale_tensor = torch.FloatTensor([[scale, 1.0, scale, 1.0]]).expand_as(boxes)
boxes = boxes * scale_tensor
return img, boxes
def random_blur(self, bgr):
if random.random() < 0.5:
return bgr
ksize = random.choice([2, 3, 4, 5])
bgr = cv2.blur(bgr, (ksize, ksize))
return bgr
def random_brightness(self, bgr):
if random.random() < 0.5:
return bgr
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
adjust = random.uniform(0.5, 1.5)
v = v * adjust
v = np.clip(v, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return bgr
def random_hue(self, bgr):
if random.random() < 0.5:
return bgr
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
adjust = random.uniform(0.8, 1.2)
h = h * adjust
h = np.clip(h, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return bgr
def random_saturation(self, bgr):
if random.random() < 0.5:
return bgr
hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
adjust = random.uniform(0.5, 1.5)
s = s * adjust
s = np.clip(s, 0, 255).astype(hsv.dtype)
hsv = cv2.merge((h, s, v))
bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
return bgr
def random_shift(self, img, boxes, labels):
if random.random() < 0.5:
return img, boxes, labels
center = (boxes[:, 2:] + boxes[:, :2]) / 2.0
h, w, c = img.shape
img_out = np.zeros((h, w, c), dtype=img.dtype)
mean_bgr = self.mean[::-1]
img_out[:, :] = mean_bgr
dx = random.uniform(-w*0.2, w*0.2)
dy = random.uniform(-h*0.2, h*0.2)
dx, dy = int(dx), int(dy)
if dx >= 0 and dy >= 0:
img_out[dy:, dx:] = img[:h-dy, :w-dx]
elif dx >= 0 and dy < 0:
img_out[:h+dy, dx:] = img[-dy:, :w-dx]
elif dx < 0 and dy >= 0:
img_out[dy:, :w+dx] = img[:h-dy, -dx:]
elif dx < 0 and dy < 0:
img_out[:h+dy, :w+dx] = img[-dy:, -dx:]
center = center + torch.FloatTensor([[dx, dy]]).expand_as(center) # [n, 2]
mask_x = (center[:, 0] >= 0) & (center[:, 0] < w) # [n,]
mask_y = (center[:, 1] >= 0) & (center[:, 1] < h) # [n,]
mask = (mask_x & mask_y).view(-1, 1) # [n, 1], mask for the boxes within the image after shift.
boxes_out = boxes[mask.expand_as(boxes)].view(-1, 4) # [m, 4]
if len(boxes_out) == 0:
return img, boxes, labels
shift = torch.FloatTensor([[dx, dy, dx, dy]]).expand_as(boxes_out) # [m, 4]
boxes_out = boxes_out + shift
boxes_out[:, 0] = boxes_out[:, 0].clamp_(min=0, max=w)
boxes_out[:, 2] = boxes_out[:, 2].clamp_(min=0, max=w)
boxes_out[:, 1] = boxes_out[:, 1].clamp_(min=0, max=h)
boxes_out[:, 3] = boxes_out[:, 3].clamp_(min=0, max=h)
labels_out = labels[mask.view(-1)]
return img_out, boxes_out, labels_out
def random_crop(self, img, boxes, labels):
if random.random() < 0.5:
return img, boxes, labels
center = (boxes[:, 2:] + boxes[:, :2]) / 2.0
h_orig, w_orig, _ = img.shape
h = random.uniform(0.6 * h_orig, h_orig)
w = random.uniform(0.6 * w_orig, w_orig)
y = random.uniform(0, h_orig - h)
x = random.uniform(0, w_orig - w)
h, w, x, y = int(h), int(w), int(x), int(y)
center = center - torch.FloatTensor([[x, y]]).expand_as(center) # [n, 2]
mask_x = (center[:, 0] >= 0) & (center[:, 0] < w) # [n,]
mask_y = (center[:, 1] >= 0) & (center[:, 1] < h) # [n,]
mask = (mask_x & mask_y).view(-1, 1) # [n, 1], mask for the boxes within the image after crop.
boxes_out = boxes[mask.expand_as(boxes)].view(-1, 4) # [m, 4]
if len(boxes_out) == 0:
return img, boxes, labels
shift = torch.FloatTensor([[x, y, x, y]]).expand_as(boxes_out) # [m, 4]
boxes_out = boxes_out - shift
boxes_out[:, 0] = boxes_out[:, 0].clamp_(min=0, max=w)
boxes_out[:, 2] = boxes_out[:, 2].clamp_(min=0, max=w)
boxes_out[:, 1] = boxes_out[:, 1].clamp_(min=0, max=h)
boxes_out[:, 3] = boxes_out[:, 3].clamp_(min=0, max=h)
labels_out = labels[mask.view(-1)]
img_out = img[y:y+h, x:x+w, :]
return img_out, boxes_out, labels_out
def test():
from torch.utils.data import DataLoader
image_dir = 'data/images/'
label_txt = ['data/8samples.txt']
dataset = VOCDataset(True, image_dir, label_txt)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
data_iter = iter(data_loader)
for i in range(len(dataset)):
img, target = next(data_iter)
print(img.size(), target.size())
if __name__ == '__main__':
test()