-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
265 lines (229 loc) · 11.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import inspect
import math
from dataclasses import dataclass
from pathlib import Path
import pytorch_lightning as pl
import torch
import torch.nn as nn
from matplotlib import pyplot as plt
from torch.nn import functional as F
@dataclass
class GPTConfig:
block_size: int = 512
vocab_size: int = 1024
n_layer: int = 3
n_head: int = 3
n_embd: int = 48
d_in: int = 5
dropout: float = 0.1
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
class CausalSelfAttention(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
# output projection
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
# regularization
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
self.dropout = config.dropout
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None,
dropout_p=self.dropout if self.training else 0,
is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
class MLP(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.tanh = nn.Tanh()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = self.tanh(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
class Block(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd, bias=config.bias)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class GPT(nn.Module):
def __init__(self, config: GPTConfig):
super().__init__()
assert config.vocab_size is not None
assert config.block_size is not None
assert config.n_embd % 2 == 0, "n_embd must be even"
self.config = config
self.transformer = nn.ModuleDict(dict(
wte=nn.Linear(config.d_in - 1, config.n_embd // 2, bias=config.bias),
wte2=nn.Embedding(config.vocab_size, config.n_embd // 2),
wpe=nn.Embedding(config.block_size, config.n_embd),
drop=nn.Dropout(config.dropout),
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f=nn.LayerNorm(config.n_embd, bias=config.bias),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
# init all weights
self.apply(self._init_weights)
# apply special scaled init to the residual projections, per GPT-2 paper
for pn, p in self.named_parameters():
if pn.endswith('c_proj.weight'):
torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None, mask=None):
device = idx.device
b, t, _ = idx.size()
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
# forward the GPT model itself
num_emb = self.transformer.wte(idx[:, :, :-1]) # numerical "embeddings"
tok_emb = self.transformer.wte2(idx[:, :, -1].to(dtype=torch.long)) # token embeddings
emb = torch.cat((num_emb, tok_emb), dim=-1)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
x = self.transformer.drop(emb + pos_emb)
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
if targets is not None:
# if we are given some desired targets also calculate the loss
logits = self.lm_head(x)
if mask is not None:
logits = logits[~mask]
targets = targets[~mask]
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
loss = None
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, exog, temperature=1.0, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
assert exog.size(2) == self.config.d_in - 1, f"Expected {self.config.d_in - 1} exogenous features, got {exog.size(0)}"
assert exog.size(1) >= max_new_tokens, f"Expected exogenous features to cover at least {max_new_tokens} time steps, got {exog.size(1)}"
for i in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution
idx_next = torch.multinomial(probs, num_samples=1)
idx_next = torch.cat((exog[:, i, :], idx_next), dim=1).unsqueeze(0)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
class MLControlsSim(pl.LightningModule):
def __init__(
self,
n_layers: int,
n_head: int,
n_embd: int,
lr: float,
weight_decay: float,
):
super().__init__()
self.lr = lr
self.weight_decay = weight_decay
self.model = GPT(GPTConfig(n_layer=n_layers, n_head=n_head, n_embd=n_embd))
self.save_hyperparameters()
def forward(self, *args):
logits, loss = self.model(*args)
return logits, loss
def training_step(self, batch, batch_idx):
logits, loss = self(*batch)
self.log('train_loss', loss, on_step=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
logits, loss = self(*batch)
self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
if batch_idx % 5 == 0:
self.plot_predictions(*batch)
return loss
def plot_predictions(self, x, y, mask) -> None:
dm = self.trainer.datamodule
idx = x[[0], :dm.CONTEXT_SIZE, :]
exog = x[[0], dm.CONTEXT_SIZE:, :-1]
preds = self.model.generate(
idx=idx,
exog=exog,
max_new_tokens=20,
)
y_pred = preds[0, :, -1].cpu().numpy().astype(int)
y_pred = dm.tokenizer.decode(y_pred)
y_true = y[0, :].cpu().numpy().astype(int)[:len(y_pred)]
y_true = dm.tokenizer.decode(y_true)
plt.plot(y_true, label="True")
plt.plot(y_pred, label="Pred")
plt.xlabel("Time step")
plt.ylabel("LatAccel")
plt.legend()
save_path = f"{self.logger.log_dir}/val_predictions/{self.current_epoch:03}.png"
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
plt.savefig(save_path)
plt.close()
def configure_optimizers(self):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': self.weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
use_fused = fused_available and self.device.type == 'cuda'
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(optim_groups, lr=self.lr, betas=(0.9, 0.95), **extra_args)
print(f"using fused AdamW: {use_fused}")
return optimizer