-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
capsNet.py
161 lines (133 loc) · 7.3 KB
/
capsNet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""
License: Apache-2.0
Author: Huadong Liao
E-mail: [email protected]
"""
import tensorflow as tf
from config import cfg
from utils import get_batch_data
from utils import softmax
from utils import reduce_sum
from capsLayer import CapsLayer
epsilon = 1e-9
class CapsNet(object):
def __init__(self, is_training=True, height=28, width=28, channels=1, num_label=10):
"""
Args:
height: Integer, the height of inputs.
width: Integer, the width of inputs.
channels: Integer, the channels of inputs.
num_label: Integer, the category number.
"""
self.height = height
self.width = width
self.channels = channels
self.num_label = num_label
self.graph = tf.Graph()
with self.graph.as_default():
if is_training:
self.X, self.labels = get_batch_data(cfg.dataset, cfg.batch_size, cfg.num_threads)
self.Y = tf.one_hot(self.labels, depth=self.num_label, axis=1, dtype=tf.float32)
self.build_arch()
self.loss()
self._summary()
# t_vars = tf.trainable_variables()
self.global_step = tf.Variable(0, name='global_step', trainable=False)
self.optimizer = tf.train.AdamOptimizer()
self.train_op = self.optimizer.minimize(self.total_loss, global_step=self.global_step)
else:
self.X = tf.placeholder(tf.float32, shape=(cfg.batch_size, self.height, self.width, self.channels))
self.labels = tf.placeholder(tf.int32, shape=(cfg.batch_size, ))
self.Y = tf.reshape(self.labels, shape=(cfg.batch_size, self.num_label, 1))
self.build_arch()
tf.logging.info('Seting up the main structure')
def build_arch(self):
with tf.variable_scope('Conv1_layer'):
# Conv1, return tensor with shape [batch_size, 20, 20, 256]
conv1 = tf.contrib.layers.conv2d(self.X, num_outputs=256,
kernel_size=9, stride=1,
padding='VALID')
# Primary Capsules layer, return tensor with shape [batch_size, 1152, 8, 1]
with tf.variable_scope('PrimaryCaps_layer'):
primaryCaps = CapsLayer(num_outputs=32, vec_len=8, with_routing=False, layer_type='CONV')
caps1 = primaryCaps(conv1, kernel_size=9, stride=2)
# DigitCaps layer, return shape [batch_size, 10, 16, 1]
with tf.variable_scope('DigitCaps_layer'):
digitCaps = CapsLayer(num_outputs=self.num_label, vec_len=16, with_routing=True, layer_type='FC')
self.caps2 = digitCaps(caps1)
# Decoder structure in Fig. 2
# 1. Do masking, how:
with tf.variable_scope('Masking'):
# a). calc ||v_c||, then do softmax(||v_c||)
# [batch_size, 10, 16, 1] => [batch_size, 10, 1, 1]
self.v_length = tf.sqrt(reduce_sum(tf.square(self.caps2),
axis=2, keepdims=True) + epsilon)
self.softmax_v = softmax(self.v_length, axis=1)
# assert self.softmax_v.get_shape() == [cfg.batch_size, self.num_label, 1, 1]
# b). pick out the index of max softmax val of the 10 caps
# [batch_size, 10, 1, 1] => [batch_size] (index)
self.argmax_idx = tf.to_int32(tf.argmax(self.softmax_v, axis=1))
# assert self.argmax_idx.get_shape() == [cfg.batch_size, 1, 1]
self.argmax_idx = tf.reshape(self.argmax_idx, shape=(cfg.batch_size, ))
# Method 1.
if not cfg.mask_with_y:
# c). indexing
# It's not easy to understand the indexing process with argmax_idx
# as we are 3-dim animal
masked_v = []
for batch_size in range(cfg.batch_size):
v = self.caps2[batch_size][self.argmax_idx[batch_size], :]
masked_v.append(tf.reshape(v, shape=(1, 1, 16, 1)))
self.masked_v = tf.concat(masked_v, axis=0)
assert self.masked_v.get_shape() == [cfg.batch_size, 1, 16, 1]
# Method 2. masking with true label, default mode
else:
self.masked_v = tf.multiply(tf.squeeze(self.caps2), tf.reshape(self.Y, (-1, self.num_label, 1)))
self.v_length = tf.sqrt(reduce_sum(tf.square(self.caps2), axis=2, keepdims=True) + epsilon)
# 2. Reconstructe the MNIST images with 3 FC layers
# [batch_size, 1, 16, 1] => [batch_size, 16] => [batch_size, 512]
with tf.variable_scope('Decoder'):
vector_j = tf.reshape(self.masked_v, shape=(cfg.batch_size, -1))
fc1 = tf.contrib.layers.fully_connected(vector_j, num_outputs=512)
fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
self.decoded = tf.contrib.layers.fully_connected(fc2,
num_outputs=self.height * self.width * self.channels,
activation_fn=tf.sigmoid)
def loss(self):
# 1. The margin loss
# [batch_size, 10, 1, 1]
# max_l = max(0, m_plus-||v_c||)^2
max_l = tf.square(tf.maximum(0., cfg.m_plus - self.v_length))
# max_r = max(0, ||v_c||-m_minus)^2
max_r = tf.square(tf.maximum(0., self.v_length - cfg.m_minus))
assert max_l.get_shape() == [cfg.batch_size, self.num_label, 1, 1]
# reshape: [batch_size, 10, 1, 1] => [batch_size, 10]
max_l = tf.reshape(max_l, shape=(cfg.batch_size, -1))
max_r = tf.reshape(max_r, shape=(cfg.batch_size, -1))
# calc T_c: [batch_size, 10]
# T_c = Y, is my understanding correct? Try it.
T_c = self.Y
# [batch_size, 10], element-wise multiply
L_c = T_c * max_l + cfg.lambda_val * (1 - T_c) * max_r
self.margin_loss = tf.reduce_mean(tf.reduce_sum(L_c, axis=1))
# 2. The reconstruction loss
orgin = tf.reshape(self.X, shape=(cfg.batch_size, -1))
squared = tf.square(self.decoded - orgin)
self.reconstruction_err = tf.reduce_mean(squared)
# 3. Total loss
# The paper uses sum of squared error as reconstruction error, but we
# have used reduce_mean in `# 2 The reconstruction loss` to calculate
# mean squared error. In order to keep in line with the paper,the
# regularization scale should be 0.0005*784=0.392
self.total_loss = self.margin_loss + cfg.regularization_scale * self.reconstruction_err
# Summary
def _summary(self):
train_summary = []
train_summary.append(tf.summary.scalar('train/margin_loss', self.margin_loss))
train_summary.append(tf.summary.scalar('train/reconstruction_loss', self.reconstruction_err))
train_summary.append(tf.summary.scalar('train/total_loss', self.total_loss))
recon_img = tf.reshape(self.decoded, shape=(cfg.batch_size, self.height, self.width, self.channels))
train_summary.append(tf.summary.image('reconstruction_img', recon_img))
self.train_summary = tf.summary.merge(train_summary)
correct_prediction = tf.equal(tf.to_int32(self.labels), self.argmax_idx)
self.accuracy = tf.reduce_sum(tf.cast(correct_prediction, tf.float32))