Skip to content

Commit

Permalink
Add GaoKaoMath Dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
liushz committed Oct 8, 2024
1 parent 4d6349d commit 25efa55
Show file tree
Hide file tree
Showing 4 changed files with 268 additions and 0 deletions.
48 changes: 48 additions & 0 deletions configs/datasets/gaokao_math/gaokao_math_gen_9b076f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import GaoKaoMATHDataset, GaoKaoMATHEvaluator


MATH_CN_PROMPT="""
你是一个数学阅卷专家,任务是从给定的回答句子中提取精确的关键答案。你必须只提供提取的关键答案,不包括任何额外的文字。
我将为你提供一个问题、回答句子和问题类型。回答句子是对所提供问题的回应。利用提供的信息,你必须准确而精确地确定并从回答句子中提取预期的关键答案。请不要对问题发表主观看法。
对于单选题,答案应该是选项字母,例如 "A";
对于多选题,答案应该是一个选项字母的列表,例如 "A" 或 "A", "B", "C";
对于填空题,答案应该是一个填入空白处的答案列表,列表的数量应该与问题中的空白数量相同,例如 ["$$\\frac{{1}}{{2}}$$"] 或 ["$$\\frac{{1}}{{2}}$$", "2"]。
对于问答题,类似填空题,为每个小问抽出相应答案,例如 ["$$\\frac{{1}}{{2}}$$"] 或 ["$$\\frac{{1}}{{2}}$$", "2"]。
如果回答句子提供了多个不同的答案,请仔细判断后面提供的答案是否是对前面答案的修正或修改。如果是这样,提取这个修正或修改后的答案作为最终答案。相反,如果回答句子在多个答案之间波动而没有明确的最终答案,你应该输出 [No valid answer]。
问题类型: {question_type}
原始问题: {question}
回答: {response}
提取的关键答案:
"""

gaokao_math_reader_cfg = dict(input_columns=['question', 'response', 'question_type'], output_column='extract_answer')


gaokao_math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt=MATH_CN_PROMPT),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))

gaokao_math_eval_cfg = dict(
evaluator=dict(type=GaoKaoMATHEvaluator, url='http://0.0.0.0:23333/v1'))

gaokao_math_datasets = [
dict(
type=GaoKaoMATHDataset,
abbr='GaoKaoMATH',
path='./data/gaokao_math/test_2k.json',
reader_cfg=gaokao_math_reader_cfg,
infer_cfg=gaokao_math_infer_cfg,
eval_cfg=gaokao_math_eval_cfg)
]
48 changes: 48 additions & 0 deletions opencompass/configs/datasets/gaokao_math/gaokao_math_gen_9b076f.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import GaoKaoMATHDataset, GaoKaoMATHEvaluator


MATH_CN_PROMPT="""
你是一个数学阅卷专家,任务是从给定的回答句子中提取精确的关键答案。你必须只提供提取的关键答案,不包括任何额外的文字。
我将为你提供一个问题、回答句子和问题类型。回答句子是对所提供问题的回应。利用提供的信息,你必须准确而精确地确定并从回答句子中提取预期的关键答案。请不要对问题发表主观看法。
对于单选题,答案应该是选项字母,例如 "A";
对于多选题,答案应该是一个选项字母的列表,例如 "A" 或 "A", "B", "C";
对于填空题,答案应该是一个填入空白处的答案列表,列表的数量应该与问题中的空白数量相同,例如 ["$$\\frac{{1}}{{2}}$$"] 或 ["$$\\frac{{1}}{{2}}$$", "2"]。
对于问答题,类似填空题,为每个小问抽出相应答案,例如 ["$$\\frac{{1}}{{2}}$$"] 或 ["$$\\frac{{1}}{{2}}$$", "2"]。
如果回答句子提供了多个不同的答案,请仔细判断后面提供的答案是否是对前面答案的修正或修改。如果是这样,提取这个修正或修改后的答案作为最终答案。相反,如果回答句子在多个答案之间波动而没有明确的最终答案,你应该输出 [No valid answer]。
问题类型: {question_type}
原始问题: {question}
回答: {response}
提取的关键答案:
"""

gaokao_math_reader_cfg = dict(input_columns=['question', 'response', 'question_type'], output_column='extract_answer')


gaokao_math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(role='HUMAN', prompt=MATH_CN_PROMPT),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))

gaokao_math_eval_cfg = dict(
evaluator=dict(type=GaoKaoMATHEvaluator, url='http://0.0.0.0:23333/v1'))

gaokao_math_datasets = [
dict(
type=GaoKaoMATHDataset,
abbr='GaoKaoMATH',
path='./data/gaokao_math/test_2k.json',
reader_cfg=gaokao_math_reader_cfg,
infer_cfg=gaokao_math_infer_cfg,
eval_cfg=gaokao_math_eval_cfg)
]
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .flames import * # noqa: F401, F403
from .flores import * # noqa: F401, F403
from .game24 import * # noqa: F401, F403
from .gaokao_math import * # noqa: F401, F403
from .GaokaoBench import * # noqa: F401, F403
from .govrepcrs import * # noqa: F401, F403
from .gpqa import * # noqa: F401, F403
Expand Down
171 changes: 171 additions & 0 deletions opencompass/datasets/gaokao_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import json
import re
import time
from logging import getLogger

