diff --git a/src/cnlpt/cnlp_processors.py b/src/cnlpt/cnlp_processors.py index f446898e..1918b46e 100644 --- a/src/cnlpt/cnlp_processors.py +++ b/src/cnlpt/cnlp_processors.py @@ -185,8 +185,11 @@ def __init__(self, data_dir: str, tasks: Set[str] = None, max_train_items=-1): dataset_tasks = first_split.features.keys() - set( ["text", "text_a", "text_b"] ) - active_tasks = set(tasks).intersection(dataset_tasks) - active_tasks = list(active_tasks) + if tasks is None: + active_tasks = list(dataset_tasks) + else: + active_tasks = set(tasks).intersection(dataset_tasks) + active_tasks = list(active_tasks) active_tasks.sort() self.dataset.task_output_modes = {} elif ext_check_file.endswith("json"):