diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 1bd283d7b..3060a5aed 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,7 +1,7 @@ import os import uuid from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -14,15 +14,7 @@ import torchvision from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES -from autoPyTorch.datasets.resampling_strategy import ( - CrossValFunc, - CrossValFuncs, - CrossValTypes, - DEFAULT_RESAMPLING_PARAMETERS, - HoldOutFunc, - HoldOutFuncs, - HoldoutValTypes -) +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes from autoPyTorch.utils.common import FitRequirement BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset] @@ -77,7 +69,7 @@ def __init__( dataset_name: Optional[str] = None, val_tensors: Optional[BaseDatasetInputType] = None, test_tensors: Optional[BaseDatasetInputType] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout, resampling_strategy_args: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, seed: Optional[int] = 42, @@ -94,14 +86,14 @@ def __init__( validation data test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute): test data - resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), - (default=HoldoutValTypes.holdout_validation): + resampling_strategy (Union[CrossValTypes, HoldoutTypes]), + (default=HoldoutTypes.holdout): strategy to split the training data. resampling_strategy_args (Optional[Dict[str, Any]]): arguments required for the chosen resampling strategy. If None, uses the default values provided in DEFAULT_RESAMPLING_PARAMETERS in ```datasets/resampling_strategy.py```. - shuffle: Whether to shuffle the data before performing splits + shuffle: Whether to shuffle the data when performing splits seed (int), (default=1): seed to be used for reproducibility. train_transforms (Optional[torchvision.transforms.Compose]): Additional Transforms to be applied to the training data @@ -116,12 +108,12 @@ def __init__( if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, CrossValFunc] = {} - self.holdout_validators: Dict[str, HoldOutFunc] = {} self.random_state = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy self.resampling_strategy_args = resampling_strategy_args + self.is_stratify = self.resampling_strategy.get('stratify', False) + self.task_type: Optional[str] = None self.issparse: bool = issparse(self.train_tensors[0]) self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:] @@ -137,9 +129,6 @@ def __init__( # TODO: Look for a criteria to define small enough to preprocess self.is_small_preprocess = True - # Make sure cross validation splits are created once - self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes) - self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes) self.splits = self.get_splits_from_resampling_strategy() # We also need to be able to transform the data, be it for pre-processing @@ -205,7 +194,30 @@ def __len__(self) -> int: return self.train_tensors[0].shape[0] def _get_indices(self) -> np.ndarray: - return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self)) + return np.arange(len(self)) + + def _process_resampling_strategy_args(self) -> None: + if not any(isinstance(self.resampling_strategy, val_type) + for val_type in [HoldoutTypes, CrossValTypes]): + raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.") + + if self.resampling_strategy_args is not None and \ + not isinstance(self.resampling_strategy_args, dict): + + raise TypeError("resampling_strategy_args must be dict or None," + f" but got {type(self.resampling_strategy_args)}") + + val_share = self.resampling_strategy_args.get('val_share', None) + num_splits = self.resampling_strategy_args.get('num_splits', None) + + if val_share is not None and (val_share < 0 or val_share > 1): + raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.") + + if num_splits is not None: + if num_splits <= 0: + raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.") + elif not isinstance(num_splits, int): + raise ValueError(f"`num_splits` must be an integer, got {num_splits}.") def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]: """ @@ -214,100 +226,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int] Returns (List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format """ - splits = [] - if isinstance(self.resampling_strategy, HoldoutValTypes): - val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'val_share', None) - if self.resampling_strategy_args is not None: - val_share = self.resampling_strategy_args.get('val_share', val_share) - splits.append( - self.create_holdout_val_split( - holdout_val_type=self.resampling_strategy, - val_share=val_share, - ) + # check if the requirements are met and if we can get splits + self._process_resampling_strategy_args() + + labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None + + if isinstance(self.resampling_strategy, HoldoutTypes): + val_share = self.resampling_strategy_args['val_share'] + + return self.resampling_strategy( + random_state=self.random_state, + val_share=val_share, + shuffle=self.shuffle, + indices=self._get_indices(), + labels_to_stratify=labels_to_stratify ) elif isinstance(self.resampling_strategy, CrossValTypes): - num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'num_splits', None) - if self.resampling_strategy_args is not None: - num_splits = self.resampling_strategy_args.get('num_splits', num_splits) - # Create the split if it was not created before - splits.extend( - self.create_cross_val_splits( - cross_val_type=self.resampling_strategy, - num_splits=cast(int, num_splits), - ) + num_splits = self.resampling_strategy_args['num_splits'] + + return self.create_cross_val_splits( + random_state=self.random_state, + num_splits=int(num_splits), + shuffle=self.shuffle, + indices=self._get_indices(), + labels_to_stratify=labels_to_stratify ) else: raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}") - return splits - - def create_cross_val_splits( - self, - cross_val_type: CrossValTypes, - num_splits: int - ) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]: - """ - This function creates the cross validation split for the given task. - - It is done once per dataset to have comparable results among pipelines - Args: - cross_val_type (CrossValTypes): - num_splits (int): number of splits to be created - - Returns: - (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]): - list containing 'num_splits' splits. - """ - # Create just the split once - # This is gonna be called multiple times, because the current dataset - # is being used for multiple pipelines. That is, to be efficient with memory - # we dump the dataset to memory and read it on a need basis. So this function - # should be robust against multiple calls, and it does so by remembering the splits - if not isinstance(cross_val_type, CrossValTypes): - raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') - kwargs = {} - if cross_val_type.is_stratified(): - # we need additional information about the data for stratification - kwargs["stratify"] = self.train_tensors[-1] - splits = self.cross_validators[cross_val_type.name]( - self.random_state, num_splits, self._get_indices(), **kwargs) - return splits - - def create_holdout_val_split( - self, - holdout_val_type: HoldoutValTypes, - val_share: float, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - This function creates the holdout split for the given task. - - It is done once per dataset to have comparable results among pipelines - Args: - holdout_val_type (HoldoutValTypes): - val_share (float): share of the validation data - - Returns: - (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices) - """ - if holdout_val_type is None: - raise ValueError( - '`val_share` specified, but `holdout_val_type` not specified.' - ) - if self.val_tensors is not None: - raise ValueError( - '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.') - if val_share < 0 or val_share > 1: - raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.") - if not isinstance(holdout_val_type, HoldoutValTypes): - raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') - kwargs = {} - if holdout_val_type.is_stratified(): - # we need additional information about the data for stratification - kwargs["stratify"] = self.train_tensors[-1] - train, val = self.holdout_validators[holdout_val_type.name]( - self.random_state, val_share, self._get_indices(), **kwargs) - return train, val def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: """ diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index a1e599dd6..efd999e7c 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,5 +1,6 @@ -from enum import IntEnum -from typing import Any, Dict, List, Optional, Tuple, Union +from enum import Enum +from functools import partial +from typing import List, Optional, Tuple, Union import numpy as np @@ -12,187 +13,62 @@ train_test_split ) -from typing_extensions import Protocol +from torch.utils.data import Dataset -# Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CrossValFunc(Protocol): - def __call__(self, - random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: - ... - - -class HoldOutFunc(Protocol): - def __call__(self, random_state: np.random.RandomState, val_share: float, - indices: np.ndarray, stratify: Optional[Any] - ) -> Tuple[np.ndarray, np.ndarray]: - ... - - -class CrossValTypes(IntEnum): - """The type of cross validation - - This class is used to specify the cross validation function - and is not supposed to be instantiated. - - Examples: This class is supposed to be used as follows - >>> cv_type = CrossValTypes.k_fold_cross_validation - >>> print(cv_type.name) - - k_fold_cross_validation - - >>> for cross_val_type in CrossValTypes: - print(cross_val_type.name, cross_val_type.value) - - stratified_k_fold_cross_validation 1 - k_fold_cross_validation 2 - stratified_shuffle_split_cross_validation 3 - shuffle_split_cross_validation 4 - time_series_cross_validation 5 - """ - stratified_k_fold_cross_validation = 1 - k_fold_cross_validation = 2 - stratified_shuffle_split_cross_validation = 3 - shuffle_split_cross_validation = 4 - time_series_cross_validation = 5 - - def is_stratified(self) -> bool: - stratified = [self.stratified_k_fold_cross_validation, - self.stratified_shuffle_split_cross_validation] - return getattr(self, self.name) in stratified - - -class HoldoutValTypes(IntEnum): - """TODO: change to enum using functools.partial""" - """The type of hold out validation (refer to CrossValTypes' doc-string)""" - holdout_validation = 6 - stratified_holdout_validation = 7 - - def is_stratified(self) -> bool: - stratified = [self.stratified_holdout_validation] - return getattr(self, self.name) in stratified - - -# TODO: replace it with another way -RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] - -DEFAULT_RESAMPLING_PARAMETERS = { - HoldoutValTypes.holdout_validation: { - 'val_share': 0.33, - }, - HoldoutValTypes.stratified_holdout_validation: { - 'val_share': 0.33, - }, - CrossValTypes.k_fold_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.stratified_k_fold_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.shuffle_split_cross_validation: { - 'num_splits': 5, - }, - CrossValTypes.time_series_cross_validation: { - 'num_splits': 5, - }, -} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] - - -class HoldOutFuncs(): - @staticmethod - def holdout_validation(random_state: np.random.RandomState, - val_share: float, - indices: np.ndarray, - **kwargs: Any - ) -> Tuple[np.ndarray, np.ndarray]: - shuffle = kwargs.get('shuffle', True) - train, val = train_test_split(indices, test_size=val_share, - shuffle=shuffle, - random_state=random_state if shuffle else None, - ) - return train, val - +class HoldoutFuncs(): @staticmethod - def stratified_holdout_validation(random_state: np.random.RandomState, - val_share: float, - indices: np.ndarray, - **kwargs: Any - ) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"], - random_state=random_state) + def holdout( + random_state: np.random.RandomState, + val_share: float, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ): + + train, val = train_test_split( + indices, test_size=val_share, shuffle=shuffle, + random_state=random_state if shuffle else None, + stratify=labels_to_stratify + ) return train, val - @classmethod - def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str, HoldOutFunc]: - - holdout_validators = { - holdout_val_type.name: getattr(cls, holdout_val_type.name) - for holdout_val_type in holdout_val_types - } - return holdout_validators - class CrossValFuncs(): - @staticmethod - def shuffle_split_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices)) - return splits - - @staticmethod - def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - @staticmethod - def stratified_k_fold_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits, random_state=random_state) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits + # (shuffle, is_stratify) -> split_fn + _args2split_fn = { + (True, True): StratifiedShuffleSplit, + (True, False): ShuffleSplit, + (False, True): StratifiedKFold, + (False, False): KFold, + } @staticmethod - def k_fold_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: + def k_fold( + random_state: np.random.RandomState, + num_splits: int, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: """ - Standard k fold cross validation. - - Args: - indices (np.ndarray): array of indices to be split - num_splits (int): number of cross validation splits - Returns: splits (List[Tuple[List, List]]): list of tuples of training and validation indices """ - shuffle = kwargs.get('shuffle', True) - cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle) + + split_fn = CrossValFuncs._args2split_fn[(shuffle, labels_to_stratify is not None)] + cv = split_fn(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices)) return splits @staticmethod - def time_series_cross_validation(random_state: np.random.RandomState, - num_splits: int, - indices: np.ndarray, - **kwargs: Any - ) -> List[Tuple[np.ndarray, np.ndarray]]: + def time_series( + random_state: np.random.RandomState, + num_splits: int, + indices: np.ndarray, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: """ Returns train and validation indices respecting the temporal ordering of the data. @@ -215,10 +91,115 @@ def time_series_cross_validation(random_state: np.random.RandomState, splits = list(cv.split(indices)) return splits - @classmethod - def get_cross_validators(cls, *cross_val_types: CrossValTypes) -> Dict[str, CrossValFunc]: - cross_validators = { - cross_val_type.name: getattr(cls, cross_val_type.name) - for cross_val_type in cross_val_types - } - return cross_validators + +class CrossValTypes(Enum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold + >>> print(cv_type.name) + + k_fold + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name, cross_val_type.value) + + k_fold functools.partial() + time_series + """ + k_fold = partial(CrossValFuncs.k_fold) + time_series = partial(CrossValFuncs.time_series) + + def __call__( + self, + random_state: np.random.RandomState, + indices: np.ndarray, + num_splits: int = 5, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check the specified function. + + Args: + random_state (np.random.RandomState): random number genetor for the reproducibility + num_splits (int): The number of splits in cross validation + indices (np.ndarray): The indices of data points in a dataset + shuffle (bool): If shuffle the indices or not + labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): + The labels of the corresponding data points. It is used for the stratification. + + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[a split identifier][0: train, 1: val][a data point identifier] + + """ + return self.value( + random_state=random_state, + num_splits=num_splits, + indices=indices, + shuffle=shuffle, + labels_to_stratify=labels_to_stratify + ) + + +class HoldoutTypes(Enum): + """The type of holdout validation + + This class is used to specify the holdout validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> holdout_type = HoldoutTypes.holdout + >>> print(holdout_type.name) + + holdout + + >>> print(holdout_type.value) + + functools.partial() + + >>> for holdout_type in HoldoutTypes: + print(holdout_type.name) + + holdout + + Additionally, HoldoutTypes. can be called directly. + """ + + holdout = partial(HoldoutFuncs.holdout) + + def __call__( + self, + random_state: np.random.RandomState, + indices: np.ndarray, + val_share: float = 0.33, + shuffle: bool = False, + labels_to_stratify: Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]] = None + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check the specified function. + + Args: + random_state (np.random.RandomState): random number genetor for the reproducibility + val_share (float): The ratio of validation dataset vs the given dataset + indices (np.ndarray): The indices of data points in a dataset + shuffle (bool): If shuffle the indices or not + labels_to_stratify (Optional[Union[Tuple[np.ndarray, np.ndarray], Dataset]]): + The labels of the corresponding data points. It is used for the stratification. + + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[a split identifier][0: train, 1: val][a data point identifier] + + """ + return self.value( + random_state=random_state, + val_share=val_share, + indices=indices, + shuffle=shuffle, + labels_to_stratify=labels_to_stratify + ) diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index edd07a80e..95968c2aa 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -5,18 +5,33 @@ import torchvision.transforms from autoPyTorch.datasets.base_dataset import BaseDataset -from autoPyTorch.datasets.resampling_strategy import ( - CrossValFuncs, - CrossValTypes, - HoldOutFuncs, - HoldoutValTypes -) +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported TIME_SERIES_REGRESSION_INPUT = Tuple[np.ndarray, np.ndarray] TIME_SERIES_CLASSIFICATION_INPUT = Tuple[np.ndarray, np.ndarray] +def _check_prohibited_resampling() -> None: + """Check if resampling strategy is suitable for a given task + + Args: + task_name (str): Typically the Dataset class name + resampling_strategy (Union[CrossValTypes, HoldoutTypes]): + The splitting function + args (Union[CrossValTypes, HoldoutTypes]): + The list of cross validation functions and + holdout validation functions that are suitable for the given task + + Returns: + None + + TODO: Especially, reject shuffle splits + """ + + pass + + class TimeSeriesForecastingDataset(BaseDataset): def __init__(self, target_variables: Tuple[int], @@ -24,7 +39,7 @@ def __init__(self, n_steps: int, train: TIME_SERIES_FORECASTING_INPUT, val: Optional[TIME_SERIES_FORECASTING_INPUT] = None, - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout, resampling_strategy_args: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = False, seed: Optional[int] = 42, @@ -60,8 +75,6 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = CrossValFuncs.get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = HoldOutFuncs.get_holdout_validators(HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,16 +130,6 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = CrossValFuncs.get_cross_validators( - CrossValTypes.stratified_k_fold_cross_validation, - CrossValTypes.k_fold_cross_validation, - CrossValTypes.shuffle_split_cross_validation, - CrossValTypes.stratified_shuffle_split_cross_validation - ) - self.holdout_validators = HoldOutFuncs.get_holdout_validators( - HoldoutValTypes.holdout_validation, - HoldoutValTypes.stratified_holdout_validation - ) class TimeSeriesRegressionDataset(BaseDataset): @@ -135,13 +138,6 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = CrossValFuncs.get_cross_validators( - CrossValTypes.k_fold_cross_validation, - CrossValTypes.shuffle_split_cross_validation - ) - self.holdout_validators = HoldOutFuncs.get_holdout_validators( - HoldoutValTypes.holdout_validation - ) def _check_time_series_inputs(task_type: str,