diff --git a/lm_eval/evaluator.py b/lm_eval/evaluator.py index e8ddec67e3..96e9562c54 100644 --- a/lm_eval/evaluator.py +++ b/lm_eval/evaluator.py @@ -294,7 +294,9 @@ def _adjust_config(task_dict): model_source=model, model_args=model_args, system_instruction=system_instruction, - chat_template=lm.chat_template(apply_chat_template), + # TODO: change this back + # chat_template=lm.chat_template(apply_chat_template), + chat_template=None, fewshot_as_multiturn=fewshot_as_multiturn, ) @@ -425,7 +427,7 @@ def evaluate( if len(incompatible_tasks) > 0: if not getattr(lm, "MULTIMODAL", False): raise ValueError( - f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the 'hf-multimodal' model type." + f"Attempted to run tasks: {incompatible_tasks} which require multimodal input, but the selected model type does not currently implement this. Multimodal support is currently restricted to the ['hf-multimodal', 'vllm-vlm'] model type." ) else: raise ValueError( diff --git a/lm_eval/models/__init__.py b/lm_eval/models/__init__.py index 657932cb6b..553be3a5fd 100644 --- a/lm_eval/models/__init__.py +++ b/lm_eval/models/__init__.py @@ -13,6 +13,7 @@ optimum_lm, textsynth, vllm_causallms, + vllm_vlms, ) diff --git a/lm_eval/models/hf_vlms.py b/lm_eval/models/hf_vlms.py index 6316a7e843..c086c03364 100644 --- a/lm_eval/models/hf_vlms.py +++ b/lm_eval/models/hf_vlms.py @@ -5,12 +5,18 @@ import torch.nn.functional as F import transformers from tqdm import tqdm +from transformers import BatchEncoding from lm_eval import utils from lm_eval.api.instance import Instance from lm_eval.api.registry import register_model from lm_eval.models.huggingface import HFLM -from lm_eval.models.utils import Collator, pad_and_concat, stop_sequences_criteria +from lm_eval.models.utils import ( + Collator, + pad_and_concat, + replace_placeholders, + stop_sequences_criteria, +) DEFAULT_IMAGE_PLACEHOLDER = "" @@ -32,6 +38,11 @@ def __init__( self, pretrained: Union[str, transformers.PreTrainedModel], image_token_id: Optional[int] = None, + image_string="", + interleave: bool = True, + # TODO: hamdle whitespace in image placeholder (replacement) + max_images: Optional[int] = 999, + convert_img_format=False, **kwargs, ): # We initialize using HFLM's init. Sub-methods like _create_model and _create_tokenizer @@ -47,26 +58,32 @@ def __init__( # HF AutoModelForVision2Seq models have an `image_token_id` value in their configs # denoting the token which indicates a location where an image will be substituted in. # This can take different string values across models, e.g. for Idefics2 and <|image_pad|> for Qwen2-VL + self.interleave = interleave + self.max_images = max_images + self.rgb = convert_img_format # WARNING: improperly set image_token_id can lead to ignored image input or other (potentially silent) errors! - self.image_token_id = ( - int(image_token_id) - if image_token_id - else ( - getattr(self.config, "image_token_id", None) - or getattr(self.config, "image_token_index", None) + if not image_string: + self.image_token_id = ( + int(image_token_id) + if image_token_id + else ( + getattr(self.config, "image_token_id", None) + or getattr(self.config, "image_token_index", None) + ) ) - ) - assert ( - self.image_token_id is not None - ), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." - # get the string this token ID corresponds to - self.image_token = self.tok_decode( - [self.image_token_id], skip_special_tokens=False - ) - if image_token_id is not None: - eval_logger.info( - f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" + assert ( + self.image_token_id is not None + ), "Must have a non-None image_token_id to evaluate a Hugging Face AutoModelForVision2Seq model. Please pass `image_token_id` in `--model_args` if model's config does not already specify one." + # get the string this token ID corresponds to + self.image_token = self.tok_decode( + [self.image_token_id], skip_special_tokens=False ) + if image_token_id is not None: + eval_logger.info( + f"A non-default image_token_id with image_token_id={self.image_token_id} and string value '{self.image_token}' was specified manually. Note that using an improper image_token placeholder may lead to ignored image input or errors!" + ) + else: + self.image_token = image_string def _create_tokenizer( self, @@ -180,22 +197,52 @@ def _encode_multimodal_pair(self, context, continuation, images): def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: self.chat_applied = True - for content in chat_history: - c = [] - text = content["content"] + if not self.interleave: + for content in chat_history: + c = [] + text = content["content"] + + # Count and remove image placeholders + image_count = min( + self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) + ) + text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") + + # Add image entries + for _ in range(image_count): + c.append({"type": "image", "image": None}) - # Count and remove image placeholders - image_count = text.count(DEFAULT_IMAGE_PLACEHOLDER) - text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") + # Add single text entry at the end + c.append({"type": "text", "text": text}) - # Add image entries - for _ in range(image_count): - c.append({"type": "image", "image": None}) + content["content"] = c + else: + for content in chat_history: + c = [] + text = content["content"] + expected_image_count = min( + self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) + ) + actual_image_count = 0 + + text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) + + for i, part in enumerate(text_parts): + # TODO: concatenate text parts (esp. if skipping images)? + if part: # Add non-empty text parts + c.append({"type": "text", "text": part}) + if ( + (i < len(text_parts) - 1) and i < self.max_images + ): # Add image placeholder after each split except the last + c.append({"type": "image"}) + actual_image_count += 1 - # Add single text entry at the end - c.append({"type": "text", "text": text}) + content["content"] = c - content["content"] = c + if actual_image_count != expected_image_count: + raise ValueError( + f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" + ) return self.processor.apply_chat_template( chat_history, add_generation_prompt=True @@ -216,17 +263,20 @@ def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str def tok_batch_multimodal_encode( self, strings: List[str], # note that input signature of this fn is different - images, # TODO: typehint on this + images: List[List], # TODO: images are pil.I padding_side: str = "left", left_truncate_len: int = None, truncation: bool = False, - ) -> Dict[ - str, torch.Tensor + ) -> Union[ + BatchEncoding, Dict[str, torch.Tensor] ]: # note that this return signature differs from HFLM tok_batch_encode. # NOTE: here, we replace tags with our model's corresponding image_token string value. + # Moves the encodings to device if not self.chat_applied: strings = [ - string.replace(DEFAULT_IMAGE_PLACEHOLDER, self.image_token) + replace_placeholders( + string, DEFAULT_IMAGE_PLACEHOLDER, self.image_token, self.max_images + ) for string in strings ] @@ -236,6 +286,10 @@ def tok_batch_multimodal_encode( # add_special_tokens = {"add_special_tokens": False or self.add_bos_token} + images = [img[: self.max_images] for img in images] + if self.rgb: + images = [[img.convert("RGB") for img in sublist] for sublist in images] + encoding = self.processor( images=images, text=strings, @@ -247,7 +301,7 @@ def tok_batch_multimodal_encode( encoding.to( self.device, self.model.dtype - ) # TODO: casting to dtype seems odd for input_ids and attn_mask. + ) # TODO: This only casts the pixel values. Should they always be float16? if left_truncate_len: encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:] encoding["attention_mask"] = encoding["attention_mask"][ @@ -621,7 +675,7 @@ def _collate(x): visuals, left_truncate_len=max_ctx_len, truncation=self.truncation, - ).to(self.device, self.model.dtype) + ) context_enc = inputs["input_ids"] diff --git a/lm_eval/models/utils.py b/lm_eval/models/utils.py index 8a81e5deca..7afbf085ee 100644 --- a/lm_eval/models/utils.py +++ b/lm_eval/models/utils.py @@ -664,3 +664,21 @@ def configure_pad_token( tokenizer.add_special_tokens({"pad_token": "<|pad|>"}) return tokenizer + + +def replace_placeholders(string, placeholder, replacement, max_count): + count = 0 + result = "" + start = 0 + while True: + index = string.find(placeholder, start) + if index == -1: + result += string[start:] + break + if count < max_count: + result += string[start:index] + replacement + count += 1 + else: + result += string[start:index] + start = index + len(placeholder) + return result diff --git a/lm_eval/models/vllm_vlms.py b/lm_eval/models/vllm_vlms.py new file mode 100644 index 0000000000..27202e9305 --- /dev/null +++ b/lm_eval/models/vllm_vlms.py @@ -0,0 +1,284 @@ +import copy +from typing import Dict, List, Optional + +import transformers +from more_itertools import distribute +from tqdm import tqdm + +from lm_eval.api.instance import Instance +from lm_eval.api.registry import register_model +from lm_eval.models.utils import Collator, undistribute +from lm_eval.models.vllm_causallms import VLLM +from lm_eval.utils import simple_parse_args_string + + +try: + import ray + from vllm import LLM, SamplingParams + from vllm.lora.request import LoRARequest # noqa: F401 + from vllm.transformers_utils.tokenizer import get_tokenizer # noqa: F401 +except ModuleNotFoundError: + pass + + +DEFAULT_IMAGE_PLACEHOLDER = "" + + +@register_model("vllm-vlm") +class VLLM_VLM(VLLM): + MULTIMODAL = True + + def __init__( + self, + pretrained: str, + trust_remote_code: Optional[bool] = False, + revision: Optional[str] = None, + interleave: bool = True, + # TODO: handle max_images and limit_mm_per_prompt better + max_images: int = 999, + limit_mm_per_prompt: str = "image=1", + **kwargs, + ): + kwargs["limit_mm_per_prompt"] = simple_parse_args_string(limit_mm_per_prompt) + super().__init__( + pretrained=pretrained, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + self.interleave = interleave + self.max_images = max_images + self.processor = transformers.AutoProcessor.from_pretrained( + pretrained, + revision=revision, + trust_remote_code=trust_remote_code, + ) + self.chat_applied: bool = False + + def tok_batch_multimodal_encode( + self, + strings: List[str], # note that input signature of this fn is different + images, # TODO: typehint on this + left_truncate_len: int = None, + truncation: bool = False, + ): + images = [img[: self.max_images] for img in images] + + outputs = [] + for x, i in zip(strings, images): + inputs = { + "prompt": x, + "multi_modal_data": {"image": i}, + } + outputs.append(inputs) + return outputs + + def _model_generate( + self, + requests: List[List[dict]] = None, + generate: bool = False, + max_tokens: int = None, + stop: Optional[List[str]] = None, + **kwargs, + ): + if generate: + kwargs = self.modify_gen_kwargs(kwargs) + sampling_params = SamplingParams(max_tokens=max_tokens, stop=stop, **kwargs) + else: + sampling_params = SamplingParams( + temperature=0, prompt_logprobs=1, max_tokens=1, detokenize=False + ) + if self.data_parallel_size > 1: + # vLLM hangs if tensor_parallel > 1 and resources are set in ray.remote + # also seems to only work with decorator and not with ray.remote() fn + # see https://github.com/vllm-project/vllm/issues/973 + # note: this has changed on 0.3.3, and it only works now if num_gpus are set. + # but then tensor_parallel breaks + @ray.remote + def run_inference_one_model( + model_args: dict, sampling_params, requests: List[List[dict]] + ): + llm = LLM(**model_args) + return llm.generate(requests, sampling_params=sampling_params) + + # dispatch requests to all self.data_parallel_size workers, in interleaved fashion + # interleaved important to balance context lengths across workers + requests = [list(x) for x in distribute(self.data_parallel_size, requests)] + inputs = ((self.model_args, sampling_params, req) for req in requests) + object_refs = [run_inference_one_model.remote(*x) for x in inputs] + results = ray.get(object_refs) + # Invoke ray.shutdown() to prevent hang-ups if subsequent calls required. + ray.shutdown() + # flatten results + return undistribute(results) + + if self.lora_request is not None: + outputs = self.model.generate( + requests, + sampling_params=sampling_params, + use_tqdm=True if self.batch_size == "auto" else False, + lora_request=self.lora_request, + ) + else: + outputs = self.model.generate( + requests, + sampling_params=sampling_params, + use_tqdm=True if self.batch_size == "auto" else False, + ) + return outputs + + def apply_chat_template(self, chat_history: List[Dict[str, str]]) -> str: + self.chat_applied = True + if not self.interleave: + for content in chat_history: + c = [] + text = content["content"] + + # Count and remove image placeholders + image_count = min( + self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) + ) + text = text.replace(DEFAULT_IMAGE_PLACEHOLDER, "") + + # Add image entries + for _ in range(image_count): + c.append({"type": "image", "image": None}) + + # Add single text entry at the end + c.append({"type": "text", "text": text}) + + content["content"] = c + else: + for content in chat_history: + c = [] + text = content["content"] + expected_image_count = min( + self.max_images, text.count(DEFAULT_IMAGE_PLACEHOLDER) + ) + actual_image_count = 0 + + text_parts = text.split(DEFAULT_IMAGE_PLACEHOLDER) + + for i, part in enumerate(text_parts): + # TODO: concatenate text parts (esp. if skipping images)? + if part: # Add non-empty text parts + c.append({"type": "text", "text": part}) + if ( + (i < len(text_parts) - 1) and i < self.max_images + ): # Add image placeholder after each split except the last + c.append({"type": "image"}) + actual_image_count += 1 + + content["content"] = c + + if actual_image_count != expected_image_count: + raise ValueError( + f"Mismatch in image placeholder count. Expected: {expected_image_count}, Actual: {actual_image_count}" + ) + + return self.processor.apply_chat_template( + chat_history, add_generation_prompt=True + ) + + def generate_until( + self, requests: List[Instance], disable_tqdm: bool = False + ) -> List[str]: + # TODO: back out to HFLM.generate_until() for all requests without aux_arguments (text-only reqs) + res = [] + + def _collate(x): + # the negative sign on len(toks) sorts descending - this has a few advantages: + # - time estimates will always be over not underestimates, which is more useful for planning + # - to know the size of a batch when going through the list, you know the first one is always the batch + # padded context length. this is useful to simplify the batching logic and more importantly to make + # automatic adaptive batches much much easier to implement + # - any OOMs will happen right away rather than near the end + toks = self.tok_encode(x[0]) + return -len(toks), x[0] + + pbar = tqdm( + total=len(requests), + disable=(disable_tqdm or (self.rank != 0)), + desc="Running generate_until requests with text+image input", + ) + # TODO: port auto-batch sizing into this. + + # we group requests by their generation_kwargs, + # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling + # in the same batch. + re_ords = Collator( + [reg.args for reg in requests], + _collate, + group_by="gen_kwargs", + group_fn=lambda x: x[1], + ) + chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None) + + ### Up to here: was identical to non-multimodal HFLM generate_until ### + + for chunk in chunks: + contexts, all_gen_kwargs, aux_arguments = zip(*chunk) + + visuals = [arg["visual"] for arg in aux_arguments] + + if not isinstance(contexts, list): + contexts = list( + contexts + ) # for Qwen2-VL, processor is unhappy accepting a tuple of strings instead of a list. + # TODO: could we upstream this workaround to HF? + ### this part onward: same as HFLM ### + + # we assume all gen kwargs in the batch are the same + # this is safe to assume because the `grouper` object ensures it. + gen_kwargs = all_gen_kwargs[0] + # unpack our keyword arguments. + until = None + if isinstance(gen_kwargs, dict): + kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 + if "until" in kwargs.keys(): + until = kwargs.pop("until") + if isinstance(until, str): + until = [until] + elif not isinstance(until, list): + raise ValueError( + f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}" + ) + else: + raise ValueError( + f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}" + ) + # add EOS token to stop sequences + eos = self.tokenizer.decode(self.eot_token_id) + if not until: + until = [eos] + else: + until.append(eos) + if "max_gen_toks" in kwargs.keys(): + max_gen_toks = kwargs.pop("max_gen_toks") + else: + max_gen_toks = self.max_gen_toks + + ### end stuff that's entirely copied verbatim from HFLM ### + + max_ctx_len = self.max_length - max_gen_toks + + inputs = self.tok_batch_multimodal_encode( + contexts, + visuals, + left_truncate_len=max_ctx_len, + ) + + cont = self._model_generate(inputs, stop=until, generate=True, **kwargs) + + for output, context in zip(cont, contexts): + generated_text = output.outputs[0].text + res.append(generated_text) + self.cache_hook.add_partial( + "generate_until", (context, gen_kwargs), generated_text + ) + pbar.update(1) + # reorder this group of results back to original unsorted form + res = re_ords.get_original(res) + + pbar.close() + return res diff --git a/lm_eval/tasks/mathvista/testmini.yaml b/lm_eval/tasks/mathvista/testmini.yaml index abaf81a665..52b5384041 100644 --- a/lm_eval/tasks/mathvista/testmini.yaml +++ b/lm_eval/tasks/mathvista/testmini.yaml @@ -1,20 +1,17 @@ dataset_path: AI4Math/MathVista +task: mathvista_mcq test_split: testmini -output_type: generate_until +output_type: multiple_choice +process_docs: !function utils.process_docs doc_to_image: !function utils.doc_to_image -doc_to_text: " {{query}}" -doc_to_target: "answer" -# TODO: add process_results. multiple-choice question_type need to be handled differently -process_results: !function utils.process_results -generation_kwargs: - until: - - "<|endoftext|>" - temperature: 0.0 - do_sample: false - max_gen_toks: 512 +doc_to_text: "{{query}}" +doc_to_choice: '{{ ["A", "B", "C", "D", "E", "F"][:choices.length] }}' +doc_to_target: "{{choices.index(answer)}}" metric_list: - metric: acc aggregation: mean higher_is_better: true metadata: - version: 0.0 + version: 1.0 +dataset_kwargs: + trust_remote_code: true diff --git a/lm_eval/tasks/mathvista/utils.py b/lm_eval/tasks/mathvista/utils.py new file mode 100644 index 0000000000..19c64035ea --- /dev/null +++ b/lm_eval/tasks/mathvista/utils.py @@ -0,0 +1,5 @@ +import datasets + + +def process_docs(dataset: datasets.Dataset) -> datasets.Dataset: + return dataset.filter(lambda x: x["question_type"].strip() == "multi_choice")