-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
224 lines (187 loc) · 8.57 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTM_PM(nn.Module):
def __init__(self, n_map = (13 + 1), temporal = 5):
super(LSTM_PM, self).__init__()
self.n_map = n_map
self.temporal = temporal
self.convnet1_conv_1 = nn.Conv2d(3, 128, kernel_size = 9, padding = 4)
self.convnet1_conv_2 = nn.Conv2d(128, 128, kernel_size = 9, padding = 4)
self.convnet1_conv_3 = nn.Conv2d(128, 128, kernel_size = 9, padding = 4)
self.convnet1_conv_4 = nn.Conv2d(128, 32, kernel_size = 5, padding = 2)
self.convnet1_conv_5 = nn.Conv2d(32, 512, kernel_size = 9, padding = 4)
self.convnet1_conv_6 = nn.Conv2d(512, 512, kernel_size = 1)
self.convnet1_conv_7 = nn.Conv2d(512, self.n_map, kernel_size = 1)
self.convnet1_pool_1 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet1_pool_2 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet1_pool_3 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet2_conv_1 = nn.Conv2d(3, 128, kernel_size = 9, padding = 4)
self.convnet2_conv_2 = nn.Conv2d(128, 128, kernel_size = 9, padding = 4)
self.convnet2_conv_3 = nn.Conv2d(128, 128, kernel_size = 9, padding = 4)
self.convnet2_conv_4 = nn.Conv2d(128, 32, kernel_size = 5, padding = 2)
self.convnet2_pool_1 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet2_pool_2 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet2_pool_3 = nn.MaxPool2d(kernel_size = 3, stride = 2)
self.convnet3_conv_1 = nn.Conv2d(48, 128, kernel_size = 11, padding = 5)
self.convnet3_conv_2 = nn.Conv2d(128, 128, kernel_size = 11, padding = 5)
self.convnet3_conv_3 = nn.Conv2d(128, 128, kernel_size = 11, padding = 5)
self.convnet3_conv_4 = nn.Conv2d(128, 128, kernel_size = 1, padding = 0)
self.convnet3_conv_5 = nn.Conv2d(128, self.n_map, kernel_size = 1, padding = 0)
self.lstm1_gx = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1)
self.lstm1_ix = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1)
self.lstm1_ox = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1)
self.lstm2_ix = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1, bias = True)
self.lstm2_ih = nn.Conv2d(48, 48, kernel_size = 3, padding = 1, bias = False)
self.lstm2_fx = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1, bias = True)
self.lstm2_fh = nn.Conv2d(48, 48, kernel_size = 3, padding = 1, bias = False)
self.lstm2_ox = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1, bias = True)
self.lstm2_oh = nn.Conv2d(48, 48, kernel_size = 3, padding = 1, bias = False)
self.lstm2_gx = nn.Conv2d(32 + 1 + self.n_map, 48, kernel_size = 3, padding = 1, bias = True)
self.lstm2_gh = nn.Conv2d(48, 48, kernel_size = 3, padding = 1, bias = False)
self.central_map_pooling = nn.AvgPool2d(kernel_size = 9, stride = 8)
def convnet1(self, image):
"""
ConvNet 1: Initial feature encoder network
Input:
Image -> 3 * 368 * 368
Output:
Initial heatmap -> n_map * 45 * 45
"""
x = self.convnet1_pool_1(F.relu(self.convnet1_conv_1(image)))
x = self.convnet1_pool_2(F.relu(self.convnet1_conv_2(x)))
x = self.convnet1_pool_3(F.relu(self.convnet1_conv_3(x)))
x = F.relu(self.convnet1_conv_4(x))
x = F.relu(self.convnet1_conv_5(x))
x = F.relu(self.convnet1_conv_6(x))
x = self.convnet1_conv_7(x)
return x
def convnet2(self, image):
"""
ConvNet 2: Common feature encoder network
Input:
Image -> 3 * 368 * 368
Output:
features -> 32 * 45 * 45
"""
x = self.convnet2_pool_1(F.relu(self.convnet2_conv_1(image)))
x = self.convnet2_pool_2(F.relu(self.convnet2_conv_2(x)))
x = self.convnet2_pool_3(F.relu(self.convnet2_conv_3(x)))
x = F.relu(self.convnet2_conv_4(x))
return x
def convnet3(self, hide_t):
"""
ConvNet 3: Prediction generator network
Input:
Hidden state (t) -> 48 * 45 * 345
Output:
Heatmap -> n_map * 45 * 45
"""
x = F.relu(self.convnet3_conv_1(hide_t))
x = F.relu(self.convnet3_conv_2(x))
x = F.relu(self.convnet3_conv_3(x))
x = F.relu(self.convnet3_conv_4(x))
x = self.convnet3_conv_5(x)
return x
def lstm(self, x, hide_t_1, cell_t_1):
"""
Common (conv) LSTM unit
Inputs:
X -> ( 32 + n_map +1 ) * 45 * 45
Hidden state (t-1) -> 48 * 45 * 45
Cell state (t-1) -> 48 * 45 * 45
Outputs:
Hidden state -> 48 * 45 * 45
Cell state -> 48 * 45 * 45
"""
# Input gate
it = torch.sigmoid(self.lstm2_ix(x) + self.lstm2_ih(hide_t_1))
# Forget gate
ft = torch.sigmoid(self.lstm2_fx(x) + self.lstm2_fh(hide_t_1))
# Output gate
ot = torch.sigmoid(self.lstm2_ox(x) + self.lstm2_oh(hide_t_1))
# g = c'
gt = torch.tanh(self.lstm2_gx(x) + self.lstm2_gh(hide_t_1))
cell = ft * cell_t_1 + it * gt
hidden = ot * torch.tanh(cell)
return cell, hidden
def lstm0(self, x):
"""
Initial (conv) LSTM unit
Input:
x - >( 32 + n_map +1 ) * 45 * 45
Outputs:
Hidden state -> 48 * 45 * 45
Cell state -> 48 * 45 * 45
"""
# Input gate
ix = torch.sigmoid(self.lstm1_ix(x))
# Output gate
ox = torch.sigmoid(self.lstm1_ox(x))
# g = c'
gx = torch.tanh(self.lstm1_gx(x))
# Because there is no C(t-1) in the initial LSTM, so no need to forget-gate
cell = torch.tanh(gx * ix)
hidden = ox * cell
return cell, hidden
def initial_stage(self, image, centralmap):
"""
Initial stage
Inputs :
image - > 3 * 368 * 368
central gaussian map -> 1 * 368 * 368
Outputs :
Initial heatmap -> n_map * 45 * 45
Heatmap -> n_map * 45 * 45
Hidden state -> 48 * 45 * 45
Cell state -> 48 * 45 * 45
New central gaussian map -> 1 * 45 * 45
"""
initial_heatmap = self.convnet1(image)
features = self.convnet2(image)
centralmap = self.central_map_pooling(centralmap)
x = torch.cat([initial_heatmap, features, centralmap], dim = 1) # Lstm input in step t
cell, hidden = self.lstm0(x)
heatmap = self.convnet3(hidden)
return initial_heatmap, heatmap, cell, hidden, centralmap
def common_stage(self, image, centralmap, heatmap, cell_t_1, hide_t_1):
"""
Common stage
Inputs:
Image - > 3 * 368 * 368
Central gaussian map -> 1 * 45 * 45
Heatmap -> n_map * 45 * 45
Hidden state (t-1) -> 48 * 45 * 45
Cell state (t-1) -> 48 * 45 * 45
Outputs:
new heatmap -> n_map * 45 * 45
hidden state -> 48 * 45 * 45
cell state -> 48 * 45 * 45
"""
features = self.convnet2(image)
x = torch.cat([heatmap, features, centralmap], dim = 1) # Lstm input in step t
cell, hidden = self.lstm(x, hide_t_1, cell_t_1)
new_heat_map = self.convnet3(hidden)
return new_heat_map, cell, hidden
def forward(self, images, centralmap):
"""
Common stage
Inputs:
images - >(temporal * channels) * w * h = (t * 3) * 368 * 368
central gaussian map -> 1 * 368 * 368
Outputs:
heatmaps -> (T + 1)* n_map * 45 * 45 (+1 is for initial heat map)
"""
heat_maps = []
# Select the channels of the first frame of all the temporal sequences in the batch
image = images[:, 0:3, :, :]
# Generate heatmap and initial heatmaps of the first frames with passing them to the first (initial) stage
initial_heatmap, heatmap, cell, hide, centralmap = self.initial_stage(image, centralmap)
heat_maps.append(initial_heatmap)
heat_maps.append(heatmap)
# For the other frames in temporal sequences, we generate heatmaps by passing them to the second (common) stage
for i in range(1, self.temporal):
image = images[:, (3 * i):(3 * i + 3), :, :]
heatmap, cell, hide = self.common_stage(image, centralmap, heatmap, cell, hide)
heat_maps.append(heatmap)
return heat_maps