Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Yggdrasill7D6 committed Apr 24, 2024
1 parent 61c477a commit 61cfabe
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
skip = *.ipynb
count =
quiet-level = 3
ignore-words-list = nd, ans, ques, rouge, softwares, wit
ignore-words-list = nd, ans, ques, rouge, softwares, wit, te
94 changes: 51 additions & 43 deletions opencompass/datasets/mgsm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os.path as osp
import re

from datasets import Dataset

from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS
from opencompass.registry import LOAD_DATASET

from .base import BaseDataset

Expand All @@ -15,12 +14,11 @@ class MGSMSDataset(BaseDataset):
@staticmethod
def load(path: str):


src_lines = open(path, 'r', encoding='utf-8').readlines()

data = {'question': [], 'answer': []}

for lines in src_lines:
for lines in src_lines:
data['question'].append(lines.split('\t')[0])
data['answer'].append(lines.split('\t')[1])

Expand All @@ -30,6 +28,7 @@ def load(path: str):
})
return dataset


# LANG_TO_ANSWER_PREFIX = {
# "en": "Answer",
# "bn": "উত্তর",
Expand All @@ -45,63 +44,70 @@ def load(path: str):
# }



def mgsm_zh_postprocess(text: str) -> str:
answer_text = text.split("答案")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('答案')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_bn_postprocess(text: str) -> str:
answer_text = text.split("উত্তর")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('উত্তর')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_de_postprocess(text: str) -> str:
answer_text = text.split("Antwort")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Antwort')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_en_postprocess(text: str) -> str:
answer_text = text.split("Answer")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Answer')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_es_postprocess(text: str) -> str:
answer_text = text.split("Respuesta")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Respuesta')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_fr_postprocess(text: str) -> str:
answer_text = text.split("Réponse")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Réponse')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_ja_postprocess(text: str) -> str:
answer_text = text.split("答え")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('答え')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_ru_postprocess(text: str) -> str:
answer_text = text.split("Ответ")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Ответ')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_sw_postprocess(text: str) -> str:
answer_text = text.split("Jibu")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('Jibu')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


def mgsm_te_postprocess(text: str) -> str:
answer_text = text.split("సమాధానం")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""
answer_text = text.split('సమాధానం')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''

def mgsm_th_postprocess(text: str) -> str:
answer_text = text.split("คำตอบ")[-1].strip()
numbers = re.findall(r"\d+\.?\d*", answer_text.replace(",", ""))
return numbers[-1].rstrip(".") if numbers else ""

def mgsm_th_postprocess(text: str) -> str:
answer_text = text.split('คำตอบ')[-1].strip()
numbers = re.findall(r'\d+\.?\d*', answer_text.replace(',', ''))
return numbers[-1].rstrip('.') if numbers else ''


class MGSM_Evaluator(BaseEvaluator):
Expand All @@ -110,12 +116,14 @@ def score(self, predictions, references):
assert len(predictions) == len(references)

result = {'pass': 0, 'fail': 0}
for index, (references_answer, predictions_answer) in enumerate(zip(references, predictions)):
for index, (references_answer, predictions_answer) in enumerate(
zip(references, predictions)):
if references_answer == predictions_answer:
result['pass'] += 1
else:
result['fail'] += 1

result['score'] = float(result['pass'] / (result['pass'] + result['fail'])) * 100
result['score'] = float(result['pass'] /
(result['pass'] + result['fail'])) * 100
final_result = {'Acc': result['score']}
return final_result
return final_result

0 comments on commit 61cfabe

Please sign in to comment.