Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lightning #35

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6110298
step 1: free the GPT module of config and flatten out the args. WIP, …
karpathy Aug 29, 2020
08f5b9a
refactor out the learning rate decay class as a callback
karpathy Aug 29, 2020
3fa57cd
data loaders are passed directly to fit() instead of the datasets
karpathy Aug 29, 2020
e88f076
data loaders are passed directly to fit() instead of the dataset, ver…
karpathy Aug 29, 2020
c0823ec
model is also passed into fit() instead of __init__ ,sure.
karpathy Aug 29, 2020
923b6fc
and finally get rid of Config object for the Trainer
karpathy Aug 29, 2020
a5a6d1a
add training_step to the model and remove DataParallel functionality …
karpathy Aug 29, 2020
990c0c7
final integration pieces, now runs with both, but it ain't super pret…
karpathy Aug 29, 2020
81650ae
one more refactor, this is better because the equivalence to lightnin…
karpathy Aug 29, 2020
fb37e03
refactor into a datamodule, attempt number 1
karpathy Aug 29, 2020
fa10298
use a standard benchmark (text8) and implement train/val/test splits
karpathy Aug 30, 2020
0ed3376
move instantiation of text dataset into the constructor so we don't h…
karpathy Aug 30, 2020
ebd40f1
support fp16/32 precision in bench
karpathy Aug 30, 2020
1aa67ca
switch to a faster version of zero_grad()
karpathy Aug 30, 2020
452a5ab
massive refactor yet again. this was all probably a pretty bad idea
karpathy Aug 30, 2020
9b1e5a4
delete Result structs in favor of dicts
karpathy Aug 30, 2020
4817231
testing now works with both lightning and minLightning
karpathy Aug 30, 2020
d91bb1c
make labels non-blocking transfer to overlap them, but i don't really…
karpathy Aug 30, 2020
492b79f
get rid of spurious function for the model
karpathy Aug 30, 2020
a796899
reorg the bench code to support multigpu training, have to indent pro…
karpathy Aug 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""
Temporary benchmarking script while integrating Lightning, will remove before merge to master
"""

import os
import time
import math
import logging
import argparse

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn

from mingpt.model import GPT
from mingpt.lr_decay import WarmupCosineLearningRateDecay
from mingpt.utils import sample

# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
# -----------------------------------------------------------------------------

class Text8Dataset(Dataset):
"""
e.g. Text8 dataset is often used: http://mattmahoney.net/dc/textdata.html
Vocabulary is lowercase English characters and space for total of 27.
Training data: First 90M characters.
Validation data: First 5M characters out of the last 10M characters.
Testing data: Last 5M characters.
"""

def __init__(self, data_path, block_size, crop=None, override_vocab=None):

# load the data and crop it appropriately
with open(data_path, 'r') as f:
if crop is None:
data = f.read()
else:
f.seek(crop[0])
data = f.read(crop[1])

# build a vocabulary from data or inherit it
vocab = sorted(list(set(data))) if override_vocab is None else override_vocab
data_size, vocab_size = len(data), len(vocab)
logging.info('data of crop %s has %d characters, vocab of size %d.' % (str(crop), data_size, vocab_size))

self.stoi = { ch:i for i,ch in enumerate(vocab) }
self.itos = { i:ch for i,ch in enumerate(vocab) }
self.block_size = block_size
self.vocab_size = vocab_size
self.data = data
self.vocab = vocab

def __len__(self):
return len(self.data) // self.block_size

def __getitem__(self, idx):
# attempt to fetch a chunk of (block_size + 1) items, but (block_size) will work too
chunk = self.data[idx*self.block_size : min(len(self.data), (idx+1)*self.block_size + 1)]
# map the string into a sequence of integers
ixes = [self.stoi[s] for s in chunk]
# if stars align (last idx and len(self.data) % self.block_size == 0), pad with -100, to skip training at the last position
if len(ixes) < self.block_size + 1:
assert len(ixes) == self.block_size # i believe this is the only way this could happen, make sure
ixes.append(-100)
dix = torch.tensor(ixes, dtype=torch.long)
return dix[:-1], dix[1:]

# -----------------------------------------------------------------------------

if __name__ == '__main__':

parser = argparse.ArgumentParser()
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")
parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
parser.add_argument('-p', '--pin-memory', type=int, default=0, help="pin memory on dataloaders?")
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
args = parser.parse_args()
print(vars(args))

logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

torch.backends.cudnn.benchmark = True # autotune kernels

logging.info("preparing the data loaders")
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(90e6)))
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(5e6)), override_vocab=train_dataset.vocab)
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(5e6)), override_vocab=train_dataset.vocab)
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)

logging.info("creating the model")
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=8, n_head=8, n_embd=256)

logging.info("preparing the learning rate schedule")
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
final_tokens=args.num_epochs*epoch_tokens)

t0 = time.time()
logging.info("training...")
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
precision=args.precision, default_root_dir=args.default_root_dir)
trainer.fit(model, train_dataloader, val_dataloader)
t1 = time.time()
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)

logging.info("testing...")
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)
trainer.test(test_dataloaders=test_dataloader)

logging.info("sampling:")
context = "anarchism originated as a term of"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
if next(model.parameters()).is_cuda:
x = x.cuda()
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
150 changes: 150 additions & 0 deletions mingpt/fake_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
A manual, minimal and non-full-featured implementation of boilerplate training loop.
Intentionally made to have the same API as PyTorch Lightning, giving two benefits:
1) Everyone can inspect/hack this simple implementation for educational purposes
2) Everyone can run the full Lightning implementation when they just want to go FAST
"""

