Skip to content

Commit

Permalink
[refactor] Remove get_cross_validators and get_holdout_validators
Browse files Browse the repository at this point in the history
Since we can call each split function directly from CrossValTypes
and HoldoutValTypes. I removed these two functions.
  • Loading branch information
nabenabe0928 committed Mar 18, 2021
1 parent ef6acf2 commit a7e8a7f
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 226 deletions.
158 changes: 62 additions & 96 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -13,14 +13,9 @@

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CROSS_VAL_FN,
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HOLDOUT_FN,
HoldoutValTypes,
get_cross_validators,
get_holdout_validators,
is_stratified,
HoldoutValTypes
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix

Expand Down Expand Up @@ -112,8 +107,6 @@ def __init__(
if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -133,9 +126,6 @@ def __init__(
# TODO: Look for a criteria to define small enough to preprocess
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = get_cross_validators(*CrossValTypes)
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -203,106 +193,82 @@ def __len__(self) -> int:
def _get_indices(self) -> np.ndarray:
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))

def _process_resampling_strategy_args(self) -> None:
"""TODO: Refactor this function after introducing BaseDict"""

if not any(isinstance(self.resampling_strategy, val_type)
for val_type in [HoldoutValTypes, CrossValTypes]):
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")

if self.splitting_params is not None and \
not isinstance(self.resampling_strategy_args, dict):

raise TypeError("resampling_strategy_args must be dict or None,"
f" but got {type(self.resampling_strategy_args)}")

if self.resampling_strategy_args is None:
self.resampling_strategy_args = {}

if isinstance(self.resampling_strategy, HoldoutValTypes):
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'val_share', None)
self.resampling_strategy_args['val_share'] = val_share
elif isinstance(self.splitting_type, CrossValTypes):
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'num_splits', None)
self.resampling_strategy_args['num_splits'] = num_splits

"""Comment: Do we need this raise Error?"""
if self.val_tensors is not None: # if we need it, we should share it with cross val as well
raise ValueError('`val_share` specified, but the Dataset was'
' a given a pre-defined split at initialization already.')

val_share = self.resampling_strategy_args.get('val_share', None)
num_splits = self.resampling_strategy_args.get('num_splits', None)

if val_share is not None and (val_share < 0 or val_share > 1):
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")

if num_splits is not None:
if num_splits <= 0:
raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.")
elif not isinstance(num_splits, int):
raise ValueError(f"`num_splits` must be an integer, got {num_splits}.")

def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
"""
Creates a set of splits based on a resampling strategy provided
Returns
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
"""
splits = []
if isinstance(self.resampling_strategy, HoldoutValTypes):
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'val_share', None)
if self.resampling_strategy_args is not None:
val_share = self.resampling_strategy_args.get('val_share', val_share)
splits.append(
self.create_holdout_val_split(
holdout_val_type=self.resampling_strategy,
val_share=val_share,
)
)
elif isinstance(self.resampling_strategy, CrossValTypes):
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'num_splits', None)
if self.resampling_strategy_args is not None:
num_splits = self.resampling_strategy_args.get('num_splits', num_splits)
# Create the split if it was not created before
splits.extend(
self.create_cross_val_splits(
cross_val_type=self.resampling_strategy,
num_splits=cast(int, num_splits),
)
)
else:
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
return splits

def create_cross_val_splits(
self,
cross_val_type: CrossValTypes,
num_splits: int
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
"""
This function creates the cross validation split for the given task.
# check if the requirements are met and if we can get splits
self._process_resampling_strategy_args()

It is done once per dataset to have comparable results among pipelines
Args:
cross_val_type (CrossValTypes):
num_splits (int): number of splits to be created
Returns:
(List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
list containing 'num_splits' splits.
"""
# Create just the split once
# This is gonna be called multiple times, because the current dataset
# is being used for multiple pipelines. That is, to be efficient with memory
# we dump the dataset to memory and read it on a need basis. So this function
# should be robust against multiple calls, and it does so by remembering the splits
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if is_stratified(cross_val_type):
if self.resampling_strategy.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
num_splits, self._get_indices(), **kwargs)
return splits

def create_holdout_val_split(
self,
holdout_val_type: HoldoutValTypes,
val_share: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""
This function creates the holdout split for the given task.
if isinstance(self.resampling_strategy, HoldoutValTypes):
val_share = self.resampling_strategy_args['val_share']

It is done once per dataset to have comparable results among pipelines
Args:
holdout_val_type (HoldoutValTypes):
val_share (float): share of the validation data
return self.resampling_strategy(
val_share=val_share,
indices=self._get_indices(),
**kwargs
)
elif isinstance(self.resampling_strategy, CrossValTypes):
num_splits = self.resampling_strategy_args['num_splits']

Returns:
(Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
"""
if holdout_val_type is None:
raise ValueError(
'`val_share` specified, but `holdout_val_type` not specified.'
return self.create_cross_val_splits(
num_splits=int(num_splits),
indices=self._get_indices(),
**kwargs
)
if self.val_tensors is not None:
raise ValueError(
'`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.')
if val_share < 0 or val_share > 1:
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if is_stratified(holdout_val_type):
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
return train, val
else:
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")

def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
"""
Expand Down
Loading

0 comments on commit a7e8a7f

Please sign in to comment.