forked from open-compass/opencompass
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add custom model postprocess function
- Loading branch information
liushz
committed
Sep 11, 2024
1 parent
7c7fa36
commit 4044407
Showing
11 changed files
with
665 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
configs/datasets/gsm8k/gsm8k_model_postprocess_gen_a58960.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess | ||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2 | ||
from opencompass.utils.model_postprocessors import navie_model_postprocess | ||
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE | ||
|
||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer') | ||
|
||
gsm8k_infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
round=[ | ||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'), | ||
], | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer, max_out_len=512), | ||
) | ||
|
||
# # You can write your own postprocess prompt like: | ||
# GSM8K_NAVIE_PROMPT_TEMPLATE = """ | ||
# There is a detailed explanation of the final answer you should extract: | ||
# 1. ... | ||
# 2. ... | ||
# ... | ||
# """ | ||
|
||
gsm8k_eval_cfg = dict( | ||
evaluator=dict(type=MATHEvaluator, version='v2'), | ||
pred_postprocessor=dict(type=math_postprocess_v2), | ||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess), | ||
model_postprocessor=dict( | ||
type=navie_model_postprocess, | ||
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE, | ||
model_name='', | ||
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1') | ||
) | ||
|
||
gsm8k_datasets = [ | ||
dict( | ||
abbr='gsm8k', | ||
type=GSM8KDataset, | ||
path='opencompass/gsm8k', | ||
reader_cfg=gsm8k_reader_cfg, | ||
infer_cfg=gsm8k_infer_cfg, | ||
eval_cfg=gsm8k_eval_cfg, | ||
) | ||
] |
141 changes: 141 additions & 0 deletions
141
configs/datasets/mmlu/mmlu_model_postprocess_gen_4d595a.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import FixKRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator | ||
from opencompass.datasets import MMLUDataset | ||
from opencompass.utils.text_postprocessors import first_option_postprocess | ||
from opencompass.utils.model_postprocessors import navie_model_postprocess | ||
from opencompass.utils.postprocessors.naive import OPTION_NAVIE_PROMPT_TEMPLATE | ||
|
||
|
||
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader | ||
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar | ||
|
||
mmlu_reader_cfg = dict( | ||
input_columns=['input', 'A', 'B', 'C', 'D'], | ||
output_column='target', | ||
train_split='dev') | ||
|
||
mmlu_all_sets = [ | ||
'college_biology', | ||
'college_chemistry', | ||
'college_computer_science', | ||
'college_mathematics', | ||
'college_physics', | ||
'electrical_engineering', | ||
'astronomy', | ||
'anatomy', | ||
'abstract_algebra', | ||
'machine_learning', | ||
'clinical_knowledge', | ||
'global_facts', | ||
'management', | ||
'nutrition', | ||
'marketing', | ||
'professional_accounting', | ||
'high_school_geography', | ||
'international_law', | ||
'moral_scenarios', | ||
'computer_security', | ||
'high_school_microeconomics', | ||
'professional_law', | ||
'medical_genetics', | ||
'professional_psychology', | ||
'jurisprudence', | ||
'world_religions', | ||
'philosophy', | ||
'virology', | ||
'high_school_chemistry', | ||
'public_relations', | ||
'high_school_macroeconomics', | ||
'human_sexuality', | ||
'elementary_mathematics', | ||
'high_school_physics', | ||
'high_school_computer_science', | ||
'high_school_european_history', | ||
'business_ethics', | ||
'moral_disputes', | ||
'high_school_statistics', | ||
'miscellaneous', | ||
'formal_logic', | ||
'high_school_government_and_politics', | ||
'prehistory', | ||
'security_studies', | ||
'high_school_biology', | ||
'logical_fallacies', | ||
'high_school_world_history', | ||
'professional_medicine', | ||
'high_school_mathematics', | ||
'college_medicine', | ||
'high_school_us_history', | ||
'sociology', | ||
'econometrics', | ||
'high_school_psychology', | ||
'human_aging', | ||
'us_foreign_policy', | ||
'conceptual_physics', | ||
] | ||
|
||
mmlu_datasets = [] | ||
for _name in mmlu_all_sets: | ||
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.' | ||
mmlu_infer_cfg = dict( | ||
ice_template=dict( | ||
type=PromptTemplate, | ||
template=dict(round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt= | ||
f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: ' | ||
), | ||
dict(role='BOT', prompt='{target}\n') | ||
]), | ||
), | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
begin='</E>', | ||
round=[ | ||
dict( | ||
role='HUMAN', | ||
prompt=f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: ' | ||
), | ||
], | ||
), | ||
ice_token='</E>', | ||
), | ||
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]), | ||
inferencer=dict(type=GenInferencer), | ||
) | ||
|
||
# # You can write your own postprocess prompt like: | ||
# MMLU_NAVIE_PROMPT_TEMPLATE = """ | ||
# There is a detailed explanation of the final answer you should extract: | ||
# 1. ... | ||
# 2. ... | ||
# ... | ||
# """ | ||
|
||
mmlu_eval_cfg = dict( | ||
evaluator=dict(type=AccwithDetailsEvaluator), | ||
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'), | ||
model_postprocessor=dict( | ||
type=navie_model_postprocess, | ||
custom_instruction=OPTION_NAVIE_PROMPT_TEMPLATE, | ||
model_name='', | ||
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1') | ||
) | ||
|
||
|
||
mmlu_datasets.append( | ||
dict( | ||
abbr=f'lukaemon_mmlu_{_name}', | ||
type=MMLUDataset, | ||
path='opencompass/mmlu', | ||
name=_name, | ||
reader_cfg=mmlu_reader_cfg, | ||
infer_cfg=mmlu_infer_cfg, | ||
eval_cfg=mmlu_eval_cfg, | ||
)) | ||
|
||
del _name, _hint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
52 changes: 52 additions & 0 deletions
52
opencompass/configs/datasets/gsm8k/gsm8k_model_postprocess_gen_a58960.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from opencompass.openicl.icl_prompt_template import PromptTemplate | ||
from opencompass.openicl.icl_retriever import ZeroRetriever | ||
from opencompass.openicl.icl_inferencer import GenInferencer | ||
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess | ||
from opencompass.datasets import MATHEvaluator, math_postprocess_v2 | ||
from opencompass.utils.model_postprocessors import navie_model_postprocess | ||
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE | ||
|
||
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer') | ||
|
||
gsm8k_infer_cfg = dict( | ||
prompt_template=dict( | ||
type=PromptTemplate, | ||
template=dict( | ||
round=[ | ||
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'), | ||
], | ||
), | ||
), | ||
retriever=dict(type=ZeroRetriever), | ||
inferencer=dict(type=GenInferencer, max_out_len=512), | ||
) | ||
|
||
# # You can write your own postprocess prompt like: | ||
# GSM8K_NAVIE_PROMPT_TEMPLATE = """ | ||
# There is a detailed explanation of the final answer you should extract: | ||
# 1. ... | ||
# 2. ... | ||
# ... | ||
# """ | ||
|
||
gsm8k_eval_cfg = dict( | ||
evaluator=dict(type=MATHEvaluator, version='v2'), | ||
pred_postprocessor=dict(type=math_postprocess_v2), | ||
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess), | ||
model_postprocessor=dict( | ||
type=navie_model_postprocess, | ||
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE, | ||
model_name='', | ||
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1') | ||
) | ||
|
||
gsm8k_datasets = [ | ||
dict( | ||
abbr='gsm8k', | ||
type=GSM8KDataset, | ||
path='opencompass/gsm8k', | ||
reader_cfg=gsm8k_reader_cfg, | ||
infer_cfg=gsm8k_infer_cfg, | ||
eval_cfg=gsm8k_eval_cfg, | ||
) | ||
] |
Oops, something went wrong.