import os
import math
import logging

from tqdm import tqdm
import torch
import torch.nn as nn

logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------

class LightningModule(nn.Module):
pass

class Callback:
pass

# -----------------------------------------------------------------------------
"""
Simple Trainer object; Boilerplate that could apply to any arbitrary neural network,
so nothing here really has anything to do with GPT specifically. This is a
very basic Trainer class that will only train the model on up to one GPU.
"""

class Trainer:

def __init__(self, max_epochs, gpus=0, gradient_clip_val=None, default_root_dir='.', callbacks=None,
precision=32, **kwargs):
self.gpus = gpus
self.max_epochs = max_epochs
self.gradient_clip_val = gradient_clip_val
self.callbacks = [] if callbacks is None else callbacks
self.model = None

if default_root_dir is not None:
os.makedirs(default_root_dir, exist_ok = True)
self.default_root_dir = default_root_dir

if self.gpus > 1:
logger.error("This simple Trainer does not support > 1 GPUs, will just use one.")

if precision != 32:
logger.error("This simple Trainer does not support non-fp32 precision, will use fp32")

def save_checkpoint(self):
ckpt_path = os.path.join(self.default_root_dir, 'model.pt')
logger.info("saving model checkpoint to %s", ckpt_path)
torch.save(self.model.state_dict(), ckpt_path)

def load_checkpoint(self):
ckpt_path = os.path.join(self.default_root_dir, 'model.pt')
logger.info("loading model from %s", ckpt_path)
state_dict = torch.load(ckpt_path)
self.model.load_state_dict(state_dict)

def eval_split_(self, dataloader, split):

self.model.eval()
use_gpu = self.gpus > 0 and torch.cuda.is_available()
losses = []
for it, (x, y) in enumerate(dataloader):
# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda(non_blocking=True)
# forward the model
with torch.no_grad():
if split == 'val':
result = self.model.validation_step((x, y))
elif split == 'test':
result = self.model.test_step((x, y))
losses.append(result['loss'].item())
mean_loss = torch.mean(torch.tensor(losses)).item()

logger.info("%s loss: %f", split, mean_loss)
return mean_loss

def test(self, test_dataloaders): # note we expect a list of dataloaders here
self.load_checkpoint() # load the best checkpoint we found during optimization
return self.eval_split_(test_dataloaders, 'test')

def val(self, val_dataloader):
return self.eval_split_(val_dataloader, 'val')

def fit(self, model, train_dataloader, val_dataloader=None):
self.model = model # bind model to the class here
self.model.train()

# ship model to gpu if possible
use_gpu = self.gpus > 0 and torch.cuda.is_available()
if use_gpu:
logger.info("found CUDA device, shipping model to GPU")
self.model.cuda()

# prepare the optimizer
optimizer = self.model.configure_optimizers()
self.optimizers = [optimizer]

# start the training loop
best_val_loss = float('inf')
for epoch in range(self.max_epochs):

# do an epoch of training
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
for it, (x, y) in pbar:

# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda(non_blocking=True)

# forward the model
result = self.model.training_step((x, y))
loss = result['loss']

# reset gradient
for param in self.model.parameters():
param.grad = None # a faster alternative to model.zero_grad()

# backward pass
loss.backward()

# clip the gradient to mitigate loss explosions
if self.gradient_clip_val is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)

# update all parameters
optimizer.step() # todo: use fused optimizer

# notify all relevant callbacks that a batch update ended. e.g. a callback may decay learning rate
for cb in self.callbacks:
if hasattr(cb, 'on_train_batch_end'):
cb.on_train_batch_end(self, None, (x, y))

# report progress
lr = optimizer.param_groups[0]['lr']
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

# calculate the current validation loss and checkpoint the model for early stopping
if val_dataloader is not None:
val_loss = self.val(val_dataloader)
if (self.default_root_dir is not None) and (val_loss < best_val_loss):
best_val_loss = val_loss
self.save_checkpoint()
40 changes: 40 additions & 0 deletions mingpt/lr_decay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import math

# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
# -----------------------------------------------------------------------------

class WarmupCosineLearningRateDecay(pl.Callback):
"""
based on the number of tokens seen during training will adjust the learning rate:
1. first it will start at zero and gradually ramp up to full learning rate
2. then it will decay down with the cosine learning rate decay down until 10% of original
"""

def __init__(self, learning_rate, warmup_tokens, final_tokens):
super().__init__()
self.learning_rate = learning_rate
self.warmup_tokens = warmup_tokens
self.final_tokens = final_tokens
# state in this class, will count number of tokens processed so far
self.tokens = 0

def on_train_batch_end(self, trainer, pl_module, batch, batch_idx=None, dataloader_idx=None):
_, y = batch
self.tokens += (y >= 0).sum() # y == -100 is "ignore", so don't count these
if self.tokens < self.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, self.warmup_tokens))
else:
# followed by cosine learning rate decay
progress = float(self.tokens - self.warmup_tokens) / float(
max(1, self.final_tokens - self.warmup_tokens))
lr_mult = 0.1 + 0.5 * (1.0 + math.cos(math.pi * progress))
lr = self.learning_rate * lr_mult
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Loading