diff --git a/open_flamingo/eval/eval_datasets.py b/open_flamingo/eval/eval_datasets.py index 06d4c829..b14576cc 100644 --- a/open_flamingo/eval/eval_datasets.py +++ b/open_flamingo/eval/eval_datasets.py @@ -2,6 +2,7 @@ import os from PIL import Image +import torch from torch.utils.data import Dataset from torchvision.datasets import ImageFolder @@ -125,6 +126,7 @@ def __init__(self, root, **kwargs): self.class_id_to_name = dict( zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES) ) + self.class_id_array = torch.tensor([y for _, y in self.samples]) def __getitem__(self, idx): sample, target = super().__getitem__(idx) @@ -142,6 +144,7 @@ def __init__(self, image_dir_path, annotations_path): self.image_dir_path = image_dir_path with open(annotations_path, "r") as f: self.annotations = [json.loads(line) for line in f] + self.class_id_array = torch.tensor([y["label"] for y in self.annotations]) def __len__(self): return len(self.annotations) @@ -178,10 +181,11 @@ def __init__(self, dataset_name: str, split: str, root_dir: str): ) else: raise Exception(f"Unimplemented WILDS dataset {dataset_name}") + self.class_id_array = self.dataset.y_array def __len__(self): return len(self.dataset) - + def __getitem__(self, idx): x, y, m = self.dataset[idx] y = y.item() diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index e170fb1c..755c30af 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -4,6 +4,7 @@ import os import uuid import random +import time from collections import defaultdict import numpy as np @@ -78,6 +79,7 @@ # Trial arguments parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int) +parser.add_argument("--true_zero_shot", action="store_true") parser.add_argument( "--num_trials", type=int, @@ -98,7 +100,7 @@ help="Number of samples to evaluate on. -1 for all samples.", ) parser.add_argument( - "--query_set_size", type=int, default=2048, help="Size of demonstration query set" + "--query_set_size", type=int, default=-1, help="Size of demonstration query set. -1 for all samples." ) parser.add_argument("--batch_size", type=int, default=8) @@ -113,6 +115,12 @@ action="store_true", help="Whether to use prompt ensembling (average log-likelihoods over permutations of in-context examples)", ) +parser.add_argument( + "--classification_num_classes_in_demos", + type=int, + default=None, + help="If set, demonstrations use class-conditional sampling with this many classes. Otherwise, random sampling.", +) parser.add_argument( "--rices", action="store_true", @@ -485,6 +493,9 @@ def main(): if len(args.trial_seeds) != args.num_trials: raise ValueError("Number of trial seeds must be == number of trials.") + if args.rices and args.classification_num_classes_in_demos is not None: + raise NotImplementedError("RICES + class conditional sampling not yet implemented.") + # set up wandb if args.rank == 0 and args.report_to_wandb: cfg_dict = vars(args) @@ -650,12 +661,12 @@ def evaluate_captioning( dataset_name=dataset_name, ) - effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) + effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot) np.random.seed(seed) test_dataloader = utils.prepare_eval_samples( test_dataset, - args.num_samples if args.num_samples > 0 else len(test_dataset), + args.num_samples, args.batch_size, ) @@ -843,12 +854,12 @@ def evaluate_vqa( dataset_name=dataset_name, ) - effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) + effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot) np.random.seed(seed) test_dataloader = utils.prepare_eval_samples( test_dataset, - args.num_samples if args.num_samples > 0 else len(test_dataset), + args.num_samples, args.batch_size, ) @@ -1064,12 +1075,12 @@ def evaluate_classification( class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names)) - effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model) + effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot) np.random.seed(seed) test_dataloader = utils.prepare_eval_samples( test_dataset, - args.num_samples if args.num_samples > 0 else len(test_dataset), + args.num_samples, args.batch_size, ) @@ -1085,21 +1096,34 @@ def evaluate_classification( else: # subset of the training set to sample context images from query_set = utils.get_query_set(train_dataset, args.query_set_size) + assert hasattr(query_set, 'class_id_array') utils.random_seed(seed, args.rank) predictions = [] + prompt_time_m = utils.AverageMeter() + rank_time_m = utils.AverageMeter() + for batch_idx, batch in tqdm( enumerate(test_dataloader), desc=f"Running inference {dataset_name}", disable=args.rank != 0, ): + + end = time.time() if args.rices: batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots) + elif args.classification_num_classes_in_demos is not None: + _, batch_demo_samples = utils.sample_class_conditional_batch_demos_from_query_set( + batch["class_id"], args.classification_num_classes_in_demos, query_set, effective_num_shots, + ) else: batch_demo_samples = utils.sample_batch_demos_from_query_set( query_set, effective_num_shots, len(batch["image"]) ) + prompt_time_m.update(time.time() - end) + end = time.time() + # set up prompt ensembling num_permutations = ( min(6, math.factorial(effective_num_shots)) if use_prompt_ensembling else 1 @@ -1141,7 +1165,8 @@ def evaluate_classification( ) # ensemble logprobs together - logprobs = torch.mean(torch.stack(logprobs, dim=-1), dim=-1) + logprobs = torch.mean(torch.stack(logprobs, dim=-1), dim=-1).to(dtype=torch.float32) + rank_time_m.update(time.time() - end) ( predicted_class_ixs, @@ -1153,6 +1178,10 @@ def evaluate_classification( class_id_to_name, ) + # dev: print some results + if batch_idx == 0: + print(list(zip(batch_text, predicted_classnames[:1]))[:5]) + # compute accuracy for i, topk in enumerate(predicted_classnames): y_i = batch["class_name"][i] @@ -1171,6 +1200,12 @@ def evaluate_classification( pred_info["metadata"] = batch["metadata"][i] predictions.append(pred_info) + if args.rank == 0: + print(f"Avg prompt loading time: {prompt_time_m.avg}") + print(f"Avg rank classification w/ ensembling time: {rank_time_m.avg}") + + end = time.time() + # all gather all_predictions = [None for _ in range(args.world_size)] torch.distributed.all_gather_object(all_predictions, predictions) # list of lists diff --git a/open_flamingo/eval/models/open_flamingo.py b/open_flamingo/eval/models/open_flamingo.py index e59a042f..f0b0c97b 100644 --- a/open_flamingo/eval/models/open_flamingo.py +++ b/open_flamingo/eval/models/open_flamingo.py @@ -35,7 +35,8 @@ def __init__(self, model_args): if ("device" in model_args and model_args["device"] >= 0) else "cpu" ) - + self.autocast = get_autocast(model_args["precision"]) + self.cast_dtype = get_cast_dtype(model_args["precision"]) ( self.model, self.image_processor, @@ -52,16 +53,11 @@ def __init__(self, model_args): checkpoint = checkpoint["model_state_dict"] checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()} self.model.load_state_dict(checkpoint, strict=False) - self.model.to(self.device) + self.model.to(self.device, dtype=self.cast_dtype) self.model.eval() self.tokenizer.padding_side = "left" - self.lm_name = model_args["lm_path"].split("/")[-1] - # autocast - self.autocast = get_autocast(model_args["precision"]) - self.cast_dtype = get_cast_dtype(model_args["precision"]) - def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor: """ Convert images to tensors, reshape them, and stack them. @@ -114,9 +110,9 @@ def _prepare_text( max_length=max_length, ) input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"] - input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True) + input_ids = input_ids.to(self.device, non_blocking=True) attention_mask = attention_mask.to( - self.device, dtype=self.cast_dtype, non_blocking=True + self.device, non_blocking=True ) return input_ids, attention_mask.bool() @@ -334,4 +330,4 @@ def get_hateful_memes_prompt(self, text, label=None) -> str: return f"is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" def get_waterbirds_prompt(self, label=None) -> str: - return f"Question: Is this a landbird or waterbird? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" \ No newline at end of file + return f"Question: Is this a landbird or waterbird? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}" \ No newline at end of file diff --git a/open_flamingo/eval/utils.py b/open_flamingo/eval/utils.py index 41535b5f..924ec645 100644 --- a/open_flamingo/eval/utils.py +++ b/open_flamingo/eval/utils.py @@ -2,7 +2,10 @@ import torch import random import torch.nn as nn +from torch.utils.data import Subset from contextlib import suppress +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler def random_seed(seed=42, rank=0): @@ -21,39 +24,120 @@ def custom_collate_fn(batch): return collated_batch -def compute_effective_num_shots(num_shots, model_type): +def compute_effective_num_shots(num_shots, model_type, true_zero_shot=False): """ Compute the effective number of shots for a given model type. For example, following Flamingo, 0-shot OF evaluations use two text-only shots. """ - if model_type == "open_flamingo": + if model_type == "open_flamingo" and not true_zero_shot: return num_shots if num_shots > 0 else 2 return num_shots def sample_batch_demos_from_query_set(query_set, num_samples, batch_size): """ - Sample random demonstrations from the query set. - """ - return [random.sample(query_set, num_samples) for _ in range(batch_size)] + Sample random demonstrations with replacement from the query set. + Returns a torch Subset + """ + random_indices = np.random.choice(len(query_set), num_samples, replace=True) + return Subset(query_set, random_indices) + +def sample_class_conditional_batch_demos_from_query_set( + batch_class_ids, + num_classes: int, + query_set: Subset, + num_samples: int, +): + """ + Two-stage demo sampling procedure. + 1. Sample num_classes classes to include in the demo, being sure to include the true class (in batch_class_ids) + Classes are only sampled from the classes in the query set. + if the batch_class_ids contains classes not in the query set, raises an error. + 2. For each sampled class, sample floor(num_samples / num_classes); the remainder gets distributed among + random classes. If there are fewer than num_classes samples, sample with replacement. + Returns a list of torch Subsets + """ + # sanity checks + all_classes = torch.unique(query_set.class_id_array) + assert num_classes <= len(all_classes), "Attempting to select more classes in the demo than there are classes in the dataset." + if not isinstance(batch_class_ids, torch.Tensor): + batch_class_ids = torch.LongTensor(batch_class_ids) + if torch.any(~torch.isin(batch_class_ids, all_classes)): + raise ValueError("batch_class_ids contains classes not in the query set.") + if num_samples < num_classes: + raise ValueError("num_samples must be >= num_classes.") + + # sample classes + demos per class + sampled_classes, sampled_demos = [], [] + samples_per_class = num_samples // num_classes + leftover_samples = num_samples % num_classes + for y in batch_class_ids: + if isinstance(y, torch.Tensor): y = y.item() + other_classes = np.setdiff1d(all_classes, [y]).tolist() + classes = random.sample(other_classes, num_classes - 1) + [y] + random.shuffle(classes) + sampled_classes.append(classes) + demos = [ + sample_examples_from_class( + query_set, + yp, + samples_per_class + int(i < leftover_samples), + replace_if_insufficient=True, + ) + for i, yp in enumerate(classes) + ] + demos = [item for sublist in demos for item in sublist] + random.shuffle(demos) # otherwise, examples will be in class chunks + sampled_demos.append(Subset(query_set, demos)) + + return sampled_classes, sampled_demos + + +def sample_examples_from_class(dataset, y, num_samples, replace_if_insufficient=False): + """ + Given a class id y and a torch dataset containing examples from multiple classes, + samples num_samples examples from class y. + Returns: indices of selected examples + """ + class_indices = torch.where(dataset.class_id_array == y)[0].tolist() + selected_indices = random.sample( + class_indices, min(num_samples, len(class_indices)) + ) + if len(selected_indices) < num_samples: + print(f"Warning: insufficient samples in query set for class {y}, sampling with replacement={replace_if_insufficient}") + if replace_if_insufficient: + selected_indices += random.choices( + class_indices, k=num_samples - len(selected_indices) + ) + + return selected_indices def get_query_set(train_dataset, query_set_size): """ - Get a subset of the training dataset to use as the query set. + Get a subset of the training dataset to use as the query set. Returns a torch Subset. """ - query_set = np.random.choice(len(train_dataset), query_set_size, replace=False) - return [train_dataset[i] for i in query_set] + if query_set_size == -1: return train_dataset + query_set_indices = np.random.choice(len(train_dataset), query_set_size, replace=False) + query_set = Subset(train_dataset, query_set_indices) + if hasattr(train_dataset, "class_id_array"): + query_set.class_id_array = train_dataset.class_id_array[query_set_indices] + if len(np.unique(query_set.class_id_array)) != len(np.unique(train_dataset.class_id_array)): + print(f"Warning: query set does not contain examples from all classes; {len(np.unique(query_set.class_id_array))} remaining classes.") + return query_set def prepare_eval_samples(test_dataset, num_samples, batch_size): """ Subset the test dataset and return a DataLoader. """ - random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) - dataset = torch.utils.data.Subset(test_dataset, random_indices) - sampler = torch.utils.data.distributed.DistributedSampler(dataset) - loader = torch.utils.data.DataLoader( + if num_samples != -1: + random_indices = np.random.choice(len(test_dataset), num_samples, replace=False) + dataset = Subset(test_dataset, random_indices) + else: + dataset = test_dataset + sampler = DistributedSampler(dataset) + loader = DataLoader( dataset, batch_size=batch_size, sampler=sampler, @@ -123,3 +207,21 @@ def get_autocast(precision): return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return suppress + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/open_flamingo/src/flamingo.py b/open_flamingo/src/flamingo.py index 9a67cfed..02b8b7fe 100644 --- a/open_flamingo/src/flamingo.py +++ b/open_flamingo/src/flamingo.py @@ -191,8 +191,7 @@ def _encode_vision_x(self, vision_x: torch.Tensor): assert F == 1, "Only single frame supported" vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w") - with torch.no_grad(): - vision_x = self.vision_encoder(vision_x)[1] + vision_x = self.vision_encoder(vision_x)[1] vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) vision_x = self.perceiver(vision_x)