Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lightning #35

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6110298
step 1: free the GPT module of config and flatten out the args. WIP, …
karpathy Aug 29, 2020
08f5b9a
refactor out the learning rate decay class as a callback
karpathy Aug 29, 2020
3fa57cd
data loaders are passed directly to fit() instead of the datasets
karpathy Aug 29, 2020
e88f076
data loaders are passed directly to fit() instead of the dataset, ver…
karpathy Aug 29, 2020
c0823ec
model is also passed into fit() instead of __init__ ,sure.
karpathy Aug 29, 2020
923b6fc
and finally get rid of Config object for the Trainer
karpathy Aug 29, 2020
a5a6d1a
add training_step to the model and remove DataParallel functionality …
karpathy Aug 29, 2020
990c0c7
final integration pieces, now runs with both, but it ain't super pret…
karpathy Aug 29, 2020
81650ae
one more refactor, this is better because the equivalence to lightnin…
karpathy Aug 29, 2020
fb37e03
refactor into a datamodule, attempt number 1
karpathy Aug 29, 2020
fa10298
use a standard benchmark (text8) and implement train/val/test splits
karpathy Aug 30, 2020
0ed3376
move instantiation of text dataset into the constructor so we don't h…
karpathy Aug 30, 2020
ebd40f1
support fp16/32 precision in bench
karpathy Aug 30, 2020
1aa67ca
switch to a faster version of zero_grad()
karpathy Aug 30, 2020
452a5ab
massive refactor yet again. this was all probably a pretty bad idea
karpathy Aug 30, 2020
9b1e5a4
delete Result structs in favor of dicts
karpathy Aug 30, 2020
4817231
testing now works with both lightning and minLightning
karpathy Aug 30, 2020
d91bb1c
make labels non-blocking transfer to overlap them, but i don't really…
karpathy Aug 30, 2020
492b79f
get rid of spurious function for the model
karpathy Aug 30, 2020
a796899
reorg the bench code to support multigpu training, have to indent pro…
karpathy Aug 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""
Temporary benchmarking script while integrating Lightning, will remove before merge to master
"""

import os
import time
import math
import logging
import argparse

import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.backends.cudnn as cudnn

from mingpt.model import GPT
from mingpt.lr_decay import WarmupCosineLearningRateDecay
from mingpt.utils import sample

logger = logging.getLogger(__name__)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)

torch.backends.cudnn.benchmark = True # autotune kernels

# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
logging.info("USING LIGHTNING!!")
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
logging.info("using our humble trainer")
# -----------------------------------------------------------------------------

class Text8Dataset(Dataset):
"""
e.g. Text8 dataset is often used: http://mattmahoney.net/dc/textdata.html
Vocabulary is lowercase English characters and space for total of 27.
Training data: First 90M characters.
Validation data: First 5M characters out of the last 10M characters.
Testing data: Last 5M characters.
"""

def __init__(self, data_path, block_size, crop=None, override_vocab=None):

# load the data and crop it appropriately
with open(data_path, 'r') as f:
if crop is None:
data = f.read()
else:
f.seek(crop[0])
data = f.read(crop[1])

# build a vocabulary from data or inherit it
vocab = sorted(list(set(data))) if override_vocab is None else override_vocab
data_size, vocab_size = len(data), len(vocab)
logging.info('data of crop %s has %d characters, vocab of size %d.' % (str(crop), data_size, vocab_size))

self.stoi = { ch:i for i,ch in enumerate(vocab) }
self.itos = { i:ch for i,ch in enumerate(vocab) }
self.block_size = block_size
self.vocab_size = vocab_size
self.data = data
self.vocab = vocab

def __len__(self):
return len(self.data) // self.block_size

def __getitem__(self, idx):
# attempt to fetch a chunk of (block_size + 1) items, but (block_size) will work too
chunk = self.data[idx*self.block_size : min(len(self.data), (idx+1)*self.block_size + 1)]
# map the string into a sequence of integers
ixes = [self.stoi[s] for s in chunk]
# if stars align (last idx and len(self.data) % self.block_size == 0), pad with -100, to skip training at the last position
if len(ixes) < self.block_size + 1:
assert len(ixes) == self.block_size # i believe this is the only way this could happen, make sure
ixes.append(-100)
dix = torch.tensor(ixes, dtype=torch.long)
return dix[:-1], dix[1:]

# -----------------------------------------------------------------------------

parser = argparse.ArgumentParser()
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datamodules can do:

parser = argparse.ArgumentParser()

# enables whatever you have in your init in argparse :) 
parser = CharDataModule.add_argparse_args(parser)

# enable all the trainer flags in argparse
parser = Trainer.add_argparse_args(parser)

args = parser.parse_args()

# now you can init whatever objects automatically as well:
trainer = Trainer.from_argparse_args(args, any_flag_to_override=...)

dm = CharDataModule.from_argparse_args(args)

Which lets you do things like:

python main.py --gpus 2 --num_nodes 3 --batch_size 32

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat! I'll have to read more of the docs

parser.add_argument('-b', '--batch-size', type=int, default=64, help="batch size to train with")
parser.add_argument('-l', '--block-size', type=int, default=128, help="block size for the model (length of window of context)")
parser.add_argument('-n', '--num-workers', type=int, default=0, help="number of workers for dataloading")
parser.add_argument('-g', '--num-gpus', type=int, default=1, help="number of gpus to train on")
parser.add_argument('-p', '--pin-memory', type=int, default=1, help="pin memory on dataloaders?")
parser.add_argument('-r', '--precision', type=int, default=32, help="fp precision to use, e.g. 32/16")
parser.add_argument('-o', '--default_root_dir', type=str, default='.', help="best model checkpoint will be written at this location")
args = parser.parse_args()
print(vars(args))

logging.info("preparing the data loaders")
# NOTE: REDUCED DATA SIZE FOR DEBUGGING, TODO CLEAN BEFORE MERGE IF EVER
train_dataset = Text8Dataset('text8', args.block_size, crop=(0, int(1e6)))
val_dataset = Text8Dataset('text8', args.block_size, crop=(int(90e6), int(1e5)), override_vocab=train_dataset.vocab)
test_dataset = Text8Dataset('text8', args.block_size, crop=(int(95e6), int(1e5)), override_vocab=train_dataset.vocab)
common = {'batch_size': args.batch_size, 'pin_memory': bool(args.pin_memory), 'num_workers': args.num_workers}
train_dataloader = DataLoader(train_dataset, shuffle=True, **common)
val_dataloader = DataLoader(val_dataset, shuffle=False, **common)
test_dataloader = DataLoader(test_dataset, shuffle=False, **common)

logging.info("creating the model")
model = GPT(train_dataset.vocab_size, args.block_size, n_layer=6, n_head=8, n_embd=256)

logging.info("preparing the learning rate schedule")
iter_tokens = args.batch_size * args.block_size # number of tokens backpropped in one iteration
epoch_tokens = math.ceil(len(train_dataset) / args.batch_size) * iter_tokens
lr_decay = WarmupCosineLearningRateDecay(learning_rate=6e-4, warmup_tokens=epoch_tokens//2,
final_tokens=args.num_epochs*epoch_tokens)

t0 = time.time()
logging.info("training...")
trainer = pl.Trainer(gpus=args.num_gpus, max_epochs=args.num_epochs, gradient_clip_val=1.0, callbacks=[lr_decay],
precision=args.precision, default_root_dir=args.default_root_dir)
trainer.fit(model, train_dataloader, val_dataloader)
t1 = time.time()
logging.info("%d epochs took %fs, or %fs/epoch", args.num_epochs, t1 - t0, (t1-t0)/args.num_epochs)

# todo below: I don't yet understand the Lightning checkpoint schema
# logging.info("testing...")
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt')
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found
# trainer.test(test_dataloader=test_dataloader)
Copy link

@williamFalcon williamFalcon Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt')
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found
# trainer.test(test_dataloader=test_dataloader)
# Note: LIGHTNING automatically loads the best checkpoint when you call .test()
trainer.test(test_dataloader=test_dataloader)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. it looks like test_dataloader is not a kwarg, it's test_dataloaders with an 's'. Similar to val_dataloaders, but not the same as train_dataloader without the s, it looks like. Some of the docs are inconsistent on the use of "s" btw, I think.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. we enable multiple dataloaders for val and test. coming support for train.
not in research i’m used to, but turns out some people need two datasets to validate haha. go figure

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some people always need something, which is why frameworks are so hard. Next thing you know you can't use a list of data loaders and have to introduce a DataLoaderSetManager object.


logging.info("sampling:")
context = "anarchism originated as a term of"
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...]
if next(model.parameters()).is_cuda:
x = x.cuda()
y = sample(model, x, 200, temperature=1.0, sample=True, top_k=None)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)
163 changes: 163 additions & 0 deletions mingpt/fake_lightning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
A manual, minimal and non-full-featured implementation of boilerplate training loop.
Intentionally made to have the same API as PyTorch Lightning, giving two benefits:
1) Everyone can inspect/hack this simple implementation for educational purposes
2) Everyone can run the full Lightning implementation when they just want to go FAST
"""

