From 0991dd33a0161c3e713abca4007553c7378833b7 Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Wed, 24 Jan 2024 16:30:32 +0800 Subject: [PATCH] [Sync] Updata dataset cfg for internMath (#837) Co-authored-by: liuhongwei --- configs/datasets/gsm8k/gsm8k_gen_701491.py | 33 ++ configs/datasets/math/math_gen_0957ff.py | 68 ++++ configs/datasets/math/math_gen_736506.py | 28 ++ opencompass/datasets/__init__.py | 1 + opencompass/datasets/math_intern.py | 342 +++++++++++++++++++++ opencompass/utils/text_postprocessors.py | 1 + 6 files changed, 473 insertions(+) create mode 100644 configs/datasets/gsm8k/gsm8k_gen_701491.py create mode 100644 configs/datasets/math/math_gen_0957ff.py create mode 100644 configs/datasets/math/math_gen_736506.py create mode 100644 opencompass/datasets/math_intern.py diff --git a/configs/datasets/gsm8k/gsm8k_gen_701491.py b/configs/datasets/gsm8k/gsm8k_gen_701491.py new file mode 100644 index 000000000..ebdbd6d3d --- /dev/null +++ b/configs/datasets/gsm8k/gsm8k_gen_701491.py @@ -0,0 +1,33 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.openicl.icl_evaluator import AccEvaluator +from opencompass.datasets import GSM8KDataset, gsm8k_postprocess, gsm8k_dataset_postprocess, Gsm8kEvaluator + +gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer') + +gsm8k_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict( + round=[ + dict(role='HUMAN', prompt="Question: {question}\nLet's think step by step\nAnswer:") + ], + )), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +gsm8k_eval_cfg = dict(evaluator=dict(type=Gsm8kEvaluator), + pred_role="BOT", + pred_postprocessor=dict(type=gsm8k_postprocess), + dataset_postprocessor=dict(type=gsm8k_dataset_postprocess)) + +gsm8k_datasets = [ + dict( + abbr='gsm8k', + type=GSM8KDataset, + path='./data/gsm8k', + reader_cfg=gsm8k_reader_cfg, + infer_cfg=gsm8k_infer_cfg, + eval_cfg=gsm8k_eval_cfg) +] diff --git a/configs/datasets/math/math_gen_0957ff.py b/configs/datasets/math/math_gen_0957ff.py new file mode 100644 index 000000000..1b8561d23 --- /dev/null +++ b/configs/datasets/math/math_gen_0957ff.py @@ -0,0 +1,68 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess + +math_reader_cfg = dict(input_columns=['problem'], output_column='solution') + +math_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict( + role="HUMAN", + prompt= + "Problem:\nFind the domain of the expression $\\frac{\sqrt{x-2}}{\sqrt{5-x}}$.}\nSolution:" + ), + dict( + role="BOT", + prompt= + "The expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{[2,5)}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.\n" + ), + dict( + role="HUMAN", + prompt= + "Problem:\nIf $\det \mathbf{A} = 2$ and $\det \mathbf{B} = 12,$ then find $\det (\mathbf{A} \mathbf{B}).$\nSolution:" + ), + dict( + role="BOT", + prompt= + "We have that $\det (\mathbf{A} \mathbf{B}) = (\det \mathbf{A})(\det \mathbf{B}) = (2)(12) = \\boxed{24}.$\nFinal Answer: The final answer is $24$. I hope it is correct.\n" + ), + dict( + role="HUMAN", + prompt= + "Problem:\nTerrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?\nSolution:" + ), + dict( + role="BOT", + prompt= + "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*} 30n&=480\\\\ \Rightarrow\qquad n&=480/30=\\boxed{16} \end{align*}\nFinal Answer: The final answer is $16$. I hope it is correct.\n" + ), + dict( + role="HUMAN", + prompt= + "Problem:\nIf the system of equations: \\begin{align*} 6x-4y&=a,\\\\ 6y-9x &=b. \end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.\nSolution:" + ), + dict( + role="BOT", + prompt= + "If we multiply the first equation by $-\\frac{3}{2}$, we obtain $$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have $$-\\frac{3}{2}a=b\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$\nFinal Answer: The final answer is $-\\frac{2}{3}$. I hope it is correct.\n" + ), + dict(role="HUMAN", prompt="Problem:\n{problem}\nSolution:\n"), + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +math_eval_cfg = dict( + evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess)) + +math_datasets = [ + dict( + type=MATHDataset, + abbr='math', + path='./data/math/math.json', + reader_cfg=math_reader_cfg, + infer_cfg=math_infer_cfg, + eval_cfg=math_eval_cfg) +] diff --git a/configs/datasets/math/math_gen_736506.py b/configs/datasets/math/math_gen_736506.py new file mode 100644 index 000000000..e68c4146e --- /dev/null +++ b/configs/datasets/math/math_gen_736506.py @@ -0,0 +1,28 @@ +from opencompass.openicl.icl_prompt_template import PromptTemplate +from opencompass.openicl.icl_retriever import ZeroRetriever +from opencompass.openicl.icl_inferencer import GenInferencer +from opencompass.datasets import MATHInternDataset, MATHInternEvaluator, math_intern_postprocess + +math_reader_cfg = dict(input_columns=['problem'], output_column='solution') + +math_infer_cfg = dict( + prompt_template=dict( + type=PromptTemplate, + template=dict(round=[ + dict(role='HUMAN', prompt="Question: {problem}\nLet's think step by step\nAnswer:") + ])), + retriever=dict(type=ZeroRetriever), + inferencer=dict(type=GenInferencer, max_out_len=512)) + +math_eval_cfg = dict( + evaluator=dict(type=MATHInternEvaluator), pred_postprocessor=dict(type=math_intern_postprocess)) + +math_datasets = [ + dict( + type=MATHInternDataset, + abbr='math', + path='./data/math/math.json', + reader_cfg=math_reader_cfg, + infer_cfg=math_infer_cfg, + eval_cfg=math_eval_cfg) +] \ No newline at end of file diff --git a/opencompass/datasets/__init__.py b/opencompass/datasets/__init__.py index b60d03cef..48eeecb5e 100644 --- a/opencompass/datasets/__init__.py +++ b/opencompass/datasets/__init__.py @@ -61,6 +61,7 @@ from .mastermath2024v1 import * # noqa: F401, F403 from .math import * # noqa: F401, F403 from .math401 import * # noqa: F401, F403 +from .math_intern import * # noqa: F401, F403 from .mathbench import * # noqa: F401, F403 from .mbpp import * # noqa: F401, F403 from .medbench import * # noqa: F401, F403 diff --git a/opencompass/datasets/math_intern.py b/opencompass/datasets/math_intern.py new file mode 100644 index 000000000..7d5a6b6e7 --- /dev/null +++ b/opencompass/datasets/math_intern.py @@ -0,0 +1,342 @@ +import json +import re + +from datasets import Dataset, DatasetDict + +from opencompass.openicl.icl_evaluator import BaseEvaluator +from opencompass.registry import (ICL_EVALUATORS, LOAD_DATASET, + TEXT_POSTPROCESSORS) + +from .base import BaseDataset + + +def last_boxed_only_string(string): + idx = string.rfind('\\boxed') + if idx < 0: + idx = string.rfind('\\fbox') + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == '{': + num_left_braces_open += 1 + if string[i] == '}': + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + +def remove_boxed(s): + left = '\\boxed{' + try: + assert s[:len(left)] == left + assert s[-1] == '}' + return s[len(left):-1] + except Exception: + return None + + +def extract_boxed_answer(pred_str, strip_double_curly_brace=False): + boxed_str = last_boxed_only_string(pred_str) + if boxed_str is None: + return None + answer = remove_boxed(boxed_str) + if answer is None: + return None + if strip_double_curly_brace: + match = re.match('^\{(.*)\}$', answer) # noqa: W605 + if match: + answer = match.group(1) + return answer + + +@LOAD_DATASET.register_module() +class MATHInternDataset(BaseDataset): + + @staticmethod + def load(path: str): + dataset = DatasetDict() + data = json.load(open(path)) + raw_data = [] + for i in data.keys(): + raw_data.append({ + 'problem': + data[i]['problem'], + 'solution': + extract_boxed_answer(data[i]['solution']) + }) + dataset['test'] = Dataset.from_list(raw_data) + dataset['train'] = Dataset.from_list(raw_data) + return dataset + + +@ICL_EVALUATORS.register_module() +class MATHInternEvaluator(BaseEvaluator): + + def score(self, predictions, references): + if len(predictions) != len(references): + return { + 'error': 'predictions and references have different ' + 'length' + } + correct = 0 + count = 0 + details = [] + for i, j in zip(predictions, references): + detail = {'pred': i, 'answer': j, 'correct': False} + count += 1 + if is_equiv(i, j): + correct += 1 + detail['correct'] = True + details.append(detail) + result = {'accuracy': 100 * correct / count, 'details': details} + return result + + +@TEXT_POSTPROCESSORS.register_module('math_intern_postprocess') +def math_intern_postprocess(text: str) -> str: + extractor = Extractor() + return extractor.extract_answer(text) + + +class Extractor: + + def extract_matching_bracket(cls, target_str: str): + if not target_str: + return target_str + current_nest_level = 1 + for i, ch in enumerate(target_str): + if ch == '{': + current_nest_level += 1 + elif ch == '}': + current_nest_level -= 1 + if current_nest_level == 0: + break + return target_str[:i] + + def clean(cls, target_str: str): + opt = target_str.strip().replace('{{', '{').replace('}}', '}') + if not opt: + return opt + if opt[-1] == '.' or opt[-1] == '。': + return opt[:-1] + return opt + + def extract_answer(cls, pred: str, extract_last_num=False): + if pred.find('The final answer is ') >= 0: + x = pred[pred.find('The final answer is ') + + len('The final answer is '):] + x = x[1:x.find('$.')] + # print(x) + return cls.clean(x) + if pred.find('\n\nQuestion:') >= 0: + pred = pred.split('\n\nQuestion:')[0] + if pred.find('The answer is'): + pred = pred[pred.find('The answer is') + len('The answer is'):] + return cls.clean(pred) + if pred.find('# Answer') >= 0: + return cls.clean(pred[pred.find('# Answer') + len('# Answer'):]) + if pred.find('The answer is:') >= 0: + return cls.clean(pred[pred.find('The answer is:') + + len('The answer is:'):]) + if pred.find('####') >= 0: + return cls.clean(pred[pred.find('####') + 4:]) + left = '\\boxed{' + if pred.find(left) >= 0: + pred = pred[pred.find(left) + len(left):] + return cls.clean(cls.extract_matching_bracket(pred)) + + if extract_last_num: + nums = [] + opt = '' + + def contain_digit(opt): + for ch in opt: + if ch.isdigit(): + return True + return False + + for ch in pred: + if ch.isdigit() or ch in ' ,.': + opt = opt + ch + else: + if contain_digit(opt): + nums.append(opt) + opt = '' + if contain_digit(opt): + return cls.clean(opt) + if nums: + return cls.clean(nums[-1]) + return None + + +def fix_fracs(string): + substrs = string.split('\\frac') + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += '\\frac' + if substr[0] == '{': + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != '{': + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}{' + b + '}' + post_substr + else: + new_str += '{' + a + '}{' + b + '}' + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += '{' + a + '}' + b + post_substr + else: + new_str += '{' + a + '}' + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split('/')) != 2: + return string + a = string.split('/')[0] + b = string.split('/')[1] + try: + a = int(a) + b = int(b) + assert string == '{}/{}'.format(a, b) + new_string = '\\frac{' + str(a) + '}{' + str(b) + '}' + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) + if '\\text{ ' in string: + splits = string.split('\\text{ ') + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if '\\sqrt' not in string: + return string + splits = string.split('\\sqrt') + new_string = splits[0] + for split in splits[1:]: + if split[0] != '{': + a = split[0] + new_substr = '\\sqrt{' + a + '}' + split[1:] + else: + new_substr = '\\sqrt' + split + new_string += new_substr + return new_string + + +def strip_string(string): + # linebreaks + string = string.replace('\n', '') + + # remove inverse spaces + string = string.replace('\\!', '') + + # replace \\ with \ + string = string.replace('\\\\', '\\') + + # replace tfrac and dfrac with frac + string = string.replace('tfrac', 'frac') + string = string.replace('dfrac', 'frac') + + # remove \left and \right + string = string.replace('\\left', '') + string = string.replace('\\right', '') + + # Remove circ (degrees) + string = string.replace('^{\\circ}', '') + string = string.replace('^\\circ', '') + + # remove dollar signs + string = string.replace('\\$', '') + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace('\\%', '') + string = string.replace('\%', '') # noqa: W605 + + string = string.replace(' .', ' 0.') + string = string.replace('{.', '{0.') + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == '.': + string = '0' + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split('=')) == 2: + if len(string.split('=')[0]) <= 2: + string = string.split('=')[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(' ', '') + + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == '0.5': + string = '\\frac{1}{2}' + + string = fix_a_slash_b(string) + string = string.replace('x \\in', '').strip() # noqa: W605 + + # a_b == a, a_{b} == a_b for bit conversion + if string.find('_') >= 0: + p = string.split('_') + p[1] = p[1].replace('{', '').replace('}', '') + string = '_'.join(p) + + # 10800 == 10,800; we only deal with single number + if string.strip().find(' ') == -1 and string.find('(') == -1: + string = string.replace(',', '') + + return string + + +def is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + # print("WARNING: Both None") + return False + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + return ss1 == ss2 + except Exception: + return str1 == str2 diff --git a/opencompass/utils/text_postprocessors.py b/opencompass/utils/text_postprocessors.py index 7b478402c..dc39ff701 100644 --- a/opencompass/utils/text_postprocessors.py +++ b/opencompass/utils/text_postprocessors.py @@ -71,6 +71,7 @@ def first_option_postprocess(text: str, options: str, cushion=True) -> str: f'答案为\s?([{options}])', f'答案选\s?([{options}])', f'选择?\s?([{options}])', + f'故选?\s?([{options}])' f'只有选?项?\s?([{options}])\s?是?对', f'只有选?项?\s?([{options}])\s?是?错', f'只有选?项?\s?([{options}])\s?不?正确',