Skip to content

Commit

Permalink
update eval code to match new src args
Browse files Browse the repository at this point in the history
  • Loading branch information
i-gao committed Sep 16, 2023
1 parent be9a4dd commit b0ff9a4
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 204 deletions.
2 changes: 2 additions & 0 deletions open_flamingo/eval/eval_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES

SUPPORTED_TASKS = ["coco", "flickr", "vqav2", "ok_vqa", "vizwiz", "textvqa", "hateful_memes", "imagenet"]


class CaptionDataset(Dataset):
def __init__(
Expand Down
1 change: 1 addition & 0 deletions open_flamingo/eval/eval_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .eval_model import *
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ class EvalModel(BaseEvalModel):
"""BLIP-2 model evaluation."""

def __init__(self, model_args, init_on_device=False):
assert (
"processor_path" in model_args and "lm_path" in model_args
), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified"
super().__init__(model_args, init_on_device)
with self.init_ctx:
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
Expand All @@ -25,6 +22,10 @@ def __init__(self, model_args, init_on_device=False):
self.tokenizer = self.processor.tokenizer
self._check_init()

@property
def required_args(self):
return ["processor_path", "lm_path"]

def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
batch_images = None
assert all(
Expand Down Expand Up @@ -58,6 +59,7 @@ def prepare_text(
max_length=2000,
add_special_tokens=True,
):
self._validate_text(batch)
encodings = self.tokenizer(
batch,
padding=padding,
Expand Down Expand Up @@ -95,39 +97,20 @@ def get_outputs(

return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)

def get_vqa_prompt(self, question, answer=None) -> str:
return (
f"Question:{question} Short answer:{answer if answer is not None else ''}"
)

def get_caption_prompt(self, caption=None) -> str:
def get_vqav2_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_ok_vqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_vizwiz_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_textvqa_prompt(self, question, answer=None) -> str:
return f"Question:{question} Short answer:{answer if answer is not None else ''}"

def get_coco_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
return f"A photo of {caption if caption is not None else ''}"

def __call__(
self,
lang_x: torch.Tensor,
vision_x: torch.Tensor,
attention_mask: torch.Tensor,
):
with self.autocast():
outputs = self.model(
pixel_values=vision_x,
input_ids=lang_x,
attention_mask=attention_mask,
)

# remove vision tokens
outputs.logits = outputs.logits[:, -lang_x.size(1) :, :]
return outputs

def get_rank_classifications(
self,
batch_text: List[str],
batch_images: List[List[Image.Image]],
all_class_names: List[str],
use_cache: bool,
normalize_length: bool,
):
raise NotImplementedError(
"BLIP-2 classification-based evaluation not implemented"
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,75 @@
import torch
from contextlib import suppress

SUPPORTED_MODELS = ["open_flamingo", "blip", "idefics"]
ZERO_SHOT_ONLY_MODELS = ["blip"]


def get_eval_model(name, *args, **kwargs):
"""Return an EvalModel object."""
if name == "open_flamingo":
from .open_flamingo import EvalModel

return EvalModel(*args, **kwargs)
elif name == "blip":
from .blip import EvalModel

return EvalModel(*args, **kwargs)
elif name == "idefics":
from .idefics import EvalModel

return EvalModel(*args, **kwargs)
else:
raise ValueError(f"Unsupported EvalModel type {name}")


class BaseEvalModel(abc.ABC):
"""Base class encapsulating functionality needed to evaluate a model."""

def __init__(self, model_args: List[str]):
def __init__(self, model_args: List[str], init_on_device=False):
"""Initialize model.
Args:
args: arguments to model. These should be parsed, or if the model
has no applicable arguments, an error should be thrown if `args`
is non-empty.
"""
# check model args
assert all(
arg in model_args for arg in self.required_args
), f"Missing required args for {self.__class__.__name__}: {self.required_args}"
self.lm_name = model_args["lm_path"].split("/")[-1]

def __init__(self, model_args, init_on_device=False):
assert "lm_path" in model_args, "All models require the lm_path argument"
# set device and precision
self.device = (
model_args["device"]
if ("device" in model_args and (type(model_args["device"]) != int or model_args["device"] >= 0))
if (
"device" in model_args
and (type(model_args["device"]) != int or model_args["device"] >= 0)
)
else "cpu"
)
print("Using device:", self.device)
self.precision = model_args.get("precision", "fp32")
self.lm_name = model_args["lm_path"].split("/")[-1]
self.autocast = get_autocast(self.precision)
self.cast_dtype = get_cast_dtype(self.precision)

# initialization context
if init_on_device:
# for deepspeed, must init on device, or likely CPU OOM
# for deepspeed, must init on device, or likely CPU OOM
import deepspeed
self.init_ctx = deepspeed.OnDevice(dtype=self.cast_dtype, device=self.device)

self.init_ctx = deepspeed.OnDevice(
dtype=self.cast_dtype, device=self.device
)
else:
self.init_ctx = suppress()

@property
def required_args(self):
"""Return list of required arguments to initialize model."""
return ["lm_path"]

def _check_init(self):
"""Finish model initialization."""
assert hasattr(self, "model"), "Model has not been initialized"
Expand All @@ -49,6 +88,7 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
if use_deepspeed:
assert "amp" not in self.precision, "Deepspeed does not support amp"
import deepspeed

self.ds_engine = deepspeed.init_inference(
self.model,
mp_size=world_size,
Expand All @@ -61,12 +101,6 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
else:
self.model = DDP(self.model, device_ids=[self.device])

def set_device(self, device):
"""Set device for model."""
torch.cuda.set_device(device)
self.device = torch.device("cuda", device)
self.model = self.model.to(device, dtype=self.cast_dtype)

def __call__(
self,
lang_x: torch.Tensor,
Expand All @@ -76,12 +110,13 @@ def __call__(
use_cache: bool = False,
):
"""
Calls the forward function of the model.
Special logic to handle the case if past_key_values is not None:
Calls the forward function of the model, and returns an object that includes logits.
Note: implementations should handle the case if past_key_values is not None:
then lang_x is assumed to contain the tokens to be generated
*excluding* the tokens already in past_key_values.
We then repeatedly call forward, updating the past_key_values.
"""
raise NotImplementedError

def prepare_text(
self,
Expand All @@ -92,7 +127,7 @@ def prepare_text(
add_special_tokens=True,
):
"""
Prepare text for model.
Prepare text for model. Note that padding is always on the left.
Args:
batch: list of text strings
Expand All @@ -101,36 +136,38 @@ def prepare_text(
max_length: maximum length of the text
Returns:
input_ids: tensor of shape (B, T)
attention_mask: tensor of shape (B, T)
input_ids: tensor of shape (B, T_txt)
attention_mask: tensor of shape (B, T_txt)
"""
raise NotImplementedError

def prepare_images(self, batch: List[List[Image.Image]]):
"""
Prepare images for model.
Args:
batch: list of lists of PIL images
Returns:
tensor of shape (B, T, *, C, H, W)
tensor of shape (B, T_img, F, C, H, W)
"""
raise NotImplementedError

def get_outputs(
self,
batch_text: List[str],
batch_images: List[List[Image.Image]],
**decode_kwargs,
) -> List[str]:
"""Get outputs for a batch of images and text.
"""Call generate on a batch of images and text.
Args:
batch_text: list of text strings, with the text "<image>" in place
of any images to be included.
batch_text: list of text strings
batch_images: images to provide to model. Should be a list of lists,
where each list contains the images for a single example.
Returns:
List of decoded output strings.
"""
raise NotImplementedError

def get_rank_classifications(
self,
Expand All @@ -150,7 +187,29 @@ def get_rank_classifications(
all_class_names: list of all class names.
use_cache: whether to cache the context to speed up evaluations.
normalize_length: whether to normalize logprobs by the length of the
class name
class name; use with caution, as this can change predictions quite a bit.
Returns:
(B, |all_class_names|) tensor containing the logprobs for each class name.
"""
raise NotImplementedError

@property
def supported_tasks(self):
"""
Return list of tasks that this model can be evaluated on.
Parsed by checking whether the model has a method called `get_{task}_prompt`.
"""
return [
task.split("_")[1]
for task in dir(self)
if task.startswith("get_") and task.endswith("_prompt")
]

def _validate_text(self, batch_text):
"""
Checks for trailing whitespaces in the text and prints a warning.
"""
if any([x.endswith(" ") for x in batch_text]):
print(
"Warning: trailing whitespace detected in text. This can cause unexpected behavior."
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ class EvalModel(BaseEvalModel):
"""IDEFICS model evaluation."""

def __init__(self, model_args, init_on_device=False):
assert (
"lm_path" in model_args and "processor_path" in model_args
), "IDEFICS requires lm_path and lm_tokenizer_path"
super().__init__(model_args, init_on_device)
with self.init_ctx:
self.model = IdeficsForVisionText2Text.from_pretrained(model_args["lm_path"])
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
self.tokenizer = self.processor.tokenizer
self._check_init()

@property
def required_args(self):
return ["lm_path", "processor_path"]

def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
batch_images = self.processor(batch)["pixel_values"]
if batch_images is not None:
Expand All @@ -44,6 +45,7 @@ def prepare_text(
max_length=2000,
add_special_tokens=True,
):
self._validate_text(batch)
# check to see if there any <image> without <fake_token_around_image> wrapping it
for i, text in enumerate(batch):
if "<image>" in text and "<fake_token_around_image>" not in text:
Expand Down Expand Up @@ -88,19 +90,6 @@ def _compute_image_attention_mask(self, batch_tokens: torch.Tensor) -> torch.Ten
)
return image_attention_mask

def get_rank_classifications(
self,
batch_text: List[str],
batch_images: List[List[Image.Image]],
all_class_names: List[str],
use_cache: bool,
normalize_length: bool,
):
"""
Returns a (B, |all_class_names|) tensor containing the logprobs for each class name.
"""
raise NotImplementedError

def get_outputs(
self,
batch_text: List[str],
Expand Down Expand Up @@ -176,18 +165,26 @@ def __call__(
past_key_values=past_key_values,
)

def get_vqa_prompt(self, question, answer=None) -> str:
def get_vqav2_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_caption_prompt(self, caption=None) -> str:
def get_ok_vqa_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_vizwiz_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_imagenet_prompt(self, label=None) -> str:
def get_textvqa_prompt(self, question, answer=None) -> str:
# TODO: handle prefix prompts
return f"<image>Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"

def get_hateful_memes_prompt(self, text, label=None) -> str:
def get_coco_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

def get_flickr_prompt(self, caption=None) -> str:
# TODO: handle prefix prompts
return f"<image>is an image with: '{text}' written on it. Is it hateful? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
Loading

0 comments on commit b0ff9a4

Please sign in to comment.