Skip to content

Commit

Permalink
Merge branch 'multimodal-prototyping' of https://github.com/EleutherA…
Browse files Browse the repository at this point in the history
…I/lm-evaluation-harness into multimodal-prototyping
  • Loading branch information
haileyschoelkopf committed Sep 13, 2024
2 parents f9cf90e + 357cf64 commit f04be6c
Show file tree
Hide file tree
Showing 7 changed files with 411 additions and 50 deletions.
6 changes: 4 additions & 2 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,9 @@ def _adjust_config(task_dict):
model_source=model,
model_args=model_args,
system_instruction=system_instruction,
chat_template=lm.chat_template(apply_chat_template),
# TODO: change this back
# chat_template=lm.chat_template(apply_chat_template),
chat_template=None,
fewshot_as_multiturn=fewshot_as_multiturn,
)

Expand Down Expand Up @@ -425,7 +427,7 @@ def evaluate(
if len(incompatible_tasks) > 0:
if not getattr(lm, "MULTIMODAL", False):
raise ValueError(
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the 'hf-multimodal' model type."
f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type."
)
else:
raise ValueError(
Expand Down
1 change: 1 addition & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
optimum_lm,
textsynth,
vllm_causallms,
vllm_vlms,
)


Expand Down
126 changes: 90 additions & 36 deletions lm_eval/models/hf_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
import torch.nn.functional as F
import transformers
from tqdm import tqdm
from transformers import BatchEncoding

from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.registry import register_model
from lm_eval.models.huggingface import HFLM
from lm_eval.models.utils import Collator, pad_and_concat, stop_sequences_criteria
from lm_eval.models.utils import (
Collator,
pad_and_concat,
replace_placeholders,
stop_sequences_criteria,
)


DEFAULT_IMAGE_PLACEHOLDER = "<image>"
Expand All @@ -32,6 +38,11 @@ def __init__(
self,
pretrained: Union[str, transformers.PreTrainedModel],
image_token_id: Optional[int] = None,
image_string="<image>",
interleave: bool = True,
# TODO: hamdle whitespace in image placeholder (replacement)
max_images: Optional[int] = 999,
convert_img_format=False,
**kwargs,
):
# We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer
Expand All @@ -47,26 +58,32 @@ def __init__(
# HF AutoModelForVision2Seq models have an `image_token_id` value in their configs
# denoting the token which indicates a location where an image will be substituted in.
# This can take different string values across models, e.g. <image> for Idefics2 and <|image_pad|> for Qwen2-VL
self.interleave = interleave
self.max_images = max_images
self.rgb = convert_img_format
# WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors!
self.image_token_id = (
int(image_token_id)
if image_token_id
else (
getattr(self.config, "image_token_id", None)
or getattr(self.config, "image_token_index", None)
if not image_string:
self.image_token_id = (
int(image_token_id)
if image_token_id
else (
getattr(self.config, "image_token_id", None)
or getattr(self.config, "image_token_index", None)
)
)
)
assert (
self.image_token_id is not None
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
# get the string this token ID corresponds to
self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False
)
if image_token_id is not None:
eval_logger.info(
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
assert (
self.image_token_id is not None
), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one."
# get the string this token ID corresponds to
self.image_token = self.tok_decode(
[self.image_token_id], skip_special_tokens=False
)
if image_token_id is not None:
eval_logger.info(
f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!"
)
else:
self.image_token = image_string

def _create_tokenizer(
self,
Expand Down Expand Up @@ -180,22 +197,52 @@ def _encode_multimodal_pair(self, context, continuation, images):

def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str:
self.chat_applied = True
for content in chat_history:
c = []
text = content["content"]
if not self.interleave:
for content in chat_history:
c = []
text = content["content"]

# Count and remove image placeholders
image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")

# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})

# Count and remove image placeholders
image_count = text.count(DEFAULT_IMAGE_PLACEHOLDER)
text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "")
# Add single text entry at the end
c.append({"type": "text", "text": text})

# Add image entries
for _ in range(image_count):
c.append({"type": "image", "image": None})
content["content"] = c
else:
for content in chat_history:
c = []
text = content["content"]
expected_image_count = min(
self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER)
)
actual_image_count = 0

text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER)

for i, part in enumerate(text_parts):
# TODO: concatenate text parts (esp. if skipping images)?
if part: # Add non-empty text parts
c.append({"type": "text", "text": part})
if (
(i < len(text_parts) - 1) and i < self.max_images
): # Add image placeholder after each split except the last
c.append({"type": "image"})
actual_image_count += 1

# Add single text entry at the end
c.append({"type": "text", "text": text})
content["content"] = c

content["content"] = c
if actual_image_count != expected_image_count:
raise ValueError(
f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}"
)

return self.processor.apply_chat_template(
chat_history, add_generation_prompt=True
Expand All @@ -216,17 +263,20 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str
def tok_batch_multimodal_encode(
self,
strings: List[str], # note that input signature of this fn is different
images, # TODO: typehint on this
images: List[List], # TODO: images are pil.I
padding_side: str = "left",
left_truncate_len: int = None,
truncation: bool = False,
) -> Dict[
str, torch.Tensor
) -> Union[
BatchEncoding, Dict[str, torch.Tensor]
]: # note that this return signature differs from HFLM tok_batch_encode.
# NOTE: here, we replace <image> tags with our model's corresponding image_token string value.
# Moves the encodings to device
if not self.chat_applied:
strings = [
string.replace(DEFAULT_IMAGE_PLACEHOLDER, self.image_token)
replace_placeholders(
string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images
)
for string in strings
]

Expand All @@ -236,6 +286,10 @@ def tok_batch_multimodal_encode(

# add_special_tokens = {"add_special_tokens": False or self.add_bos_token}

images = [img[: self.max_images] for img in images]
if self.rgb:
images = [[img.convert("RGB") for img in sublist] for sublist in images]

encoding = self.processor(
images=images,
text=strings,
Expand All @@ -247,7 +301,7 @@ def tok_batch_multimodal_encode(

encoding.to(
self.device, self.model.dtype
) # TODO: casting to dtype seems odd for input_ids and attn_mask.
) # TODO: This only casts the pixel values. Should they always be float16?
if left_truncate_len:
encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
encoding["attention_mask"] = encoding["attention_mask"][
Expand Down Expand Up @@ -621,7 +675,7 @@ def _collate(x):
visuals,
left_truncate_len=max_ctx_len,
truncation=self.truncation,
).to(self.device, self.model.dtype)
)

context_enc = inputs["input_ids"]

Expand Down
18 changes: 18 additions & 0 deletions lm_eval/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,3 +664,21 @@ def configure_pad_token(
tokenizer.add_special_tokens({"pad_token": "<|pad|>"})

return tokenizer


def replace_placeholders(string, placeholder, replacement, max_count):
count = 0
result = ""
start = 0
while True:
index = string.find(placeholder, start)
if index == -1:
result += string[start:]
break
if count < max_count:
result += string[start:index] + replacement
count += 1
else:
result += string[start:index]
start = index + len(placeholder)
return result
Loading

0 comments on commit f04be6c

Please sign in to comment.