-
Notifications
You must be signed in to change notification settings - Fork 1
/
encoder.py
141 lines (103 loc) · 3.98 KB
/
encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
def tie_weights(src, trg):
assert type(src) == type(trg)
trg.weight = src.weight
trg.bias = src.bias
### Works with Finism SAC
# OUT_DIM = {2: 39, 4: 35, 6: 31}
# OUT_DIM_84 = {2: 39, 4: 35, 6: 31}
# OUT_DIM_100 = {2: 39, 4: 43, 6: 31}
### Works with RAD SAC
OUT_DIM = {2: 39, 4: 35, 6: 31}
OUT_DIM_84 = {2: 29, 4: 35, 6: 21}
OUT_DIM_100 = {4: 47}
class PixelEncoder(nn.Module):
"""Convolutional encoder of pixels observations."""
def __init__(self, obs_shape, feature_dim, num_layers=2, num_filters=32,output_logits=False):
super().__init__()
assert len(obs_shape) == 3
self.obs_shape = obs_shape
self.feature_dim = feature_dim
self.num_layers = num_layers
# try 2 5x5s with strides 2x2. with samep adding, it should reduce 84 to 21, so with valid, it should be even smaller than 21.
self.convs = nn.ModuleList(
[nn.Conv2d(obs_shape[0], num_filters, 3, stride=2)]
)
for i in range(num_layers - 1):
self.convs.append(nn.Conv2d(num_filters, num_filters, 3, stride=1))
if obs_shape[-1] == 100:
assert num_layers in OUT_DIM_100
out_dim = OUT_DIM_100[num_layers]
elif obs_shape[-1] == 84:
out_dim = OUT_DIM_84[num_layers]
else:
out_dim = OUT_DIM[num_layers]
self.fc = nn.Linear(num_filters * out_dim * out_dim, self.feature_dim)
self.ln = nn.LayerNorm(self.feature_dim)
self.outputs = dict()
self.output_logits = output_logits
def reparameterize(self, mu, logstd):
std = torch.exp(logstd)
eps = torch.randn_like(std)
return mu + eps * std
def forward_conv(self, obs):
if obs.max() > 1.:
obs = obs / 255.
self.outputs['obs'] = obs
conv = torch.relu(self.convs[0](obs))
self.outputs['conv1'] = conv
for i in range(1, self.num_layers):
conv = torch.relu(self.convs[i](conv))
self.outputs['conv%s' % (i + 1)] = conv
h = conv.view(conv.size(0), -1)
return h
def forward(self, obs, detach=False):
h = self.forward_conv(obs)
if detach:
h = h.detach()
h_fc = self.fc(h)
self.outputs['fc'] = h_fc
h_norm = self.ln(h_fc)
self.outputs['ln'] = h_norm
if self.output_logits:
out = h_norm
else:
out = torch.tanh(h_norm)
self.outputs['tanh'] = out
return out
def copy_conv_weights_from(self, source):
"""Tie convolutional layers"""
# only tie conv layers
for i in range(self.num_layers):
tie_weights(src=source.convs[i], trg=self.convs[i])
def log(self, L, step, log_freq):
if step % log_freq != 0:
return
for k, v in self.outputs.items():
L.log_histogram('train_encoder/%s_hist' % k, v, step)
if len(v.shape) > 2:
L.log_image('train_encoder/%s_img' % k, v[0], step)
for i in range(self.num_layers):
L.log_param('train_encoder/conv%s' % (i + 1), self.convs[i], step)
L.log_param('train_encoder/fc', self.fc, step)
L.log_param('train_encoder/ln', self.ln, step)
class IdentityEncoder(nn.Module):
def __init__(self, obs_shape, feature_dim, num_layers, num_filters,*args):
super().__init__()
assert len(obs_shape) == 1
self.feature_dim = obs_shape[0]
def forward(self, obs, detach=False):
return obs
def copy_conv_weights_from(self, source):
pass
def log(self, L, step, log_freq):
pass
_AVAILABLE_ENCODERS = {'pixel': PixelEncoder, 'identity': IdentityEncoder}
def make_encoder(
encoder_type, obs_shape, feature_dim, num_layers, num_filters, output_logits=False
):
assert encoder_type in _AVAILABLE_ENCODERS
return _AVAILABLE_ENCODERS[encoder_type](
obs_shape, feature_dim, num_layers, num_filters, output_logits
)