Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lightning #35

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open

Feature/lightning #35

wants to merge 20 commits into from

Conversation

karpathy
Copy link
Owner

@karpathy karpathy commented Aug 29, 2020

I'm trying @williamFalcon , but I have somewhat mixed feelings about it. The API are now matched up and I can train the basic loop with either:

$ USE_LIGHTNING=0 python bench.py
5 epochs took 33.225314s, or 6.645063s/epoch

or

$ USE_LIGHTNING=1 python bench.py
5 epochs took 30.068728s, or 6.013746s/epoch

some overhead incurred, not that it matters too much at the stage of a single GPU.

To merge would still have to:

  • clean up a bit further and make even more transparent ideally, needs a bit more thought
  • delete bench.py
  • uprev all notebooks
  • uprev Readme file
  • include also support for validation/test splits, right now only training set stuff is included

@karpathy
Copy link
Owner Author

ok with the last commit I feel a bit better about things. mingpt/fake_lightning.py is now basically a minLightning :D, which could ideally be imported instead of Lightning and get a strict subset of just the very basic functionality. It's not functionally equivalent and a little bit hardcoded to minGPT purposes, but that's okay.

@williamFalcon
Copy link

williamFalcon commented Aug 29, 2020

nice! yeah, when you enable tensorboard and checkpoints the training slows down a bit.

we actually have tests to ensure we don’t incur a meaningful overhead over a vanilla pytorch script. lightning only adds 300ms per epoch (and no memory leaks, etc)

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/benchmarks/test_parity.py

@williamFalcon
Copy link

next you should convert the datasets to datamodules. this means you can 100% decouple the model from the data :)

@karpathy
Copy link
Owner Author

karpathy commented Aug 29, 2020

@williamFalcon ok done. few things I find a bit gross:

  • feels like I have to manually create a spurious train_dataset, because I need to know the size of vocab to calculate the learning rate decay, and to sample at the end etc.
  • see data_module.setup('fit') # fit... should be 'train'? ;\ . Based on this https://pytorch-lightning.readthedocs.io/en/latest/new-project.html is 'fit' the currently sanctioned name for the training stage? Should it maybe be 'train'?
  • any reason you don't pin_memory in your examples?
  • when I do have the validation set it seems like the encoding for val/test data will be a function of the vocabulary of the training set, so they have to be called serially and the val/test datasets have to be informed of some sufficient statistics from train?

@karpathy
Copy link
Owner Author

Alright calling it here for today, I'm tired and still have some actual work to do. I'm pretty sure I don't understand how pl.LightningDataModule API is supposed to be used properly between init, prepare, setup, and dataloader calls, what stage=None means, and how it is expected to be called from the Trainer, etc. To be continued...

@williamFalcon
Copy link

williamFalcon commented Aug 30, 2020

ah i see the confusion!

what is stage in setup?

.fit() calls setup (for train, val data)
and .test() also calls setup (for test only).

Our data calls are lazy, so we defer initializing them as long as we can to not cause unnecessary overhead.


0.9.0 -> 1.0 feedback

Got through a good chunk of the refactors (check out the evaluate loop).

Wrapping up the rest over the next few days to get the train loop to look pristine as well haha.

But, if there are any weird issues you see with the API, or any other paradigms we might consider, happy to make the changes this week to enable any use cases we may have missed or get rid of any parts that feel gross.


Do the below only if you care about multi-gpu...

Datamodule tutorial by one of our team members.

If you're training on 1 gpu, none of what i'm about to say matters. this only matters when making the code agnostic to n gpus
or n tpus.

prepare data

This is to do something only once (ie: in 100 GPU world, only on GPU 0. examples are: tokenize, download, etc...).

setup

This is another prep stage but it's called on every GPU. This means that splitting or anything like that can be done here.

train, val, test dataloaders

These are lazy called... which means that you don't have the overhead of creating the data until you really need it (this is key for performance applications).

data for init

