From d719ed9c4d79a713907df63b9ca43b48031e12da Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 6 May 2022 16:34:08 -0700 Subject: [PATCH] Add grad checkpointing support --- src/open_clip/model.py | 22 +++++++++++++++++++++- src/training/main.py | 3 +++ src/training/params.py | 6 ++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index c6d3a03bf..0db4ab294 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -11,6 +11,7 @@ import torch import torch.nn.functional as F from torch import nn +from torch.utils.checkpoint import checkpoint from .timm_model import TimmModel from .utils import freeze_batch_norm_2d @@ -167,6 +168,11 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False): if freeze_bn_stats: freeze_batch_norm_2d(self) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + # FIXME support for non-transformer + pass + def stem(self, x): x = self.relu1(self.bn1(self.conv1(x))) x = self.relu2(self.bn2(self.conv2(x))) @@ -228,6 +234,8 @@ def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, super().__init__() self.width = width self.layers = layers + self.grad_checkpointing = False + self.resblocks = nn.ModuleList([ ResidualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer) for _ in range(layers) @@ -235,7 +243,10 @@ def __init__(self, width: int, layers: int, heads: int, mlp_ratio: float = 4.0, def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): for r in self.resblocks: - x = r(x, attn_mask=attn_mask) + if self.grad_checkpointing: + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) return x @@ -263,6 +274,10 @@ def lock(self, unlocked_groups=0, freeze_bn_stats=False): for param in self.parameters(): param.requires_grad = False + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.transformer.grad_checkpointing = enable + def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] @@ -411,6 +426,11 @@ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.visual.set_grad_checkpointing(enable) + self.transformer.grad_checkpointing = enable + def encode_image(self, image): return self.visual(image) diff --git a/src/training/main.py b/src/training/main.py index de41d8b63..ef470ad44 100644 --- a/src/training/main.py +++ b/src/training/main.py @@ -131,6 +131,9 @@ def main(): unlocked_groups=args.lock_image_unlocked_groups, freeze_bn_stats=args.lock_image_freeze_bn_stats) + if args.grad_checkpointing: + model.set_grad_checkpointing() + if is_master(args): logging.info("Model:") logging.info(f"{str(model)}") diff --git a/src/training/params.py b/src/training/params.py index 79b17936d..ef2b0990a 100644 --- a/src/training/params.py +++ b/src/training/params.py @@ -181,6 +181,12 @@ def parse_args(): action='store_true', help="Freeze BatchNorm running stats in image tower for any locked layers.", ) + parser.add_argument( + "--grad-checkpointing", + default=False, + action='store_true', + help="Enable gradient checkpointing.", + ) parser.add_argument( "--local-loss", default=False,