Skip to content

Commit

Permalink
[Fix] Refactor development reproducibility (#172)
Browse files Browse the repository at this point in the history
* [Fix] pass random state to randomized algorithms

* [Fix] double instantiation of random state

* [fix] Flaky for sample configuration

* [FIX] Runtime warning

* [FIX] hardcoded budget

* [FIX] flake

* [Fix] try forked

* [Fix] try forked

* [FIX] budget

* [Fix] missing random_state in trainer

* [Fix] overwrite in random_state

* [FIX] fix seed in splits

* [Rebase]

* [FIX] Update cv score after split num change

* [FIX] CV split
  • Loading branch information
franchuterivera authored May 3, 2021
1 parent fae72a4 commit 9f4b855
Show file tree
Hide file tree
Showing 26 changed files with 225 additions and 127 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Run tests
run: |
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
python -m pytest --durations=20 --timeout=600 --timeout-method=signal -v $codecov test
python -m pytest --forked --durations=20 --timeout=600 --timeout-method=signal -v $codecov test
- name: Check for files left behind by test
if: ${{ always() }}
run: |
Expand Down
3 changes: 2 additions & 1 deletion autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,8 @@ def __del__(self) -> None:
# When a multiprocessing work is done, the
# objects are deleted. We don't want to delete run areas
# until the estimator is deleted
self._backend.context.delete_directories(force=False)
if hasattr(self, '_backend'):
self._backend.context.delete_directories(force=False)

@typing.no_type_check
def get_incumbent_results(
Expand Down
9 changes: 5 additions & 4 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
self.cross_validators: Dict[str, CrossValFunc] = {}
self.holdout_validators: Dict[str, HoldOutFunc] = {}
self.rng = np.random.RandomState(seed=seed)
self.random_state = np.random.RandomState(seed=seed)
self.shuffle = shuffle
self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args
Expand Down Expand Up @@ -205,7 +205,7 @@ def __len__(self) -> int:
return self.train_tensors[0].shape[0]

def _get_indices(self) -> np.ndarray:
return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self))
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))

def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
"""
Expand Down Expand Up @@ -271,7 +271,7 @@ def create_cross_val_splits(
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
splits = self.cross_validators[cross_val_type.name](
num_splits, self._get_indices(), **kwargs)
self.random_state, num_splits, self._get_indices(), **kwargs)
return splits

def create_holdout_val_split(
Expand Down Expand Up @@ -305,7 +305,8 @@ def create_holdout_val_split(
if holdout_val_type.is_stratified():
# we need additional information about the data for stratification
kwargs["stratify"] = self.train_tensors[-1]
train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs)
train, val = self.holdout_validators[holdout_val_type.name](
self.random_state, val_share, self._get_indices(), **kwargs)
return train, val

def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
Expand Down
53 changes: 34 additions & 19 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
# Use callback protocol as workaround, since callable with function fields count 'self' as argument
class CrossValFunc(Protocol):
def __call__(self,
random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]:
...


class HoldOutFunc(Protocol):
def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any]
def __call__(self, random_state: np.random.RandomState, val_share: float,
indices: np.ndarray, stratify: Optional[Any]
) -> Tuple[np.ndarray, np.ndarray]:
...

Expand Down Expand Up @@ -85,35 +87,42 @@ def is_stratified(self) -> bool:
'val_share': 0.33,
},
CrossValTypes.k_fold_cross_validation: {
'num_splits': 3,
'num_splits': 5,
},
CrossValTypes.stratified_k_fold_cross_validation: {
'num_splits': 3,
'num_splits': 5,
},
CrossValTypes.shuffle_split_cross_validation: {
'num_splits': 3,
'num_splits': 5,
},
CrossValTypes.time_series_cross_validation: {
'num_splits': 3,
'num_splits': 5,
},
} # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]]


class HoldOutFuncs():
@staticmethod
def holdout_validation(val_share: float,
def holdout_validation(random_state: np.random.RandomState,
val_share: float,
indices: np.ndarray,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=False)
shuffle = kwargs.get('shuffle', True)
train, val = train_test_split(indices, test_size=val_share,
shuffle=shuffle,
random_state=random_state if shuffle else None,
)
return train, val

@staticmethod
def stratified_holdout_validation(val_share: float,
def stratified_holdout_validation(random_state: np.random.RandomState,
val_share: float,
indices: np.ndarray,
**kwargs: Any
) -> Tuple[np.ndarray, np.ndarray]:
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"])
train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"],
random_state=random_state)
return train, val

@classmethod
Expand All @@ -128,34 +137,38 @@ def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str

class CrossValFuncs():
@staticmethod
def shuffle_split_cross_validation(num_splits: int,
def shuffle_split_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = ShuffleSplit(n_splits=num_splits)
cv = ShuffleSplit(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices))
return splits

@staticmethod
def stratified_shuffle_split_cross_validation(num_splits: int,
def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedShuffleSplit(n_splits=num_splits)
cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits

@staticmethod
def stratified_k_fold_cross_validation(num_splits: int,
def stratified_k_fold_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
cv = StratifiedKFold(n_splits=num_splits)
cv = StratifiedKFold(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices, kwargs["stratify"]))
return splits

@staticmethod
def k_fold_cross_validation(num_splits: int,
def k_fold_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
Expand All @@ -169,12 +182,14 @@ def k_fold_cross_validation(num_splits: int,
Returns:
splits (List[Tuple[List, List]]): list of tuples of training and validation indices
"""
cv = KFold(n_splits=num_splits)
shuffle = kwargs.get('shuffle', True)
cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle)
splits = list(cv.split(indices))
return splits