there's a case where your model might depend on information about the data (ie: voca size, num_classes, etc). In this case, you can just hardcode this into the datamodule:

def __init__(...):
    self.num_classes = x
    self.vocab_size = y

or get the info in setup:

def setup(...):
    download()
    tokenize()
    self.vocab = count_vocab()

Normally, lightning calls prepare_data and setup for you automatically in training. However, depending on how you set it up (let's say you got the vocab size in setup), then you can manually call it after init.

dm = Datamodule()

# even if called manually, lightning makes sure it only happens on the correct devices
dm.prepare_data()
dm.setup()

model = LitModel(vocab_size=dm.vocab_size)
trainer = Trainer()
trainer.fit(model, dm)

I also don't love that you have to call prepare_data + setup yourself. Open to any ideas that you might have :)
This is a pattern we came up with in close collaboration with our partners at Facebook.

Datamodule examples

Simclr made agnostic of dataset. (it's pretty cool you can use the same code and train on any dataset without changing your original base code.

Imagenet datamodule

STL-10 datamodule

datamodule videos

and of course, the obligatory tutorial on datamodules by one of our team members.

real-world tutorials

Here's one of our new tutorials as well on implementing SimCLR which would be a more realistic complex example.
https://www.youtube.com/watch?v=p8QFB1CiAoQ

bench.py Outdated
def setup(self, stage): # called for every GPU/machine
if stage == 'train' or stage == 'fit':
pass # nothing to do, the train_dataset is initialized in the constructor
elif stage == 'val':

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage here is whether the setup is happening on fit or test...

BUT... to your point, do you think it might make more sense to change the stage to 'train', 'val', 'test'? the thing is that train, val are usually handled together (ie: split train into train/val and have a separate test set)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm pretty sure I totally misunderstood what a "stage" is and thought it referred to splits.

bench.py Outdated
# -----------------------------------------------------------------------------

parser = argparse.ArgumentParser()
parser.add_argument('-x', '--num-epochs', type=int, default=5, help="number of epochs to train for")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datamodules can do:

parser = argparse.ArgumentParser()

# enables whatever you have in your init in argparse :) 
parser = CharDataModule.add_argparse_args(parser)

# enable all the trainer flags in argparse
parser = Trainer.add_argparse_args(parser)

args = parser.parse_args()

# now you can init whatever objects automatically as well:
trainer = Trainer.from_argparse_args(args, any_flag_to_override=...)

dm = CharDataModule.from_argparse_args(args)

Which lets you do things like:

python main.py --gpus 2 --num_nodes 3 --batch_size 32

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Neat! I'll have to read more of the docs

bench.py Outdated
def train_dataloader(self):
loader = DataLoader(self.train_dataset, batch_size=self.batch_size,
shuffle=True, pin_memory=bool(self.pin_memory),
num_workers=self.num_workers)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw... with multiple GPUs num_workers > 0 in ddp_spawn mode is slow. This is a pytorch limitation because ddp_spawn generates subprocesses and in each subprocess there are more subprocesses generated by dataloaders.

that's why we recommend ddp as the backend for multi-gpu but unfortunately can't be called on a jupyter lab because those have limitations as well haha.

basically, until we re-invent jupyter notebooks we are a bit stuck haha...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, subtle point!

if self.gpus > 0 and torch.cuda.is_available():
logger.info("found CUDA device, shipping model to GPU")
device = 'cuda'
self.model = self.model.to(device)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you are enabling gpus=x i would just do model.cuda(x) so people can place models on a gpu indexed by the PCI_BUS_ID (you might need this flag enabled though

export CUDA_DEVICE_ORDER=PCI_BUS_ID

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but for teaching purposes it may be overkill

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, yes it's cleaner. I like the more general .to syntax because we'll have many different XPUs etc, feels a bit more future proof

@williamFalcon
Copy link

Alright calling it here for today, I'm tired and still have some actual work to do. I'm pretty sure I don't understand how pl.LightningDataModule API is supposed to be used properly between init, prepare, setup, and dataloader calls, what stage=None means, and how it is expected to be called from the Trainer, etc. To be continued...

yeah, this is an optional abstraction (you can use any dataloader with lightning). we also just introduced it, so need to make the docs more clear. but it really is optional, so it’s not a big deal to use dataloaders directly. it just makes the data more reusable

@karpathy
Copy link
Owner Author

Okay I merged one more big refactor. Honestly I am starting to think this branch was a very bad idea. I thought I could make things clean but there is a lot of baggage that Lightning "leaks" in a number of places, e.g. w.r.t. model checkpointing, the use of Training/Eval Result structures, forcing me into relatively odd looking abstractions and half-measures.

Anyway, thank you for your help @williamFalcon , I'll have to sleep on this a few days, read the Lightning docs more, and then maybe give it another shot some other time.

@williamFalcon
Copy link

williamFalcon commented Aug 30, 2020

ok, i understand the confusion!

doc updates

I updated the docs to show results as an optional extension! Also split the docs into optional vs required.
Doc updates here

Only required APIs are:

  1. lightningModule,
  2. trainer

Optional:

  • datamodules
  • results

no results

We added the results object recently. But forgot to show in docs that it is 100% optional.

def training_step(...)
    loss = ...
    
    # option 1
    return loss

    # if you also want to log
    return {'loss': loss, 'log': {'train_loss': loss}}

    # Option 2 (optional):
    # results just make it more flexible/clean and adds functionality
    result = TrainResult(loss)
    result.log('train_loss', loss, on_step=True, on_epoch=True)

checkpoints

Checkpoints store hyperparams, training state, etc... however if you just want the plain python checkpoint:

ckpt = torch.load(path)
model.load_state_dict(ckpt['state_dict'])

bench.py Outdated
Comment on lines 129 to 131
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt')
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found
# trainer.test(test_dataloader=test_dataloader)
Copy link

@williamFalcon williamFalcon Aug 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# ckpt_path = os.path.join(args.default_root_dir, 'model.pt')
# model.load_from_checkpoint(ckpt_path) # load the best checkpoint we found
# trainer.test(test_dataloader=test_dataloader)
# Note: LIGHTNING automatically loads the best checkpoint when you call .test()
trainer.test(test_dataloader=test_dataloader)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. it looks like test_dataloader is not a kwarg, it's test_dataloaders with an 's'. Similar to val_dataloaders, but not the same as train_dataloader without the s, it looks like. Some of the docs are inconsistent on the use of "s" btw, I think.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. we enable multiple dataloaders for val and test. coming support for train.
not in research i’m used to, but turns out some people need two datasets to validate haha. go figure

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some people always need something, which is why frameworks are so hard. Next thing you know you can't use a list of data loaders and have to introduce a DataLoaderSetManager object.

Comment on lines 20 to 43
class Result:
""" very thin wrapper around a result of a train/val/test step of the model """
def __init__(self, minimize=None, checkpoint_on=None):
self.minimize = minimize
self.checkpoint_on = checkpoint_on

def log(self, key, val):
setattr(self, key, val)

class TrainResult(Result):
pass

class EvalResult(Result):
pass

class LightningModule(nn.Module):

def load_from_checkpoint(self, checkpoint_path):
logger.info("loading the best model checkpoint from %s", checkpoint_path)
state_dict = torch.load(checkpoint_path)
self.load_state_dict(state_dict)

class Callback:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry! this is 100% optional. This is a new addition and I see we forgot to include the simple case and doc examples using a dict or the loss directly

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it, ok converted to use of dicts with latest commit

@karpathy
Copy link
Owner Author

Ok I think things have improved quite a bit. In particular, my "fake lightning" has now been reduced all the way to

class LightningModule(nn.Module):
    pass

class Callback:
    pass

which is fun :) And I can train with the fake trainer or the lightning trainer and the code looks decent ish.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants