-
Notifications
You must be signed in to change notification settings - Fork 988
/
mnist.py
178 lines (148 loc) · 4.54 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import struct
from fc import *
from datetime import datetime
# 数据加载器基类
class Loader(object):
def __init__(self, path, count):
'''
初始化加载器
path: 数据文件路径
count: 文件中的样本个数
'''
self.path = path
self.count = count
def get_file_content(self):
'''
读取文件内容
'''
f = open(self.path, 'rb')
content = f.read()
f.close()
return content
def to_int(self, byte):
'''
将unsigned byte字符转换为整数
'''
return struct.unpack('B', byte)[0]
# 图像数据加载器
class ImageLoader(Loader):
def get_picture(self, content, index):
'''
内部函数,从文件中获取图像
'''
start = index * 28 * 28 + 16
picture = []
for i in range(28):
picture.append([])
for j in range(28):
picture[i].append(
self.to_int(content[start + i * 28 + j]))
return picture
def get_one_sample(self, picture):
'''
内部函数,将图像转化为样本的输入向量
'''
sample = []
for i in range(28):
for j in range(28):
sample.append(picture[i][j])
return sample
def load(self):
'''
加载数据文件,获得全部样本的输入向量
'''
content = self.get_file_content()
data_set = []
for index in range(self.count):
data_set.append(
self.get_one_sample(
self.get_picture(content, index)))
return data_set
# 标签数据加载器
class LabelLoader(Loader):
def load(self):
'''
加载数据文件,获得全部样本的标签向量
'''
content = self.get_file_content()
labels = []
for index in range(self.count):
labels.append(self.norm(content[index + 8]))
return labels
def norm(self, label):
'''
内部函数,将一个值转换为10维标签向量
'''
label_vec = []
label_value = self.to_int(label)
for i in range(10):
if i == label_value:
label_vec.append(0.9)
else:
label_vec.append(0.1)
return label_vec
def get_training_data_set():
'''
获得训练数据集
'''
image_loader = ImageLoader('train-images-idx3-ubyte', 60000)
label_loader = LabelLoader('train-labels-idx1-ubyte', 60000)
return image_loader.load(), label_loader.load()
def get_test_data_set():
'''
获得测试数据集
'''
image_loader = ImageLoader('t10k-images-idx3-ubyte', 10000)
label_loader = LabelLoader('t10k-labels-idx1-ubyte', 10000)
return image_loader.load(), label_loader.load()
def show(sample):
str = ''
for i in range(28):
for j in range(28):
if sample[i*28+j] != 0:
str += '*'
else:
str += ' '
str += '\n'
print str
def get_result(vec):
max_value_index = 0
max_value = 0
for i in range(len(vec)):
if vec[i] > max_value:
max_value = vec[i]
max_value_index = i
return max_value_index
def evaluate(network, test_data_set, test_labels):
error = 0
total = len(test_data_set)
for i in range(total):
label = get_result(test_labels[i])
predict = get_result(network.predict(test_data_set[i]))
if label != predict:
error += 1
return float(error) / float(total)
def now():
return datetime.now().strftime('%c')
def train_and_evaluate():
last_error_ratio = 1.0
epoch = 0
train_data_set, train_labels = transpose(get_training_data_set())
test_data_set, test_labels = transpose(get_test_data_set())
network = Network([784, 100, 10])
while True:
epoch += 1
network.train(train_labels, train_data_set, 0.01, 1)
print '%s epoch %d finished, loss %f' % (now(), epoch,
network.loss(train_labels[-1], network.predict(train_data_set[-1])))
if epoch % 2 == 0:
error_ratio = evaluate(network, test_data_set, test_labels)
print '%s after epoch %d, error ratio is %f' % (now(), epoch, error_ratio)
if error_ratio > last_error_ratio:
break
else:
last_error_ratio = error_ratio
if __name__ == '__main__':
train_and_evaluate()