forked from open-compass/opencompass
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/InternLM/opencompass into a…
…dd_api_qa_tot
- Loading branch information
Showing
12 changed files
with
339 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,5 +10,5 @@ assign: | |
- Leymore | ||
- gaotongxiao | ||
- yingfhu | ||
- Ezra-Yu | ||
- fangyixiao18 | ||
- tonysy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
from mmengine.config import read_base | ||
from opencompass.models.claude_api import Claude | ||
from opencompass.partitioners import NaivePartitioner | ||
from opencompass.runners import LocalRunner | ||
from opencompass.tasks import OpenICLInferTask | ||
|
||
with read_base(): | ||
# choose a list of datasets | ||
from .datasets.collections.chat_medium import datasets | ||
# and output the results in a choosen format | ||
from .summarizers.medium import summarizer | ||
|
||
models = [ | ||
dict(abbr='Claude2', | ||
type=Claude, | ||
path='claude-2', | ||
key='YOUR_CLAUDE_KEY', | ||
query_per_second=1, | ||
max_out_len=2048, max_seq_len=2048, batch_size=2), | ||
] | ||
|
||
infer = dict( | ||
partitioner=dict(type=NaivePartitioner), | ||
runner=dict( | ||
type=LocalRunner, | ||
max_num_workers=8, | ||
task=dict(type=OpenICLInferTask)), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# OpenFlamingo | ||
|
||
### Prepare the environment | ||
|
||
Install [MMPretrain](https://github.com/open-mmlab/mmpretrain) according to this [doc](https://mmpretrain.readthedocs.io/en/latest/get_started.html#installation) | ||
|
||
### Start evaluation | ||
|
||
#### Slurm | ||
|
||
```sh | ||
cd $root | ||
python run.py configs/multimodal/tasks.py --mm-eval --slurm -p $PARTITION | ||
``` | ||
|
||
#### PyTorch | ||
|
||
```sh | ||
cd $root | ||
python run.py configs/multimodal/tasks.py --mm-eval | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# dataloader settings | ||
val_pipeline = [ | ||
dict(type='mmpretrain.PILToNumpy'), | ||
dict(type='mmpretrain.ResizeEdge', | ||
scale=224, | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict(type='CenterCrop', crop_size=(224, 224)), | ||
dict(type='mmpretrain.PackInputs', | ||
algorithm_keys=[ | ||
'question', 'options', 'category', 'l2-category', 'index', | ||
'context', 'options_dict' | ||
]) | ||
] | ||
|
||
dataset = dict(type='opencompass.MMBenchDataset', | ||
data_file='data/mmbench/mmbench_test_20230712.tsv', | ||
pipeline=val_pipeline) | ||
|
||
openflamingo_dataloader = dict( | ||
batch_size=1, | ||
num_workers=4, | ||
dataset=dataset, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
collate_fn=dict(type='default_collate'), | ||
persistent_workers=True, | ||
) | ||
|
||
# model settings | ||
openflamingo_model = dict( | ||
type='openflamingo', | ||
data_preprocessor=dict( | ||
type='mmpretrain.MultiModalDataPreprocessor', | ||
mean=[122.770938, 116.7460125, 104.09373615], | ||
std=[68.5005327, 66.6321579, 70.32316305], | ||
to_rgb=True, | ||
), | ||
tokenizer=dict(type='mmpretrain.LlamaTokenizer', | ||
name_or_path='decapoda-research/llama-7b-hf'), | ||
vision_encoder=dict( | ||
type='mmpretrain.VisionTransformer', | ||
arch='l', | ||
patch_size=14, | ||
pre_norm=True, | ||
norm_cfg=dict(type='LN', eps=1e-5), | ||
layer_cfgs=dict(act_cfg=dict(type='mmpretrain.QuickGELU')), | ||
final_norm=False, | ||
out_type='raw', | ||
pretrained= # noqa: E251 | ||
'/path/to/vision/encoder', # noqa | ||
), | ||
lang_encoder=dict( | ||
base=dict(type='mmpretrain.AutoModelForCausalLM', | ||
name_or_path= | ||
'decapoda-research/llama-7b-hf', | ||
local_files_only=True), | ||
adapter=dict(type='mmpretrain.FlamingoLMAdapter', | ||
vis_hidden_size=1024, | ||
cross_attn_every_n_layers=4, | ||
use_media_placement_augmentation=False), | ||
), | ||
generation_cfg=dict(num_beams=3, max_new_tokens=20, length_penalty=-2.0), | ||
) | ||
|
||
# evaluation settings | ||
openflamingo_evaluator = [ | ||
dict( | ||
type='opencompass.DumpResults', | ||
save_path= # noqa: E251 | ||
'work_dirs/9b-flamingo/9b-flamingo-mmbench.xlsx') | ||
] | ||
|
||
openflamingo_load_from = '/path/to/pretrained/weights' # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Dict, List, Optional, Union | ||
|
||
from opencompass.registry import MODELS | ||
from opencompass.utils import PromptList | ||
|
||
from .base_api import BaseAPIModel | ||
|
||
PromptType = Union[PromptList, str] | ||
|
||
|
||
@MODELS.register_module() | ||
class Claude(BaseAPIModel): | ||
"""Model wrapper around Claude API. | ||
Args: | ||
key (str): Authorization key. | ||
path (str): The model to be used. Defaults to claude-2. | ||
query_per_second (int): The maximum queries allowed per second | ||
between two consecutive calls of the API. Defaults to 1. | ||
max_seq_len (int): Unused here. | ||
meta_template (Dict, optional): The model's meta prompt | ||
template if needed, in case the requirement of injecting or | ||
wrapping of any meta instructions. | ||
retry (int): Number of retires if the API call fails. Defaults to 2. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
key: str, | ||
path: str = 'claude-2', | ||
query_per_second: int = 2, | ||
max_seq_len: int = 2048, | ||
meta_template: Optional[Dict] = None, | ||
retry: int = 2, | ||
): | ||
super().__init__(path=path, | ||
max_seq_len=max_seq_len, | ||
query_per_second=query_per_second, | ||
meta_template=meta_template, | ||
retry=retry) | ||
try: | ||
from anthropic import AI_PROMPT, HUMAN_PROMPT, Anthropic | ||
except ImportError: | ||
raise ImportError('Import anthropic failed. Please install it ' | ||
'with "pip install anthropic" and try again.') | ||
|
||
self.anthropic = Anthropic(api_key=key) | ||
self.model = path | ||
self.human_prompt = HUMAN_PROMPT | ||
self.ai_prompt = AI_PROMPT | ||
|
||
def generate( | ||
self, | ||
inputs: List[str or PromptList], | ||
max_out_len: int = 512, | ||
) -> List[str]: | ||
"""Generate results given a list of inputs. | ||
Args: | ||
inputs (List[str or PromptList]): A list of strings or PromptDicts. | ||
The PromptDict should be organized in OpenCompass' | ||
API format. | ||
max_out_len (int): The maximum length of the output. | ||
Returns: | ||
List[str]: A list of generated strings. | ||
""" | ||
with ThreadPoolExecutor() as executor: | ||
results = list( | ||
executor.map(self._generate, inputs, | ||
[max_out_len] * len(inputs))) | ||
return results | ||
|
||
def _generate( | ||
self, | ||
input: str or PromptList, | ||
max_out_len: int = 512, | ||
) -> str: | ||
"""Generate results given an input. | ||
Args: | ||
inputs (str or PromptList): A string or PromptDict. | ||
The PromptDict should be organized in OpenCompass' | ||
API format. | ||
max_out_len (int): The maximum length of the output. | ||
Returns: | ||
str: The generated string. | ||
""" | ||
assert isinstance(input, (str, PromptList)) | ||
|
||
if isinstance(input, str): | ||
messages = f'{self.human_prompt} {input}{self.ai_prompt}' | ||
else: | ||
messages = '' | ||
for item in input: | ||
if item['role'] == 'HUMAN' or item['role'] == 'SYSTEM': | ||
messages += f'{self.human_prompt} {item["prompt"]}' | ||
elif item['role'] == 'BOT': | ||
messages += f'{self.ai_prompt} {item["prompt"]}' | ||
if not messages.endswith(self.ai_prompt): | ||
messages += self.ai_prompt | ||
|
||
num_retries = 0 | ||
while num_retries < self.retry: | ||
self.wait() | ||
try: | ||
completion = self.anthropic.completions.create( | ||
model=self.model, | ||
max_tokens_to_sample=max_out_len, | ||
prompt=messages) | ||
return completion.completion | ||
except Exception as e: | ||
self.logger.error(e) | ||
num_retries += 1 | ||
raise RuntimeError('Calling Claude API failed after retrying for ' | ||
f'{self.retry} times. Check the logs for details.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,13 @@ | ||
import os.path as osp | ||
|
||
from opencompass.utils import satisfy_requirement | ||
|
||
if satisfy_requirement('salesforce-lavis'): | ||
from .instructblip import * # noqa: F401, F403 | ||
|
||
if osp.exists('opencompass/multimodal/models/minigpt_4/MiniGPT-4'): | ||
from .minigpt_4 import * # noqa: F401, F403 | ||
|
||
from .llava import * # noqa: F401, F403 | ||
from .minigpt_4 import * # noqa: F401, F403 | ||
from .openflamingo import * # noqa: F401, F403 | ||
from .visualglm import * # noqa: F401, F403 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .openflamingo import OpenFlamingoInferencer | ||
|
||
__all__ = ['OpenFlamingoInferencer'] |
Oops, something went wrong.