Skip to content

Commit

Permalink
Add toolbench
Browse files Browse the repository at this point in the history
  • Loading branch information
liushz committed Aug 28, 2023
1 parent b2d602f commit 0bf908e
Show file tree
Hide file tree
Showing 7 changed files with 2,908 additions and 0 deletions.
4 changes: 4 additions & 0 deletions configs/datasets/ToolBench/toolbench_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .toolbench_gen_3131 import toolbench_datasets # noqa: F401, F403
31 changes: 31 additions & 0 deletions configs/datasets/ToolBench/toolbench_gen_3131.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.datasets import ToolBenchDataset, ToolBenchEvaluator

toolbench_reader_cfg = dict(
input_columns=['query', 'avaliable_tools'],
output_column='answer'
)

toolbench_infer_cfg = dict(
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer))

toolbench_eval_cfg = dict(
evaluator=dict(type=ToolBenchEvaluator), pred_postprocessor=dict(type='toolbench')) # use the same processor to find answer

splits = ['G1_answer_converted','G2_answer_converted', 'G3_answer_converted']

toolbench_datasets = []

for split in splits:
toolbench_datasets.append(
dict(
abbr='toolbench_{}'.format(split),
type=ToolBenchDataset,
path='./data/ToolBench',
name = split,
reader_cfg=toolbench_reader_cfg,
infer_cfg=toolbench_infer_cfg,
eval_cfg=toolbench_eval_cfg))
802 changes: 802 additions & 0 deletions data/ToolBench/G1_answer_converted.json

Large diffs are not rendered by default.

721 changes: 721 additions & 0 deletions data/ToolBench/G2_answer_converted.json

Large diffs are not rendered by default.

1,309 changes: 1,309 additions & 0 deletions data/ToolBench/G3_answer_converted.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,5 @@
from .xiezhi import XiezhiDataset, XiezhiRetriever # noqa: F401, F403
from .xlsum import * # noqa: F401, F403
from .xsum import * # noqa: F401, F403
from .msra import * # noqa: F401, F403
from .toolbench import * # noqa: F401, F403
39 changes: 39 additions & 0 deletions opencompass/datasets/toolbench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from opencompass.registry import TEXT_POSTPROCESSORS, LOAD_DATASET, ICL_EVALUATORS
from .base import BaseDataset
from opencompass.openicl.icl_evaluator import BaseEvaluator

import os
import json

from datasets import Dataset, load_dataset, DatasetDict


@TEXT_POSTPROCESSORS.register_module('toolbench_dataset')
def toolbench_dataset_postprocess(text: str) -> str:
pass


@ICL_EVALUATORS.register_module()
class ToolBenchEvaluator(BaseEvaluator):

def score(self, predictions, references):
pass


@LOAD_DATASET.register_module()
class ToolBenchDataset(BaseDataset):

@staticmethod
def load(path: str, name: str):
dataset = []
# splits = ['G1_answer_converted', 'G1_answer_converted', 'G1_answer_converted']
if '.json' not in name:
name += '.json'
path = os.path.join(path, name)
with open(path, 'r') as f:
all_data = json.load(f)
dataset = list(all_data.values())

return Dataset.from_list(dataset)


0 comments on commit 0bf908e

Please sign in to comment.