From a7e8a7f45174878bccc1c7e64c934452e945620d Mon Sep 17 00:00:00 2001 From: nabenabe0928 Date: Fri, 19 Mar 2021 08:00:54 +0900 Subject: [PATCH] [refactor] Remove get_cross_validators and get_holdout_validators Since we can call each split function directly from CrossValTypes and HoldoutValTypes. I removed these two functions. --- autoPyTorch/datasets/base_dataset.py | 158 ++++------- autoPyTorch/datasets/resampling_strategy.py | 296 +++++++++++++------- autoPyTorch/datasets/time_series_dataset.py | 70 +++-- 3 files changed, 298 insertions(+), 226 deletions(-) diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 4c19fa17d..13505e375 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -1,5 +1,5 @@ from abc import ABCMeta -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np @@ -13,14 +13,9 @@ from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES from autoPyTorch.datasets.resampling_strategy import ( - CROSS_VAL_FN, CrossValTypes, DEFAULT_RESAMPLING_PARAMETERS, - HOLDOUT_FN, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators, - is_stratified, + HoldoutValTypes ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix @@ -112,8 +107,6 @@ def __init__( if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors - self.cross_validators: Dict[str, CROSS_VAL_FN] = {} - self.holdout_validators: Dict[str, HOLDOUT_FN] = {} self.rng = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy @@ -133,9 +126,6 @@ def __init__( # TODO: Look for a criteria to define small enough to preprocess self.is_small_preprocess = True - # Make sure cross validation splits are created once - self.cross_validators = get_cross_validators(*CrossValTypes) - self.holdout_validators = get_holdout_validators(*HoldoutValTypes) self.splits = self.get_splits_from_resampling_strategy() # We also need to be able to transform the data, be it for pre-processing @@ -203,6 +193,48 @@ def __len__(self) -> int: def _get_indices(self) -> np.ndarray: return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self)) + def _process_resampling_strategy_args(self) -> None: + """TODO: Refactor this function after introducing BaseDict""" + + if not any(isinstance(self.resampling_strategy, val_type) + for val_type in [HoldoutValTypes, CrossValTypes]): + raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.") + + if self.splitting_params is not None and \ + not isinstance(self.resampling_strategy_args, dict): + + raise TypeError("resampling_strategy_args must be dict or None," + f" but got {type(self.resampling_strategy_args)}") + + if self.resampling_strategy_args is None: + self.resampling_strategy_args = {} + + if isinstance(self.resampling_strategy, HoldoutValTypes): + val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( + 'val_share', None) + self.resampling_strategy_args['val_share'] = val_share + elif isinstance(self.splitting_type, CrossValTypes): + num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( + 'num_splits', None) + self.resampling_strategy_args['num_splits'] = num_splits + + """Comment: Do we need this raise Error?""" + if self.val_tensors is not None: # if we need it, we should share it with cross val as well + raise ValueError('`val_share` specified, but the Dataset was' + ' a given a pre-defined split at initialization already.') + + val_share = self.resampling_strategy_args.get('val_share', None) + num_splits = self.resampling_strategy_args.get('num_splits', None) + + if val_share is not None and (val_share < 0 or val_share > 1): + raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.") + + if num_splits is not None: + if num_splits <= 0: + raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.") + elif not isinstance(num_splits, int): + raise ValueError(f"`num_splits` must be an integer, got {num_splits}.") + def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]: """ Creates a set of splits based on a resampling strategy provided @@ -210,99 +242,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int] Returns (List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format """ - splits = [] - if isinstance(self.resampling_strategy, HoldoutValTypes): - val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'val_share', None) - if self.resampling_strategy_args is not None: - val_share = self.resampling_strategy_args.get('val_share', val_share) - splits.append( - self.create_holdout_val_split( - holdout_val_type=self.resampling_strategy, - val_share=val_share, - ) - ) - elif isinstance(self.resampling_strategy, CrossValTypes): - num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get( - 'num_splits', None) - if self.resampling_strategy_args is not None: - num_splits = self.resampling_strategy_args.get('num_splits', num_splits) - # Create the split if it was not created before - splits.extend( - self.create_cross_val_splits( - cross_val_type=self.resampling_strategy, - num_splits=cast(int, num_splits), - ) - ) - else: - raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}") - return splits - def create_cross_val_splits( - self, - cross_val_type: CrossValTypes, - num_splits: int - ) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]: - """ - This function creates the cross validation split for the given task. + # check if the requirements are met and if we can get splits + self._process_resampling_strategy_args() - It is done once per dataset to have comparable results among pipelines - Args: - cross_val_type (CrossValTypes): - num_splits (int): number of splits to be created - - Returns: - (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]): - list containing 'num_splits' splits. - """ - # Create just the split once - # This is gonna be called multiple times, because the current dataset - # is being used for multiple pipelines. That is, to be efficient with memory - # we dump the dataset to memory and read it on a need basis. So this function - # should be robust against multiple calls, and it does so by remembering the splits - if not isinstance(cross_val_type, CrossValTypes): - raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.') kwargs = {} - if is_stratified(cross_val_type): + if self.resampling_strategy.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] - splits = self.cross_validators[cross_val_type.name]( - num_splits, self._get_indices(), **kwargs) - return splits - def create_holdout_val_split( - self, - holdout_val_type: HoldoutValTypes, - val_share: float, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - This function creates the holdout split for the given task. + if isinstance(self.resampling_strategy, HoldoutValTypes): + val_share = self.resampling_strategy_args['val_share'] - It is done once per dataset to have comparable results among pipelines - Args: - holdout_val_type (HoldoutValTypes): - val_share (float): share of the validation data + return self.resampling_strategy( + val_share=val_share, + indices=self._get_indices(), + **kwargs + ) + elif isinstance(self.resampling_strategy, CrossValTypes): + num_splits = self.resampling_strategy_args['num_splits'] - Returns: - (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices) - """ - if holdout_val_type is None: - raise ValueError( - '`val_share` specified, but `holdout_val_type` not specified.' + return self.create_cross_val_splits( + num_splits=int(num_splits), + indices=self._get_indices(), + **kwargs ) - if self.val_tensors is not None: - raise ValueError( - '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.') - if val_share < 0 or val_share > 1: - raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.") - if not isinstance(holdout_val_type, HoldoutValTypes): - raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.') - kwargs = {} - if is_stratified(holdout_val_type): - # we need additional information about the data for stratification - kwargs["stratify"] = self.train_tensors[-1] - train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) - return train, val + else: + raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}") def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: """ diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index b853fac0a..f47d47b82 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -1,4 +1,5 @@ -from enum import IntEnum +from enum import Enum +from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -12,39 +13,204 @@ train_test_split ) -from typing_extensions import Protocol +class HoldoutValFuncs(): + """We follow the type of returns in cross val""" + @staticmethod + def holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + train, val = train_test_split(indices, test_size=val_share, shuffle=False) + return [(train, val)] + + @staticmethod + def stratified_holdout_validation(val_share: float, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + train, val = train_test_split(indices, test_size=val_share, shuffle=False, stratify=stratify) + return [(train, val)] + + +class CrossValFuncs(): + @staticmethod + def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = ShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, + stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedShuffleSplit(n_splits=num_splits) + splits = list(cv.split(indices, stratify)) + return splits + + @staticmethod + def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + cv = StratifiedKFold(n_splits=num_splits) + splits = list(cv.split(indices, stratify)) + return splits + + @staticmethod + def k_fold_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Standard k fold cross validation. + + :param indices: array of indices to be split + :param num_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv = KFold(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + @staticmethod + def time_series_cross_validation(num_splits: int, indices: np.ndarray, stratify: Optional[Any] = None) \ + -> List[Tuple[np.ndarray, np.ndarray]]: + """ + Returns train and validation indices respecting the temporal ordering of the data. + Dummy example: [0, 1, 2, 3] with 3 folds yields + [0] [1] + [0, 1] [2] + [0, 1, 2] [3] + + :param indices: array of indices to be split + :param num_splits: number of cross validation splits + :return: list of tuples of training and validation indices + """ + cv = TimeSeriesSplit(n_splits=num_splits) + splits = list(cv.split(indices)) + return splits + + +class CrossValTypes(Enum): + """The type of cross validation + + This class is used to specify the cross validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> cv_type = CrossValTypes.k_fold_cross_validation + >>> print(cv_type.name) + + k_fold_cross_validation + + >>> print(cv_type.value) + + functools.partial() + + >>> for cross_val_type in CrossValTypes: + print(cross_val_type.name) + + stratified_k_fold_cross_validation + k_fold_cross_validation + stratified_shuffle_split_cross_validation + shuffle_split_cross_validation + time_series_cross_validation + + Additionally, CrossValTypes. can be called directly. + """ + stratified_k_fold_cross_validation = partial( + CrossValFuncs.stratified_k_fold_cross_validation + ) + k_fold_cross_validation = partial( + CrossValFuncs.k_fold_cross_validation + ) + stratified_shuffle_split_cross_validation = partial( + CrossValFuncs.stratified_shuffle_split_cross_validation + ) + shuffle_split_cross_validation = partial( + CrossValFuncs.shuffle_split_cross_validation + ) + time_series_cross_validation = partial( + CrossValFuncs.time_series_cross_validation + ) -# Use callback protocol as workaround, since callable with function fields count 'self' as argument -class CROSS_VAL_FN(Protocol): - def __call__(self, - num_splits: int, - indices: np.ndarray, - stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: - ... + def is_stratified(self) -> bool: + stratified = [self.stratified_k_fold_cross_validation, + self.stratified_shuffle_split_cross_validation] + return getattr(self, self.name) in stratified + def __call__(self, num_splits: int, indices: np.ndarray, stratify: Optional[Any] + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check + the specified function. -class HOLDOUT_FN(Protocol): - def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] - ) -> Tuple[np.ndarray, np.ndarray]: - ... + Args: + num_splits (int): The number of splits in cross validation + indices (np.ndarray): The indices of data points in a dataset + stratify (np.ndarray): The labels of the corresponding data points + + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[a split identifier][0: train, 1: val][a data point identifier] + + """ + return self.value(num_splits=num_splits, indices=indices, stratify=stratify) + + +class HoldoutValTypes(Enum): + """The type of holdout validation + + This class is used to specify the holdout validation function + and is not supposed to be instantiated. + + Examples: This class is supposed to be used as follows + >>> holdout_type = HoldoutValTypes.holdout_validation + >>> print(holdout_type.name) + + holdout_validation + + >>> print(holdout_type.value) + + functools.partial() + >>> for holdout_type in HoldoutValTypes: + print(holdout_type.name) -class CrossValTypes(IntEnum): - stratified_k_fold_cross_validation = 1 - k_fold_cross_validation = 2 - stratified_shuffle_split_cross_validation = 3 - shuffle_split_cross_validation = 4 - time_series_cross_validation = 5 + holdout_validation + stratified_holdout_validation + Additionally, HoldoutValTypes. can be called directly. + """ + + holdout_validation = partial( + HoldoutValFuncs.holdout_validation + ) + stratified_holdout_validation = partial( + HoldoutValFuncs.stratified_holdout_validation + ) + + def is_stratified(self) -> bool: + stratified = [self.stratified_holdout_validation] + return getattr(self, self.name) in stratified + + def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] + ) -> List[Tuple[np.ndarray, np.ndarray]]: + """ + This function allows to call and type-check + the specified function. + + Args: + val_share (float): The ratio of validation dataset vs the given dataset + indices (np.ndarray): The indices of data points in a dataset + stratify (np.ndarray): The labels of the corresponding data points -class HoldoutValTypes(IntEnum): - holdout_validation = 6 - stratified_holdout_validation = 7 + Returns: + splits (List[Tuple[np.ndarray, np.ndarray]]): + splits[0][0: train, 1: val][a data point identifier] + """ + return self.value(val_share=val_share, indices=indices, stratify=stratify) + +"""TODO: deprecate soon (Will rename CrossValTypes -> CrossValFunc)""" RESAMPLING_STRATEGIES = [CrossValTypes, HoldoutValTypes] +"""TODO: deprecate soon""" DEFAULT_RESAMPLING_PARAMETERS = { HoldoutValTypes.holdout_validation: { 'val_share': 0.33, @@ -65,89 +231,3 @@ class HoldoutValTypes(IntEnum): 'num_splits': 3, }, } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] - - -def get_cross_validators(*cross_val_types: CrossValTypes) -> Dict[str, CROSS_VAL_FN]: - cross_validators = {} # type: Dict[str, CROSS_VAL_FN] - for cross_val_type in cross_val_types: - cross_val_fn = globals()[cross_val_type.name] - cross_validators[cross_val_type.name] = cross_val_fn - return cross_validators - - -def get_holdout_validators(*holdout_val_types: HoldoutValTypes) -> Dict[str, HOLDOUT_FN]: - holdout_validators = {} # type: Dict[str, HOLDOUT_FN] - for holdout_val_type in holdout_val_types: - holdout_val_fn = globals()[holdout_val_type.name] - holdout_validators[holdout_val_type.name] = holdout_val_fn - return holdout_validators - - -def is_stratified(val_type: Union[str, CrossValTypes, HoldoutValTypes]) -> bool: - if isinstance(val_type, str): - return val_type.lower().startswith("stratified") - else: - return val_type.name.lower().startswith("stratified") - - -def holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False) - return train, val - - -def stratified_holdout_validation(val_share: float, indices: np.ndarray, **kwargs: Any) \ - -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"]) - return train, val - - -def shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def stratified_shuffle_split_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def stratified_k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits) - splits = list(cv.split(indices, kwargs["stratify"])) - return splits - - -def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Standard k fold cross validation. - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = KFold(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits - - -def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) \ - -> List[Tuple[np.ndarray, np.ndarray]]: - """ - Returns train and validation indices respecting the temporal ordering of the data. - Dummy example: [0, 1, 2, 3] with 3 folds yields - [0] [1] - [0, 1] [2] - [0, 1, 2] [3] - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices - """ - cv = TimeSeriesSplit(n_splits=num_splits) - splits = list(cv.split(indices)) - return splits diff --git a/autoPyTorch/datasets/time_series_dataset.py b/autoPyTorch/datasets/time_series_dataset.py index 7b0435d19..14e089c95 100644 --- a/autoPyTorch/datasets/time_series_dataset.py +++ b/autoPyTorch/datasets/time_series_dataset.py @@ -7,9 +7,7 @@ from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, - HoldoutValTypes, - get_cross_validators, - get_holdout_validators + HoldoutValTypes ) TIME_SERIES_FORECASTING_INPUT = Tuple[np.ndarray, np.ndarray] # currently only numpy arrays are supported @@ -17,6 +15,33 @@ TIME_SERIES_CLASSIFICATION_INPUT = Tuple[np.ndarray, np.ndarray] +def _check_prohibited_resampling(task_name: str, + resampling_strategy: Union[CrossValTypes, HoldoutValTypes], + *args: Union[CrossValTypes, HoldoutValTypes]) -> None: + """Check if resampling strategy is suitable for a given task + + Args: + task_name (str): Typically the Dataset class name + resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): + The splitting function + args (Union[CrossValTypes, HoldoutValTypes]): + The list of cross validation functions and + holdout validation functions that are suitable for the given task + + Returns: + None + """ + + if isinstance(resampling_strategy, CrossValTypes): + if resampling_strategy not in args: + raise ValueError(f'Cross validation for {task_name} must be ' + f'chosen from {args}.') + elif isinstance(resampling_strategy, HoldoutValTypes): + if resampling_strategy not in args: + raise ValueError(f'Holdout validation for {task_name} must be ' + f'chosen from {args}.') + + class TimeSeriesForecastingDataset(BaseDataset): def __init__(self, target_variables: Tuple[int], @@ -60,8 +85,11 @@ def __init__(self, train_transforms=train_transforms, val_transforms=val_transforms, ) - self.cross_validators = get_cross_validators(CrossValTypes.time_series_cross_validation) - self.holdout_validators = get_holdout_validators(HoldoutValTypes.holdout_validation) + + task_name = self.__class__.__name__ + _check_prohibited_resampling(task_name, resampling_strategy, + CrossValTypes.time_series_cross_validation, + HoldoutValTypes.holdout_validation) def _check_time_series_forecasting_inputs(target_variables: Tuple[int], @@ -117,16 +145,15 @@ def __init__(self, val=val, task_type="time_series_classification") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( - CrossValTypes.stratified_k_fold_cross_validation, - CrossValTypes.k_fold_cross_validation, - CrossValTypes.shuffle_split_cross_validation, - CrossValTypes.stratified_shuffle_split_cross_validation - ) - self.holdout_validators = get_holdout_validators( - HoldoutValTypes.holdout_validation, - HoldoutValTypes.stratified_holdout_validation - ) + + task_name = self.__class__.__name__ + _check_prohibited_resampling(task_name, self.resampling_strategy, + CrossValTypes.stratified_k_fold_cross_validation, + CrossValTypes.k_fold_cross_validation, + CrossValTypes.shuffle_split_cross_validation, + CrossValTypes.stratified_shuffle_split_cross_validation, + HoldoutValTypes.holdout_validation, + HoldoutValTypes.stratified_holdout_validation) class TimeSeriesRegressionDataset(BaseDataset): @@ -135,13 +162,12 @@ def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np. val=val, task_type="time_series_regression") super().__init__(train_tensors=train, val_tensors=val, shuffle=True) - self.cross_validators = get_cross_validators( - CrossValTypes.k_fold_cross_validation, - CrossValTypes.shuffle_split_cross_validation - ) - self.holdout_validators = get_holdout_validators( - HoldoutValTypes.holdout_validation - ) + + task_name = self.__class__.__name__ + _check_prohibited_resampling(task_name, self.resampling_strategy, + CrossValTypes.k_fold_cross_validation, + CrossValTypes.shuffle_split_cross_validation, + HoldoutValTypes.holdout_validation) def _check_time_series_inputs(task_type: str,