Skip to content

Commit

Permalink
deepspeed running
Browse files Browse the repository at this point in the history
  • Loading branch information
anas-awadalla committed Aug 25, 2023
1 parent b6cd898 commit e19133d
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 86 deletions.
34 changes: 22 additions & 12 deletions open_flamingo/scripts/run_train.sh
Original file line number Diff line number Diff line change
@@ -1,32 +1,42 @@
#!/bin/bash
#SBATCH --nodes 1
#SBATCH --ntasks-per-node=8
#SBATCH --ntasks-per-node=6
#SBATCH --gpus-per-task=1
#SBATCH --account=efml
#SBATCH --partition=gpu
#SBATCH --time=48:00:00
#SBATCH --job-name=flamingo

export PYTHONFAULTHANDLER=1
export CUDA_LAUNCH_BLOCKING=0
export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"`
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_PORT=15000
export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l`
export HF_DATASETS_CACHE="/gscratch/efml/anasa2/.huggingface" TRANSFORMERS_CACHE="/gscratch/efml/anasa2/.huggingface"

export PYTHONPATH="$PYTHONPATH:open_flamingo"
srun --cpu_bind=v --accel-bind=gn python open_flamingo/open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-1b-redpajama-200b \
--tokenizer_path anas-awadalla/mpt-1b-redpajama-200b \
--cross_attn_every_n_layers 1 \
srun --cpu_bind=v --accel-bind=gn python



deepspeed open_flamingo/open_flamingo/train/train.py \
--lm_path anas-awadalla/mpt-7b \
--tokenizer_path anas-awadalla/mpt-7b \
--cross_attn_every_n_layers 4 \
--dataset_resampled \
--batch_size_mmc4 32 \
--batch_size_laion 64 \
--batch_size_mmc4 16 \
--batch_size_laion 32 \
--deepspeed \
--train_num_samples_mmc4 125000\
--train_num_samples_laion 250000 \
--loss_multiplier_laion 0.2 \
--workers=4 \
--run_name OpenFlamingo-3B-vitl-mpt1b \
--run_name "deepspeed" \
--num_epochs 480 \
--warmup_steps 1875 \
--mmc4_textsim_threshold 0.24 \
--laion_shards "/path/to/shards/shard-{0000..0999}.tar" \
--mmc4_shards "/path/to/shards/shard-{0000..0999}.tar" \
--warmup_steps 0 \
--mmc4_textsim_threshold 0.0 \
--laion_shards "/mmfs1/gscratch/efml/anasa2/laion-samples/{000000..000001}.tar" \
--mmc4_shards "/mmfs1/gscratch/efml/anasa2/mmc4-samples/shard_{0..1}-000000000.tar" \
--gradient_checkpointing \
--report_to_wandb \
2 changes: 1 addition & 1 deletion open_flamingo/train/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def count_samples(dataloader):

def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
# logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
return True


Expand Down
130 changes: 111 additions & 19 deletions open_flamingo/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
CheckpointImpl,
apply_activation_checkpointing,
)

import deepspeed

from torch.distributed.fsdp._init_utils import _init_intra_and_inter_node_groups
from torch.distributed.distributed_c10d import _get_default_group
import functools
Expand Down Expand Up @@ -183,6 +186,8 @@ def main():
action="store_true",
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
)

# fsdp args
parser.add_argument(
"--fsdp",
default=False,
Expand All @@ -199,6 +204,27 @@ def main():
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"]
)

# deepspeed args
parser.add_argument(
"--deepspeed",
default=False,
action="store_true",
help="Use deepspeed for distributed training.",
)
parser.add_argument(
"--deepspeed_stage",
default=2,
type=int,
options=[1, 2, 3],
help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)",
)
parser.add_argument(
"--local_rank",
type=int,
default=-1,
help="local rank passed from deepspeed distributed launcher",
)

# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
Expand Down Expand Up @@ -252,8 +278,50 @@ def main():
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"

args.local_rank, args.rank, args.world_size = world_info_from_env()
device_id = init_distributed_device(args)
if args.deepspeed:
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
deepspeed.init_distributed()

zero_opt_dict = {
"stage": args.deepspeed_stage,
"offload_param": {"device": "none"}, # TODO: Support CPU offload
"offload_optimizer": {"device": "none"},
"stage3_param_persistence_threshold": 1e4,
"stage3_max_live_parameters": 3e7,
"stage3_prefetch_bucket_size": 3e7,
"memory_efficient_linear": False,
}
ds_config = {
"train_batch_size": (args.batch_size_mmc4 + args.batch_size_laion)
* args.world_size
* args.gradient_accumulation_steps,
"train_micro_batch_size_per_gpu": (
args.batch_size_mmc4 + args.batch_size_laion
)
* args.gradient_accumulation_steps,
"steps_per_print": 10,
"zero_optimization": zero_opt_dict,
"fp16": {"enabled": True, "loss_scale_window": 100},
"gradient_clipping": 1.0,
"prescale_gradients": False,
"wall_clock_breakdown": False,
"hybrid_engine": { # TODO: investigate
"enabled": False,
"max_out_tokens": 512,
"inference_tp_size": 1,
"release_inference_cache": False,
"pin_parameters": True,
"tp_gather_partition_size": 8,
},
}

else:
device_id = init_distributed_device(args)

random_seed(args.seed)

# Initialize model
Expand Down Expand Up @@ -361,27 +429,27 @@ def main():
f"After FSDP {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
)

else:
elif not args.deepspeed:
model = model.to(device_id)
ddp_model = DDP(model, device_ids=[device_id])

# Initialize gradient checkpointing
if args.gradient_checkpointing:
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=True,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
apply_activation_checkpointing(
ddp_model,
checkpoint_wrapper_fn=non_reentrant_wrapper,
check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
and not isinstance(m, FSDP)
and not isinstance(m, CheckpointWrapper),
)
# if args.gradient_checkpointing:
# non_reentrant_wrapper = functools.partial(
# checkpoint_wrapper,
# offload_to_cpu=True,
# checkpoint_impl=CheckpointImpl.NO_REENTRANT,
# )
# apply_activation_checkpointing(
# ddp_model,
# checkpoint_wrapper_fn=non_reentrant_wrapper,
# check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
# and not isinstance(m, FSDP)
# and not isinstance(m, CheckpointWrapper),
# )

# Initialize optimizer
params_to_optimize = ddp_model.named_parameters()
params_to_optimize = model.named_parameters()
params_to_optimize = list(
filter(
lambda x: x[1].requires_grad
Expand Down Expand Up @@ -453,8 +521,32 @@ def get_grouped_params(model):
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

# Start training!
ddp_model.train()
if args.deepspeed:
print(
f"Before deepspeed parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
)
print(
f"Before deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
)
ddp_model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
print(
f"After deepspeed parameter num: {sum(p.numel() for p in model.parameters())} on rank {args.rank}"
)
print(
f"DeepSpeed Engine parameters: {sum(p.numel() for p in ddp_model.parameters())}"
)
print(
f"After deepspeed {torch.cuda.memory_allocated()/1024**3:.3} GB on rank {args.rank}"
)
# Start training!
ddp_model.train()

for epoch in range(resume_from_epoch, args.num_epochs):
laion_dataset.set_epoch(epoch)
Expand All @@ -471,7 +563,7 @@ def get_grouped_params(model):
lr_scheduler=lr_scheduler,
laion_loader=laion_loader,
mmc4_loader=mmc4_loader,
device_id=device_id,
device_id=args.local_rank if args.deepspeed else device_id,
wandb=wandb,
)
save_checkpoint(ddp_model, optimizer, lr_scheduler, epoch, args)
Expand Down
106 changes: 52 additions & 54 deletions open_flamingo/train/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def train_one_epoch(
global_step = num_steps + epoch * num_batches_per_epoch

#### LAION FORWARD PASS ####
images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True)
images = batch_laion[0].to(device_id, dtype=torch.float16, non_blocking=True)
images = rearrange(images, "(b t f) c h w -> b t f c h w", t=1, f=1)
input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True)
attention_mask = batch_laion[1][1].to(
Expand All @@ -116,38 +116,29 @@ def train_one_epoch(
)[0]

divided_loss_laion = loss_laion / args.gradient_accumulation_steps
(divided_loss_laion * args.loss_multiplier_laion).backward()
if args.deepspeed:
model.backward(divided_loss_laion * args.loss_multiplier_laion)
else:
(divided_loss_laion * args.loss_multiplier_laion).backward()

#### MMC4 FORWARD PASS ####
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True)
images = batch_mmc4[0].to(device_id, dtype=torch.float16, non_blocking=True)
images = rearrange(images, "b (t f) c h w -> b t f c h w", f=1)
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
input_ids = (
torch.stack([x[0] for x in batch_mmc4[1]])
.squeeze(1)
.to(device_id, dtype=cast_dtype, non_blocking=True)
)
attention_mask = (
torch.stack([x[1] for x in batch_mmc4[1]])
.squeeze(1)
.to(device_id, dtype=cast_dtype, non_blocking=True)
)

# set up labels; language model is expected to handle shifting
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[labels == tokenizer.eos_token] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
while (
label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id
):
labels[i][label_idx] = -100
label_idx += 1

# get index of all endofchunk tokens in the sequence
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
for endofchunk_idx in endofchunk_idxs:
token_idx = endofchunk_idx + 1
while (
token_idx < labels.shape[1]
and labels[i][token_idx] != media_token_id
):
labels[i][token_idx] = -100
token_idx += 1

labels[labels == media_token_id] = -100
labels = labels.to(device_id)

Expand All @@ -171,31 +162,35 @@ def train_one_epoch(
continue

divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
(divided_loss_mmc4 * args.loss_multiplier_mmc4).backward()

if (not args.freeze_lm_embeddings) and (
not args.fsdp or args.fsdp_use_orig_params
):
# Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
if args.fsdp:
embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
else:
embed_grad = (
model.module.lang_encoder.get_input_embeddings().weight.grad
)
zero_mask = torch.zeros_like(embed_grad)
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
zero_mask[endofchunk_token_id] = torch.ones_like(
zero_mask[endofchunk_token_id]
)
if args.fsdp:
model.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
else:
model.module.lang_encoder.get_input_embeddings().weight.grad = (
embed_grad * zero_mask
)
if args.deepspeed:
model.backward(divided_loss_mmc4 * args.loss_multiplier_mmc4)
else:
(divided_loss_mmc4 * args.loss_multiplier_mmc4).backward()

# TODO: investigate whether this is necessary
# if (not args.freeze_lm_embeddings) and (
# not args.fsdp or args.fsdp_use_orig_params
# ):
# # Mask gradients for input embeddings s.t. we only update the added tokens <image> and <|endofchunk|>
# if args.fsdp:
# embed_grad = model.lang_encoder.get_input_embeddings().weight.grad
# else:
# embed_grad = (
# model.module.lang_encoder.get_input_embeddings().weight.grad
# )
# zero_mask = torch.zeros_like(embed_grad)
# zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
# zero_mask[endofchunk_token_id] = torch.ones_like(
# zero_mask[endofchunk_token_id]
# )
# if args.fsdp:
# model.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )
# else:
# model.module.lang_encoder.get_input_embeddings().weight.grad = (
# embed_grad * zero_mask
# )

# clip gradient norm
if args.fsdp:
Expand All @@ -206,16 +201,19 @@ def train_one_epoch(
At least for OPT-125M, this didn't seem to make a difference in performance.
"""
model.clip_grad_norm_(1.0)
else:
elif not args.deepspeed: # deepspeed handles clipping internally
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

# step optimizer and log
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (
num_steps == num_batches_per_epoch - 1
):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if args.deepspeed:
model.step()
else:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
Expand Down

0 comments on commit e19133d

Please sign in to comment.