diff --git a/.owners.yml b/.owners.yml index f3b96a3a9..ad20322bb 100644 --- a/.owners.yml +++ b/.owners.yml @@ -10,5 +10,5 @@ assign: - Leymore - gaotongxiao - yingfhu - - Ezra-Yu + - fangyixiao18 - tonysy diff --git a/configs/eval_claude2.py b/configs/eval_claude2.py new file mode 100644 index 000000000..ef33174b0 --- /dev/null +++ b/configs/eval_claude2.py @@ -0,0 +1,28 @@ +from mmengine.config import read_base +from opencompass.models.claude_api import Claude +from opencompass.partitioners import NaivePartitioner +from opencompass.runners import LocalRunner +from opencompass.tasks import OpenICLInferTask + +with read_base(): + # choose a list of datasets + from .datasets.collections.chat_medium import datasets + # and output the results in a choosen format + from .summarizers.medium import summarizer + +models = [ + dict(abbr='Claude2', + type=Claude, + path='claude-2', + key='YOUR_CLAUDE_KEY', + query_per_second=1, + max_out_len=2048, max_seq_len=2048, batch_size=2), +] + +infer = dict( + partitioner=dict(type=NaivePartitioner), + runner=dict( + type=LocalRunner, + max_num_workers=8, + task=dict(type=OpenICLInferTask)), +) diff --git a/configs/multimodal/minigpt_4/README.md b/configs/multimodal/minigpt_4/README.md index aa434d0ef..c7a06d342 100644 --- a/configs/multimodal/minigpt_4/README.md +++ b/configs/multimodal/minigpt_4/README.md @@ -22,5 +22,5 @@ python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION ```sh cd $root -python run.py configs/multimodal/tasks.py +python run.py configs/multimodal/tasks.py --mm-eval ``` \ No newline at end of file diff --git a/configs/multimodal/openflamingo/README.md b/configs/multimodal/openflamingo/README.md new file mode 100644 index 000000000..c8b62736d --- /dev/null +++ b/configs/multimodal/openflamingo/README.md @@ -0,0 +1,21 @@ +# OpenFlamingo + +### Prepare the environment + +Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation) + +### Start evaluation + +#### Slurm + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION +``` + +#### PyTorch + +```sh +cd $root +python run.py configs/multimodal/tasks.py --mm-eval +``` \ No newline at end of file diff --git a/configs/multimodal/openflamingo/openflamingo_mmbench.py b/configs/multimodal/openflamingo/openflamingo_mmbench.py new file mode 100644 index 000000000..8327fb09d --- /dev/null +++ b/configs/multimodal/openflamingo/openflamingo_mmbench.py @@ -0,0 +1,73 @@ +# dataloader settings +val_pipeline = [ + dict(type='mmpretrain.PILToNumpy'), + dict(type='mmpretrain.ResizeEdge', + scale=224, + interpolation='bicubic', + backend='pillow'), + dict(type='CenterCrop', crop_size=(224, 224)), + dict(type='mmpretrain.PackInputs', + algorithm_keys=[ + 'question', 'options', 'category', 'l2-category', 'index', + 'context', 'options_dict' + ]) +] + +dataset = dict(type='opencompass.MMBenchDataset', + data_file='data/mmbench/mmbench_test_20230712.tsv', + pipeline=val_pipeline) + +openflamingo_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dataset, + sampler=dict(type='DefaultSampler', shuffle=False), + collate_fn=dict(type='default_collate'), + persistent_workers=True, +) + +# model settings +openflamingo_model = dict( + type='openflamingo', + data_preprocessor=dict( + type='mmpretrain.MultiModalDataPreprocessor', + mean=[122.770938, 116.7460125, 104.09373615], + std=[68.5005327, 66.6321579, 70.32316305], + to_rgb=True, + ), + tokenizer=dict(type='mmpretrain.LlamaTokenizer', + name_or_path='decapoda-research/llama-7b-hf'), + vision_encoder=dict( + type='mmpretrain.VisionTransformer', + arch='l', + patch_size=14, + pre_norm=True, + norm_cfg=dict(type='LN', eps=1e-5), + layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), + final_norm=False, + out_type='raw', + pretrained= # noqa: E251 + '/path/to/vision/encoder', # noqa + ), + lang_encoder=dict( + base=dict(type='mmpretrain.AutoModelForCausalLM', + name_or_path= + 'decapoda-research/llama-7b-hf', + local_files_only=True), + adapter=dict(type='mmpretrain.FlamingoLMAdapter', + vis_hidden_size=1024, + cross_attn_every_n_layers=4, + use_media_placement_augmentation=False), + ), + generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), +) + +# evaluation settings +openflamingo_evaluator = [ + dict( + type='opencompass.DumpResults', + save_path= # noqa: E251 + 'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx') +] + +openflamingo_load_from = '/path/to/pretrained/weights' # noqa diff --git a/configs/multimodal/tasks.py b/configs/multimodal/tasks.py index e03a1ed26..ef6bd417b 100644 --- a/configs/multimodal/tasks.py +++ b/configs/multimodal/tasks.py @@ -10,6 +10,7 @@ datasets = [minigpt_4_mmbench_dataloader] evaluators = [minigpt_4_mmbench_evaluator] load_froms = [minigpt_4_mmbench_load_from] + num_gpus = 8 num_procs = 8 launcher = 'pytorch' \ No newline at end of file diff --git a/opencompass/datasets/game24.py b/opencompass/datasets/game24.py index 1730311f1..acfb14abf 100644 --- a/opencompass/datasets/game24.py +++ b/opencompass/datasets/game24.py @@ -4,7 +4,6 @@ from typing import List import pandas as pd -import sympy from datasets import Dataset from opencompass.openicl.icl_evaluator import BaseEvaluator @@ -234,6 +233,8 @@ def game24_postprocess(output: str): class Game24Evaluator(BaseEvaluator): def __init__(self) -> None: + import sympy + self.sympy = sympy super().__init__() def check_nums(self, prediction, reference): @@ -242,7 +243,7 @@ def check_nums(self, prediction, reference): if sorted(numbers) != sorted(problem_numbers): return 0 try: - return int(sympy.simplify(prediction) == 24) + return int(self.sympy.simplify(prediction) == 24) except Exception: return 0 diff --git a/opencompass/models/claude_api.py b/opencompass/models/claude_api.py new file mode 100644 index 000000000..08df7e249 --- /dev/null +++ b/opencompass/models/claude_api.py @@ -0,0 +1,118 @@ +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, List, Optional, Union + +from opencompass.registry import MODELS +from opencompass.utils import PromptList + +from .base_api import BaseAPIModel + +PromptType = Union[PromptList, str] + + +@MODELS.register_module() +class Claude(BaseAPIModel): + """Model wrapper around Claude API. + + Args: + key (str): Authorization key. + path (str): The model to be used. Defaults to claude-2. + query_per_second (int): The maximum queries allowed per second + between two consecutive calls of the API. Defaults to 1. + max_seq_len (int): Unused here. + meta_template (Dict, optional): The model's meta prompt + template if needed, in case the requirement of injecting or + wrapping of any meta instructions. + retry (int): Number of retires if the API call fails. Defaults to 2. + """ + + def __init__( + self, + key: str, + path: str = 'claude-2', + query_per_second: int = 2, + max_seq_len: int = 2048, + meta_template: Optional[Dict] = None, + retry: int = 2, + ): + super().__init__(path=path, + max_seq_len=max_seq_len, + query_per_second=query_per_second, + meta_template=meta_template, + retry=retry) + try: + from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic + except ImportError: + raise ImportError('Import anthropic failed. Please install it ' + 'with "pip install anthropic" and try again.') + + self.anthropic = Anthropic(api_key=key) + self.model = path + self.human_prompt = HUMAN_PROMPT + self.ai_prompt = AI_PROMPT + + def generate( + self, + inputs: List[str or PromptList], + max_out_len: int = 512, + ) -> List[str]: + """Generate results given a list of inputs. + + Args: + inputs (List[str or PromptList]): A list of strings or PromptDicts. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + List[str]: A list of generated strings. + """ + with ThreadPoolExecutor() as executor: + results = list( + executor.map(self._generate, inputs, + [max_out_len] * len(inputs))) + return results + + def _generate( + self, + input: str or PromptList, + max_out_len: int = 512, + ) -> str: + """Generate results given an input. + + Args: + inputs (str or PromptList): A string or PromptDict. + The PromptDict should be organized in OpenCompass' + API format. + max_out_len (int): The maximum length of the output. + + Returns: + str: The generated string. + """ + assert isinstance(input, (str, PromptList)) + + if isinstance(input, str): + messages = f'{self.human_prompt} {input}{self.ai_prompt}' + else: + messages = '' + for item in input: + if item['role'] == 'HUMAN' or item['role'] == 'SYSTEM': + messages += f'{self.human_prompt} {item["prompt"]}' + elif item['role'] == 'BOT': + messages += f'{self.ai_prompt} {item["prompt"]}' + if not messages.endswith(self.ai_prompt): + messages += self.ai_prompt + + num_retries = 0 + while num_retries < self.retry: + self.wait() + try: + completion = self.anthropic.completions.create( + model=self.model, + max_tokens_to_sample=max_out_len, + prompt=messages) + return completion.completion + except Exception as e: + self.logger.error(e) + num_retries += 1 + raise RuntimeError('Calling Claude API failed after retrying for ' + f'{self.retry} times. Check the logs for details.') diff --git a/opencompass/models/huggingface.py b/opencompass/models/huggingface.py index fa3d28a1f..a293ff8b7 100644 --- a/opencompass/models/huggingface.py +++ b/opencompass/models/huggingface.py @@ -203,7 +203,9 @@ def _single_generate(self, inputs: List[str], max_out_len: int, max_length=self.max_seq_len - max_out_len)['input_ids'] input_ids = torch.tensor(input_ids, device=self.model.device) - outputs = self.model.generate(input_ids, + # To accommodate the PeftModel, parameters should be passed in + # key-value format for generate. + outputs = self.model.generate(input_ids=input_ids, max_new_tokens=max_out_len, **kwargs) diff --git a/opencompass/multimodal/models/__init__.py b/opencompass/multimodal/models/__init__.py index 724657069..b61e20f0c 100644 --- a/opencompass/multimodal/models/__init__.py +++ b/opencompass/multimodal/models/__init__.py @@ -1,8 +1,13 @@ +import os.path as osp + from opencompass.utils import satisfy_requirement if satisfy_requirement('salesforce-lavis'): from .instructblip import * # noqa: F401, F403 +if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'): + from .minigpt_4 import * # noqa: F401, F403 + from .llava import * # noqa: F401, F403 -from .minigpt_4 import * # noqa: F401, F403 +from .openflamingo import * # noqa: F401, F403 from .visualglm import * # noqa: F401, F403 diff --git a/opencompass/multimodal/models/openflamingo/__init__.py b/opencompass/multimodal/models/openflamingo/__init__.py new file mode 100644 index 000000000..a6707eaf0 --- /dev/null +++ b/opencompass/multimodal/models/openflamingo/__init__.py @@ -0,0 +1,3 @@ +from .openflamingo import OpenFlamingoInferencer + +__all__ = ['OpenFlamingoInferencer'] diff --git a/opencompass/multimodal/models/openflamingo/openflamingo.py b/opencompass/multimodal/models/openflamingo/openflamingo.py new file mode 100644 index 000000000..a46e7ff0f --- /dev/null +++ b/opencompass/multimodal/models/openflamingo/openflamingo.py @@ -0,0 +1,81 @@ +from typing import List, Optional, Union + +import mmengine +import torch +from mmpretrain.models.multimodal import Flamingo +from mmpretrain.structures import DataSample + +from opencompass.registry import MM_MODELS + + +@MM_MODELS.register_module('openflamingo') +class OpenFlamingoInferencer(Flamingo): + """Inference code of OpenFlamingo. + + Args: + prompt_constructor (optional, dict): The config of prompt constructor. + Defaults to None. + post_processor (optional, dict): The config of post processor. + Defaults to None. + mode (str): The mode of inference. Defaults to 'generation'. + """ + + def __init__(self, + prompt_constructor: Optional[dict] = None, + post_processor: Optional[dict] = None, + mode: str = 'generation', + **kwargs): + super().__init__(**kwargs) + if prompt_constructor is not None: + self.prompt_constructor = mmengine.registry.build_from_cfg( + prompt_constructor, MM_MODELS) + if post_processor is not None: + self.post_processor = mmengine.registry.build_from_cfg( + post_processor, MM_MODELS) + self.mode = mode + + def preprocess_text(self, data_samples: List[DataSample], + device: torch.device) -> List[DataSample]: + """Preprocess text in advance before fed into language model. + + Args: + data_samples (List[DataSample]): The annotation + data of every samples. Defaults to None. + device (torch.device): Device for text to put on. + + Returns: + List[DataSample]: Return list of data samples. + """ + prompts = [] + for sample in data_samples: + question = sample.get('question') + option = sample.get('options') + + prompt = '' + question + ' ' + option + ' ' + 'Answer:' + if data_samples[0].get('context') is not None: + prompt = sample.get('context') + ' ' + prompt + + prompts.append(prompt) + + self.tokenizer.padding_side = 'left' + input_text = self.tokenizer( + prompts, + padding='longest', + truncation=True, + return_tensors='pt', + max_length=2000, + ).to(device) + return input_text + + def forward(self, batch: dict) -> Union[DataSample, List[DataSample]]: + + if self.mode == 'generation': + return self.generate(batch) + else: + raise RuntimeError(f'Unsupported mode: {self.mode}') + + def generate(self, batch: dict) -> Union[DataSample, List[DataSample]]: + batch = self.data_preprocessor(batch, False) + images = batch['images'] + data_samples = batch['data_samples'] + return self.predict(images, data_samples)