Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Math Evaluation with Judge Model Evaluator #1077

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions configs/datasets/math/math_gen_78ced2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import MATHDataset, MATHEvaluator, math_postprocess

QUERY_TEMPLATE = """
Solve the following math problem step by step. The last line of your response should be of the form ANSWER: $ANSWER (without quotes) where $ANSWER is the answer to the problem.

{problem}

Remember to put your answer on its own line after "ANSWER:", and you do not need to use a \\boxed command.
""".strip()

math_reader_cfg = dict(input_columns=['problem'], output_column='solution')

math_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,

template=dict(round=[
dict(role="HUMAN", prompt=QUERY_TEMPLATE),
])),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer, max_out_len=512))

math_eval_cfg = dict(
evaluator=dict(type=MATHEvaluator), pred_postprocessor=dict(type=math_postprocess))

math_datasets = [
dict(
type=MATHDataset,
abbr='math',
path='./data/math/math.json',
reader_cfg=math_reader_cfg,
infer_cfg=math_infer_cfg,
eval_cfg=math_eval_cfg)
]
137 changes: 137 additions & 0 deletions configs/eval_math_judgement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Most of the code in this file is copied from https://github.com/openai/simple-evals/blob/main/math_eval.py
from mmengine.config import read_base
with read_base():
from .models.hf_llama.hf_llama3_8b_instruct import models as hf_llama3_8b_instruct_model # noqa: F401, F403
from .models.hf_internlm.hf_internlm2_chat_20b import models as hf_internlm2_chat_20b_model # noqa: F401, F403
from .models.hf_llama.hf_llama3_70b_instruct import models as hf_llama3_70b_instruct_model # noqa: F401, F403
from .datasets.math.math_gen_78ced2 import math_datasets # noqa: F401, F403
from opencompass.models.openai_api import OpenAIAllesAPIN
from opencompass.datasets import math_judement_preprocess
from opencompass.partitioners import NaivePartitioner, SizePartitioner
from opencompass.partitioners.sub_naive import SubjectiveNaivePartitioner
from opencompass.partitioners.sub_size import SubjectiveSizePartitioner
from opencompass.runners import LocalRunner
from opencompass.runners import SlurmSequentialRunner
from opencompass.tasks import OpenICLInferTask
from opencompass.tasks.subjective_eval import SubjectiveEvalTask
from opencompass.summarizers import AllObjSummarizer
from opencompass.openicl.icl_evaluator import LMEvaluator
from opencompass.openicl.icl_prompt_template import PromptTemplate


# -------------Prompt Settings ----------------------------------------
eng_obj_prompt = """
Look at the following two expressions (answers to a math problem) and judge whether they are equivalent. Only perform trivial simplifications

Examples:

Expression 1: $2x+3$
Expression 2: $3+2x$

Result: [[Correct]]

Expression 1: 3/2
Expression 2: 1.5

Result: [[Correct]]

Expression 1: $x^2+2x+1$
Expression 2: $y^2+2y+1$

Result: [[Incorrect]]

Expression 1: $x^2+2x+1$
Expression 2: $(x+1)^2$

Result: [[Correct]]

Expression 1: 3245/5
Expression 2: 649

Result: [[Incorrect]]
(these are actually equal, don't mark them equivalent if you need to do nontrivial simplifications)

Expression 1: 2/(-3)
Expression 2: -2/3

Result: [[Correct]]
(trivial simplifications are allowed)

Expression 1: 72 degrees
Expression 2: 72

Result: [[Correct]]
(give benefit of the doubt to units)

Expression 1: 64
Expression 2: 64 square feet

Result: [[Correct]]
(give benefit of the doubt to units)

---

YOUR TASK


Respond with only "Result: [[Correct]]" or "Result: [[Incorrect]]" (without quotes). Do not include a rationale.

Expression 1: {obj_gold}
Expression 2: {prediction}
""".strip()

# -------------Inferen Stage ----------------------------------------
# eval models
models = [*hf_llama3_8b_instruct_model]
# judge models
judge_models = hf_llama3_70b_instruct_model

eng_datasets = [*math_datasets]
chn_datasets = []
datasets = eng_datasets + chn_datasets
work_dir = 'outputs/obj_all/'