@staticmethod
def time_series_cross_validation(num_splits: int,
def time_series_cross_validation(random_state: np.random.RandomState,
num_splits: int,
indices: np.ndarray,
**kwargs: Any
) -> List[Tuple[np.ndarray, np.ndarray]]:
Expand All @@ -196,7 +211,7 @@ def time_series_cross_validation(num_splits: int,
([0, 1, 2], [3])]
"""
cv = TimeSeriesSplit(n_splits=num_splits)
cv = TimeSeriesSplit(n_splits=num_splits, random_state=random_state)
splits = list(cv.split(indices))
return splits

Expand Down
3 changes: 2 additions & 1 deletion autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def __init__(self, config: str,
self.random_state = random_state
self.init_params = init_params
self.pipeline = autoPyTorch.pipeline.traditional_tabular_classification.\
TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties)
TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties,
random_state=self.random_state)
configuration_space = self.pipeline.get_hyperparameter_search_space()
default_configuration = configuration_space.get_default_configuration().get_dictionary()
default_configuration['model_trainer:tabular_classifier:classifier'] = config
Expand Down
11 changes: 9 additions & 2 deletions autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import multiprocessing
import os
import time
import traceback
import typing
Expand All @@ -25,6 +26,7 @@
from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.common import replace_string_bool_to_bool
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger

Expand Down Expand Up @@ -144,7 +146,12 @@ def __init__(
self.exclude = exclude
self.disable_file_output = disable_file_output
self.init_params = init_params
self.pipeline_config = pipeline_config
self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict()
if pipeline_config is None:
pipeline_config = replace_string_bool_to_bool(json.load(open(
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
self.pipeline_config.update(pipeline_config)

self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
self.logger_port = logger_port
if self.logger_port is None:
Expand Down Expand Up @@ -199,7 +206,7 @@ def run_wrapper(
)
else:
if run_info.budget == 0:
run_info = run_info._replace(budget=100.0)
run_info = run_info._replace(budget=self.pipeline_config[self.budget_type])
elif run_info.budget <= 0 or run_info.budget > 100:
raise ValueError('Illegal value for budget, must be >0 and <=100, but is %f' %
run_info.budget)
Expand Down
8 changes: 4 additions & 4 deletions autoPyTorch/pipeline/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def __init__(
self.include = include if include is not None else {}
self.exclude = exclude if exclude is not None else {}
self.search_space_updates = search_space_updates
if random_state is None:
self.random_state = check_random_state(1)
else:
self.random_state = check_random_state(random_state)

if steps is None:
self.steps = self._get_pipeline_steps(dataset_properties)
Expand Down Expand Up @@ -98,10 +102,6 @@ def __init__(

self.set_hyperparameters(self.config, init_params=init_params)

if random_state is None:
self.random_state = check_random_state(1)
else:
self.random_state = check_random_state(random_state)
super().__init__(steps=self.steps)

self._additional_run_info = {} # type: Dict[str, str]
Expand Down
9 changes: 8 additions & 1 deletion autoPyTorch/pipeline/components/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@

from ConfigSpace.configuration_space import Configuration, ConfigurationSpace

import numpy as np

from sklearn.base import BaseEstimator
from sklearn.utils import check_random_state

from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdate
Expand Down Expand Up @@ -93,8 +96,12 @@ def add_component(self, obj: BaseEstimator) -> None:
class autoPyTorchComponent(BaseEstimator):
_required_properties: Optional[List[str]] = None

def __init__(self) -> None:
def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None:
super().__init__()
if random_state is None:
self.random_state = check_random_state(1)
else:
self.random_state = check_random_state(random_state)
self._fit_requirements: List[FitRequirement] = list()
self._cs_updates: Dict[str, HyperparameterSearchSpaceUpdate] = dict()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(self, target_dim: int = 128,

def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator:

self.preprocessor['numerical'] = sklearn.decomposition.TruncatedSVD(self.target_dim, algorithm="randomized")
self.preprocessor['numerical'] = sklearn.decomposition.TruncatedSVD(self.target_dim, algorithm="randomized",
random_state=self.random_state)

return self

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
import typing
import warnings

Expand Down Expand Up @@ -120,7 +119,8 @@ def shake_drop_get_bl(
pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake)

if not is_training:
bl = torch.tensor(1.0) if random.random() <= pl else torch.tensor(0.0)
# Move to torch.randn(1) for reproducibility
bl = torch.tensor(1.0) if torch.randn(1) <= pl else torch.tensor(0.0)
if is_training:
bl = torch.tensor(pl)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent:

# instantiate model
self.model = self.build_model(input_shape=input_shape,
logger_port=X['logger_port'],
output_shape=output_shape)

# train model
Expand All @@ -91,7 +92,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent:
return self

@abstractmethod
def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> BaseClassifier:
def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...],
logger_port: int) -> BaseClassifier:
"""
This method returns a pytorch model, that is dynamically built using
a self.config that is model specific, and contains the additional
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,37 @@
import json
import logging
import logging.handlers
import os as os
from abc import abstractmethod
from typing import Any, Dict, List, Optional

import numpy as np

from sklearn.utils import check_random_state

from autoPyTorch.metrics import accuracy
from autoPyTorch.utils.logging_ import get_named_client_logger


class BaseClassifier():
class BaseClassifier:
"""
Base class for classifiers.
"""

def __init__(self, name: str = ''):

self.configure_logging()
def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT,
random_state: Optional[np.random.RandomState] = None, name: str = ''):

self.name = name
self.logger_port = logger_port
self.logger = get_named_client_logger(
name=name,
host='localhost',
port=logger_port,
)

if random_state is None:
self.random_state = check_random_state(1)
else:
self.random_state = check_random_state(random_state)
self.config = self.get_config()

self.categoricals: np.ndarray = np.array(())
Expand All @@ -28,17 +41,6 @@ def __init__(self, name: str = ''):

self.metric = accuracy

def configure_logging(self) -> None:
"""
Setup self.logger
"""
self.logger = logging.getLogger(__name__)
self.logger.setLevel(logging.INFO)

ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
self.logger.addHandler(ch)

def get_config(self) -> Dict[str, Any]:
"""
Load the parameters for the classifier model from ../classifier_configs/modelname.json.
Expand Down
Loading

0 comments on commit 9f4b855

Please sign in to comment.