diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 38b22f445..9c2748b8e 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -23,7 +23,7 @@ ) from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix -BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset] +BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset] def check_valid_data(data: Any) -> None: @@ -32,10 +32,9 @@ def check_valid_data(data: Any) -> None: 'The specified Data for Dataset must have both __getitem__ and __len__ attribute.') -def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None: - """To avoid unexpected behavior, we use loops over indices.""" - for i in range(len(train_tensors)): - check_valid_data(train_tensors[i]) +def type_check(train_tensors: BaseDatasetInputType, val_tensors: Optional[BaseDatasetInputType] = None) -> None: + for train_tensor in train_tensors: + check_valid_data(train_tensor) if val_tensors is not None: for i in range(len(val_tensors)): check_valid_data(val_tensors[i]) @@ -63,10 +62,10 @@ def __getitem__(self, idx: int) -> np.ndarray: class BaseDataset(Dataset, metaclass=ABCMeta): def __init__( self, - train_tensors: BaseDatasetType, + train_tensors: BaseDatasetInputType, dataset_name: Optional[str] = None, - val_tensors: Optional[BaseDatasetType] = None, - test_tensors: Optional[BaseDatasetType] = None, + val_tensors: Optional[BaseDatasetInputType] = None, + test_tensors: Optional[BaseDatasetInputType] = None, resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, resampling_strategy_args: Optional[Dict[str, Any]] = None, shuffle: Optional[bool] = True, @@ -313,7 +312,7 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: return (TransformSubset(self, self.splits[split_id][0], train=True), TransformSubset(self, self.splits[split_id][1], train=False)) - def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset': + def replace_data(self, X_train: BaseDatasetInputType, X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset': """ To speed up the training of small dataset, early pre-processing of the data can be made on the fly by the pipeline. diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 1c8fea5fd..ee9458217 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -150,9 +150,12 @@ def k_fold_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: Any) """ Standard k fold cross validation. - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices + Args: + indices (np.ndarray): array of indices to be split + num_splits (int): number of cross validation splits + + Returns: + splits (List[Tuple[List, List]]): list of tuples of training and validation indices """ cv = KFold(n_splits=num_splits) splits = list(cv.split(indices)) @@ -163,14 +166,21 @@ def time_series_cross_validation(num_splits: int, indices: np.ndarray, **kwargs: -> List[Tuple[np.ndarray, np.ndarray]]: """ Returns train and validation indices respecting the temporal ordering of the data. - Dummy example: [0, 1, 2, 3] with 3 folds yields - [0] [1] - [0, 1] [2] - [0, 1, 2] [3] - - :param indices: array of indices to be split - :param num_splits: number of cross validation splits - :return: list of tuples of training and validation indices + + Args: + indices (np.ndarray): array of indices to be split + num_splits (int): number of cross validation splits + + Returns: + splits (List[Tuple[List, List]]): list of tuples of training and validation indices + + Examples: + >>> indices = np.array([0, 1, 2, 3]) + >>> CrossValFuncs.time_series_cross_validation(3, indices) + [([0], [1]), + ([0, 1], [2]), + ([0, 1, 2], [3])] + """ cv = TimeSeriesSplit(n_splits=num_splits) splits = list(cv.split(indices))