Skip to content

Commit

Permalink
Merge pull request #12 from gooofy/master
Browse files Browse the repository at this point in the history
implement gradient checkpointing
  • Loading branch information
lopuhin authored Jul 30, 2019
2 parents 0638ca5 + 27ea69e commit fa3f529
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 2 deletions.
4 changes: 4 additions & 0 deletions lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def main(
lr=2.5e-4,
batch_size=2, # per GPU
g_accum_gradients=None, # accumulate gradients N times (globally)
gradient_checkpointing=False, # saves GPU memory
n_ctx=1024,
n_embed=768,
n_head=12,
Expand Down Expand Up @@ -86,6 +87,7 @@ def main(
n_hidden=n_hidden or n_embed,
n_head=n_head,
n_layer=n_layer,
gradient_checkpointing=gradient_checkpointing,
)
params = dict(
hparams=attr.asdict(hparams),
Expand Down Expand Up @@ -225,6 +227,8 @@ def get_valid_loss():
with torch.no_grad():
for ctx in _valid_batch_iter(
valid_dataset, batch_size=batch_size, n_ctx=n_ctx):
if not ctx:
continue
ctx = torch.LongTensor(ctx).to(device)
logits = model(ctx)['logits']
loss = loss_fn(logits, ctx)
Expand Down
8 changes: 6 additions & 2 deletions lm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from torch import nn
from torch.nn import functional as F

import torch.utils.checkpoint

@attr.s(auto_attribs=True, frozen=True)
class HParams:
Expand All @@ -17,6 +17,7 @@ class HParams:
n_hidden: int
n_head: int
n_layer: int
gradient_checkpointing: bool


class Model(nn.Module):
Expand Down Expand Up @@ -48,7 +49,10 @@ def forward(self, x, past=None):
# Transformer
presents = []
for i, block in enumerate(self.blocks):
h, present = block(h, past=past[:, i] if past is not None else None)
if self.hparams.gradient_checkpointing:
h, present = torch.utils.checkpoint.checkpoint(block, h, past[:, i] if past is not None else None)
else:
h, present = block(h, past=past[:, i] if past is not None else None)
presents.append(present)
h = self.ln_f(h)
if self.out_proj:
Expand Down
1 change: 1 addition & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def test_attention_mask():
n_hidden=32,
n_head=4,
n_layer=5,
gradient_checkpointing=False,
)


Expand Down

0 comments on commit fa3f529

Please sign in to comment.