Skip to content

Commit

Permalink
[Feature] Add gpqa prompt from simple_evals, openai (open-compass#1080)
Browse files Browse the repository at this point in the history
* add gpqa_openai_simple_eval

* 触发CI构建

* reorg

---------

Co-authored-by: Leymore <[email protected]>
  • Loading branch information
Francis-llgg and Leymore authored Apr 26, 2024
1 parent f054e24 commit 980f704
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 2 deletions.
2 changes: 1 addition & 1 deletion configs/datasets/gpqa/gpqa_gen.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from mmengine.config import read_base

with read_base():
from .gpqa_gen_4baadb import gpqa_datasets
from .gpqa_openai_simple_evals_gen_5aeece import gpqa_datasets
52 changes: 52 additions & 0 deletions configs/datasets/gpqa/gpqa_openai_simple_evals_gen_5aeece.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import GPQADataset_Simple_Eval, GPQA_Simple_Eval_postprocess, GPQAEvaluator

# openai_simple_eval prompt
align_prompt = """
Answer the following multiple choice question. The last line of your response should be of the following format: 'ANSWER: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
{question}
A) {A}
B) {B}
C) {C}
D) {D}
""".strip()

gpqa_reader_cfg = dict(
input_columns=['question', 'A', 'B', 'C', 'D'],
output_column='answer')

gpqa_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt=align_prompt),
], )),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

gpqa_eval_cfg = dict(evaluator=dict(type=GPQAEvaluator),
pred_postprocessor=dict(type=GPQA_Simple_Eval_postprocess))

gpqa_datasets = []
gpqa_subsets = {
# 'extended': 'gpqa_extended.csv',
# 'main': 'gpqa_main.csv',
'diamond': 'gpqa_diamond.csv'
}

for split in list(gpqa_subsets.keys()):
gpqa_datasets.append(
dict(
abbr='GPQA_' + split,
type=GPQADataset_Simple_Eval,
path='./data/gpqa/',
name=gpqa_subsets[split],
reader_cfg=gpqa_reader_cfg,
infer_cfg=gpqa_infer_cfg,
eval_cfg=gpqa_eval_cfg)
)
54 changes: 53 additions & 1 deletion opencompass/datasets/gpqa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import csv
import os
import random
import re

from datasets import Dataset

from opencompass.openicl import BaseEvaluator
from opencompass.registry import LOAD_DATASET
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS

from .base import BaseDataset

Expand Down Expand Up @@ -57,3 +59,53 @@ def score(self, predictions, references):
details.append(detail)
result = {'accuracy': 100 * correct / count, 'details': details}
return result


@LOAD_DATASET.register_module()
class GPQADataset_Simple_Eval(BaseDataset):

@staticmethod
def load(path: str, name: str):
n_repeats = 4
data = []
with open(os.path.join(path, name), 'r', encoding='utf-8') as f:
reader = csv.reader(f, delimiter=',')
for row in reader:
if row[7] == 'Question':
continue
question = row[7]
# 第一个是正确选项
options = [row[8], row[9], row[10], row[11]]
line = {'question': question}
line['answer'] = 'A'
line['options'] = options
data.append(line)

data_list = data * n_repeats
rng = random.Random(0)
data_list = [
data | {
'permutation': rng.sample(range(4), 4)
} for data in data_list
]
for entry in data_list:
options = entry['options']
correct_options = [options[i] for i in entry['permutation']]
for i in range(4):
entry['ABCD'[i]] = correct_options[i]
correct_index = entry['permutation'].index(0)
correct_answer = 'ABCD'[correct_index]
entry['options'] = correct_options
entry['answer'] = correct_answer

dataset = Dataset.from_list(data_list)
return dataset


@TEXT_POSTPROCESSORS.register_module()
def GPQA_Simple_Eval_postprocess(text: str) -> str:
ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([A-D])'
match = re.search(ANSWER_PATTERN, text)
if match:
return match.group(1)
return None

0 comments on commit 980f704

Please sign in to comment.