Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
ssagawa committed Aug 25, 2023
2 parents cfba7ae + fb3afad commit fef4df2
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 35 deletions.
8 changes: 6 additions & 2 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder

Expand Down Expand Up @@ -126,6 +127,7 @@ def __init__(self, root, **kwargs):
self.class_id_to_name = dict(
zip(range(len(IMAGENET_CLASSNAMES)), IMAGENET_CLASSNAMES)
)
self.class_id_array = torch.tensor([y for _, y in self.samples])

def __getitem__(self, idx):
sample, target = super().__getitem__(idx)
Expand All @@ -143,6 +145,7 @@ def __init__(self, image_dir_path, annotations_path):
self.image_dir_path = image_dir_path
with open(annotations_path, "r") as f:
self.annotations = [json.loads(line) for line in f]
self.class_id_array = torch.tensor([y["label"] for y in self.annotations])

def __len__(self):
return len(self.annotations)
Expand Down Expand Up @@ -185,10 +188,11 @@ def __init__(self, dataset_name: str, split: str, root_dir: str):
)
else:
raise Exception(f"Unimplemented WILDS dataset {dataset_name}")

self.class_id_array = self.dataset.y_array

def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
x, y, m = self.dataset[idx]
y = y.item()
Expand Down
51 changes: 43 additions & 8 deletions open_flamingo/eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import uuid
import random
import time
from collections import defaultdict

import numpy as np
Expand Down Expand Up @@ -80,6 +81,7 @@

# Trial arguments
parser.add_argument("--shots", nargs="+", default=[0, 4, 8, 16, 32], type=int)
parser.add_argument("--true_zero_shot", action="store_true")
parser.add_argument(
"--num_trials",
type=int,
Expand All @@ -100,7 +102,7 @@
help="Number of samples to evaluate on. -1 for all samples.",
)
parser.add_argument(
"--query_set_size", type=int, default=2048, help="Size of demonstration query set"
"--query_set_size", type=int, default=-1, help="Size of demonstration query set. -1 for all samples."
)

parser.add_argument("--batch_size", type=int, default=8)
Expand All @@ -115,6 +117,12 @@
action="store_true",
help="Whether to use prompt ensembling (average log-likelihoods over permutations of in-context examples)",
)
parser.add_argument(
"--classification_num_classes_in_demos",
type=int,
default=None,
help="If set, demonstrations use class-conditional sampling with this many classes. Otherwise, random sampling.",
)
parser.add_argument(
"--rices",
action="store_true",
Expand Down Expand Up @@ -493,6 +501,9 @@ def main():
if len(args.trial_seeds) != args.num_trials:
raise ValueError("Number of trial seeds must be == number of trials.")

if args.rices and args.classification_num_classes_in_demos is not None:
raise NotImplementedError("RICES + class conditional sampling not yet implemented.")

# set up wandb
if args.rank == 0 and args.report_to_wandb:
cfg_dict = vars(args)
Expand Down Expand Up @@ -669,12 +680,12 @@ def evaluate_captioning(
dataset_name=dataset_name,
)

effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot)

np.random.seed(seed)
test_dataloader = utils.prepare_eval_samples(
test_dataset,
args.num_samples if args.num_samples > 0 else len(test_dataset),
args.num_samples,
args.batch_size,
)

Expand Down Expand Up @@ -862,12 +873,12 @@ def evaluate_vqa(
dataset_name=dataset_name,
)

effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot)

np.random.seed(seed)
test_dataloader = utils.prepare_eval_samples(
test_dataset,
args.num_samples if args.num_samples > 0 else len(test_dataset),
args.num_samples,
args.batch_size,
)

Expand Down Expand Up @@ -1099,12 +1110,12 @@ def evaluate_classification(

class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names))

effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model, args.true_zero_shot)

np.random.seed(seed)
test_dataloader = utils.prepare_eval_samples(
test_dataset,
args.num_samples if args.num_samples > 0 else len(test_dataset),
args.num_samples,
args.batch_size,
)

