Skip to content

Commit

Permalink
[Refactor] Move fix_id_list to Retriever (#442)
Browse files Browse the repository at this point in the history
* [Refactor] Move fix_id_list to Retriever

* update

* move to base

* fix
  • Loading branch information
gaotongxiao authored Oct 7, 2023
1 parent 767c12a commit 119bfd1
Show file tree
Hide file tree
Showing 30 changed files with 68 additions and 98 deletions.
4 changes: 2 additions & 2 deletions configs/datasets/GLUE_CoLA/GULE_CoLA_ppl_77d0df.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[17, 18, 19, 20, 21]))
retriever=dict(type=FixKRetriever, fix_id_list=[17, 18, 19, 20, 21]),
inferencer=dict(type=PPLInferencer))

CoLA_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/GLUE_QQP/GLUE_QQP_ppl_250d00.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
},
ice_token='</E>',
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer))

QQP_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )

Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/ceval/ceval_gen_2daf24.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

ceval_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/ceval/ceval_gen_5f30c7.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

ceval_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/ceval/ceval_ppl_578f8d.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)

ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/ceval/ceval_ppl_93e5ce.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)

ceval_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/cmb/cmb_gen_72cbb7.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

cmb_datasets.append(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/cmmlu/cmmlu_gen_c13365.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@
]),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

cmmlu_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/cmmlu/cmmlu_ppl_8b9c76.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)

cmmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator))
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/mmlu/mmlu_gen_23a9a9.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
dict(role='BOT', prompt='{target}\n')
])),
prompt_template=mmlu_prompt_template,
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]))
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer))

mmlu_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/mmlu/mmlu_gen_5d1409.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

mmlu_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/mmlu/mmlu_gen_79e572.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
f"{_hint}</E>{{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer:",
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

mmlu_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/mmlu/mmlu_gen_a484b3.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=GenInferencer),
)

mmlu_eval_cfg = dict(
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/mmlu/mmlu_ppl_ac766d.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
},
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=PPLInferencer, fix_id_list=[0, 1, 2, 3, 4]),
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
inferencer=dict(type=PPLInferencer),
)

mmlu_eval_cfg = dict(evaluator=dict(type=AccEvaluator), )
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/nq/nq_gen_0356ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)

nq_eval_cfg = dict(evaluator=dict(type=NQEvaluator), pred_role="BOT")
Expand Down
4 changes: 2 additions & 2 deletions configs/datasets/triviaqa/triviaqa_gen_0356ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@
),
ice_token="</E>",
),
retriever=dict(type=FixKRetriever),
inferencer=dict(type=GenInferencer, max_out_len=50, fix_id_list=list(range(k))),
retriever=dict(type=FixKRetriever, fix_id_list=list(range(k))),
inferencer=dict(type=GenInferencer, max_out_len=50),
)

triviaqa_eval_cfg = dict(evaluator=dict(type=TriviaQAEvaluator), pred_role="BOT")
Expand Down
4 changes: 2 additions & 2 deletions docs/en/prompt/prompt_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ infer_cfg = dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # Method used to generate predictions.
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # Definition of how to retrieve in-context examples.
inferencer=dict(type=GenInferencer), # Method used to generate predictions.
)
```

Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/prompt/prompt_template.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ infer_cfg=dict(
template='Solve the following questions.\n</E>{question}\n{answer}',
ice_token="</E>"
),
retriever=dict(type=FixKRetriever), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer, fix_id_list=[0, 1]), # 使用何种方式推理得到 prediction
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1]), # 定义 in context example 的获取方式
inferencer=dict(type=GenInferencer), # 使用何种方式推理得到 prediction
)
```

Expand Down
5 changes: 1 addition & 4 deletions opencompass/openicl/icl_inferencer/icl_agent_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# Create tmp json file for saving intermediate results and future
# resuming
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_attack_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
dataset_cfg: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
Expand All @@ -78,7 +77,6 @@ def __init__(
self.output_column = dataset_cfg['reader_cfg']['output_column']
self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list

if self.model.is_api and save_every is None:
save_every = 1
Expand All @@ -94,10 +92,7 @@ def predict(self, adv_prompt) -> List:
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in self.retriever.__class__.__name__:
ice_idx_list = self.retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = self.retriever.retrieve()
ice_idx_list = self.retriever.retrieve()

# 3. Generate prompts for testing input
prompt_list, label_list = self.get_generation_prompt_list_from_retriever_indices( # noqa
Expand Down
10 changes: 7 additions & 3 deletions opencompass/openicl/icl_inferencer/icl_base_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ class BaseInferencer:
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
api_name (:obj:`str`, optional): Name of API service.
call_api (:obj:`bool`): If ``True``, an API for LM models will be used,
determined by :obj:`api_name`.
"""
model = None

Expand All @@ -38,8 +35,15 @@ def __init__(
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
**kwargs,
) -> None:

if fix_id_list:
raise ValueError('Passing fix_id_list to Inferencer is no longer '
'allowed. Please pass it to FixKRetriever '
'instead.')

self.model = model

self.max_seq_len = max_seq_len
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_clp_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
fix_id_list: Optional[List[int]] = None,
single_token: bool = True,
**kwargs) -> None:
super().__init__(
Expand All @@ -66,7 +65,6 @@ def __init__(
**kwargs,
)

self.fix_id_list = fix_id_list
# TODO: support multiple token
assert single_token, 'Only support single token choice currently.'
self.single_token = single_token
Expand Down Expand Up @@ -103,10 +101,7 @@ def inference(self,
raise ValueError(err_msg)

# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate in-context examples for testing inputs
for idx in range(len(ice_idx_list)):
Expand Down
12 changes: 2 additions & 10 deletions opencompass/openicl/icl_inferencer/icl_gen_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
Expand All @@ -64,7 +63,6 @@ def __init__(

self.gen_field_replace_token = gen_field_replace_token
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list

if self.model.is_api and save_every is None:
save_every = 1
Expand All @@ -85,10 +83,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
Expand Down Expand Up @@ -220,10 +215,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_ppl_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
labels: Optional[List] = None,
fix_id_list: Optional[List[int]] = None,
**kwargs) -> None:
super().__init__(
model=model,
Expand All @@ -53,7 +52,6 @@ def __init__(
)

self.labels = labels
self.fix_id_list = fix_id_list

def inference(self,
retriever: BaseRetriever,
Expand All @@ -75,10 +73,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if self.fix_id_list:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Get labels of all the classes
if self.labels is None:
Expand Down
7 changes: 1 addition & 6 deletions opencompass/openicl/icl_inferencer/icl_sc_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = None,
fix_id_list: Optional[List[int]] = None,
sc_size: Optional[int] = 1,
infer_type: Optional[str] = '',
generation_kwargs: dict = {},
Expand All @@ -69,7 +68,6 @@ def __init__(
self.gen_field_replace_token = gen_field_replace_token
self.generation_kwargs = generation_kwargs
self.max_out_len = max_out_len
self.fix_id_list = fix_id_list
self.sc_size = sc_size

if self.model.is_api and save_every is None:
Expand All @@ -91,10 +89,7 @@ def inference(self,
output_json_filename = self.output_json_filename

# 2. Get results of retrieval process
if 'Fix' in retriever.__class__.__name__:
ice_idx_list = retriever.retrieve(self.fix_id_list)
else:
ice_idx_list = retriever.retrieve()
ice_idx_list = retriever.retrieve()

# 3. Generate prompts for testing input
prompt_list = self.get_generation_prompt_list_from_retriever_indices(
Expand Down
Loading

0 comments on commit 119bfd1

Please sign in to comment.