Skip to content

Commit

Permalink
[refactor] Remove get_cross_validators and get_holdout_validators
Browse files Browse the repository at this point in the history
Since we can call each split function directly from CrossValTypes
and HoldoutValTypes. I removed these two functions.
  • Loading branch information
nabenabe0928 committed May 10, 2021
1 parent 94df1e3 commit 36cef27
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 308 deletions.
163 changes: 54 additions & 109 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import uuid
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -14,15 +14,7 @@
import torchvision

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CrossValFunc,
CrossValFuncs,
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HoldOutFunc,
HoldOutFuncs,
HoldoutValTypes
)
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes
from autoPyTorch.utils.common import FitRequirement

BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
Expand Down Expand Up @@ -77,7 +69,7 @@ def __init__(
dataset_name: Optional[str] = None,
val_tensors: Optional[BaseDatasetInputType] = None,
test_tensors: Optional[BaseDatasetInputType] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
seed: Optional[int] = 42,
Expand All @@ -94,14 +86,14 @@ def __init__(
validation data
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
test data
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
(default=HoldoutValTypes.holdout_validation):
resampling_strategy (Union[CrossValTypes, HoldoutTypes]),
(default=HoldoutTypes.holdout):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
shuffle: Whether to shuffle the data before performing splits
shuffle: Whether to shuffle the data when performing splits
seed (int), (default=1): seed to be used for reproducibility.
train_transforms (Optional[torchvision.transforms.Compose]):
Additional Transforms to be applied to the training data
Expand All @@ -116,12 +108,12 @@ def __init__(
if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.random_state = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args
self.is_stratify = self.resampling_strategy.get('stratify', False)

self.task_type: Optional[str] = None
self.issparse: bool = issparse(self.train_tensors[0])
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]
Expand All @@ -137,9 +129,6 @@ def __init__(
# TODO: Look for a criteria to define small enough to preprocess
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -205,7 +194,30 @@ def __len__(self) -> int:
return self.train_tensors[0].shape[0]

def _get_indices(self) -> np.ndarray:
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
return np.arange(len(self))

def _process_resampling_strategy_args(self) -> None:
if not any(isinstance(self.resampling_strategy, val_type)
for val_type in [HoldoutTypes, CrossValTypes]):
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")

if self.resampling_strategy_args is not None and \
not isinstance(self.resampling_strategy_args, dict):

raise TypeError("resampling_strategy_args must be dict or None,"
f" but got {type(self.resampling_strategy_args)}")

val_share = self.resampling_strategy_args.get('val_share', None)
num_splits = self.resampling_strategy_args.get('num_splits', None)

if val_share is not None and (val_share < 0 or val_share > 1):
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")

if num_splits is not None:
if num_splits <= 0:
raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.")
elif not isinstance(num_splits, int):
raise ValueError(f"`num_splits` must be an integer, got {num_splits}.")

def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
"""
Expand All @@ -214,100 +226,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
Returns
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
"""
splits = []
if isinstance(self.resampling_strategy, HoldoutValTypes):
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'val_share', None)
if self.resampling_strategy_args is not None:
val_share = self.resampling_strategy_args.get('val_share', val_share)
splits.append(
self.create_holdout_val_split(
holdout_val_type=self.resampling_strategy,
val_share=val_share,
)
# check if the requirements are met and if we can get splits
self._process_resampling_strategy_args()

labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None

if isinstance(self.resampling_strategy, HoldoutTypes):
val_share = self.resampling_strategy_args['val_share']

return self.resampling_strategy(
random_state=self.random_state,
val_share=val_share,
shuffle=self.shuffle,
indices=self._get_indices(),
labels_to_stratify=labels_to_stratify
)
elif isinstance(self.resampling_strategy, CrossValTypes):
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'num_splits', None)
if self.resampling_strategy_args is not None:
num_splits = self.resampling_strategy_args.get('num_splits', num_splits)
# Create the split if it was not created before
splits.extend(
self.create_cross_val_splits(
cross_val_type=self.resampling_strategy,
num_splits=cast(int, num_splits),
)
num_splits = self.resampling_strategy_args['num_splits']

return self.create_cross_val_splits(
random_state=self.random_state,
num_splits=int(num_splits),
shuffle=self.shuffle,
indices=self._get_indices(),
labels_to_stratify=labels_to_stratify
)
else:
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
return splits

def create_cross_val_splits(
self,
cross_val_type: CrossValTypes,
num_splits: int
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
"""
This function creates the cross validation split for the given task.
It is done once per dataset to have comparable results among pipelines
Args:
cross_val_type (CrossValTypes):
num_splits (int): number of splits to be created
Returns:
(List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
list containing 'num_splits' splits.
"""
# Create just the split once
# This is gonna be called multiple times, because the current dataset
# is being used for multiple pipelines. That is, to be efficient with memory
# we dump the dataset to memory and read it on a need basis. So this function
# should be robust against multiple calls, and it does so by remembering the splits
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if cross_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
self.random_state, num_splits, self._get_indices(), **kwargs)
return splits

def create_holdout_val_split(
self,
holdout_val_type: HoldoutValTypes,
val_share: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""
This function creates the holdout split for the given task.
It is done once per dataset to have comparable results among pipelines
Args:
holdout_val_type (HoldoutValTypes):
val_share (float): share of the validation data
Returns:
(Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
"""
if holdout_val_type is None:
raise ValueError(
'`val_share` specified, but `holdout_val_type` not specified.'
)
if self.val_tensors is not None:
raise ValueError(
'`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.')
if val_share < 0 or val_share > 1:
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](
self.random_state, val_share, self._get_indices(), **kwargs)
return train, val

def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
"""
Expand Down
Loading

0 comments on commit 36cef27

Please sign in to comment.