Expand All @@ -1120,21 +1131,34 @@ def evaluate_classification(
else:
# subset of the training set to sample context images from
query_set = utils.get_query_set(train_dataset, args.query_set_size)
assert hasattr(query_set, 'class_id_array')

utils.random_seed(seed, args.rank)
predictions = []
prompt_time_m = utils.AverageMeter()
rank_time_m = utils.AverageMeter()

for batch_idx, batch in tqdm(
enumerate(test_dataloader),
desc=f"Running inference {dataset_name}",
disable=args.rank != 0,
):

end = time.time()
if args.rices:
batch_demo_samples = rices_dataset.find(batch["image"], effective_num_shots)
elif args.classification_num_classes_in_demos is not None:
_, batch_demo_samples = utils.sample_class_conditional_batch_demos_from_query_set(
batch["class_id"], args.classification_num_classes_in_demos, query_set, effective_num_shots,
)
else:
batch_demo_samples = utils.sample_batch_demos_from_query_set(
query_set, effective_num_shots, len(batch["image"])
)

prompt_time_m.update(time.time() - end)
end = time.time()

# set up prompt ensembling
num_permutations = (
min(6, math.factorial(effective_num_shots)) if use_prompt_ensembling else 1
Expand Down Expand Up @@ -1176,7 +1200,8 @@ def evaluate_classification(
)

# ensemble logprobs together
logprobs = torch.mean(torch.stack(logprobs, dim=-1), dim=-1)
logprobs = torch.mean(torch.stack(logprobs, dim=-1), dim=-1).to(dtype=torch.float32)
rank_time_m.update(time.time() - end)

(
predicted_class_ixs,
Expand All @@ -1188,6 +1213,10 @@ def evaluate_classification(
class_id_to_name,
)

# dev: print some results
if batch_idx == 0:
print(list(zip(batch_text, predicted_classnames[:1]))[:5])

# compute accuracy
for i, topk in enumerate(predicted_classnames):
y_i = batch["class_name"][i]
Expand All @@ -1206,6 +1235,12 @@ def evaluate_classification(
pred_info["metadata"] = batch["metadata"][i]
predictions.append(pred_info)

if args.rank == 0:
print(f"Avg prompt loading time: {prompt_time_m.avg}")
print(f"Avg rank classification w/ ensembling time: {rank_time_m.avg}")

end = time.time()

# all gather
all_predictions = [None for _ in range(args.world_size)]
torch.distributed.all_gather_object(all_predictions, predictions) # list of lists
Expand Down
18 changes: 7 additions & 11 deletions open_flamingo/eval/models/open_flamingo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def __init__(self, model_args):
if ("device" in model_args and model_args["device"] >= 0)
else "cpu"
)

self.autocast = get_autocast(model_args["precision"])
self.cast_dtype = get_cast_dtype(model_args["precision"])
(
self.model,
self.image_processor,
Expand All @@ -52,16 +53,11 @@ def __init__(self, model_args):
checkpoint = checkpoint["model_state_dict"]
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
self.model.load_state_dict(checkpoint, strict=False)
self.model.to(self.device)
self.model.to(self.device, dtype=self.cast_dtype)
self.model.eval()
self.tokenizer.padding_side = "left"

self.lm_name = model_args["lm_path"].split("/")[-1]

# autocast
self.autocast = get_autocast(model_args["precision"])
self.cast_dtype = get_cast_dtype(model_args["precision"])

def _prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
"""
Convert images to tensors, reshape them, and stack them.
Expand Down Expand Up @@ -114,9 +110,9 @@ def _prepare_text(
max_length=max_length,
)
input_ids, attention_mask = encodings["input_ids"], encodings["attention_mask"]
input_ids = input_ids.to(self.device, dtype=self.cast_dtype, non_blocking=True)
input_ids = input_ids.to(self.device, non_blocking=True)
attention_mask = attention_mask.to(
self.device, dtype=self.cast_dtype, non_blocking=True
self.device, non_blocking=True
)
return input_ids, attention_mask.bool()

Expand Down Expand Up @@ -334,7 +330,7 @@ def get_hateful_memes_prompt(self, text, label=None) -> str:
return f"<image>is an image with: '{text}' written on it. Is it hateful? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"

def get_waterbirds_prompt(self, label=None) -> str:
return f"<image>Question: Is this a landbird or waterbird? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
return f"<image>Question: Is this a landbird or waterbird? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"

def get_camelyon17_prompt(self, label=None) -> str:
return f"<image>Question: Is this a normal tissue or cancer tissue? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
return f"<image>Question: Is this a normal tissue or cancer tissue? Answer:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
Loading

0 comments on commit fef4df2

Please sign in to comment.