Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring base dataset splitting functions #106

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import time
import typing
import unittest.mock
import uuid
import warnings
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Union, cast
Expand Down Expand Up @@ -782,13 +781,15 @@ def _search(
":{}".format(self.task_type, dataset.task_type))

# Initialise information needed for the experiment
experiment_task_name = 'runSearch'
experiment_task_name: str = 'runSearch'
dataset_requirements = get_dataset_requirements(
info=self._get_required_dataset_properties(dataset))
self._dataset_requirements = dataset_requirements
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._stopwatch.start_task(experiment_task_name)
self.dataset_name = dataset.dataset_name
assert self.dataset_name is not None

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._all_supported_metrics = all_supported_metrics
Expand Down Expand Up @@ -897,7 +898,7 @@ def _search(
start_time=time.time(),
time_left_for_ensembles=time_left_for_ensembles,
backend=copy.deepcopy(self._backend),
dataset_name=dataset.dataset_name,
dataset_name=str(dataset.dataset_name),
output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type],
task_type=STRING_TO_TASK_TYPES[self.task_type],
metrics=[self._metric],
Expand All @@ -916,7 +917,7 @@ def _search(
self._stopwatch.stop_task(ensemble_task_name)

# ==> Run SMAC
smac_task_name = 'runSMAC'
smac_task_name: str = 'runSMAC'
self._stopwatch.start_task(smac_task_name)
elapsed_time = self._stopwatch.wall_elapsed(experiment_task_name)
time_left_for_smac = max(0, total_walltime_limit - elapsed_time)
Expand All @@ -928,7 +929,7 @@ def _search(

_proc_smac = AutoMLSMBO(
config_space=self.search_space,
dataset_name=dataset.dataset_name,
dataset_name=str(dataset.dataset_name),
backend=self._backend,
total_walltime_limit=total_walltime_limit,
func_eval_time_limit_secs=func_eval_time_limit_secs,
Expand Down Expand Up @@ -1035,11 +1036,11 @@ def refit(
Returns:
self
"""
if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

self.dataset_name = dataset.dataset_name

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._logger = self._get_logger(str(self.dataset_name))

dataset_requirements = get_dataset_requirements(
info=self._get_required_dataset_properties(dataset))
Expand Down Expand Up @@ -1105,11 +1106,10 @@ def fit(self,
Returns:
(BasePipeline): fitted pipeline
"""
if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
self.dataset_name = dataset.dataset_name

if self._logger is None:
self._logger = self._get_logger(self.dataset_name)
self._logger = self._get_logger(str(self.dataset_name))

# get dataset properties
dataset_requirements = get_dataset_requirements(
Expand Down
51 changes: 28 additions & 23 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import uuid
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

Expand All @@ -13,18 +15,17 @@

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CROSS_VAL_FN,
CrossValFunc,
CrossValFuncs,
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HOLDOUT_FN,
HoldoutValTypes,
get_cross_validators,
get_holdout_validators,
is_stratified,
HoldOutFunc,
HoldOutFuncs,
nabenabe0928 marked this conversation as resolved.
Show resolved Hide resolved
HoldoutValTypes
)
from autoPyTorch.utils.common import FitRequirement, hash_array_or_matrix
from autoPyTorch.utils.common import FitRequirement

BaseDatasetType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]


def check_valid_data(data: Any) -> None:
Expand All @@ -33,7 +34,8 @@ def check_valid_data(data: Any) -> None:
'The specified Data for Dataset must have both __getitem__ and __len__ attribute.')


def type_check(train_tensors: BaseDatasetType, val_tensors: Optional[BaseDatasetType] = None) -> None:
def type_check(train_tensors: BaseDatasetInputType,
val_tensors: Optional[BaseDatasetInputType] = None) -> None:
"""To avoid unexpected behavior, we use loops over indices."""
for i in range(len(train_tensors)):
check_valid_data(train_tensors[i])
Expand All @@ -49,8 +51,8 @@ class TransformSubset(Subset):
we require different transformation for each data point.
This class helps to take the subset of the dataset
with either training or validation transformation.

We achieve so by adding a train flag to the pytorch subset
The TransformSubset allows to add train flags
while indexing the main dataset towards this goal.

Attributes:
dataset (BaseDataset/Dataset): Dataset to sample the subset
Expand All @@ -71,10 +73,10 @@ def __getitem__(self, idx: int) -> np.ndarray:
class BaseDataset(Dataset, metaclass=ABCMeta):
def __init__(
self,
train_tensors: BaseDatasetType,
train_tensors: BaseDatasetInputType,
dataset_name: Optional[str] = None,
val_tensors: Optional[BaseDatasetType] = None,
test_tensors: Optional[BaseDatasetType] = None,
val_tensors: Optional[BaseDatasetInputType] = None,
test_tensors: Optional[BaseDatasetInputType] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
shuffle: Optional[bool] = True,
Expand Down Expand Up @@ -106,14 +108,16 @@ def __init__(
val_transforms (Optional[torchvision.transforms.Compose]):
Additional Transforms to be applied to the validation/test data
"""
self.dataset_name = dataset_name if dataset_name is not None \
else hash_array_or_matrix(train_tensors[0])
self.dataset_name = dataset_name

if self.dataset_name is None:
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then if you think that it should not be required, then maybe do it as:

self.dataset_name = dataset_name if dataset_name is not None else str(
    uuid.uuid1(clock_seq=os.getpid())
)

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, actually I am also thinking about it this way.
Probably, it will be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, wait. But when I think about the case where we would like to use totally new datasets which do not have any name, probably we would like to choose our own name. In this sense, it is better to get back to Optional.


if not hasattr(train_tensors[0], 'shape'):
type_check(train_tensors, val_tensors)
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CROSS_VAL_FN] = {}
self.holdout_validators: Dict[str, HOLDOUT_FN] = {}
ravinkohli marked this conversation as resolved.
Show resolved Hide resolved
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
Expand All @@ -134,8 +138,8 @@ def __init__(
self.is_small_preprocess = True

# Make sure cross validation splits are created once
self.cross_validators = get_cross_validators(*CrossValTypes)
self.holdout_validators = get_holdout_validators(*HoldoutValTypes)
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
self.splits = self.get_splits_from_resampling_strategy()

# We also need to be able to transform the data, be it for pre-processing
Expand Down Expand Up @@ -263,7 +267,7 @@ def create_cross_val_splits(
if not isinstance(cross_val_type, CrossValTypes):
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
kwargs = {}
if is_stratified(cross_val_type):
if cross_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
Expand Down Expand Up @@ -298,7 +302,7 @@ def create_holdout_val_split(
if not isinstance(holdout_val_type, HoldoutValTypes):
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
kwargs = {}
if is_stratified(holdout_val_type):
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
Expand All @@ -321,7 +325,8 @@ def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
return (TransformSubset(self, self.splits[split_id][0], train=True),
TransformSubset(self, self.splits[split_id][1], train=False))

def replace_data(self, X_train: BaseDatasetType, X_test: Optional[BaseDatasetType]) -> 'BaseDataset':
def replace_data(self, X_train: BaseDatasetInputType,
X_test: Optional[BaseDatasetInputType]) -> 'BaseDataset':
"""
To speed up the training of small dataset, early pre-processing of the data
can be made on the fly by the pipeline.
Expand Down
Loading