diff --git a/src/genbench/tasks/cross_lingual_consistency/config.jsonnet b/src/genbench/tasks/cross_lingual_consistency/config.jsonnet index 7bb1b15..bce14a1 100644 --- a/src/genbench/tasks/cross_lingual_consistency/config.jsonnet +++ b/src/genbench/tasks/cross_lingual_consistency/config.jsonnet @@ -23,9 +23,7 @@ data_source: { type: 'manual', - BMLAMA17: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/BMLAMA17/', - BMLAMA53: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/BMLAMA53/', - test: 'https://placeholder' + test: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/', }, diff --git a/src/genbench/tasks/cross_lingual_consistency/task.py b/src/genbench/tasks/cross_lingual_consistency/task.py index 076b384..c2b8856 100644 --- a/src/genbench/tasks/cross_lingual_consistency/task.py +++ b/src/genbench/tasks/cross_lingual_consistency/task.py @@ -2,7 +2,7 @@ import os import pickle as pkl from collections import defaultdict -from typing import Any, Dict, List +from typing import Any, Callable, List, Mapping, Optional, Union, Dict import datasets import numpy as np @@ -20,8 +20,15 @@ pipeline, ) +from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset + from genbench import Task +from genbench.api import DatasetSplit, EvaluationResult, PreparationStrategy, TaskInterface, TaskType +from genbench.task_config import PromptBuilderConfig, TaskConfig +from genbench.utils.file import load_jsonnet +from genbench.utils.logging import get_logger +from genbench.utils.tasks import get_task_dir class CrossLingualConsistencyTask(Task): def _load_data_source( @@ -29,7 +36,7 @@ def _load_data_source( mini, lang1, lang2, - ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]: + ): """ Private method to load the data source based on the type specified in the configuration. @@ -48,10 +55,10 @@ def _load_data_source( if self.config.data_source.type == "manual": if mini: # langs = ['en', 'fr', 'nl', 'es', 'ru', 'ja', 'zh', 'ko', 'vi', 'el', 'hu', 'he', 'tr', 'ca', 'ar', 'uk', 'fa'] - file_path = self.config.data_source.BMLAMA17 + file_path = self.config.data_source.test + 'BMLAMA17/' else: # langs = ['ca', 'az', 'en', 'ar', 'uk', 'fa', 'tr', 'it', 'el', 'ru', 'hr', 'hi', 'sv', 'sq', 'fr', 'ga', 'eu', 'de', 'nl', 'et', 'he', 'es', 'bn', 'ms', 'sr', 'hy', 'ur', 'hu', 'la', 'sl', 'cs', 'af', 'gl', 'fi', 'ro', 'ko', 'cy', 'th', 'be', 'id', 'pt', 'vi', 'ka', 'ja', 'da', 'bg', 'zh', 'pl', 'lv', 'sk', 'lt', 'ta', 'ceb'] - file_path = self.config.data_source.BMLAMA53 + file_path = self.config.data_source.test + 'BMLAMA53/' data_files = dict() for lang in [lang1, lang2]: @@ -86,7 +93,7 @@ def _load_data_source( else: raise ValueError(f"Unsupported data source type: {self.config.data_source.type}") - def get_datasets_raw(self, mini, lang1, lang2) -> Mapping[DatasetSplit, Dataset]: + def get_datasets_raw(self, mini=True, lang1='en', lang2='es'): data_source = self._load_data_source(mini=mini, lang1=lang1, lang2=lang2) if self.config.split_file is not None: @@ -122,10 +129,10 @@ def get_prepared_datasets( preparation_strategy: PreparationStrategy, shot_list: Optional[List[int]] = None, random_seed: int = 42, - mini=None, + mini=True, lang1=None, lang2=None, - ) -> Union[Mapping[DatasetSplit, Dataset], Mapping[int, Dataset]]: + ): if not mini: raise ValueError("Value for 'mini=True/False' is required for this task") if not lang1 or not lang2: