Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Betswish committed Nov 17, 2023
1 parent 40b8d1d commit 68ab640
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 10 deletions.
4 changes: 1 addition & 3 deletions src/genbench/tasks/cross_lingual_consistency/config.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@

data_source: {
type: 'manual',
BMLAMA17: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/BMLAMA17/',
BMLAMA53: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/BMLAMA53/',
test: 'https://placeholder'
test: 'https://raw.githubusercontent.com/Betswish/genbench_cbt/BMLAMA/src/genbench/tasks/cross_lingual_consistency/',
},


Expand Down
21 changes: 14 additions & 7 deletions src/genbench/tasks/cross_lingual_consistency/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import pickle as pkl
from collections import defaultdict
from typing import Any, Dict, List
from typing import Any, Callable, List, Mapping, Optional, Union, Dict

import datasets
import numpy as np
Expand All @@ -20,16 +20,23 @@
pipeline,
)

from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset

from genbench import Task

from genbench.api import DatasetSplit, EvaluationResult, PreparationStrategy, TaskInterface, TaskType
from genbench.task_config import PromptBuilderConfig, TaskConfig
from genbench.utils.file import load_jsonnet
from genbench.utils.logging import get_logger
from genbench.utils.tasks import get_task_dir

class CrossLingualConsistencyTask(Task):
def _load_data_source(
self,
mini,
lang1,
lang2,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset]:
):
"""
Private method to load the data source based on the type specified in the configuration.
Expand All @@ -48,10 +55,10 @@ def _load_data_source(
if self.config.data_source.type == "manual":
if mini:
# langs = ['en', 'fr', 'nl', 'es', 'ru', 'ja', 'zh', 'ko', 'vi', 'el', 'hu', 'he', 'tr', 'ca', 'ar', 'uk', 'fa']
file_path = self.config.data_source.BMLAMA17
file_path = self.config.data_source.test + 'BMLAMA17/'
else:
# langs = ['ca', 'az', 'en', 'ar', 'uk', 'fa', 'tr', 'it', 'el', 'ru', 'hr', 'hi', 'sv', 'sq', 'fr', 'ga', 'eu', 'de', 'nl', 'et', 'he', 'es', 'bn', 'ms', 'sr', 'hy', 'ur', 'hu', 'la', 'sl', 'cs', 'af', 'gl', 'fi', 'ro', 'ko', 'cy', 'th', 'be', 'id', 'pt', 'vi', 'ka', 'ja', 'da', 'bg', 'zh', 'pl', 'lv', 'sk', 'lt', 'ta', 'ceb']
file_path = self.config.data_source.BMLAMA53
file_path = self.config.data_source.test + 'BMLAMA53/'

data_files = dict()
for lang in [lang1, lang2]:
Expand Down Expand Up @@ -86,7 +93,7 @@ def _load_data_source(
else:
raise ValueError(f"Unsupported data source type: {self.config.data_source.type}")

def get_datasets_raw(self, mini, lang1, lang2) -> Mapping[DatasetSplit, Dataset]:
def get_datasets_raw(self, mini=True, lang1='en', lang2='es'):
data_source = self._load_data_source(mini=mini, lang1=lang1, lang2=lang2)

if self.config.split_file is not None:
Expand Down Expand Up @@ -122,10 +129,10 @@ def get_prepared_datasets(
preparation_strategy: PreparationStrategy,
shot_list: Optional[List[int]] = None,
random_seed: int = 42,
mini=None,
mini=True,
lang1=None,
lang2=None,
) -> Union[Mapping[DatasetSplit, Dataset], Mapping[int, Dataset]]:
):
if not mini:
raise ValueError("Value for 'mini=True/False' is required for this task")
if not lang1 or not lang2:
Expand Down

0 comments on commit 68ab640

Please sign in to comment.