import os
import math
import logging

from tqdm import tqdm
import torch
import torch.nn as nn

logger = logging.getLogger(__name__)

# -----------------------------------------------------------------------------

class Result:
""" very thin wrapper around a result of a train/val/test step of the model """
def __init__(self, minimize=None, checkpoint_on=None):
self.minimize = minimize
self.checkpoint_on = checkpoint_on

def log(self, key, val):
setattr(self, key, val)

class TrainResult(Result):
pass

class EvalResult(Result):
pass

class LightningModule(nn.Module):

def load_from_checkpoint(self, checkpoint_path):
logger.info("loading the best model checkpoint from %s", checkpoint_path)
state_dict = torch.load(checkpoint_path)
self.load_state_dict(state_dict)

class Callback:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry! this is 100% optional. This is a new addition and I see we forgot to include the simple case and doc examples using a dict or the loss directly

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, ok converted to use of dicts with latest commit


# -----------------------------------------------------------------------------
"""
Simple Trainer object; Boilerplate that could apply to any arbitrary neural network,
so nothing here really has anything to do with GPT specifically. This is a
very basic Trainer class that will only train the model on up to one GPU.
"""

class Trainer:

def __init__(self, max_epochs, gpus=0, gradient_clip_val=None, default_root_dir='.', callbacks=None,
precision=32, **kwargs):
self.gpus = gpus
self.max_epochs = max_epochs
self.gradient_clip_val = gradient_clip_val
self.callbacks = [] if callbacks is None else callbacks
self.model = None