for d in eng_datasets:
d['eval_cfg']= dict(
evaluator=dict(
type=LMEvaluator,
# If you need to preprocess the prediction before judging,
# you can specify the pred_postprocessor function here
pred_postprocessor=dict(type=math_judement_preprocess),
prompt_template=dict(
type=PromptTemplate,
template=dict(round=[
dict(
role='HUMAN',
prompt = eng_obj_prompt
),
]),
),
),
pred_role="BOT",
)

infer = dict(
partitioner=dict(type=SizePartitioner, max_task_size=40000),
runner=dict(
type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=256,
task=dict(type=OpenICLInferTask)),
)

# ------------- Evaluation Configuration --------------------------------
eval = dict(
partitioner=dict(
type=SubjectiveSizePartitioner, max_task_size=100000, mode='singlescore', models=models, judge_models=judge_models,
),
runner=dict(type=SlurmSequentialRunner,
partition='llmeval',
quotatype='auto',
max_num_workers=16, task=dict(type=SubjectiveEvalTask)),
)

summarizer = dict(
type=AllObjSummarizer
)
14 changes: 14 additions & 0 deletions opencompass/datasets/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def normalize_final_answer(final_answer: str) -> str:
return final_answer


ANSWER_PATTERN = r'(?i)ANSWER\s*:\s*([^\n]+)'


def extract_answer(response_text: str):
match = re.search(ANSWER_PATTERN, response_text)
return match.group(1) if match else None


@LOAD_DATASET.register_module()
class MATHDataset(BaseDataset):

Expand Down Expand Up @@ -156,6 +164,12 @@ def math_postprocess(text: str) -> str:
# text.split('Final Answer: ', 1)[-1].split('\n\n')[0])


@TEXT_POSTPROCESSORS.register_module('math_judement_preprocess')
def math_judement_preprocess(text: str) -> str:
"""Preprocess prediction before judgement."""
return extract_answer(text)


@TEXT_POSTPROCESSORS.register_module('math_postprocess_v2')
def math_postprocess_v2(text: str) -> str:

Expand Down
9 changes: 4 additions & 5 deletions opencompass/openicl/icl_evaluator/lm_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from opencompass.registry import ICL_PROMPT_TEMPLATES
from opencompass.utils import build_dataset_from_cfg, build_model_from_cfg
from opencompass.utils.logging import get_logger
from opencompass.utils.text_postprocessors import first_number_postprocess
from opencompass.utils.types import get_type_from_cfg


def extract_dicts(data):
Expand Down Expand Up @@ -80,7 +78,7 @@ class LMEvaluator:
dataset_cfg (ConfigDict, optional): The config of the dataset to be
evaluated.
pack_all_predictions (bool, optional): For multiround evaluation, judge all round or judge every single round.
postprocessor (ConfigDict): The model prediction's postprocessor
pred_postprocessor (ConfigDict): The model prediction's postprocessor
config.
"""

Expand All @@ -92,7 +90,7 @@ def __init__(
meta_review_prompt_template: Optional[ConfigDict] = None,
pack_all_predictions: Optional[bool] = False,
dataset_cfg: Optional[ConfigDict] = None,
postprocessor: ConfigDict = dict(type=first_number_postprocess)
pred_postprocessor: Optional[ConfigDict] = None,
) -> None:
self.output_path = output_path
out_dir, out_name = osp.split(output_path)
Expand All @@ -112,7 +110,6 @@ def __init__(
batch_size=batch_size,
output_json_filepath=out_dir,
output_json_filename=out_name)
self.postprocessor = get_type_from_cfg(postprocessor)
self.logger = get_logger()
self.dataset_cfg = dataset_cfg
self.pack_all_predictions = pack_all_predictions
Expand Down Expand Up @@ -163,7 +160,9 @@ def score(self,
): #single chat for format like [['xxx', 'xxxx'], ['xxx', 'xxxx']]
for i in range(len(predictions)):
key = 'prediction' if i == 0 else f'prediction{i + 1}'
gold_key = 'obj_gold'
pred_dict[key] = predictions[i]
pred_dict[gold_key] = references
if judgements:
for i in range(len(judgements)):
key = 'judgement' if i == 0 else f'judgement{i + 1}'
Expand Down
1 change: 1 addition & 0 deletions opencompass/summarizers/subjective/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa: F401, E501
from .alignmentbench import AlignmentBenchSummarizer
from .all_obj import AllObjSummarizer
from .alpacaeval import AlpacaSummarizer
from .compass_arena import CompassArenaSummarizer
from .corev2 import Corev2Summarizer
Expand Down
122 changes: 122 additions & 0 deletions opencompass/summarizers/subjective/all_obj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# flake8: noqa: E501
import csv
import os
import os.path as osp
import re
from collections import defaultdict
from datetime import datetime

import numpy as np
from mmengine import ConfigDict
from prettytable import from_csv

from opencompass.utils import dataset_abbr_from_cfg, model_abbr_from_cfg

from .utils import get_judgeanswer_and_reference, get_outdir


def post_process_allobj(judgement: str):
"""Input a string like below:

xxx[[correct]]xxx, and extract the judge
"""
pattern = r'(?i)\[(incorrect|correct|正确|错误)\]'
matched_result = re.findall(pattern, judgement)
if matched_result:
content = matched_result[0].lower()
if content in ['correct', '正确']:
return {'score': 1}
elif content in ['incorrect', '错误']:
return {'score': 0}
else:
return None


def get_capability_results(
judged_answers,
references,
fout,
fout_flag,
model,
):
capability_ratings = defaultdict(int)
capability_counts = defaultdict(int)
for ans, ref in zip(judged_answers, references):
capability_ratings['total'] += ans['score']
capability_counts['total'] += 1

capability_avg_ratings = defaultdict(float)

for capability, total_score in capability_ratings.items():
capability_avg_ratings[
capability] = total_score / capability_counts[capability]
columns = list(capability_avg_ratings.keys())
columns.insert(0, columns.pop(columns.index('total')))
with open(fout, 'a+', newline='') as csvfile:
writer = csv.writer(csvfile)
if fout_flag == 0:
writer.writerow(['model'] + columns)
writer.writerow([model] +
[capability_avg_ratings[column] for column in columns])


class AllObjSummarizer:
"""Do the subjectivity analyze based on evaluation results.

Args:
config (ConfigDict): The configuration object of the evaluation task.
It's expected to be filled out at runtime.
"""

def __init__(self, config: ConfigDict, judge_type='single') -> None:
self.judge_type = judge_type
self.tasks = []
self.cfg = config
if self.judge_type == 'single':
self.eval_model_cfgs = self.cfg['eval']['partitioner']['models']
self.eval_model_abbrs = [
model_abbr_from_cfg(model) for model in self.eval_model_cfgs
]
elif self.judge_type == 'pair':
self.base_models = self.cfg['eval']['partitioner']['base_models']
self.compare_models = self.cfg['eval']['partitioner'][
'compare_models']
self.judge_abbr = model_abbr_from_cfg(self.cfg['judge_models'][0])
self.judge_map = {'single': post_process_allobj}
self.judge_function = self.judge_map[self.judge_type]

def summarize(self,
time_str: str = datetime.now().strftime('%Y%m%d_%H%M%S')):
"""Summarize the subjectivity analysis based on evaluation results.

Args:
time_str (str): Timestamp for file naming.

Returns:
pd.DataFrame: The summary results.
"""
if self.judge_type == 'single':
dataset_cfgs = self.cfg['datasets']
judge_model = self.judge_abbr
output_dir, results_folder = get_outdir(self.cfg, time_str)
for dataset in dataset_cfgs:
dataset_abbr = dataset_abbr_from_cfg(dataset)
fout = osp.join(
output_dir,
'judged-by--' + judge_model + '-' + dataset_abbr + '.csv')
fout_flag = 0
for eval_model_abbr in self.eval_model_abbrs:
subdir = eval_model_abbr + '_judged-by--' + self.judge_abbr
subdir_path = os.path.join(results_folder, subdir)
if os.path.isdir(subdir_path):
model = eval_model_abbr
judged_answers, references = get_judgeanswer_and_reference(
dataset, subdir_path, self.judge_function)
get_capability_results(judged_answers, references,
fout, fout_flag, model)
fout_flag += 1
else:
print(subdir_path + ' is not exist! please check!')
with open(fout, 'r') as f:
x = from_csv(f)
print(x)
Loading
Loading