diff --git a/configs/datasets/subjective/alpaca_eval/alpacav1_judgeby_gpt4.py b/configs/datasets/subjective/alpaca_eval/alpacav1_judgeby_gpt4.py index acd4ab27e..b5f89cd82 100644 --- a/configs/datasets/subjective/alpaca_eval/alpacav1_judgeby_gpt4.py +++ b/configs/datasets/subjective/alpaca_eval/alpacav1_judgeby_gpt4.py @@ -90,7 +90,7 @@ dict( abbr=f"{_name}", type=SubjectiveCmpDataset, - path="./data/subjective/", + path="./data/subjective/alpaca_eval", name=_name, reader_cfg=subjective_reader_cfg, infer_cfg=subjective_infer_cfg, diff --git a/configs/datasets/subjective/alpaca_eval/alpacav2_judgeby_gpt4.py b/configs/datasets/subjective/alpaca_eval/alpacav2_judgeby_gpt4.py index 0e3255288..b45b7622a 100644 --- a/configs/datasets/subjective/alpaca_eval/alpacav2_judgeby_gpt4.py +++ b/configs/datasets/subjective/alpaca_eval/alpacav2_judgeby_gpt4.py @@ -92,7 +92,7 @@ dict( abbr=f"{_name}", type=SubjectiveCmpDataset, - path="./data/subjective/", + path="./data/subjective/alpaca_eval", name=_name, reader_cfg=subjective_reader_cfg, infer_cfg=subjective_infer_cfg, diff --git a/configs/eval_subjective_alpacaeval.py b/configs/eval_subjective_alpacaeval.py index 13fd5ebe5..0d27b4c88 100644 --- a/configs/eval_subjective_alpacaeval.py +++ b/configs/eval_subjective_alpacaeval.py @@ -1,7 +1,6 @@ from mmengine.config import read_base with read_base(): - from .datasets.subjective.alpaca_eval.alpacav1_judgeby_gpt4 import subjective_datasets as alpacav1 from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2 from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3 @@ -12,7 +11,7 @@ from opencompass.runners import LocalRunner from opencompass.runners import SlurmSequentialRunner from opencompass.tasks import OpenICLInferTask -from opencompass.tasks.subjective_eval import SubjectiveEvalTask +from opencompass.tasks.outer_eval.alpacaeval import AlpacaEvalTask from opencompass.summarizers import AlpacaSummarizer api_meta_template = dict( @@ -29,7 +28,7 @@ models = [ dict( type=HuggingFaceChatGLM3, - abbr='chatglm3-6b-hf', + abbr='chatglm3-6b', path='THUDM/chatglm3-6b', tokenizer_path='THUDM/chatglm3-6b', model_kwargs=dict( @@ -54,52 +53,25 @@ datasets = [*alpacav2] -gpt4 = dict( - abbr='gpt4-turbo', - type=OpenAI, - path='gpt-4-1106-preview', - key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well - meta_template=api_meta_template, - query_per_second=1, - max_out_len=2048, - max_seq_len=4096, - batch_size=4, - retry=20, - temperature=1, -) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions - - - # -------------Evalation Stage ---------------------------------------- ## ------------- JudgeLLM Configuration -judge_model = dict( +gpt4_judge = dict( abbr='GPT4-Turbo', - type=OpenAI, path='gpt-4-1106-preview', key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well - meta_template=api_meta_template, - query_per_second=1, - max_out_len=1024, - max_seq_len=4096, - batch_size=2, - retry=20, - temperature=0, + config='weighted_alpaca_eval_gpt4_turbo' ) - ## ------------- Evaluation Configuration eval = dict( partitioner=dict( - type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models + type=NaivePartitioner ), runner=dict( - type=SlurmSequentialRunner, - partition='llmeval', - quotatype='auto', + type=LocalRunner, max_num_workers=256, - task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model), - ), + task=dict(type=AlpacaEvalTask, judge_cfg=gpt4_judge), + ) ) work_dir = 'outputs/alpaca/' -summarizer = dict(type=AlpacaSummarizer, judge_type='v2') diff --git a/configs/eval_subjective_alpacaeval_oc.py b/configs/eval_subjective_alpacaeval_oc.py new file mode 100644 index 000000000..13d1971b8 --- /dev/null +++ b/configs/eval_subjective_alpacaeval_oc.py @@ -0,0 +1,105 @@ +from mmengine.config import read_base + +with read_base(): + from .datasets.subjective.alpaca_eval.alpacav1_judgeby_gpt4 import subjective_datasets as alpacav1 + from .datasets.subjective.alpaca_eval.alpacav2_judgeby_gpt4 import subjective_datasets as alpacav2 + +from opencompass.models import HuggingFaceCausalLM, HuggingFace, HuggingFaceChatGLM3 +from opencompass.models.openai_api import OpenAI, OpenAIAllesAPIN +from opencompass.partitioners import NaivePartitioner, SizePartitioner +from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner +from opencompass.partitioners.sub_size import SubjectiveSizePartitioner +from opencompass.runners import LocalRunner +from opencompass.runners import SlurmSequentialRunner +from opencompass.tasks import OpenICLInferTask +from opencompass.tasks.subjective_eval import SubjectiveEvalTask +from opencompass.summarizers import AlpacaSummarizer + +api_meta_template = dict( + round=[ + dict(role='HUMAN', api_role='HUMAN'), + dict(role='BOT', api_role='BOT', generate=True), + ], + reserved_roles=[dict(role='SYSTEM', api_role='SYSTEM')], +) + +# -------------Inference Stage ---------------------------------------- + +# For subjective evaluation, we often set do sample for models +models = [ + dict( + type=HuggingFaceChatGLM3, + abbr='chatglm3-6b-hf', + path='THUDM/chatglm3-6b', + tokenizer_path='THUDM/chatglm3-6b', + model_kwargs=dict( + device_map='auto', + trust_remote_code=True, + ), + tokenizer_kwargs=dict( + padding_side='left', + truncation_side='left', + trust_remote_code=True, + ), + generation_kwargs=dict( + do_sample=True, + ), + meta_template=api_meta_template, + max_out_len=2048, + max_seq_len=4096, + batch_size=1, + run_cfg=dict(num_gpus=1, num_procs=1), + ) +] + +datasets = [*alpacav2] + +gpt4 = dict( + abbr='gpt4-turbo', + type=OpenAI, + path='gpt-4-1106-preview', + key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=1, + max_out_len=2048, + max_seq_len=4096, + batch_size=4, + retry=20, + temperature=1, +) # Re-inference gpt4's predictions or you can choose to use the pre-commited gpt4's predictions + + + +# -------------Evalation Stage ---------------------------------------- + +## ------------- JudgeLLM Configuration +judge_model = dict( + abbr='GPT4-Turbo', + type=OpenAI, + path='gpt-4-1106-preview', + key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + meta_template=api_meta_template, + query_per_second=1, + max_out_len=1024, + max_seq_len=4096, + batch_size=2, + retry=20, + temperature=0, +) + +## ------------- Evaluation Configuration +eval = dict( + partitioner=dict( + type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models + ), + runner=dict( + type=SlurmSequentialRunner, + partition='llmeval', + quotatype='auto', + max_num_workers=256, + task=dict(type=SubjectiveEvalTask, judge_cfg=judge_model), + ), +) +work_dir = 'outputs/alpaca/' + +summarizer = dict(type=AlpacaSummarizer, judge_type='v2') \ No newline at end of file diff --git a/configs/subjective/eval_subjective_alpacaeval.py b/configs/subjective/eval_subjective_alpacaeval.py index 6d9a1b884..31094c18b 100644 --- a/configs/subjective/eval_subjective_alpacaeval.py +++ b/configs/subjective/eval_subjective_alpacaeval.py @@ -7,14 +7,24 @@ from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner from opencompass.partitioners.sub_size import SubjectiveSizePartitioner from opencompass.summarizers import AlpacaSummarizer +from opencompass.tasks.outer_eval.alpacaeval import AlpacaEvalTask datasets = [*alpacav2] +gpt4_judge = dict( + abbr='GPT4-Turbo', + path='gpt-4-1106-preview', + key='', # The key will be obtained from $OPENAI_API_KEY, but you can write down your key here as well + config='weighted_alpaca_eval_gpt4_turbo' +) +## ------------- Evaluation Configuration eval = dict( partitioner=dict( - type=SubjectiveSizePartitioner, max_task_size=1000, mode='m2n', base_models=[gpt4], compare_models=models + type=NaivePartitioner ), -runner=runner, -given_pred=given_pred + runner=dict( + type=LocalRunner, + max_num_workers=256, + task=dict(type=AlpacaEvalTask, judge_cfg=gpt4_judge), + ) ) work_dir = 'outputs/alpaca/' -summarizer = dict(type=AlpacaSummarizer, judge_type='v2') diff --git a/docs/en/advanced_guides/subjective_evaluation.md b/docs/en/advanced_guides/subjective_evaluation.md index b4304bcee..dda9d6f5f 100644 --- a/docs/en/advanced_guides/subjective_evaluation.md +++ b/docs/en/advanced_guides/subjective_evaluation.md @@ -13,6 +13,13 @@ A popular evaluation method involves We support the use of GPT-4 (or other JudgeLLM) for the subjective evaluation of models based on above methods. +## Current Supported Subjective Evaluation Datasets + +1. AlginBench (https://github.com/THUDM/AlignBench) +2. MTBench (https://github.com/lm-sys/FastChat) +3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval) +4. CompassArena (Internal dataset) + ## Subjective Evaluation with Custom Dataset The specific process includes: diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md index af4532c54..13cc00f70 100644 --- a/docs/en/get_started/installation.md +++ b/docs/en/get_started/installation.md @@ -72,6 +72,19 @@ +5. Install alpaca-eval (Optional): + + If you want to**evaluate alpaca-eval in official ways**, follow this step. + +
+ click to show the details + + ```bash + pip install alpaca-eval + ``` + +
+ # Dataset Preparation The datasets supported by OpenCompass mainly include two parts: diff --git a/docs/zh_cn/advanced_guides/subjective_evaluation.md b/docs/zh_cn/advanced_guides/subjective_evaluation.md index f59d8135e..565b03ea5 100644 --- a/docs/zh_cn/advanced_guides/subjective_evaluation.md +++ b/docs/zh_cn/advanced_guides/subjective_evaluation.md @@ -13,6 +13,13 @@ 我们基于以上方法支持了JudgeLLM用于模型的主观能力评估(目前opencompass仓库里支持的所有模型都可以直接作为JudgeLLM进行调用,此外一些专用的JudgeLLM我们也在计划支持中)。 +## 目前已支持的主观评测数据集 + +1. AlginBench(https://github.com/THUDM/AlignBench) +2. MTBench (https://github.com/lm-sys/FastChat) +3. AlpacaEvalv2 (https://github.com/tatsu-lab/alpaca_eval) +4. CompassArena(内部数据集) + ## 自定义主观数据集评测 主观评测的具体流程包括: diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md index e87189608..088ad6368 100644 --- a/docs/zh_cn/get_started/installation.md +++ b/docs/zh_cn/get_started/installation.md @@ -73,6 +73,19 @@ +5. 安装 alpaca-eval(可选): + + 如果你需要**使用官方alpaca-eval实现评测 alpaca-eval 数据集**,请执行此步骤,否则忽略这一步。 + +
+ 点击查看详细 + + ```bash + pip install alpaca-eval + ``` + +
+ # 数据集准备 OpenCompass 支持的数据集主要包括两个部分: diff --git a/opencompass/models/openai_api.py b/opencompass/models/openai_api.py index 251e97972..a36670942 100644 --- a/opencompass/models/openai_api.py +++ b/opencompass/models/openai_api.py @@ -65,6 +65,8 @@ def __init__(self, meta_template: Optional[Dict] = None, openai_api_base: str = OPENAI_API_BASE, mode: str = 'none', + logprobs: Optional[bool] = False, + top_logprobs: Optional[int] = None, temperature: Optional[float] = None): super().__init__(path=path, @@ -78,6 +80,8 @@ def __init__(self, self.temperature = temperature assert mode in ['none', 'front', 'mid', 'rear'] self.mode = mode + self.logprobs = logprobs + self.top_logprobs = top_logprobs if isinstance(key, str): self.keys = [os.getenv('OPENAI_API_KEY') if key == 'ENV' else key] @@ -218,6 +222,8 @@ def _generate(self, input: str or PromptList, max_out_len: int, messages=messages, max_tokens=max_out_len, n=1, + logprobs=self.logprobs, + top_logprobs=self.top_logprobs, stop=None, temperature=temperature, ) @@ -234,7 +240,10 @@ def _generate(self, input: str or PromptList, max_out_len: int, str(raw_response.content)) continue try: - return response['choices'][0]['message']['content'].strip() + if self.logprobs: + return response['choices'] + else: + return response['choices'][0]['message']['content'].strip() except KeyError: if 'error' in response: if response['error']['code'] == 'rate_limit_exceeded': diff --git a/opencompass/tasks/outer_eval/alpacaeval.py b/opencompass/tasks/outer_eval/alpacaeval.py new file mode 100644 index 000000000..4e7c7146f --- /dev/null +++ b/opencompass/tasks/outer_eval/alpacaeval.py @@ -0,0 +1,128 @@ +# flake8: noqa: E501 +import copy +import json +import os.path as osp + +import mmengine +from mmengine.config import Config, ConfigDict + +from opencompass.tasks.base import BaseTask +from opencompass.utils import (build_dataset_from_cfg, get_infer_output_path, + get_logger) + + +class PredictionMerger: + """""" + + def __init__(self, cfg: ConfigDict) -> None: + + self.cfg = cfg + self.model_cfg = copy.deepcopy(self.cfg['model']) + self.dataset_cfg = copy.deepcopy(self.cfg['dataset']) + + self.work_dir = self.cfg.get('work_dir') + + def run(self): + filename = get_infer_output_path( + self.model_cfg, self.dataset_cfg, + osp.join(self.work_dir, 'predictions')) + root, ext = osp.splitext(filename) + partial_filename = root + '_0' + ext + + if osp.exists(osp.realpath(filename)): + return + + if not osp.exists(osp.realpath(partial_filename)): + print(f'{filename} not found') + return + + # Load predictions + partial_filenames = [] + if osp.exists(osp.realpath(filename)): + preds = mmengine.load(filename) + else: + preds, offset = {}, 0 + i = 1 + while osp.exists(osp.realpath(partial_filename)): + partial_filenames.append(osp.realpath(partial_filename)) + _preds = mmengine.load(partial_filename) + partial_filename = root + f'_{i}' + ext + i += 1 + for _o in range(len(_preds)): + preds[str(offset)] = _preds[str(_o)] + offset += 1 + + dataset = build_dataset_from_cfg(self.dataset_cfg) + if len(preds) != len(dataset.test): + print('length mismatch') + return + + with open( + osp.realpath(osp.join(self.dataset_cfg['path'], + 'example.json')), 'r') as f: + data_format = json.load(f) + + for idx in range(len(preds)): + data_format[idx]['output'] = preds[str(idx)]['prediction'] + data_format[idx]['generator'] = self.model_cfg['abbr'] + + print(f'Merge {partial_filenames} to {filename}') + with open(filename, 'w', encoding='utf-8') as f: + json.dump(data_format, f, indent=4, ensure_ascii=False) + + +class AlpacaEvalTask(BaseTask): + """Subjective Evaluation Task. + + This task is used to evaluate the metric between predictions and + references. + + Args: + cfg (ConfigDict): The configuration of the entire evaluation task. + """ + + name_prefix = 'SubjectiveEval' + log_subdir = 'logs/eval' + output_subdir = 'results' + + def __init__(self, cfg: ConfigDict): + super().__init__(cfg) + self.logger = get_logger() + judge_cfg = cfg.eval.runner.task.get('judge_cfg', {}) + assert type(judge_cfg) == ConfigDict + run_cfg = judge_cfg.get('run_cfg', {}) + self.num_gpus = run_cfg.get('num_gpus', 0) + self.num_procs = run_cfg.get('num_procs', 1) + self.judge_cfg = copy.deepcopy(judge_cfg) + + def get_command(self, cfg_path, template): + """Get the command template for the task. + + Args: + cfg_path (str): The path to the config file of the task. + template (str): The template which have '{task_cmd}' to format + the command. + """ + # script_path = __file__ + alpaca_cfg = self.judge_cfg.get('config', None) + api_key = self.judge_cfg.get('key', None) + assert alpaca_cfg is not None + all_cfg = Config.fromfile(cfg_path) + model_cfg = all_cfg['models'] + dataset_cfg = all_cfg['datasets'][0][0] + work_dir = osp.realpath(all_cfg['work_dir']) + for m_cfg in model_cfg: + PredictionMerger({ + 'model': m_cfg, + 'dataset': dataset_cfg, + 'work_dir': work_dir + }).run() + filename = get_infer_output_path(m_cfg, dataset_cfg, + osp.join(work_dir, 'predictions')) + output_path = osp.join(work_dir, 'results', m_cfg['abbr']) + command = f'export OPENAI_API_KEY={api_key}; alpaca_eval --model_outputs {filename} --annotators_config {alpaca_cfg} --output_path {output_path}' + return template.format(task_cmd=command) + + def run(self): + # model_cfg can be a list of model configs + pass diff --git a/opencompass/tasks/subjective_eval.py b/opencompass/tasks/subjective_eval.py index 08e37b5ba..30847f790 100644 --- a/opencompass/tasks/subjective_eval.py +++ b/opencompass/tasks/subjective_eval.py @@ -132,30 +132,29 @@ def _load_model_pred( # Get partition name root, ext = osp.splitext(filename) partial_filename = root + '_0' + ext - # If no predictions get in predictions dir - if not osp.exists(osp.realpath(filename)) and not osp.exists( - osp.realpath(partial_filename)): - return {'error': 'No predictions found.'} + assert osp.exists(filename) or osp.exists( + osp.realpath(partial_filename) + ), 'No predictions found for {filename}.'.format(filename=filename) + + # If use Naive partition in infer stage + if osp.exists(osp.realpath(filename)): + preds = mmengine.load(filename) + pred_strs = [ + preds[str(i)]['prediction'] for i in range(len(preds)) + ] + # If use Size partition in infer stage else: - # If use Naive partition in infer stage - if osp.exists(osp.realpath(filename)): + filename = partial_filename + pred_strs = [] + i = 1 + while osp.exists(osp.realpath(filename)): preds = mmengine.load(filename) - pred_strs = [ + filename = root + f'_{i}' + ext + i += 1 + pred_strs += [ preds[str(i)]['prediction'] for i in range(len(preds)) ] - # If use Size partition in infer stage - else: - filename = partial_filename - pred_strs = [] - i = 1 - while osp.exists(osp.realpath(filename)): - preds = mmengine.load(filename) - filename = root + f'_{i}' + ext - i += 1 - pred_strs += [ - preds[str(i)]['prediction'] for i in range(len(preds)) - ] # Get all predictions in pred_strs # If take SubjectSizePartition, get new pred_strs based on test_range diff --git a/requirements/extra.txt b/requirements/extra.txt index f5f709ce3..6b3409e74 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -1 +1,2 @@ +alpaca-eval faiss_gpu==1.7.2