if default_root_dir is not None:
os.makedirs(default_root_dir, exist_ok = True)
self.default_root_dir = default_root_dir

if self.gpus > 1:
logger.error("This simple Trainer does not support > 1 GPUs, will just use one.")

if precision != 32:
logger.error("This simple Trainer does not support non-fp32 precision, will use fp32")

def save_checkpoint(self):
ckpt_path = os.path.join(self.default_root_dir, 'model.pt')
logger.info("saving model checkpoint to %s", ckpt_path)
torch.save(self.model.state_dict(), ckpt_path)

def eval_split_(self, dataloader, split):

self.model.eval()
use_gpu = self.gpus > 0 and torch.cuda.is_available()
losses = []
for it, (x, y) in enumerate(dataloader):
# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda()
# forward the model
with torch.no_grad():
if split == 'val':
result = self.model.validation_step((x, y))
loss = result.val_loss
elif split == 'test':
result = self.model.test_step((x, y))
loss = result.test_loss
losses.append(loss.item())
mean_loss = torch.mean(torch.tensor(losses)).item()
logger.info("%s loss: %f", split, mean_loss)
return mean_loss

def test(self, test_dataloader):
return self.eval_split_(test_dataloader, 'test')

def val(self, val_dataloader):
return self.eval_split_(val_dataloader, 'val')

def fit(self, model, train_dataloader, val_dataloader=None):
self.model = model # bind model to the class here
self.model.train()

# ship model to gpu if possible
use_gpu = self.gpus > 0 and torch.cuda.is_available()
if use_gpu:
logger.info("found CUDA device, shipping model to GPU")
self.model.cuda()

# prepare the optimizer
optimizer = self.model.configure_optimizers()
self.optimizers = [optimizer]

# start the training loop
best_val_loss = float('inf')
for epoch in range(self.max_epochs):

# do an epoch of training
pbar = tqdm(enumerate(train_dataloader), total=len(train_dataloader))
for it, (x, y) in pbar:

# place data on the correct device
if use_gpu:
x, y = x.cuda(), y.cuda()

# forward the model
result = self.model.training_step((x, y))
loss = result.minimize

# reset gradient
for param in self.model.parameters():
param.grad = None # a faster alternative to model.zero_grad()

# backward pass
loss.backward()

# clip the gradient to mitigate loss explosions
if self.gradient_clip_val is not None:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.gradient_clip_val)

# update all parameters
optimizer.step() # todo: use fused optimizer

# notify all relevant callbacks that a batch update ended. e.g. a callback may decay learning rate
for cb in self.callbacks:
if hasattr(cb, 'on_train_batch_end'):
cb.on_train_batch_end(self, None, (x, y))

# report progress
lr = optimizer.param_groups[0]['lr']
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

# calculate the current validation loss and checkpoint the model for early stopping
if val_dataloader is not None:
val_loss = self.val(val_dataloader)
if (self.default_root_dir is not None) and (val_loss < best_val_loss):
best_val_loss = val_loss
self.save_checkpoint()
40 changes: 40 additions & 0 deletions mingpt/lr_decay.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import math

# -----------------------------------------------------------------------------
import os
if int(os.environ.get('USE_LIGHTNING', 0)):
import pytorch_lightning as pl
else:
import mingpt.fake_lightning as pl
# -----------------------------------------------------------------------------

class WarmupCosineLearningRateDecay(pl.Callback):
"""
based on the number of tokens seen during training will adjust the learning rate:
1. first it will start at zero and gradually ramp up to full learning rate
2. then it will decay down with the cosine learning rate decay down until 10% of original
"""

def __init__(self, learning_rate, warmup_tokens, final_tokens):
super().__init__()
self.learning_rate = learning_rate
self.warmup_tokens = warmup_tokens
self.final_tokens = final_tokens
# state in this class, will count number of tokens processed so far
self.tokens = 0

def on_train_batch_end(self, trainer, pl_module, batch, batch_idx=None, dataloader_idx=None):
_, y = batch
self.tokens += (y >= 0).sum() # y == -100 is "ignore", so don't count these
if self.tokens < self.warmup_tokens:
# linear warmup
lr_mult = float(self.tokens) / float(max(1, self.warmup_tokens))
else:
# followed by cosine learning rate decay
progress = float(self.tokens - self.warmup_tokens) / float(
max(1, self.final_tokens - self.warmup_tokens))
lr_mult = 0.1 + 0.5 * (1.0 + math.cos(math.pi * progress))
lr = self.learning_rate * lr_mult
for optimizer in trainer.optimizers:
for param_group in optimizer.param_groups:
param_group['lr'] = lr
Loading