diff --git a/open_flamingo/eval/evaluate.py b/open_flamingo/eval/evaluate.py index efe84397..fc2e8963 100644 --- a/open_flamingo/eval/evaluate.py +++ b/open_flamingo/eval/evaluate.py @@ -689,18 +689,17 @@ def evaluate_captioning( args.batch_size, ) + # subset of the training set to sample context images from + query_set = utils.get_query_set(train_dataset, args.query_set_size) if args.rices: rices_dataset = RICES( - train_dataset, + query_set, eval_model.device, args.batch_size, - cached_features=cached_features, + cached_features=cached_features[query_set.indices], vision_encoder_path=args.rices_vision_encoder_path, vision_encoder_pretrained=args.rices_vision_encoder_pretrained, ) - else: - # subset of the training set to sample context images from - query_set = utils.get_query_set(train_dataset, args.query_set_size) utils.random_seed(seed, args.rank) predictions = defaultdict() @@ -882,17 +881,16 @@ def evaluate_vqa( args.batch_size, ) + query_set = utils.get_query_set(train_dataset, args.query_set_size) if args.rices: rices_dataset = RICES( - train_dataset, + query_set, eval_model.device, args.batch_size, - cached_features=cached_features, + cached_features=cached_features[query_set.indices], vision_encoder_path=args.rices_vision_encoder_path, vision_encoder_pretrained=args.rices_vision_encoder_pretrained, ) - else: - query_set = utils.get_query_set(train_dataset, args.query_set_size) utils.random_seed(seed, args.rank) predictions = [] @@ -1119,19 +1117,17 @@ def evaluate_classification( args.batch_size, ) + query_set = utils.get_query_set(train_dataset, args.query_set_size) + assert hasattr(query_set, 'class_id_array') if args.rices: rices_dataset = RICES( - train_dataset, + query_set, eval_model.device, args.batch_size, - cached_features=cached_features, + cached_features=cached_features[query_set.indices], vision_encoder_path=args.rices_vision_encoder_path, vision_encoder_pretrained=args.rices_vision_encoder_pretrained, ) - else: - # subset of the training set to sample context images from - query_set = utils.get_query_set(train_dataset, args.query_set_size) - assert hasattr(query_set, 'class_id_array') utils.random_seed(seed, args.rank) predictions = [] @@ -1195,7 +1191,7 @@ def evaluate_classification( batch_images, all_class_names, use_cache=(not no_kv_caching), - normalize_length=True, + normalize_length=False, ) ) @@ -1215,7 +1211,7 @@ def evaluate_classification( # dev: print some results if batch_idx == 0: - print(list(zip(batch_text, predicted_classnames[:1]))[:5]) + print("Context:", batch_text[0], "\n", "Generated:", predicted_classnames[0][0], "\n", "True:", batch["class_name"][0]) # compute accuracy for i, topk in enumerate(predicted_classnames): diff --git a/open_flamingo/eval/utils.py b/open_flamingo/eval/utils.py index 924ec645..14f05b47 100644 --- a/open_flamingo/eval/utils.py +++ b/open_flamingo/eval/utils.py @@ -116,14 +116,18 @@ def sample_examples_from_class(dataset, y, num_samples, replace_if_insufficient= def get_query_set(train_dataset, query_set_size): """ Get a subset of the training dataset to use as the query set. Returns a torch Subset. + Adds the "indices" attribute containing the indices of each example in the original set. """ - if query_set_size == -1: return train_dataset + if query_set_size == -1: + train_dataset.indices = np.arange(len(train_dataset)) + return train_dataset query_set_indices = np.random.choice(len(train_dataset), query_set_size, replace=False) query_set = Subset(train_dataset, query_set_indices) if hasattr(train_dataset, "class_id_array"): query_set.class_id_array = train_dataset.class_id_array[query_set_indices] if len(np.unique(query_set.class_id_array)) != len(np.unique(train_dataset.class_id_array)): print(f"Warning: query set does not contain examples from all classes; {len(np.unique(query_set.class_id_array))} remaining classes.") + query_set.indices = query_set_indices return query_set