Skip to content

Commit

Permalink
added ds checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Aug 26, 2023
1 parent 870f20c commit f9162a0
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 26 deletions.
69 changes: 46 additions & 23 deletions open_flamingo/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
train_one_epoch,
get_mp_policy_dtype,
save_checkpoint,
ds_save_checkpoint,
)
from transformers import (
get_constant_schedule_with_warmup,
Expand Down Expand Up @@ -80,7 +81,7 @@ def main():
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default",
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default. If using deepspeed this should be a directory, not a file.",
default=None,
)
parser.add_argument(
Expand Down Expand Up @@ -217,12 +218,6 @@ def main():
type=int,
help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)",
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from deepspeed distributed launcher",
)

# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
Expand All @@ -242,6 +237,8 @@ def main():
)

args = parser.parse_args()

args.local_rank = int(os.environ.get('LOCAL_RANK', -1)) # for deepspeed

# Validate args
if args.laion_shards.startswith("s3"):
Expand Down Expand Up @@ -319,6 +316,8 @@ def main():
elif "amp" in args.precision:
raise ValueError("amp not supported with DeepSpeed")

device_id = args.local_rank

else:
device_id = init_distributed_device(args)

Expand Down Expand Up @@ -351,19 +350,28 @@ def main():
if os.path.exists(f"{args.run_name}") and args.resume_from_checkpoint is None:
# if args do not specify a checkpoint to resume from, check if checkpoints exist for this run
# and automatically resume from the latest checkpoint
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
print(f"Found no checkpoints for run {args.run_name}.")
if args.deepspeed:
if os.path.exists(f"{args.run_name}/latest"):
args.resume_from_checkpoint = args.run_name
print(
f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}."
)
else:
print(f"Found no checkpoints for run {args.run_name}.")
else:
args.resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
print(
f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}."
)
checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt")
if len(checkpoint_list) == 0:
print(f"Found no checkpoints for run {args.run_name}.")
else:
args.resume_from_checkpoint = sorted(
checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0])
)[-1]
print(
f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}."
)

resume_from_epoch = 0
if args.resume_from_checkpoint is not None:
if args.resume_from_checkpoint is not None and not args.deepspeed:
if args.rank == 0:
print(f"Loading checkpoint from {args.resume_from_checkpoint}")
checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu")
Expand Down Expand Up @@ -468,7 +476,7 @@ def get_grouped_params(model):
)

# load optimizer checkpoint
if args.resume_from_checkpoint is not None:
if args.resume_from_checkpoint is not None and not args.deepspeed:
osd = checkpoint["optimizer_state_dict"]
if args.fsdp:
osd = FSDP.optim_state_dict_to_load(osd, ddp_model, optimizer)
Expand Down Expand Up @@ -503,7 +511,7 @@ def get_grouped_params(model):
)

# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None:
if args.resume_from_checkpoint is not None and not args.deepspeed:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

if args.deepspeed:
Expand All @@ -519,6 +527,20 @@ def get_grouped_params(model):
f"After deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
)

if args.resume_from_checkpoint is not None:
if args.rank == 0:
print(f"Loading checkpoint from {args.resume_from_checkpoint}")
# We will not pass in a 'tag' and instead rely on 'latest' file in the checkpoint directory
ddp_model.load_checkpoint(
load_dir=args.resume_from_checkpoint, # Note: this is the dir, not the file
load_module_strict=False,
)
# read latest file to get epoch
latest_file = os.path.join(args.resume_from_checkpoint, "latest")
with open(latest_file, "r") as f:
checkpoint_epoch = int(f.read().split("_")[-1])
resume_from_epoch = checkpoint_epoch + 1

# Initialize gradient checkpointing
if args.gradient_checkpointing:
if args.deepspeed:
Expand Down Expand Up @@ -553,13 +575,14 @@ def get_grouped_params(model):
lr_scheduler=lr_scheduler,
laion_loader=laion_loader,
mmc4_loader=mmc4_loader,
device_id=args.local_rank if args.deepspeed else device_id,
device_id=device_id,
wandb=wandb,
)
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)

# save final checkpoint
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
if args.deepspeed:
ds_save_checkpoint(ddp_model, epoch, args)
else:
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)


if __name__ == "__main__":
Expand Down
27 changes: 24 additions & 3 deletions open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
)
from torch.distributed.fsdp.api import FullOptimStateDictConfig
import os
import shutil
import wandb
from einops import rearrange

Expand Down Expand Up @@ -170,7 +171,7 @@ def train_one_epoch(
not args.fsdp or args.fsdp_use_orig_params
):
# Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
if args.fsdp:
if args.fsdp or args.deepspeed:
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
else:
embed_grad = (
Expand All @@ -181,7 +182,7 @@ def train_one_epoch(
zero_mask[endofchunk_token_id] = torch.ones_like(
zero_mask[endofchunk_token_id]
)
if args.fsdp:
if args.fsdp or args.deepspeed:
model.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
Expand Down Expand Up @@ -344,7 +345,6 @@ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
)
model_state = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer, group=args.my_group)

else:
model_state = model.state_dict()
optim_state = optimizer.state_dict()
Expand All @@ -371,3 +371,24 @@ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args):
if args.delete_previous_checkpoint:
if epoch > 0:
os.remove(f"{args.run_name}/checkpoint_{epoch-1}.pt")


def ds_save_checkpoint(model, epoch, args):
"""
Save training checkpoint for deepspeed.
"""
print(f"Saving checkpoint to {args.run_name}")
model.save_checkpoint(
save_dir=args.run_name,
save_latest=True,
tag=f"epoch_{epoch}",
exclude_frozen_parameters=True,
)

if args.rank == 0:
if args.report_to_wandb and args.save_checkpoints_to_wandb:
wandb.save(f"{args.run_name}/epoch_{epoch}/mp_rank_00_model_states.pt")

if args.delete_previous_checkpoint:
if epoch > 0: # remove checkpoint dir epoch_{epoch-1}
shutil.rmtree(f"{args.run_name}/epoch_{epoch-1}")

0 comments on commit f9162a0

Please sign in to comment.