-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/lightning #35
base: master
Are you sure you want to change the base?
Changes from 15 commits
6110298
08f5b9a
3fa57cd
e88f076
c0823ec
923b6fc
a5a6d1a
990c0c7
81650ae
fb37e03
fa10298
0ed3376
ebd40f1
1aa67ca
452a5ab
9b1e5a4
4817231
d91bb1c
492b79f
a796899
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,140 @@ | ||||||||||||
""" | ||||||||||||
Temporary benchmarking script while integrating Lightning, will remove before merge to master | ||||||||||||
""" | ||||||||||||
|
||||||||||||
import os | ||||||||||||
import time | ||||||||||||
import math | ||||||||||||
import logging | ||||||||||||
import argparse | ||||||||||||
|
||||||||||||
import numpy as np | ||||||||||||
import torch | ||||||||||||
from torch.utils.data import Dataset | ||||||||||||
from torch.utils.data.dataloader import DataLoader | ||||||||||||
import torch.backends.cudnn as cudnn | ||||||||||||
|
||||||||||||
from mingpt.model import GPT | ||||||||||||
from mingpt.lr_decay import WarmupCosineLearningRateDecay | ||||||||||||
from mingpt.utils import sample | ||||||||||||
|
||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||
logging.basicConfig( | ||||||||||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | ||||||||||||
datefmt="%m/%d/%Y %H:%M:%S", | ||||||||||||
level=logging.INFO, | ||||||||||||
) | ||||||||||||
|
||||||||||||
torch.backends.cudnn.benchmark = True # autotune kernels | ||||||||||||
|
||||||||||||
# ----------------------------------------------------------------------------- | ||||||||||||
import os | ||||||||||||
if int(os.environ.get('USE_LIGHTNING', 0)): | ||||||||||||
logging.info("USING LIGHTNING!!") | ||||||||||||
import pytorch_lightning as pl | ||||||||||||
else: | ||||||||||||
import mingpt.fake_lightning as pl | ||||||||||||
logging.info("using our humble trainer") | ||||||||||||
# ----------------------------------------------------------------------------- | ||||||||||||
|
||||||||||||
class Text8Dataset(Dataset): | ||||||||||||
""" | ||||||||||||
e.g. Text8 dataset is often used: http://mattmahoney.net/dc/textdata.html | ||||||||||||
Vocabulary is lowercase English characters and space for total of 27. | ||||||||||||
Training data: First 90M characters. | ||||||||||||
Validation data: First 5M characters out of the last 10M characters. | ||||||||||||
Testing data: Last 5M characters. | ||||||||||||
""" | ||||||||||||
|
||||||||||||
def __init__(self, data_path, block_size, crop=None, override_vocab=None): | ||||||||||||
|
||||||||||||
# load the data and crop it appropriately | ||||||||||||
with open(data_path, 'r') as f: | ||||||||||||
if crop is None: | ||||||||||||
data = f.read() | ||||||||||||
else: | ||||||||||||
f.seek(crop[0]) | ||||||||||||
data = f.read(crop[1]) | ||||||||||||
|
||||||||||||
# build a vocabulary from data or inherit it | ||||||||||||
vocab = sorted(list(set(data))) if override_vocab is None else override_vocab | ||||||||||||
data_size, vocab_size = len(data), len(vocab) | ||||||||||||
logging.info('data of crop %s has %d characters, vocab of size %d.' % (str(crop), data_size, vocab_size)) | ||||||||||||
|
||||||||||||
self.stoi = { ch:i for i,ch in enumerate(vocab) } | ||||||||||||
self.itos = { i:ch for i,ch in enumerate(vocab) } | ||||||||||||
self.block_size = block_size | ||||||||||||
self.vocab_size = vocab_size | ||||||||||||
self.data = data | ||||||||||||
self.vocab = vocab | ||||||||||||
|
||||||||||||
def __len__(self): | ||||||||||||
return len(self.data) // self.block_size | ||||||||||||
|
||||||||||||
def __getitem__(self, idx): | ||||||||||||
# attempt to fetch a chunk of (block_size + 1) items, but (block_size) will work too | ||||||||||||
chunk = self.data[idx*self.block_size : min(len(self.data), (idx+1)*self.block_size + 1)] | ||||||||||||
# map the string into a sequence of integers | ||||||||||||
ixes = [self.stoi[s] for s in chunk] | ||||||||||||
# if stars align (last idx and len(self.data) % self.block_size == 0), pad with -100, to skip training at the last position | ||||||||||||
if len(ixes) < self.block_size + 1: | ||||||||||||
assert len(ixes) == self.block_size # i believe this is the only way this could happen, make sure | ||||||||||||
ixes.append(-100) | ||||||||||||
dix = torch.tensor(ixes, dtype=torch.long) | ||||||||||||
return dix[:-1], dix[1:] | ||||||||||||
|
||||||||||||
# ----------------------------------------------------------------------------- | ||||||||||||
|
||||||||||||
parser = argparse.ArgumentParser() | ||||||||||||
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for") | ||||||||||||
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with") | ||||||||||||
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)") | ||||||||||||
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading") | ||||||||||||
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on") | ||||||||||||
parser.add_argument('-p', '--pin-memory', type=int, default=1, help="pin memory on dataloaders?") | ||||||||||||
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16") | ||||||||||||
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location") | ||||||||||||
args = parser.parse_args() | ||||||||||||
print(vars(args)) | ||||||||||||
|
||||||||||||
logging.info("preparing the data loaders") | ||||||||||||
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER | ||||||||||||
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(1e6))) | ||||||||||||
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(1e5)), override_vocab=train_dataset.vocab) | ||||||||||||
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(1e5)), override_vocab=train_dataset.vocab) | ||||||||||||
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers} | ||||||||||||
train_dataloader = DataLoader(train_dataset, shuffle=True, **common) | ||||||||||||
val_dataloader = DataLoader(val_dataset, shuffle=False, **common) | ||||||||||||
test_dataloader = DataLoader(test_dataset, shuffle=False, **common) | ||||||||||||
|
||||||||||||
logging.info("creating the model") | ||||||||||||
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=6, n_head=8, n_embd=256) | ||||||||||||
|
||||||||||||
logging.info("preparing the learning rate schedule") | ||||||||||||
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration | ||||||||||||
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens | ||||||||||||
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2, | ||||||||||||
final_tokens=args.num_epochs*epoch_tokens) | ||||||||||||
|
||||||||||||
t0 = time.time() | ||||||||||||
logging.info("training...") | ||||||||||||
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay], | ||||||||||||
precision=args.precision, default_root_dir=args.default_root_dir) | ||||||||||||
trainer.fit(model, train_dataloader, val_dataloader) | ||||||||||||
t1 = time.time() | ||||||||||||
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs) | ||||||||||||
|
||||||||||||
# todo below: I don't yet understand the Lightning checkpoint schema | ||||||||||||
# logging.info("testing...") | ||||||||||||
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt') | ||||||||||||
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found | ||||||||||||
# trainer.test(test_dataloader=test_dataloader) | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it. it looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes. we enable multiple dataloaders for val and test. coming support for train. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some people always need something, which is why frameworks are so hard. Next thing you know you can't use a list of data loaders and have to introduce a |
||||||||||||
|
||||||||||||
logging.info("sampling:") | ||||||||||||
context = "anarchism originated as a term of" | ||||||||||||
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...] | ||||||||||||
if next(model.parameters()).is_cuda: | ||||||||||||
x = x.cuda() | ||||||||||||
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0] | ||||||||||||
completion = ''.join([train_dataset.itos[int(i)] for i in y]) | ||||||||||||
print(completion) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
""" | ||
A manual, minimal and non-full-featured implementation of boilerplate training loop. | ||
Intentionally made to have the same API as PyTorch Lightning, giving two benefits: | ||
1) Everyone can inspect/hack this simple implementation for educational purposes | ||
2) Everyone can run the full Lightning implementation when they just want to go FAST | ||
""" | ||
|
||
import os | ||
import math | ||
import logging | ||
|
||
from tqdm import tqdm | ||
import torch | ||
import torch.nn as nn | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# ----------------------------------------------------------------------------- | ||
|
||
class Result: | ||
""" very thin wrapper around a result of a train/val/test step of the model """ | ||
def __init__(self, minimize=None, checkpoint_on=None): | ||
self.minimize = minimize | ||
self.checkpoint_on = checkpoint_on | ||
|
||
def log(self, key, val): | ||
setattr(self, key, val) | ||
|
||
class TrainResult(Result): | ||
pass | ||
|
||
class EvalResult(Result): | ||
pass | ||
|
||
class LightningModule(nn.Module): | ||
|
||
def load_from_checkpoint(self, checkpoint_path): | ||
logger.info("loading the best model checkpoint from %s", checkpoint_path) | ||
state_dict = torch.load(checkpoint_path) | ||
self.load_state_dict(state_dict) | ||
|
||
class Callback: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry! this is 100% optional. This is a new addition and I see we forgot to include the simple case and doc examples using a dict or the loss directly There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it, ok converted to use of dicts with latest commit |
||
|
||
# ----------------------------------------------------------------------------- | ||
""" | ||
Simple Trainer object; Boilerplate that could apply to any arbitrary neural network, | ||
so nothing here really has anything to do with GPT specifically. This is a | ||
very basic Trainer class that will only train the model on up to one GPU. | ||
""" | ||
|
||
class Trainer: | ||
|
||
def __init__(self, max_epochs, gpus=0, gradient_clip_val=None, default_root_dir='.', callbacks=None, | ||
precision=32, **kwargs): | ||
self.gpus = gpus | ||
self.max_epochs = max_epochs | ||
self.gradient_clip_val = gradient_clip_val | ||
self.callbacks = [] if callbacks is None else callbacks | ||
self.model = None | ||
|
||
if default_root_dir is not None: | ||
os.makedirs(default_root_dir, exist_ok = True) | ||
self.default_root_dir = default_root_dir | ||
|
||
if self.gpus > 1: | ||
logger.error("This simple Trainer does not support > 1 GPUs, will just use one.") | ||
|
||
if precision != 32: | ||
logger.error("This simple Trainer does not support non-fp32 precision, will use fp32") | ||
|
||
def save_checkpoint(self): | ||
ckpt_path = os.path.join(self.default_root_dir, 'model.pt') | ||
logger.info("saving model checkpoint to %s", ckpt_path) | ||
torch.save(self.model.state_dict(), ckpt_path) | ||
|
||
def eval_split_(self, dataloader, split): | ||
|
||
self.model.eval() | ||
use_gpu = self.gpus > 0 and torch.cuda.is_available() | ||
losses = [] | ||
for it, (x, y) in enumerate(dataloader): | ||
# place data on the correct device | ||
if use_gpu: | ||
x, y = x.cuda(), y.cuda() | ||
# forward the model | ||
with torch.no_grad(): | ||
if split == 'val': | ||
result = self.model.validation_step((x, y)) | ||
loss = result.val_loss | ||
elif split == 'test': | ||
result = self.model.test_step((x, y)) | ||
loss = result.test_loss | ||
losses.append(loss.item()) | ||
mean_loss = torch.mean(torch.tensor(losses)).item() | ||
logger.info("%s loss: %f", split, mean_loss) | ||
return mean_loss | ||
|
||
def test(self, test_dataloader): | ||
return self.eval_split_(test_dataloader, 'test') | ||
|
||
def val(self, val_dataloader): | ||
return self.eval_split_(val_dataloader, 'val') | ||
|
||
def fit(self, model, train_dataloader, val_dataloader=None): | ||
self.model = model # bind model to the class here | ||
self.model.train() | ||
|
||
# ship model to gpu if possible | ||
use_gpu = self.gpus > 0 and torch.cuda.is_available() | ||
if use_gpu: | ||
logger.info("found CUDA device, shipping model to GPU") | ||
self.model.cuda() | ||
|
||
# prepare the optimizer | ||
optimizer = self.model.configure_optimizers() | ||
self.optimizers = [optimizer] | ||
|
||
# start the training loop | ||
best_val_loss = float('inf') | ||
for epoch in range(self.max_epochs): | ||
|
||
# do an epoch of training | ||
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader)) | ||
for it, (x, y) in pbar: | ||
|
||
# place data on the correct device | ||
if use_gpu: | ||
x, y = x.cuda(), y.cuda() | ||
|
||
# forward the model | ||
result = self.model.training_step((x, y)) | ||
loss = result.minimize | ||
|
||
# reset gradient | ||
for param in self.model.parameters(): | ||
param.grad = None # a faster alternative to model.zero_grad() | ||
|
||
# backward pass | ||
loss.backward() | ||
|
||
# clip the gradient to mitigate loss explosions | ||
if self.gradient_clip_val is not None: | ||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val) | ||
|
||
# update all parameters | ||
optimizer.step() # todo: use fused optimizer | ||
|
||
# notify all relevant callbacks that a batch update ended. e.g. a callback may decay learning rate | ||
for cb in self.callbacks: | ||
if hasattr(cb, 'on_train_batch_end'): | ||
cb.on_train_batch_end(self, None, (x, y)) | ||
|
||
# report progress | ||
lr = optimizer.param_groups[0]['lr'] | ||
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}") | ||
|
||
# calculate the current validation loss and checkpoint the model for early stopping | ||
if val_dataloader is not None: | ||
val_loss = self.val(val_dataloader) | ||
if (self.default_root_dir is not None) and (val_loss < best_val_loss): | ||
best_val_loss = val_loss | ||
self.save_checkpoint() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import math | ||
|
||
# ----------------------------------------------------------------------------- | ||
import os | ||
if int(os.environ.get('USE_LIGHTNING', 0)): | ||
import pytorch_lightning as pl | ||
else: | ||
import mingpt.fake_lightning as pl | ||
# ----------------------------------------------------------------------------- | ||
|
||
class WarmupCosineLearningRateDecay(pl.Callback): | ||
""" | ||
based on the number of tokens seen during training will adjust the learning rate: | ||
1. first it will start at zero and gradually ramp up to full learning rate | ||
2. then it will decay down with the cosine learning rate decay down until 10% of original | ||
""" | ||
|
||
def __init__(self, learning_rate, warmup_tokens, final_tokens): | ||
super().__init__() | ||
self.learning_rate = learning_rate | ||
self.warmup_tokens = warmup_tokens | ||
self.final_tokens = final_tokens | ||
# state in this class, will count number of tokens processed so far | ||
self.tokens = 0 | ||
|
||
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx=None, dataloader_idx=None): | ||
_, y = batch | ||
self.tokens += (y >= 0).sum() # y == -100 is "ignore", so don't count these | ||
if self.tokens < self.warmup_tokens: | ||
# linear warmup | ||
lr_mult = float(self.tokens) / float(max(1, self.warmup_tokens)) | ||
else: | ||
# followed by cosine learning rate decay | ||
progress = float(self.tokens - self.warmup_tokens) / float( | ||
max(1, self.final_tokens - self.warmup_tokens)) | ||
lr_mult = 0.1 + 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
lr = self.learning_rate * lr_mult | ||
for optimizer in trainer.optimizers: | ||
for param_group in optimizer.param_groups: | ||
param_group['lr'] = lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Datamodules can do:
Which lets you do things like:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neat! I'll have to read more of the docs