Skip to content

Commit

Permalink
Merge pull request #85 from mlfoundations/grad_checkpointing
Browse files Browse the repository at this point in the history
Add gradient checkpointing support
  • Loading branch information
rwightman authored May 9, 2022
2 parents 964981a + d719ed9 commit d441b92
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint

from .timm_model import TimmModel
from .utils import freeze_batch_norm_2d
Expand Down Expand Up @@ -167,6 +168,11 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
if freeze_bn_stats:
freeze_batch_norm_2d(self)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
# FIXME support for non-transformer
pass

def stem(self, x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
Expand Down Expand Up @@ -228,14 +234,19 @@ def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0,
super().__init__()
self.width = width
self.layers = layers
self.grad_checkpointing = False

self.resblocks = nn.ModuleList([
ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer)
for _ in range(layers)
])

def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask)
if self.grad_checkpointing:
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x


Expand Down Expand Up @@ -263,6 +274,10 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
for param in self.parameters():
param.requires_grad = False

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.grad_checkpointing = enable

def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
Expand Down Expand Up @@ -411,6 +426,11 @@ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
self.transformer.grad_checkpointing = enable

def encode_image(self, image):
return self.visual(image)

Expand Down
3 changes: 3 additions & 0 deletions src/training/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def main():
unlocked_groups=args.lock_image_unlocked_groups,
freeze_bn_stats=args.lock_image_freeze_bn_stats)

if args.grad_checkpointing:
model.set_grad_checkpointing()

if is_master(args):
logging.info("Model:")
logging.info(f"{str(model)}")
Expand Down
6 changes: 6 additions & 0 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,12 @@ def parse_args():
action='store_true',
help="Freeze BatchNorm running stats in image tower for any locked layers.",
)
parser.add_argument(
"--grad-checkpointing",
default=False,
action='store_true',
help="Enable gradient checkpointing.",
)
parser.add_argument(
"--local-loss",
default=False,
Expand Down

0 comments on commit d441b92

Please sign in to comment.