Skip to content

Commit

Permalink
[Feature] Evaluating acc based on minimum edit distance, update SIQA (#…
Browse files Browse the repository at this point in the history
…130)

* [Feature] Support evaluating acc based on minimum edit distance, update SIQA

* update
  • Loading branch information
gaotongxiao authored Aug 1, 2023
1 parent e9b7b8a commit c00179d
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
8 changes: 3 additions & 5 deletions configs/datasets/siqa/siqa_gen_e78df3.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.openicl.icl_evaluator import EDAccEvaluator
from opencompass.datasets import siqaDataset_V2
from opencompass.utils.text_postprocessors import first_capital_postprocess

siqa_reader_cfg = dict(
input_columns=["context", "question", "answerA", "answerB", "answerC"],
output_column="label",
output_column="all_labels",
test_split="validation")

siqa_infer_cfg = dict(
Expand All @@ -27,9 +26,8 @@
)

siqa_eval_cfg = dict(
evaluator=dict(type=AccEvaluator),
evaluator=dict(type=EDAccEvaluator),
pred_role="BOT",
pred_postprocessor=dict(type=first_capital_postprocess),
)

siqa_datasets = [
Expand Down
9 changes: 9 additions & 0 deletions opencompass/datasets/siqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ def load(**kwargs):
dataset = load_dataset(**kwargs)

def preprocess(example):
example['all_labels'] = {
'candidates': [
f'A. {example["answerA"]}',
f'B. {example["answerB"]}',
f'C. {example["answerC"]}',
],
'label':
int(example['label']) - 1
}
example['label'] = ' ABC'[int(example['label'])]
return example

Expand Down
49 changes: 49 additions & 0 deletions opencompass/openicl/icl_evaluator/icl_hf_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,52 @@ def _postprocess(self, scores: dict) -> dict:
dict: postprocessed scores.
"""
return scores['f1']


@ICL_EVALUATORS.register_module()
class EDAccEvaluator(AccEvaluator):
"""Edit distance based accuracy evaluator.
This implementation requires the un-postprocessed outputs from the model,
and the reference list where each item is structured as:
.. code-block:: python
{
'candidates': [], # a list of informative answer candidates
'label': 0, # the index of the gold answer
}
It always matches the model's output to a valid answer with the citerion
as the minimum editing distance.
"""

def __init__(self) -> None:
super().__init__()
from rapidfuzz.distance import Levenshtein
self.dist = Levenshtein.distance

def _preprocess(self, predictions: List, references: List) -> dict:
"""Preprocess the final predictions and references to needed format.
Args:
predictions (List): List of predictions of each sample.
references (List): List of targets for each sample.
Returns:
dict: preprocessed results.
"""

preds = []
golds = []

for i in range(len(predictions)):
pred, ref = predictions[i], references[i]
dists = [self.dist(pred, cand) for cand in ref['candidates']]
preds.append(np.argmin(dists))
golds.append(ref['label'])

return {
'predictions': preds,
'references': golds,
}

0 comments on commit c00179d

Please sign in to comment.