From 79ad152a725bdd485c34dbe32394d8d8fc3688af Mon Sep 17 00:00:00 2001 From: root Date: Sun, 17 Mar 2024 07:12:34 +0000 Subject: [PATCH] some more fixes --- open_flamingo/eval/eval_datasets.py | 8 +-- open_flamingo/eval/eval_models/blip.py | 21 ++++--- open_flamingo/eval/eval_models/idefics.py | 21 ++++--- .../eval/eval_models/open_flamingo.py | 57 +++++++++++-------- open_flamingo/eval/evaluate.py | 33 ++++++----- open_flamingo/eval/ok_vqa_utils.py | 1 + open_flamingo/eval/vqa_metric.py | 1 + open_flamingo/src/blip.py | 8 --- open_flamingo/src/factory.py | 11 +++- open_flamingo/src/helpers.py | 45 ++++----------- open_flamingo/src/llava.py | 7 +-- open_flamingo/src/vlm.py | 16 ++++-- open_flamingo/train/data.py | 15 ++++- open_flamingo/train/distributed.py | 19 ++++++- open_flamingo/train/losses.py | 5 +- open_flamingo/train/train.py | 18 +++--- open_flamingo/train/train_utils.py | 26 +++------ 17 files changed, 159 insertions(+), 153 deletions(-) diff --git a/open_flamingo/eval/eval_datasets.py b/open_flamingo/eval/eval_datasets.py index eba08263..df50af6a 100644 --- a/open_flamingo/eval/eval_datasets.py +++ b/open_flamingo/eval/eval_datasets.py @@ -9,9 +9,9 @@ SUPPORTED_TASKS = [ "coco", - "flickr", + "flickr30", "vqav2", - "ok_vqa", + "okvqa", "vizwiz", "textvqa", "hateful_memes", @@ -87,7 +87,7 @@ def __init__( self.image_dir_path = image_dir_path self.is_train = is_train self.dataset_name = dataset_name - if self.dataset_name in {"vqav2", "ok_vqa"}: + if self.dataset_name in {"vqav2", "okvqa"}: self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1] assert self.img_coco_split in {"train2014", "val2014", "test2015"} @@ -95,7 +95,7 @@ def __len__(self): return len(self.questions) def get_img_path(self, question): - if self.dataset_name in {"vqav2", "ok_vqa"}: + if self.dataset_name in {"vqav2", "okvqa"}: return os.path.join( self.image_dir_path, f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg" diff --git a/open_flamingo/eval/eval_models/blip.py b/open_flamingo/eval/eval_models/blip.py index 81179794..2fefa994 100644 --- a/open_flamingo/eval/eval_models/blip.py +++ b/open_flamingo/eval/eval_models/blip.py @@ -12,16 +12,15 @@ class EvalModel(BaseEvalModel): """BLIP-2 model evaluation.""" - def __init__(self, model_args, init_on_device=False): - super().__init__(model_args, init_on_device) - with self.init_ctx: - self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) - self.model = Blip2ForConditionalGeneration.from_pretrained( - model_args["lm_path"] - ) - self.tokenizer = self.processor.tokenizer + def __init__(self, model_args): + super().__init__(model_args) + self.processor = Blip2Processor.from_pretrained(model_args["processor_path"]) + self.model = Blip2ForConditionalGeneration.from_pretrained( + model_args["lm_path"] + ) + self.tokenizer = self.processor.tokenizer + self._check_init() - @property def required_args(self): return ["processor_path", "lm_path"] @@ -100,7 +99,7 @@ def get_outputs( def get_vqav2_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}" - def get_ok_vqa_prompt(self, question, answer=None) -> str: + def get_okvqa_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}" def get_vizwiz_prompt(self, question, answer=None) -> str: @@ -112,5 +111,5 @@ def get_textvqa_prompt(self, question, answer=None) -> str: def get_coco_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" - def get_flickr_prompt(self, caption=None) -> str: + def get_flickr30_prompt(self, caption=None) -> str: return f"A photo of {caption if caption is not None else ''}" diff --git a/open_flamingo/eval/eval_models/idefics.py b/open_flamingo/eval/eval_models/idefics.py index f664b817..b9d2f584 100644 --- a/open_flamingo/eval/eval_models/idefics.py +++ b/open_flamingo/eval/eval_models/idefics.py @@ -17,16 +17,15 @@ class EvalModel(BaseEvalModel): """IDEFICS model evaluation.""" - def __init__(self, model_args, init_on_device=False): - super().__init__(model_args, init_on_device) - with self.init_ctx: - self.model = IdeficsForVisionText2Text.from_pretrained( - model_args["lm_path"] - ) - self.processor = AutoProcessor.from_pretrained(model_args["processor_path"]) - self.tokenizer = self.processor.tokenizer + def __init__(self, model_args): + super().__init__(model_args) + self.model = IdeficsForVisionText2Text.from_pretrained( + model_args["lm_path"] + ) + self.processor = AutoProcessor.from_pretrained(model_args["processor_path"]) + self.tokenizer = self.processor.tokenizer + self._check_init() - @property def required_args(self): return ["lm_path", "processor_path"] @@ -171,7 +170,7 @@ def get_vqav2_prompt(self, question, answer=None) -> str: # TODO: handle prefix prompts return f"Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" - def get_ok_vqa_prompt(self, question, answer=None) -> str: + def get_okvqa_prompt(self, question, answer=None) -> str: # TODO: handle prefix prompts return f"Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" @@ -187,6 +186,6 @@ def get_coco_prompt(self, caption=None) -> str: # TODO: handle prefix prompts return f"Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" - def get_flickr_prompt(self, caption=None) -> str: + def get_flickr30_prompt(self, caption=None) -> str: # TODO: handle prefix prompts return f"Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" diff --git a/open_flamingo/eval/eval_models/open_flamingo.py b/open_flamingo/eval/eval_models/open_flamingo.py index 247fc9da..0a25198c 100644 --- a/open_flamingo/eval/eval_models/open_flamingo.py +++ b/open_flamingo/eval/eval_models/open_flamingo.py @@ -13,24 +13,32 @@ class EvalModel(BaseEvalModel): """OpenFlamingo model evaluation.""" - def __init__(self, model_args, init_on_device=False): - super().__init__(model_args, init_on_device) + def __init__(self, model_args): + super().__init__(model_args) + + if model_args["model_family"] == "openflamingo": + assert "cross_attn_every_n_layers" in model_args, "cross_attn_every_n_layers is required for Flamingo models" + else: + assert "cross_attn_every_n_layers" not in model_args, "cross_attn_every_n_layers is only for Flamingo models" + # initialize the model - with self.init_ctx: - ( - self.model, - self.image_processor, - self.tokenizer, - ) = create_model_and_transforms( - clip_vision_encoder_path=model_args["vision_encoder_path"], - clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"], - lang_model_path=model_args["lm_path"], - tokenizer_path=model_args["tokenizer_path"], - model_family=model_args["model_family"], - cross_attn_every_n_layers=int( - model_args.get("cross_attn_every_n_layers", 1) - ), - ) + additional_kwargs = ( + {"cross_attn_every_n_layers": model_args.get("cross_attn_every_n_layers", 1)} + if model_args["model_family"] == "flamingo" + else {} + ) + ( + self.model, + self.image_processor, + self.tokenizer, + ) = create_model_and_transforms( + clip_vision_encoder_path=model_args["vision_encoder_path"], + clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"], + lang_model_path=model_args["lm_path"], + tokenizer_path=model_args["tokenizer_path"], + model_family=model_args["model_family"], + **additional_kwargs, + ) # load the checkpoint checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu") @@ -38,6 +46,7 @@ def __init__(self, model_args, init_on_device=False): checkpoint = checkpoint["model_state_dict"] checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} self.model.load_state_dict(checkpoint, strict=False) + self.model_family = model_args["model_family"] self._check_init() @@ -46,11 +55,10 @@ def required_args(self): """Return list of required arguments to initialize model.""" return [ "vision_encoder_path", - "model_familyl", + "model_family", "lm_path", "checkpoint_path", "tokenizer_path", - "cross_attn_every_n_layers", "vision_encoder_pretrained", ] @@ -170,8 +178,9 @@ def get_outputs( **decode_kwargs, ) - # Extract only the new generated tokens - outputs = outputs[:, len(input_ids[0]) :] + if self.model_family == "flamingo": + # Extract only the new generated tokens + outputs = outputs[:, len(input_ids[0]) :] return self.tokenizer.batch_decode(outputs, skip_special_tokens=True) def get_rank_classifications( @@ -270,8 +279,8 @@ def get_rank_classifications( def get_vqav2_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" - def get_ok_vqa_prompt(self, question, answer=None) -> str: - return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" + def get_okvqa_prompt(self, question, answer=None) -> str: + return f"Instruct: {question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" def get_vizwiz_prompt(self, question, answer=None) -> str: return f"Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}" @@ -282,7 +291,7 @@ def get_textvqa_prompt(self, question, answer=None) -> str: def get_coco_prompt(self, caption=None) -> str: return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" - def get_flickr_prompt(self, caption=None) -> str: + def get_flickr30_prompt(self, caption=None) -> str: return f"Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}" def get_imagenet_prompt(self, label=None) -> str: diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index 26dcfa94..4a25fca0 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -122,7 +122,7 @@ help="Whether to evaluate on VQAV2.", ) parser.add_argument( - "--eval_ok_vqa", + "--eval_okvqa", action="store_true", default=False, help="Whether to evaluate on OK-VQA.", @@ -408,10 +408,8 @@ def main(): model_args["device"] = device_id # initialize model - eval_model = get_eval_model(args.model, model_args, init_on_device=False) - eval_model.init_distributed( - local_rank=args.local_rank, - ) + eval_model = get_eval_model(args.model, model_args) + eval_model.init_distributed() # Validate args if args.model in ZERO_SHOT_ONLY_MODELS and args.shots != [0]: @@ -504,7 +502,7 @@ def main(): } ) - if args.eval_ok_vqa: + if args.eval_okvqa: print("Evaluating on OK-VQA...") # load cached demonstration features for RICES @@ -523,7 +521,7 @@ def main(): eval_model=eval_model, num_shots=shot, seed=seed, - dataset_name="ok_vqa", + dataset_name="okvqa", cached_features=cached_features, ) if args.rank == 0: @@ -919,7 +917,7 @@ def evaluate_vqa( seed: int = 42, min_new_tokens: int = 0, max_new_tokens: int = 5, - num_beams: int = 3, + num_beams: int = 5, length_penalty: float = 0.0, num_shots: int = 8, dataset_name: str = "vqav2", @@ -936,13 +934,13 @@ def evaluate_vqa( num_beams (int, optional): number of beams to use for beam search. Defaults to 3. length_penalty (float, optional): length penalty for beam search. Defaults to -2.0. num_shots (int, optional): number of shots to use. Defaults to 8. - dataset_name (string): type of vqa dataset: currently supports vqav2, ok_vqa. Defaults to vqav2. + dataset_name (string): type of vqa dataset: currently supports vqav2, okvqa. Defaults to vqav2. cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None. Returns: float: accuracy score """ - if dataset_name == "ok_vqa": + if dataset_name == "okvqa": train_image_dir_path = args.ok_vqa_train_image_dir_path train_questions_json_path = args.ok_vqa_train_questions_json_path train_annotations_json_path = args.ok_vqa_train_annotations_json_path @@ -989,7 +987,7 @@ def evaluate_vqa( dataset_name=dataset_name, ) - effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) + effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model) np.random.seed(seed) test_dataloader = utils.prepare_eval_samples( @@ -1012,6 +1010,11 @@ def evaluate_vqa( utils.random_seed(seed, args.rank) predictions = [] + + get_vqa_prompt = getattr( + eval_model, f"get_{dataset_name}_prompt" + ) + for batch in tqdm( test_dataloader, desc=f"Running inference {dataset_name}", @@ -1034,7 +1037,7 @@ def evaluate_vqa( context_text = "".join( [ - eval_model.get_vqa_prompt( + get_vqa_prompt( question=x["question"], answer=x["answers"][0] ) + "\n" @@ -1047,9 +1050,9 @@ def evaluate_vqa( context_text = context_text.replace("", "") batch_text.append( - context_text + eval_model.get_vqa_prompt(question=batch["question"][i]) + context_text + get_vqa_prompt(question=batch["question"][i]) ) - + outputs = eval_model.get_outputs( batch_images=batch_images, batch_text=batch_text, @@ -1186,7 +1189,7 @@ def evaluate_classification( class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names)) - effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) + effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model) np.random.seed(seed) test_dataloader = utils.prepare_eval_samples( diff --git a/open_flamingo/eval/ok_vqa_utils.py b/open_flamingo/eval/ok_vqa_utils.py index 69c422e7..21ae0cec 100644 --- a/open_flamingo/eval/ok_vqa_utils.py +++ b/open_flamingo/eval/ok_vqa_utils.py @@ -210,6 +210,7 @@ def stem(self, input_string): def postprocess_ok_vqa_generation(predictions) -> str: prediction = re.split("Question|Answer|Short", predictions, 1)[0] + prediction = prediction.split(". ", 1)[0] prediction = re.split(", ", prediction, 1)[0] prediction_stem = stemmer.stem(prediction) return prediction_stem diff --git a/open_flamingo/eval/vqa_metric.py b/open_flamingo/eval/vqa_metric.py index 90568166..3659c556 100644 --- a/open_flamingo/eval/vqa_metric.py +++ b/open_flamingo/eval/vqa_metric.py @@ -556,5 +556,6 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p def postprocess_vqa_generation(predictions): answer = re.split("Question|Answer|Short", predictions, 1)[0] + answer = answer.split(". ", 1)[0] answer = re.split(", ", answer, 1)[0] return answer diff --git a/open_flamingo/src/blip.py b/open_flamingo/src/blip.py index a97ead24..b3806959 100644 --- a/open_flamingo/src/blip.py +++ b/open_flamingo/src/blip.py @@ -50,14 +50,6 @@ def set_trainable(self): """ self.requires_grad_(False) self.vision_tokenizer.requires_grad_(True) - self.lang_model.get_output_embeddings().set_requires_grad( - require_regular_grad=False, - require_additional_grad=True, - ) - self.lang_model.get_input_embeddings().set_requires_grad( - require_regular_grad=False, - require_additional_grad=True, - ) def _should_apply_weight_decay(self, parameter_name): """BLIP applies 0.05 weight decay to everything""" diff --git a/open_flamingo/src/factory.py b/open_flamingo/src/factory.py index 89a28096..637f05ed 100644 --- a/open_flamingo/src/factory.py +++ b/open_flamingo/src/factory.py @@ -58,11 +58,12 @@ def create_model_and_transforms( clip_vision_encoder_path, pretrained=clip_vision_encoder_pretrained, cache_dir=cache_dir, + force_image_size=490, ) vision_encoder.visual.output_tokens = True vision_encoder = vision_encoder.visual vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path) - if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format + if "SigLIP" in clip_vision_encoder_path or "EVA" in clip_vision_encoder_path: # SigLIP models have a different config format vis_hidden_dim = vision_encoder_config["embed_dim"] else: vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"] @@ -74,8 +75,9 @@ def create_model_and_transforms( trust_remote_code=True, cache_dir=cache_dir, ) - if text_tokenizer.pad_token is None: - text_tokenizer.pad_token_id = text_tokenizer.eos_token_id + if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token: + # add a pad token if it doesn't exist + text_tokenizer.add_special_tokens({"pad_token": ""}) # load langauge model lang_model = AutoModelForCausalLM.from_pretrained( @@ -150,6 +152,9 @@ def _infer_decoder_layers_attr_name(model): "gemma": "model.layers", "phi": "model.layers", "minicpm": "model.layers", + "stablelm": "model.layers", + "qwen": "model.layers", + "mistral": "model.layers" } diff --git a/open_flamingo/src/helpers.py b/open_flamingo/src/helpers.py index c4ab859e..8db95307 100644 --- a/open_flamingo/src/helpers.py +++ b/open_flamingo/src/helpers.py @@ -345,43 +345,22 @@ def __init__( dim_inner=768, num_hidden_layers=12, num_query_tokens=32, - pretrained_path=None, ): super().__init__(dim_media=dim_out, num_tokens_per_media=num_query_tokens) # initialize the qformer - from transformers import Blip2Model, Blip2QFormerModel, Blip2QFormerConfig - - if pretrained_path is None: - self.qformer = Blip2QFormerModel( - Blip2QFormerConfig( - encoder_hidden_size=dim_input, - hidden_size=dim_inner, - num_hidden_layers=num_hidden_layers, - ) - ) - self.query_tokens = nn.Parameter( - torch.zeros(1, num_query_tokens, dim_inner) - ) - self.proj = nn.Linear(dim_inner, dim_out) - else: - model = Blip2Model.from_pretrained( - pretrained_path, + from transformers import Blip2QFormerModel, Blip2QFormerConfig + + self.qformer = Blip2QFormerModel( + Blip2QFormerConfig( + encoder_hidden_size=dim_input, + hidden_size=dim_inner, + num_hidden_layers=num_hidden_layers, ) - self.qformer = model.qformer - self.query_tokens = model.query_tokens - self.proj = model.language_projection - assert ( - self.qformer.config.hidden_size == dim_inner - ), f"dim_inner={dim_inner} but pretrained model expects {self.qformer.config.hidden_size}" - assert ( - self.qformer.config.encoder_hidden_size == dim_input - ), f"dim_input={dim_input} but pretrained model expects {self.qformer.config.encoder_hidden_size}" - assert ( - self.qformer.config.num_hidden_layers == num_hidden_layers - ), f"num_hidden_layers={num_hidden_layers} but pretrained model expects {self.qformer.config.num_hidden_layers}" - assert ( - self.query_tokens.shape[1] == num_query_tokens - ), f"num_query_tokens={num_query_tokens} but pretrained model expects {self.query_tokens.shape[1]}" + ) + self.query_tokens = nn.Parameter( + torch.zeros(1, num_query_tokens, dim_inner) + ) + self.proj = nn.Linear(dim_inner, dim_out) def forward(self, x): """ diff --git a/open_flamingo/src/llava.py b/open_flamingo/src/llava.py index aa0b97d1..1d3fea2f 100644 --- a/open_flamingo/src/llava.py +++ b/open_flamingo/src/llava.py @@ -51,11 +51,10 @@ def __init__( def set_trainable(self): """ - Freeze everything except the Q-former and the inserted LM embeddings + Unfreeze everything except the vision_encoder """ - self.requires_grad_(False) - self.vision_tokenizer.requires_grad_(True) - self.lang_model.requires_grad_(True) + self.requires_grad_(True) + self.vision_encoder.requires_grad_(False) def _should_apply_weight_decay(self, parameter_name): return True diff --git a/open_flamingo/src/vlm.py b/open_flamingo/src/vlm.py index 22db0dc7..b43fb039 100644 --- a/open_flamingo/src/vlm.py +++ b/open_flamingo/src/vlm.py @@ -147,6 +147,7 @@ def forward( labels=labels, past_key_values=past_key_values, past_media_locations=past_media_locations, + padding_side="right", past_vision_tokens=past_vision_tokens, ) output = self.lang_model( @@ -273,7 +274,7 @@ def generate( past_key_values=past_key_values, past_media_locations=past_media_locations, past_vision_tokens=past_vision_tokens, - padding_side="right", + padding_side="left", num_beams=num_beams, ) output = self.lang_model.generate( @@ -303,10 +304,11 @@ def group_params_by_weight_decay(self): """ params_with_wd, params_without_wd = [], [] for n, p in self.named_parameters(): - if self._should_apply_weight_decay(n): - params_with_wd.append(p) - else: - params_without_wd.append(p) + if p.requires_grad: + if self._should_apply_weight_decay(n): + params_with_wd.append(p) + else: + params_without_wd.append(p) return params_with_wd, params_without_wd def _should_apply_weight_decay(self, parameter_name): @@ -475,7 +477,9 @@ def get_fsdp_lambda_fn(self): ) from .helpers import GatedCrossAttentionBlock - original_decoder_block_class = self.lang_model.decoder_block_class + decoder_block_class = getattr_recursive( + self.lang_model, self.decoder_layers_attr_name + )[0].__class__ def lambda_fn(module: nn.Module): # we want FSDP(ckpt(module)), not ckpt(FSDP(module)) diff --git a/open_flamingo/train/data.py b/open_flamingo/train/data.py index 6bb06028..828f75bb 100644 --- a/open_flamingo/train/data.py +++ b/open_flamingo/train/data.py @@ -24,6 +24,7 @@ from open_flamingo.train.data_utils import * SUPPORTED_DATASETS = ["laion", "mmc4"] +CAPTION_BAN_PATTERN = r'\b(?:This image showcases |This image depicts |This image appears to be |This image is |This image captures |The image showcases |The image depicts | The image appears to be )\b' Image.MAX_IMAGE_PIXELS = 1000000000 N_CHANNELS = 3 @@ -68,23 +69,31 @@ def preprocess_laion_image(sample, image_processor): return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1) -def preprocess_laion_text(sample, tokenizer, max_tokens=256): +def preprocess_laion_text(sample, tokenizer, max_tokens=128): """ Preprocess text for LAION. Applied to a batch of captions. Captions are truncated to 256 tokens by default. """ tokenizer.padding_side = "right" + + if any("" in s for s in sample): + raise ValueError("Image token found in text") + sample = [ # (f"{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample - (f"{s.strip()}{tokenizer.eos_token}") for s in sample + (f"{re.sub(CAPTION_BAN_PATTERN, '', s.split('<|synthetic caption|>')[-1].strip())}{tokenizer.eos_token}") for s in sample ] + + print(sample[0]) + text = tokenizer( sample, max_length=max_tokens, - padding="longest", + padding="max_length", truncation="only_first", return_tensors="pt", ) + return text["input_ids"], text["attention_mask"] diff --git a/open_flamingo/train/distributed.py b/open_flamingo/train/distributed.py index 4b558ed2..8c76c2cb 100644 --- a/open_flamingo/train/distributed.py +++ b/open_flamingo/train/distributed.py @@ -190,13 +190,26 @@ def get_fsdp_config( BackwardPrefetch, ) + if args.fsdp_sharding_strategy == "full": + sharding_strategy = ShardingStrategy.FULL_SHARD + elif args.fsdp_sharding_strategy == "hybrid": + sharding_strategy = ShardingStrategy.HYBRID_SHARD + elif args.fsdp_sharding_strategy == "shard_grad_op": + sharding_strategy = ShardingStrategy.SHARD_GRAD_OP + elif args.fsdp_sharding_strategy == "hybrid_shard_grad_op": + sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 + elif args.fsdp_sharding_strategy == "no_shard": + sharding_strategy = ShardingStrategy.NO_SHARD + else: + raise ValueError( + f"Invalid sharding strategy: {args.fsdp_sharding_strategy}. Supported: full, hybrid, shard_grad_op, hybrid_shard_grad_op, no_shard" + ) + return dict( cpu_offload=None, device_id=device_id, sync_module_states=True, # broadcast loaded ckpt from rank 0 -> all ranks - sharding_strategy=ShardingStrategy.FULL_SHARD - if args.fsdp_sharding_strategy == "full" - else ShardingStrategy.HYBRID_SHARD, + sharding_strategy=sharding_strategy, use_orig_params=True, mixed_precision=mp_policy, forward_prefetch=True, diff --git a/open_flamingo/train/losses.py b/open_flamingo/train/losses.py index f57727b4..19bede10 100644 --- a/open_flamingo/train/losses.py +++ b/open_flamingo/train/losses.py @@ -57,14 +57,15 @@ def __call__( attention_mask: torch.Tensor, autocast: callable, ): + print(input_ids[0]) # set up labels; language model is expected to handle shifting labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 - labels[labels == tokenizer.eos_token] = -100 special_token_ids = torch.Tensor(unwrap_model(model).special_token_ids).to( labels.device ) - labels[torch.isin(labels, special_token_ids)] = -100 + labels[torch.isin(labels, special_token_ids)] = -100 # TODO: dont want to remove loss on <|endofchunk|> tokens + labels = labels.to(input_ids.device) # call forward diff --git a/open_flamingo/train/train.py b/open_flamingo/train/train.py index 4b5a9aee..6e732f66 100644 --- a/open_flamingo/train/train.py +++ b/open_flamingo/train/train.py @@ -180,7 +180,7 @@ def main(): help="Use FullyShardedDataParallel for distributed training. Not supported for some models, e.g. OPT.", ) parser.add_argument( - "--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"] + "--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid", "shard_grad_op", "hybrid_shard_grad_op", "no_shard"] ) # wandb args @@ -321,16 +321,16 @@ def main(): # load optimizer checkpoint if args.resume_from_checkpoint is not None: - osd = checkpoint["optimizer_state_dict"] + optim_state_dict = checkpoint["optimizer_state_dict"] if args.fsdp: - FSDP.set_state_dict_type( - distributed_model, - **args.fsdp_checkpoint_config, + # FSDP.set_state_dict_type( + # distributed_model, + # **args.fsdp_checkpoint_config, + # ) + optim_state_dict = FSDP.optim_state_dict_to_load( + model=distributed_model, optim=optimizer, optim_state_dict=optim_state_dict ) - osd = FSDP.optim_state_dict_to_load( - model=distributed_model, optim=optimizer, optim_state_dict=osd - ) - optimizer.load_state_dict(osd) + optimizer.load_state_dict(optim_state_dict) # Initialize datasets datasets = [ diff --git a/open_flamingo/train/train_utils.py b/open_flamingo/train/train_utils.py index badc6bb6..69a6f8bb 100644 --- a/open_flamingo/train/train_utils.py +++ b/open_flamingo/train/train_utils.py @@ -48,7 +48,6 @@ def train_one_epoch( # set up model, autocast, and dtypes model.train() autocast = get_autocast(args.precision) - cast_dtype = get_cast_dtype(args.precision) # set up logging step_time_m = AverageMeter() @@ -70,7 +69,7 @@ def train_one_epoch( batch_metadata_to_log = {} for dataset_ix, (images, (input_ids, attention_mask)) in enumerate(batches): # unpack the batch and move to device - images = images.to(device_id, dtype=cast_dtype, non_blocking=True) + images = images.to(device_id, non_blocking=True) input_ids = input_ids.to(device_id, non_blocking=True) attention_mask = attention_mask.to(device_id, non_blocking=True) @@ -102,7 +101,7 @@ def train_one_epoch( # clip gradient norm if args.fsdp: - model.clip_grad_norm_(1.0) + model.clip_grad_norm_(1.0, norm_type=2.0) else: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) @@ -149,18 +148,6 @@ def train_one_epoch( ) -def get_cast_dtype(precision: str): - """ - Parses the precision argument and returns the dtype to cast inputs to. - """ - cast_dtype = None - if precision == "bf16": - cast_dtype = torch.bfloat16 - elif precision == "fp16": - cast_dtype = torch.float16 - return cast_dtype - - def get_autocast(precision, cache_enabled=True): """ Parses the precision argument and returns an autocast context manager. @@ -168,7 +155,6 @@ def get_autocast(precision, cache_enabled=True): if precision == "amp": return torch.cuda.amp.autocast(cache_enabled=cache_enabled) elif precision == "amp_bfloat16" or precision == "amp_bf16": - # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast( dtype=torch.bfloat16, cache_enabled=cache_enabled ) @@ -284,13 +270,17 @@ def load_checkpoint(args, model): checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") msd = checkpoint.pop("model_state_dict") msd = {k.replace("module.", ""): v for k, v in msd.items()} + + # remove any module with vision_encoder in the name + # msd = {k: v for k, v in msd.items() if "vision_encoder" not in k} + resume_from_epoch = checkpoint["epoch"] + 1 if args.fsdp: FSDP.set_state_dict_type( model, **args.fsdp_checkpoint_config, ) - model.load_state_dict(msd, False) + model.load_state_dict(msd, strict=False) return resume_from_epoch, checkpoint def filter_state_dict_to_trainable(model, state_dict): @@ -325,6 +315,8 @@ def save_checkpoint(model, optimizer, lr_scheduler, epoch, args): """ Save training checkpoint with model, optimizer, and lr_scheduler state. """ + torch.cuda.empty_cache() # (Sometimes this is necessary to avoid OOM errors when saving checkpoints) + if args.fsdp: FSDP.set_state_dict_type( model,