Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add LAION 2B only training #145

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions open_flamingo/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def main():
)
parser.add_argument(
"--mmc4_shards",
default=None,
type=str,
help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar",
)
Expand Down Expand Up @@ -155,15 +156,17 @@ def main():
if args.laion_shards.startswith("s3"):
args.laion_shards = f"pipe:aws s3 cp {args.laion_shards} -"

if args.mmc4_shards.startswith("s3"):
args.mmc4_shards = f"pipe:aws s3 cp {args.mmc4_shards} -"
if args.mmc4_shards is not None:
if args.mmc4_shards.startswith("s3"):
args.mmc4_shards = f"pipe:aws s3 cp {args.mmc4_shards} -"

if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")

assert (args.train_num_samples_laion // args.batch_size_laion) == (
args.train_num_samples_mmc4 // args.batch_size_mmc4
), "number of samples per epoch must be equal for mmc4 and laion"
if args.mmc4_shards is not None:
assert (args.train_num_samples_laion // args.batch_size_laion) == (
args.train_num_samples_mmc4 // args.batch_size_mmc4
), "number of samples per epoch must be equal for mmc4 and laion"

if args.offline:
os.environ["WANDB_MODE"] = "offline"
Expand Down Expand Up @@ -203,7 +206,8 @@ def main():
ddp_model = DDP(model, device_ids=[device_id])

laion_dataset = get_data(args, image_processor, tokenizer, "image_text")
mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4")
if args.mmc4_shards is not None:
mmc4_dataset = get_data(args, image_processor, tokenizer, "mmc4")

def get_grouped_params(model):
params_with_wd, params_without_wd = [], []
Expand Down Expand Up @@ -231,9 +235,14 @@ def apply_decay(x):

optimizer = torch.optim.AdamW(get_grouped_params(ddp_model), lr=args.learning_rate)

total_training_steps = (
(args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size)
) * args.num_epochs
if args.mmc4_shards is not None:
total_training_steps = (
(args.train_num_samples_mmc4) // (args.batch_size_mmc4 * args.world_size)
) * args.num_epochs
else:
total_training_steps = (
(args.train_num_samples_laion) // (args.batch_size_laion * args.world_size)
) * args.num_epochs

if args.rank == 0:
print(f"Total training steps: {total_training_steps}")
Expand Down Expand Up @@ -283,8 +292,11 @@ def apply_decay(x):
for epoch in range(resume_from_epoch, args.num_epochs):
laion_dataset.set_epoch(epoch)
laion_loader = laion_dataset.dataloader
mmc4_dataset.set_epoch(epoch)
mmc4_loader = mmc4_dataset.dataloader
if args.mmc4_shards is not None:
mmc4_dataset.set_epoch(epoch)
mmc4_loader = mmc4_dataset.dataloader
else:
mmc4_loader = None

train_one_epoch(
args=args,
Expand Down
Loading