-
Notifications
You must be signed in to change notification settings - Fork 0
/
datagenerator.py
executable file
·100 lines (80 loc) · 2.95 KB
/
datagenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
import cv2
class ImageDataGenerator:
def __init__(self, class_list, horizontal_flip=False, shuffle=False,
mean=np.array([104., 117., 124.]), scale_size=(227, 227),
nb_classes=2):
# Init params
self.horizontal_flip = horizontal_flip
self.n_classes = nb_classes
self.shuffle = shuffle
self.mean = mean
self.scale_size = scale_size
self.pointer = 0
self.read_class_list(class_list)
if self.shuffle:
self.shuffle_data()
def read_class_list(self, class_list):
"""
Scan the image file and get the image paths and labels
"""
with open(class_list) as f:
lines = f.readlines()
self.images = []
self.labels = []
for l in lines:
items = l.split()
self.images.append(items[0])
self.labels.append(int(items[1]))
# store total number of data
self.data_size = len(self.labels)
def shuffle_data(self):
"""
Random shuffle the images and labels
"""
images = self.images.copy()
labels = self.labels.copy()
self.images = []
self.labels = []
# create list of permutated index and shuffle data accoding to list
idx = np.random.permutation(len(labels))
for i in idx:
self.images.append(images[i])
self.labels.append(labels[i])
def reset_pointer(self):
"""
reset pointer to begin of the list
"""
self.pointer = 0
if self.shuffle:
self.shuffle_data()
def next_batch(self, batch_size):
"""
This function gets the next n ( = batch_size) images from the path list
and labels and loads the images into them into memory
"""
# Get next batch of image (path) and labels
paths = self.images[self.pointer:self.pointer + batch_size]
labels = self.labels[self.pointer:self.pointer + batch_size]
# update pointer
self.pointer += batch_size
# Read images
images = np.ndarray(
[batch_size, self.scale_size[0], self.scale_size[1], 3])
for i in range(len(paths)):
img = cv2.imread(paths[i])
# flip image at random if flag is selected
if self.horizontal_flip and np.random.random() < 0.5:
img = cv2.flip(img, 1)
# rescale image
img = cv2.resize(img, (self.scale_size[0], self.scale_size[1]))
img = img.astype(np.float32)
# subtract mean
img -= self.mean
images[i] = img
# Expand labels to one hot encoding
one_hot_labels = np.zeros((batch_size, self.n_classes))
for i in range(len(labels)):
one_hot_labels[i][labels[i]] = 1
# return array of images and labels
return images, one_hot_labels