diff --git a/open_flamingo/scripts/run_train.sh b/open_flamingo/scripts/run_train.sh index 8d45355e..ffe5cfa3 100644 --- a/open_flamingo/scripts/run_train.sh +++ b/open_flamingo/scripts/run_train.sh @@ -1,7 +1,11 @@ #!/bin/bash #SBATCH --nodes 1 -#SBATCH --ntasks-per-node=8 +#SBATCH --ntasks-per-node=6 #SBATCH --gpus-per-task=1 +#SBATCH --account=efml +#SBATCH --partition=gpu +#SBATCH --time=48:00:00 +#SBATCH --job-name=flamingo export PYTHONFAULTHANDLER=1 export CUDA_LAUNCH_BLOCKING=0 @@ -9,24 +13,30 @@ export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) export MASTER_PORT=15000 export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` +export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface" export PYTHONPATH="$PYTHONPATH:open_flamingo" -srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \ - --lm_path anas-awadalla/mpt-1b-redpajama-200b \ - --tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \ - --cross_attn_every_n_layers 1 \ +srun --cpu_bind=v --accel-bind=gn python + + + +deepspeed open_flamingo/open_flamingo/train/train.py \ + --lm_path anas-awadalla/mpt-7b \ + --tokenizer_path anas-awadalla/mpt-7b \ + --cross_attn_every_n_layers 4 \ --dataset_resampled \ - --batch_size_mmc4 32 \ - --batch_size_laion 64 \ + --batch_size_mmc4 16 \ + --batch_size_laion 32 \ + --deepspeed \ --train_num_samples_mmc4 125000\ --train_num_samples_laion 250000 \ --loss_multiplier_laion 0.2 \ --workers=4 \ - --run_name OpenFlamingo-3B-vitl-mpt1b \ + --run_name "deepspeed" \ --num_epochs 480 \ - --warmup_steps 1875 \ - --mmc4_textsim_threshold 0.24 \ - --laion_shards "/path/to/shards/shard-{0000..0999}.tar" \ - --mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \ + --warmup_steps 0 \ + --mmc4_textsim_threshold 0.0 \ + --laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \ + --mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \ --gradient_checkpointing \ --report_to_wandb \ diff --git a/open_flamingo/train/data_utils.py b/open_flamingo/train/data_utils.py index 3c1e83bc..96418fae 100644 --- a/open_flamingo/train/data_utils.py +++ b/open_flamingo/train/data_utils.py @@ -96,7 +96,7 @@ def count_samples(dataloader): def log_and_continue(exn): """Call in an exception handler to ignore any exception, issue a warning, and continue.""" - logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + # logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") return True diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index 3a110901..9485ae16 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -35,6 +35,9 @@ CheckpointImpl, apply_activation_checkpointing, ) + +import deepspeed + from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups from torch.distributed.distributed_c10d import _get_default_group import functools @@ -183,6 +186,8 @@ def main(): action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", ) + + # fsdp args parser.add_argument( "--fsdp", default=False, @@ -199,6 +204,27 @@ def main(): "--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"] ) + # deepspeed args + parser.add_argument( + "--deepspeed", + default=False, + action="store_true", + help="Use deepspeed for distributed training.", + ) + parser.add_argument( + "--deepspeed_stage", + default=2, + type=int, + options=[1, 2, 3], + help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)", + ) + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="local rank passed from deepspeed distributed launcher", + ) + # wandb args parser.add_argument("--report_to_wandb", default=False, action="store_true") parser.add_argument( @@ -252,8 +278,50 @@ def main(): if args.offline: os.environ["WANDB_MODE"] = "offline" os.environ["TRANSFORMERS_OFFLINE"] = "1" + args.local_rank, args.rank, args.world_size = world_info_from_env() - device_id = init_distributed_device(args) + if args.deepspeed: + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + deepspeed.init_distributed() + + zero_opt_dict = { + "stage": args.deepspeed_stage, + "offload_param": {"device": "none"}, # TODO: Support CPU offload + "offload_optimizer": {"device": "none"}, + "stage3_param_persistence_threshold": 1e4, + "stage3_max_live_parameters": 3e7, + "stage3_prefetch_bucket_size": 3e7, + "memory_efficient_linear": False, + } + ds_config = { + "train_batch_size": (args.batch_size_mmc4 + args.batch_size_laion) + * args.world_size + * args.gradient_accumulation_steps, + "train_micro_batch_size_per_gpu": ( + args.batch_size_mmc4 + args.batch_size_laion + ) + * args.gradient_accumulation_steps, + "steps_per_print": 10, + "zero_optimization": zero_opt_dict, + "fp16": {"enabled": True, "loss_scale_window": 100}, + "gradient_clipping": 1.0, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "hybrid_engine": { # TODO: investigate + "enabled": False, + "max_out_tokens": 512, + "inference_tp_size": 1, + "release_inference_cache": False, + "pin_parameters": True, + "tp_gather_partition_size": 8, + }, + } + + else: + device_id = init_distributed_device(args) + random_seed(args.seed) # Initialize model @@ -361,27 +429,27 @@ def main(): f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" ) - else: + elif not args.deepspeed: model = model.to(device_id) ddp_model = DDP(model, device_ids=[device_id]) # Initialize gradient checkpointing - if args.gradient_checkpointing: - non_reentrant_wrapper = functools.partial( - checkpoint_wrapper, - offload_to_cpu=True, - checkpoint_impl=CheckpointImpl.NO_REENTRANT, - ) - apply_activation_checkpointing( - ddp_model, - checkpoint_wrapper_fn=non_reentrant_wrapper, - check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) - and not isinstance(m, FSDP) - and not isinstance(m, CheckpointWrapper), - ) + # if args.gradient_checkpointing: + # non_reentrant_wrapper = functools.partial( + # checkpoint_wrapper, + # offload_to_cpu=True, + # checkpoint_impl=CheckpointImpl.NO_REENTRANT, + # ) + # apply_activation_checkpointing( + # ddp_model, + # checkpoint_wrapper_fn=non_reentrant_wrapper, + # check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False) + # and not isinstance(m, FSDP) + # and not isinstance(m, CheckpointWrapper), + # ) # Initialize optimizer - params_to_optimize = ddp_model.named_parameters() + params_to_optimize = model.named_parameters() params_to_optimize = list( filter( lambda x: x[1].requires_grad @@ -453,8 +521,32 @@ def get_grouped_params(model): if args.resume_from_checkpoint is not None: lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) - # Start training! - ddp_model.train() + if args.deepspeed: + print( + f"Before deepspeed parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" + ) + print( + f"Before deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" + ) + ddp_model, optimizer, _, lr_scheduler = deepspeed.initialize( + model=model, + optimizer=optimizer, + args=args, + config=ds_config, + lr_scheduler=lr_scheduler, + dist_init_required=True, + ) + print( + f"After deepspeed parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}" + ) + print( + f"DeepSpeed Engine parameters: {sum(p.numel() for p in ddp_model.parameters())}" + ) + print( + f"After deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}" + ) + # Start training! + ddp_model.train() for epoch in range(resume_from_epoch, args.num_epochs): laion_dataset.set_epoch(epoch) @@ -471,7 +563,7 @@ def get_grouped_params(model): lr_scheduler=lr_scheduler, laion_loader=laion_loader, mmc4_loader=mmc4_loader, - device_id=device_id, + device_id=args.local_rank if args.deepspeed else device_id, wandb=wandb, ) save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args) diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index e508c4c9..fe0122f5 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -92,7 +92,7 @@ def train_one_epoch( global_step = num_steps + epoch * num_batches_per_epoch #### LAION FORWARD PASS #### - images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True) + images = batch_laion[0].to(device_id, dtype=torch.float16, non_blocking=True) images = rearrange(images, "(b t f) c h w -> b t f c h w", t=1, f=1) input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) attention_mask = batch_laion[1][1].to( @@ -116,38 +116,29 @@ def train_one_epoch( )[0] divided_loss_laion = loss_laion / args.gradient_accumulation_steps - (divided_loss_laion * args.loss_multiplier_laion).backward() + if args.deepspeed: + model.backward(divided_loss_laion * args.loss_multiplier_laion) + else: + (divided_loss_laion * args.loss_multiplier_laion).backward() #### MMC4 FORWARD PASS #### - images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True) + images = batch_mmc4[0].to(device_id, dtype=torch.float16, non_blocking=True) images = rearrange(images, "b (t f) c h w -> b t f c h w", f=1) - input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) - attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) + input_ids = ( + torch.stack([x[0] for x in batch_mmc4[1]]) + .squeeze(1) + .to(device_id, dtype=cast_dtype, non_blocking=True) + ) + attention_mask = ( + torch.stack([x[1] for x in batch_mmc4[1]]) + .squeeze(1) + .to(device_id, dtype=cast_dtype, non_blocking=True) + ) # set up labels; language model is expected to handle shifting labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 labels[labels == tokenizer.eos_token] = -100 - for i in range(labels.shape[0]): - # remove loss for any token before the first token - label_idx = 0 - while ( - label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id - ): - labels[i][label_idx] = -100 - label_idx += 1 - - # get index of all endofchunk tokens in the sequence - endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] - for endofchunk_idx in endofchunk_idxs: - token_idx = endofchunk_idx + 1 - while ( - token_idx < labels.shape[1] - and labels[i][token_idx] != media_token_id - ): - labels[i][token_idx] = -100 - token_idx += 1 - labels[labels == media_token_id] = -100 labels = labels.to(device_id) @@ -171,31 +162,35 @@ def train_one_epoch( continue divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps - (divided_loss_mmc4 * args.loss_multiplier_mmc4).backward() - - if (not args.freeze_lm_embeddings) and ( - not args.fsdp or args.fsdp_use_orig_params - ): - # Mask gradients for input embeddings s.t. we only update the added tokens and <|endofchunk|> - if args.fsdp: - embed_grad = model.lang_encoder.get_input_embeddings().weight.grad - else: - embed_grad = ( - model.module.lang_encoder.get_input_embeddings().weight.grad - ) - zero_mask = torch.zeros_like(embed_grad) - zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) - zero_mask[endofchunk_token_id] = torch.ones_like( - zero_mask[endofchunk_token_id] - ) - if args.fsdp: - model.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) - else: - model.module.lang_encoder.get_input_embeddings().weight.grad = ( - embed_grad * zero_mask - ) + if args.deepspeed: + model.backward(divided_loss_mmc4 * args.loss_multiplier_mmc4) + else: + (divided_loss_mmc4 * args.loss_multiplier_mmc4).backward() + + # TODO: investigate whether this is necessary + # if (not args.freeze_lm_embeddings) and ( + # not args.fsdp or args.fsdp_use_orig_params + # ): + # # Mask gradients for input embeddings s.t. we only update the added tokens and <|endofchunk|> + # if args.fsdp: + # embed_grad = model.lang_encoder.get_input_embeddings().weight.grad + # else: + # embed_grad = ( + # model.module.lang_encoder.get_input_embeddings().weight.grad + # ) + # zero_mask = torch.zeros_like(embed_grad) + # zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) + # zero_mask[endofchunk_token_id] = torch.ones_like( + # zero_mask[endofchunk_token_id] + # ) + # if args.fsdp: + # model.lang_encoder.get_input_embeddings().weight.grad = ( + # embed_grad * zero_mask + # ) + # else: + # model.module.lang_encoder.get_input_embeddings().weight.grad = ( + # embed_grad * zero_mask + # ) # clip gradient norm if args.fsdp: @@ -206,16 +201,19 @@ def train_one_epoch( At least for OPT-125M, this didn't seem to make a difference in performance. """ model.clip_grad_norm_(1.0) - else: + elif not args.deepspeed: # deepspeed handles clipping internally torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # step optimizer and log if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or ( num_steps == num_batches_per_epoch - 1 ): - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad(set_to_none=True) + if args.deepspeed: + model.step() + else: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) # step time and reset end outside of rank 0 step_time_m.update(time.time() - end)