diff --git a/open_flamingo/eval/cache_rices_features.py b/open_flamingo/eval/cache_rices_features.py deleted file mode 100644 index 6dd10076..00000000 --- a/open_flamingo/eval/cache_rices_features.py +++ /dev/null @@ -1,346 +0,0 @@ -""" -Cache CLIP ViT-B/32 features for all images in training split in preparation for RICES -""" -import argparse -from rices import RICES -from eval_datasets import ( - CaptionDataset, - VQADataset, - ImageNetDataset, - HatefulMemesDataset, -) -import os -import torch - -parser = argparse.ArgumentParser() -parser.add_argument( - "--output_dir", - type=str, - required=True, - help="Directory to save the cached features.", -) - -parser.add_argument("--batch_size", default=256) - -# Per-dataset flags -parser.add_argument( - "--eval_coco", - action="store_true", - default=False, - help="Whether to cache COCO.", -) -parser.add_argument( - "--eval_vqav2", - action="store_true", - default=False, - help="Whether to cache VQAV2.", -) -parser.add_argument( - "--eval_ok_vqa", - action="store_true", - default=False, - help="Whether to cache OK-VQA.", -) -parser.add_argument( - "--eval_vizwiz", - action="store_true", - default=False, - help="Whether to cache VizWiz.", -) -parser.add_argument( - "--eval_textvqa", - action="store_true", - default=False, - help="Whether to cache TextVQA.", -) -parser.add_argument( - "--eval_imagenet", - action="store_true", - default=False, - help="Whether to cache ImageNet.", -) -parser.add_argument( - "--eval_flickr30", - action="store_true", - default=False, - help="Whether to cache Flickr30.", -) -parser.add_argument( - "--eval_hateful_memes", - action="store_true", - default=False, - help="Whether to cache Hateful Memes.", -) - -# Dataset arguments - -## Flickr30 Dataset -parser.add_argument( - "--flickr_image_dir_path", - type=str, - help="Path to the flickr30/flickr30k_images directory.", - default=None, -) -parser.add_argument( - "--flickr_karpathy_json_path", - type=str, - help="Path to the dataset_flickr30k.json file.", - default=None, -) -parser.add_argument( - "--flickr_annotations_json_path", - type=str, - help="Path to the dataset_flickr30k_coco_style.json file.", -) -## COCO Dataset -parser.add_argument( - "--coco_train_image_dir_path", - type=str, - default=None, -) -parser.add_argument( - "--coco_val_image_dir_path", - type=str, - default=None, -) -parser.add_argument( - "--coco_karpathy_json_path", - type=str, - default=None, -) -parser.add_argument( - "--coco_annotations_json_path", - type=str, - default=None, -) - -## VQAV2 Dataset -parser.add_argument( - "--vqav2_train_image_dir_path", - type=str, - default=None, -) -parser.add_argument( - "--vqav2_train_questions_json_path", - type=str, - default=None, -) -parser.add_argument( - "--vqav2_train_annotations_json_path", - type=str, - default=None, -) - -## OK-VQA Dataset -parser.add_argument( - "--ok_vqa_train_image_dir_path", - type=str, - help="Path to the vqav2/train2014 directory.", - default=None, -) -parser.add_argument( - "--ok_vqa_train_questions_json_path", - type=str, - help="Path to the v2_OpenEnded_mscoco_train2014_questions.json file.", - default=None, -) -parser.add_argument( - "--ok_vqa_train_annotations_json_path", - type=str, - help="Path to the v2_mscoco_train2014_annotations.json file.", - default=None, -) - -## VizWiz Dataset -parser.add_argument( - "--vizwiz_train_image_dir_path", - type=str, - help="Path to the vizwiz train images directory.", - default=None, -) -parser.add_argument( - "--vizwiz_train_questions_json_path", - type=str, - help="Path to the vizwiz questions json file.", - default=None, -) -parser.add_argument( - "--vizwiz_train_annotations_json_path", - type=str, - help="Path to the vizwiz annotations json file.", - default=None, -) - -# TextVQA Dataset -parser.add_argument( - "--textvqa_image_dir_path", - type=str, - help="Path to the textvqa images directory.", - default=None, -) -parser.add_argument( - "--textvqa_train_questions_json_path", - type=str, - help="Path to the textvqa questions json file.", - default=None, -) -parser.add_argument( - "--textvqa_train_annotations_json_path", - type=str, - help="Path to the textvqa annotations json file.", - default=None, -) - - -## Imagenet dataset -parser.add_argument("--imagenet_root", type=str, default="/tmp") - -## Hateful Memes dataset -parser.add_argument( - "--hateful_memes_image_dir_path", - type=str, - default=None, -) -parser.add_argument( - "--hateful_memes_train_annotations_json_path", - type=str, - default=None, -) - - -def main(): - args, leftovers = parser.parse_known_args() - device_id = torch.cuda.current_device() if torch.cuda.is_available() else "cpu" - if args.eval_flickr30: - print("Caching Flickr30k...") - train_dataset = CaptionDataset( - image_train_dir_path=args.flickr_image_dir_path, - image_val_dir_path=None, - annotations_path=args.flickr_karpathy_json_path, - is_train=True, - dataset_name="flickr30", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "flickr30.pkl"), - ) - - if args.eval_coco: - print("Caching COCO...") - train_dataset = CaptionDataset( - image_train_dir_path=args.coco_train_image_dir_path, - image_val_dir_path=args.coco_val_image_dir_path, - annotations_path=args.coco_karpathy_json_path, - is_train=True, - dataset_name="coco", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "coco.pkl"), - ) - - if args.eval_ok_vqa: - print("Caching OK-VQA...") - train_dataset = VQADataset( - image_dir_path=args.ok_vqa_train_image_dir_path, - question_path=args.ok_vqa_train_questions_json_path, - annotations_path=args.ok_vqa_train_annotations_json_path, - is_train=True, - dataset_name="ok_vqa", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "ok_vqa.pkl"), - ) - - if args.eval_vizwiz: - print("Caching VizWiz...") - train_dataset = VQADataset( - image_dir_path=args.vizwiz_train_image_dir_path, - question_path=args.vizwiz_train_questions_json_path, - annotations_path=args.vizwiz_train_annotations_json_path, - is_train=True, - dataset_name="vizwiz", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "vizwiz.pkl"), - ) - - if args.eval_vqav2: - print("Caching VQAv2...") - train_dataset = VQADataset( - image_dir_path=args.vqav2_train_image_dir_path, - question_path=args.vqav2_train_questions_json_path, - annotations_path=args.vqav2_train_annotations_json_path, - is_train=True, - dataset_name="vqav2", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "vqav2.pkl"), - ) - - if args.eval_textvqa: - print("Caching TextVQA...") - train_dataset = VQADataset( - image_dir_path=args.textvqa_image_dir_path, - question_path=args.textvqa_train_questions_json_path, - annotations_path=args.textvqa_train_annotations_json_path, - is_train=True, - dataset_name="textvqa", - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "textvqa.pkl"), - ) - - if args.eval_hateful_memes: - print("Caching Hateful Memes...") - train_dataset = HatefulMemesDataset( - image_dir_path=args.hateful_memes_image_dir_path, - annotations_path=args.hateful_memes_train_annotations_json_path, - ) - rices_dataset = RICES( - train_dataset, - device_id, - args.batch_size, - ) - torch.save( - rices_dataset.features, - os.path.join(args.output_dir, "hateful_memes.pkl"), - ) - - -if __name__ == "__main__": - main() diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index fc2e8963..a98d57da 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -501,9 +501,6 @@ def main(): if len(args.trial_seeds) != args.num_trials: raise ValueError("Number of trial seeds must be == number of trials.") - if args.rices and args.classification_num_classes_in_demos is not None: - raise NotImplementedError("RICES + class conditional sampling not yet implemented.") - # set up wandb if args.rank == 0 and args.report_to_wandb: cfg_dict = vars(args) @@ -750,7 +747,8 @@ def evaluate_captioning( ] if args.rank == 0: - print("Context:", batch_text[0], "\n", "Generated:", new_predictions[0]) + for i in range(len(batch_text)): + print("Context:", batch_text[i], "\n", "Generated:", new_predictions[i]) for i, sample_id in enumerate(batch["image_id"]): predictions[sample_id] = { @@ -1141,16 +1139,30 @@ def evaluate_classification( ): end = time.time() - if args.rices: - batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots) - elif args.classification_num_classes_in_demos is not None: - _, batch_demo_samples = utils.sample_class_conditional_batch_demos_from_query_set( + if args.classification_num_classes_in_demos is not None: + batch_classes, batch_demo_samples = utils.sample_class_conditional_batch_demos_from_query_set( batch["class_id"], args.classification_num_classes_in_demos, query_set, effective_num_shots, ) + if args.rices: + batch_classes = torch.LongTensor(batch_classes) + num_samples = effective_num_shots // args.classification_num_classes_in_demos * torch.ones_like(batch_classes) + num_samples += torch.tensor([int(i < effective_num_shots % args.classification_num_classes_in_demos) for i in range(args.classification_num_classes_in_demos)]).unsqueeze(0).repeat(len(batch_classes), 1) + batch_demo_samples = rices_dataset.find_filtered( + utils.repeat_interleave(batch["image"], args.classification_num_classes_in_demos), + num_samples.view(-1), + [torch.where(query_set.class_id_array == class_id)[0].tolist() for class_id in batch_classes.view(-1)], + ) + batch_demo_samples = utils.reshape_nested_list(batch_demo_samples, (len(batch_classes), effective_num_shots)) else: - batch_demo_samples = utils.sample_batch_demos_from_query_set( - query_set, effective_num_shots, len(batch["image"]) - ) + if not args.rices: + batch_demo_samples = utils.sample_batch_demos_from_query_set( + query_set, effective_num_shots, len(batch["image"]) + ) + else: + batch_demo_samples = rices_dataset.find( + batch["image"], + effective_num_shots, + ) prompt_time_m.update(time.time() - end) end = time.time() diff --git a/open_flamingo/eval/rices.py b/open_flamingo/eval/rices.py index d1535c71..5b3ea39e 100644 --- a/open_flamingo/eval/rices.py +++ b/open_flamingo/eval/rices.py @@ -2,8 +2,9 @@ import torch from tqdm import tqdm import torch -from utils import custom_collate_fn - +from open_flamingo.eval.utils import custom_collate_fn +from functools import partial +import multiprocessing class RICES: def __init__( @@ -11,11 +12,12 @@ def __init__( dataset, device, batch_size, - vision_encoder_path="ViT-B-32", + vision_encoder_path="ViT-L-14", vision_encoder_pretrained="openai", cached_features=None, ): self.dataset = dataset + self.dataset_indices = torch.arange(len(dataset)) self.device = device self.batch_size = batch_size @@ -23,7 +25,7 @@ def __init__( vision_encoder, _, image_processor = open_clip.create_model_and_transforms( vision_encoder_path, pretrained=vision_encoder_pretrained, - cache_dir="/mmfs1/gscratch/efml/anasa2/clip_cache", + cache_dir="/juice/scr/irena/.cache", ) self.model = vision_encoder.to(self.device) self.image_processor = image_processor @@ -63,7 +65,7 @@ def _precompute_features(self): features = torch.cat(features) return features - def find(self, batch, num_examples): + def find(self, batch, num_examples, return_similarity=False): """ Get the top num_examples most similar examples to the images. """ @@ -89,8 +91,27 @@ def find(self, batch, num_examples): if similarity.ndim == 1: similarity = similarity.unsqueeze(0) + if return_similarity: + return similarity + # Get the indices of the 'num_examples' most similar images indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] # Return with the most similar images last - return [[self.dataset[i] for i in reversed(row)] for row in indices] + return [[self.dataset[self.dataset_indices[i]] for i in reversed(row)] for row in indices] + + def find_filtered(self, batch, num_examples, indices): + """ + For each element in batch, find the top num_examples most similar examples + out of indices. + Args: + - indices: list of lists of indices of examples to consider for each element in batch + """ + similarity = self.find(batch, None, return_similarity=True) # (B, len(self.dataset)) + mask = torch.zeros_like(similarity) + for i, idx_list in enumerate(indices): + mask[i, idx_list] = 1 + similarity[~mask.bool()] = -torch.inf + indices = similarity.argsort(dim=-1, descending=True) + # Return with the most similar images last + return [[self.dataset[self.dataset_indices[i]] for i in reversed(row[:num_examples[j]])] for j, row in enumerate(indices)] \ No newline at end of file diff --git a/open_flamingo/eval/utils.py b/open_flamingo/eval/utils.py index 14f05b47..a64991e3 100644 --- a/open_flamingo/eval/utils.py +++ b/open_flamingo/eval/utils.py @@ -96,7 +96,7 @@ def sample_class_conditional_batch_demos_from_query_set( def sample_examples_from_class(dataset, y, num_samples, replace_if_insufficient=False): """ Given a class id y and a torch dataset containing examples from multiple classes, - samples num_samples examples from class y. + samples num_samples examples from class y uniformly at random. Returns: indices of selected examples """ class_indices = torch.where(dataset.class_id_array == y)[0].tolist() @@ -229,3 +229,25 @@ def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count + +def repeat_interleave(list, n): + """ + Mimics torch.repeat_interelave for a list of arbitrary objects + """ + return [item for item in list for _ in range(n)] + +def reshape_nested_list(original_list, shape: tuple): + """ + Reshapes a 2D list into a 2D list of shape shape + """ + assert len(shape) == 2 + outer_list, inner_list = [], [] + for list in original_list: + for x in list: + inner_list.append(x) + if len(inner_list) == shape[1]: + outer_list.append(inner_list) + inner_list = [] + if len(outer_list) != shape[0]: + raise ValueError(f"List could not be reshaped to {shape}") + return outer_list diff --git a/open_flamingo/scripts/cache_rices_features.py b/open_flamingo/scripts/cache_rices_features.py index 63d49834..7b53ee9f 100644 --- a/open_flamingo/scripts/cache_rices_features.py +++ b/open_flamingo/scripts/cache_rices_features.py @@ -365,6 +365,22 @@ def main(): os.path.join(args.output_dir, "hateful_memes.pkl"), ) + if args.eval_imagenet: + print("Caching ImageNet...") + train_dataset = ImageNetDataset( + os.path.join(args.imagenet_root, "train") + ) + rices_dataset = RICES( + train_dataset, + device_id, + args.batch_size, + vision_encoder_path=args.vision_encoder_path, + vision_encoder_pretrained=args.vision_encoder_pretrained, + ) + torch.save( + rices_dataset.features, + os.path.join(args.output_dir, "imagenet.pkl"), + ) if __name__ == "__main__": main()