from datasets import Dataset
from openai import OpenAI

from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from opencompass.utils import get_data_path

from .base import BaseDataset

EVAL_PROMPT = """
请你作为一个数学高考阅卷专家,判断下面的答案是否与标准答案一致,即考生是否回答正确。下面是一些评判标准:
1. 有些答案可能包含多项内容,可能有单选题,多选题,填空题等,只要答案与标准答案一致即可, 对于多选题和多个空的填空题,需要考生对应的选项或空都回答正确才算正确。
2. 有些答案可能通过不同的方式表达,比如有些答案可能是一个数学表达式,有些答案可能是一个文字描述,只要表达的意思一致即可。且有些公式通过不同的方式表达,但等价,也是正确的。
请你根据上述标准,判断下面的答案是否与标准答案一致,如果一致,请在最后输出\\boxed{yes}, 否则输出\\boxed{no}, 如果难以判断,请输出\\boxed{no}.
考生答案:{answer}
标准答案:{gold_answer}
分析:
""" # noqa E501


def extract_boxed_answer(text):
match = re.search(r'\\boxed{(.+)}', text)
if match:
return match.group(1)
return None


@LOAD_DATASET.register_module()
class GaoKaoMATHDataset(BaseDataset):

@staticmethod
def load(path: str):
path = get_data_path(path, local_mode=True)
data = json.load(open(path))
dataset = Dataset.from_list(data)
return dataset


class API_Infer:

def __init__(self, api_key, url, model_name, temperature, max_tokens):
self.api_key = api_key
self.url = url
self.model_name = model_name
self.temperature = temperature
self.max_tokens = max_tokens
self.SYSTEM = 'You are a helpful assistant.'
self.logger = getLogger(__name__)

def openai_infer(self, query: str, retry=9) -> str:
"""Perform inference on the OpenAI model.
Args:
query (str): The input query.
Returns:
str: The extracted answer (xFinder's output).
"""
if isinstance(self.url, list):
# Randomly api for better load balancing
import random
self.url = random.choice(self.url)
self.client = OpenAI(
api_key=self.api_key,
base_url=self.url,
)
self.retry = retry

t = time.time()
retry = self.retry
response = ''
while retry > 0:
try:
chat_response = self.client.chat.completions.create(
model=self.client.models.list().data[0].id
if self.model_name == '' else self.model_name,
messages=[
{
'role': 'system',
'content': self.SYSTEM
},
{
'role': 'user',
'content': query
},
],
temperature=self.temperature,
max_tokens=self.max_tokens,
)
js_response = json.loads(chat_response.model_dump_json())
response = js_response['choices'][0]['message']['content']
break
except Exception as e:
self.logger.info(f'Error: {e}')
self.logger.info(f'{self.url} is down. Retrying...')
self.logger.info(f'Time elapsed: {time.time() - t} seconds')
time.sleep(6)
retry -= 1
if retry == 0:
response = 'Error: Failed to get response.'
self.logger.info(f'{response} after {self.retry} tries.')
raise ValueError('The api is down')
return response.strip()


@ICL_EVALUATORS.register_module()
class GaoKaoMATHEvaluator(BaseEvaluator):

def __init__(self,
url,
temperature=1e-6,
max_tokens=2048,
procs=8,
**kwargs):
self.model = API_Infer('', url, '', temperature, max_tokens)
self.procs = procs

def is_equiv(self, i, j):
judges = []
for pred, ref in zip(i, j):
pred = self.model.openai_infer(
EVAL_PROMPT.replace('{answer}',
pred).replace('{gold_answer}', ref))
if extract_boxed_answer(pred) == 'yes':
judges.append(1)
else:
judges.append(0)
return judges

def score(self, predictions, references):
if len(predictions) != len(references):
return {'error': 'preds and refrs have different length'}
details = []
correct = 0
count = 0
results = []
for pred, ref in zip(predictions, references):
result = self.is_equiv(pred, ref)
results.append(result)

for pred, ref, result in zip(predictions, references, results):
detail = {'pred': pred, 'answer': ref, 'correct': False}
count += 1
if result:
correct += 1
detail['correct'] = True
details.append(detail)

detailed_result = {
'accuracy': 100 * correct / count,
'details': details
}
self.logger.info(json.dumps(detailed_result, indent=4))
return detailed_result


if __name__ == '__main__':
evaluator = GaoKaoMATHEvaluator('http://22.8.75.210:23333/v1',
temperature=0.01,
max_tokens=2048,
procs=8)
predictions = ['1', '2', '3']
references = ['1', '2', '3']
evaluator.score(predictions, references)

0 comments on commit 25efa55

Please sign in to comment.