Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
164 changes: 54 additions & 110 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import uuid
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np

Expand All @@ -14,15 +14,7 @@
import torchvision

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CrossValFunc,
CrossValFuncs,
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HoldOutFunc,
HoldOutFuncs,
HoldoutValTypes
)
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
from autoPyTorch.utils.common import FitRequirement

BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
Expand Down Expand Up @@ -97,10 +89,9 @@ def __init__(
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
resampling_strategy_args (Optional[Dict[str, Any]]):
arguments required for the chosen resampling strategy.
The details are provided in autoPytorch/datasets/resampling_strategy.py
shuffle: Whether to shuffle the data before performing splits
seed (int), (default=1): seed to be used for reproducibility.
train_transforms (Optional[torchvision.transforms.Compose]):
Expand All @@ -116,12 +107,17 @@ def __init__(
if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.random_state = np.random.RandomState(seed=seed)
self.shuffle = shuffle

self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args
self.resampling_strategy_args: Dict[str, Any] = {}
if resampling_strategy_args is not None:
self.resampling_strategy_args = resampling_strategy_args

self.shuffle_split = self.resampling_strategy_args.get('shuffle', False)
self.is_stratify = self.resampling_strategy_args.get('stratify', False)

self.task_type: Optional[str] = None
self.issparse: bool = issparse(self.train_tensors[0])
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]
Expand All @@ -137,9 +133,6 @@ def __init__(
# TODO: Look for a criteria to define small enough to preprocess
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -207,107 +200,58 @@ def __len__(self) -> int:
def _get_indices(self) -> np.ndarray:
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))

def _check_resampling_strategy_args(self) -> None:
if not any(isinstance(self.resampling_strategy, val_type)
for val_type in [HoldoutValTypes, CrossValTypes]):
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")

if self.resampling_strategy_args is not None and \
not isinstance(self.resampling_strategy_args, dict):

raise TypeError("resampling_strategy_args must be dict or None,"
f" but got {type(self.resampling_strategy_args)}")

val_share = self.resampling_strategy_args.get('val_share', None)
num_splits = self.resampling_strategy_args.get('num_splits', None)

if val_share is not None and (val_share < 0 or val_share > 1):
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")

if num_splits is not None:
if num_splits <= 0:
raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.")
elif not isinstance(num_splits, int):
raise ValueError(f"`num_splits` must be an integer, got {num_splits}.")

def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
"""
Creates a set of splits based on a resampling strategy provided

Returns
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
"""
splits = []
# check if the requirements are met and if we can get splits
self._check_resampling_strategy_args()

labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None
kwargs: Dict[str, Any] = {}
kwargs.update(
random_state=self.random_state,
shuffle=self.shuffle_split,
indices=self._get_indices(),
labels_to_stratify=labels_to_stratify
)

if isinstance(self.resampling_strategy, HoldoutValTypes):
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'val_share', None)
if self.resampling_strategy_args is not None:
val_share = self.resampling_strategy_args.get('val_share', val_share)
splits.append(
self.create_holdout_val_split(
holdout_val_type=self.resampling_strategy,
val_share=val_share,
)
)
val_share = self.resampling_strategy_args.get('val_share', None)
return self.resampling_strategy(val_share=val_share, **kwargs)

elif isinstance(self.resampling_strategy, CrossValTypes):
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
'num_splits', None)
if self.resampling_strategy_args is not None:
num_splits = self.resampling_strategy_args.get('num_splits', num_splits)
# Create the split if it was not created before
splits.extend(
self.create_cross_val_splits(
cross_val_type=self.resampling_strategy,
num_splits=cast(int, num_splits),
)
)
num_splits = self.resampling_strategy_args.get('num_splits', None)
return self.resampling_strategy(num_splits=num_splits, **kwargs)

else:
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
return splits

def create_cross_val_splits(
self,
cross_val_type: CrossValTypes,
num_splits: int
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
"""
This function creates the cross validation split for the given task.

It is done once per dataset to have comparable results among pipelines
Args:
cross_val_type (CrossValTypes):
num_splits (int): number of splits to be created

Returns:
(List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
list containing 'num_splits' splits.
"""
# Create just the split once
# This is gonna be called multiple times, because the current dataset
# is being used for multiple pipelines. That is, to be efficient with memory
# we dump the dataset to memory and read it on a need basis. So this function
# should be robust against multiple calls, and it does so by remembering the splits
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if cross_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
self.random_state, num_splits, self._get_indices(), **kwargs)
return splits

def create_holdout_val_split(
self,
holdout_val_type: HoldoutValTypes,
val_share: float,
) -> Tuple[np.ndarray, np.ndarray]:
"""
This function creates the holdout split for the given task.

It is done once per dataset to have comparable results among pipelines
Args:
holdout_val_type (HoldoutValTypes):
val_share (float): share of the validation data

Returns:
(Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
"""
if holdout_val_type is None:
raise ValueError(
'`val_share` specified, but `holdout_val_type` not specified.'
)
if self.val_tensors is not None:
raise ValueError(
'`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.')
if val_share < 0 or val_share > 1:
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](
self.random_state, val_share, self._get_indices(), **kwargs)
return train, val

def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
"""
Expand Down
7 changes: 3 additions & 4 deletions autoPyTorch/datasets/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,9 @@ class ImageDataset(BaseDataset):
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
(default=HoldoutValTypes.holdout_validation):
strategy to split the training data.
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
required for the chosen resampling strategy. If None, uses
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
in ```datasets/resampling_strategy.py```.
resampling_strategy_args (Optional[Dict[str, Any]]):
arguments required for the chosen resampling strategy.
The details are provided in autoPytorch/datasets/resampling_strategy.py
shuffle: Whether to shuffle the data before performing splits
seed (int), (default=1): seed to be used for reproducibility.
train_transforms (Optional[torchvision.transforms.Compose]):
Expand Down
Loading