diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2084d7138..e4b226d86 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,7 +29,7 @@ jobs: - name: Run tests run: | if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi - python -m pytest --durations=20 --timeout=600 --timeout-method=signal -v $codecov test + python -m pytest --forked --durations=20 --timeout=600 --timeout-method=signal -v $codecov test - name: Check for files left behind by test if: ${{ always() }} run: | diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 3c712efa9..5c842c0aa 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1235,7 +1235,8 @@ def __del__(self) -> None: # When a multiprocessing work is done, the # objects are deleted. We don't want to delete run areas # until the estimator is deleted - self._backend.context.delete_directories(force=False) + if hasattr(self, '_backend'): + self._backend.context.delete_directories(force=False) @typing.no_type_check def get_incumbent_results( diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 2f99e54f7..1bd283d7b 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -118,7 +118,7 @@ def __init__( self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors self.cross_validators: Dict[str, CrossValFunc] = {} self.holdout_validators: Dict[str, HoldOutFunc] = {} - self.rng = np.random.RandomState(seed=seed) + self.random_state = np.random.RandomState(seed=seed) self.shuffle = shuffle self.resampling_strategy = resampling_strategy self.resampling_strategy_args = resampling_strategy_args @@ -205,7 +205,7 @@ def __len__(self) -> int: return self.train_tensors[0].shape[0] def _get_indices(self) -> np.ndarray: - return self.rng.permutation(len(self)) if self.shuffle else np.arange(len(self)) + return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self)) def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]: """ @@ -271,7 +271,7 @@ def create_cross_val_splits( # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] splits = self.cross_validators[cross_val_type.name]( - num_splits, self._get_indices(), **kwargs) + self.random_state, num_splits, self._get_indices(), **kwargs) return splits def create_holdout_val_split( @@ -305,7 +305,8 @@ def create_holdout_val_split( if holdout_val_type.is_stratified(): # we need additional information about the data for stratification kwargs["stratify"] = self.train_tensors[-1] - train, val = self.holdout_validators[holdout_val_type.name](val_share, self._get_indices(), **kwargs) + train, val = self.holdout_validators[holdout_val_type.name]( + self.random_state, val_share, self._get_indices(), **kwargs) return train, val def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]: diff --git a/autoPyTorch/datasets/resampling_strategy.py b/autoPyTorch/datasets/resampling_strategy.py index 765a31cdb..a1e599dd6 100644 --- a/autoPyTorch/datasets/resampling_strategy.py +++ b/autoPyTorch/datasets/resampling_strategy.py @@ -18,6 +18,7 @@ # Use callback protocol as workaround, since callable with function fields count 'self' as argument class CrossValFunc(Protocol): def __call__(self, + random_state: np.random.RandomState, num_splits: int, indices: np.ndarray, stratify: Optional[Any]) -> List[Tuple[np.ndarray, np.ndarray]]: @@ -25,7 +26,8 @@ def __call__(self, class HoldOutFunc(Protocol): - def __call__(self, val_share: float, indices: np.ndarray, stratify: Optional[Any] + def __call__(self, random_state: np.random.RandomState, val_share: float, + indices: np.ndarray, stratify: Optional[Any] ) -> Tuple[np.ndarray, np.ndarray]: ... @@ -85,35 +87,42 @@ def is_stratified(self) -> bool: 'val_share': 0.33, }, CrossValTypes.k_fold_cross_validation: { - 'num_splits': 3, + 'num_splits': 5, }, CrossValTypes.stratified_k_fold_cross_validation: { - 'num_splits': 3, + 'num_splits': 5, }, CrossValTypes.shuffle_split_cross_validation: { - 'num_splits': 3, + 'num_splits': 5, }, CrossValTypes.time_series_cross_validation: { - 'num_splits': 3, + 'num_splits': 5, }, } # type: Dict[Union[HoldoutValTypes, CrossValTypes], Dict[str, Any]] class HoldOutFuncs(): @staticmethod - def holdout_validation(val_share: float, + def holdout_validation(random_state: np.random.RandomState, + val_share: float, indices: np.ndarray, **kwargs: Any ) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=False) + shuffle = kwargs.get('shuffle', True) + train, val = train_test_split(indices, test_size=val_share, + shuffle=shuffle, + random_state=random_state if shuffle else None, + ) return train, val @staticmethod - def stratified_holdout_validation(val_share: float, + def stratified_holdout_validation(random_state: np.random.RandomState, + val_share: float, indices: np.ndarray, **kwargs: Any ) -> Tuple[np.ndarray, np.ndarray]: - train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"]) + train, val = train_test_split(indices, test_size=val_share, shuffle=True, stratify=kwargs["stratify"], + random_state=random_state) return train, val @classmethod @@ -128,34 +137,38 @@ def get_holdout_validators(cls, *holdout_val_types: HoldoutValTypes) -> Dict[str class CrossValFuncs(): @staticmethod - def shuffle_split_cross_validation(num_splits: int, + def shuffle_split_cross_validation(random_state: np.random.RandomState, + num_splits: int, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = ShuffleSplit(n_splits=num_splits) + cv = ShuffleSplit(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices)) return splits @staticmethod - def stratified_shuffle_split_cross_validation(num_splits: int, + def stratified_shuffle_split_cross_validation(random_state: np.random.RandomState, + num_splits: int, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedShuffleSplit(n_splits=num_splits) + cv = StratifiedShuffleSplit(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices, kwargs["stratify"])) return splits @staticmethod - def stratified_k_fold_cross_validation(num_splits: int, + def stratified_k_fold_cross_validation(random_state: np.random.RandomState, + num_splits: int, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: - cv = StratifiedKFold(n_splits=num_splits) + cv = StratifiedKFold(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices, kwargs["stratify"])) return splits @staticmethod - def k_fold_cross_validation(num_splits: int, + def k_fold_cross_validation(random_state: np.random.RandomState, + num_splits: int, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: @@ -169,12 +182,14 @@ def k_fold_cross_validation(num_splits: int, Returns: splits (List[Tuple[List, List]]): list of tuples of training and validation indices """ - cv = KFold(n_splits=num_splits) + shuffle = kwargs.get('shuffle', True) + cv = KFold(n_splits=num_splits, random_state=random_state if shuffle else None, shuffle=shuffle) splits = list(cv.split(indices)) return splits @staticmethod - def time_series_cross_validation(num_splits: int, + def time_series_cross_validation(random_state: np.random.RandomState, + num_splits: int, indices: np.ndarray, **kwargs: Any ) -> List[Tuple[np.ndarray, np.ndarray]]: @@ -196,7 +211,7 @@ def time_series_cross_validation(num_splits: int, ([0, 1, 2], [3])] """ - cv = TimeSeriesSplit(n_splits=num_splits) + cv = TimeSeriesSplit(n_splits=num_splits, random_state=random_state) splits = list(cv.split(indices)) return splits diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 63ef6bbb0..ac022f41b 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -80,7 +80,8 @@ def __init__(self, config: str, self.random_state = random_state self.init_params = init_params self.pipeline = autoPyTorch.pipeline.traditional_tabular_classification.\ - TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties) + TraditionalTabularClassificationPipeline(dataset_properties=dataset_properties, + random_state=self.random_state) configuration_space = self.pipeline.get_hyperparameter_search_space() default_configuration = configuration_space.get_default_configuration().get_dictionary() default_configuration['model_trainer:tabular_classifier:classifier'] = config diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 1ef4f552d..cf55bba41 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -4,6 +4,7 @@ import logging import math import multiprocessing +import os import time import traceback import typing @@ -25,6 +26,7 @@ from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.common import replace_string_bool_to_bool from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger @@ -144,7 +146,12 @@ def __init__( self.exclude = exclude self.disable_file_output = disable_file_output self.init_params = init_params - self.pipeline_config = pipeline_config + self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict() + if pipeline_config is None: + pipeline_config = replace_string_bool_to_bool(json.load(open( + os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json')))) + self.pipeline_config.update(pipeline_config) + self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type self.logger_port = logger_port if self.logger_port is None: @@ -199,7 +206,7 @@ def run_wrapper( ) else: if run_info.budget == 0: - run_info = run_info._replace(budget=100.0) + run_info = run_info._replace(budget=self.pipeline_config[self.budget_type]) elif run_info.budget <= 0 or run_info.budget > 100: raise ValueError('Illegal value for budget, must be >0 and <=100, but is %f' % run_info.budget) diff --git a/autoPyTorch/pipeline/base_pipeline.py b/autoPyTorch/pipeline/base_pipeline.py index 792268487..fc086c902 100644 --- a/autoPyTorch/pipeline/base_pipeline.py +++ b/autoPyTorch/pipeline/base_pipeline.py @@ -70,6 +70,10 @@ def __init__( self.include = include if include is not None else {} self.exclude = exclude if exclude is not None else {} self.search_space_updates = search_space_updates + if random_state is None: + self.random_state = check_random_state(1) + else: + self.random_state = check_random_state(random_state) if steps is None: self.steps = self._get_pipeline_steps(dataset_properties) @@ -98,10 +102,6 @@ def __init__( self.set_hyperparameters(self.config, init_params=init_params) - if random_state is None: - self.random_state = check_random_state(1) - else: - self.random_state = check_random_state(random_state) super().__init__(steps=self.steps) self._additional_run_info = {} # type: Dict[str, str] diff --git a/autoPyTorch/pipeline/components/base_component.py b/autoPyTorch/pipeline/components/base_component.py index 09d981342..c67c2827d 100644 --- a/autoPyTorch/pipeline/components/base_component.py +++ b/autoPyTorch/pipeline/components/base_component.py @@ -8,7 +8,10 @@ from ConfigSpace.configuration_space import Configuration, ConfigurationSpace +import numpy as np + from sklearn.base import BaseEstimator +from sklearn.utils import check_random_state from autoPyTorch.utils.common import FitRequirement, HyperparameterSearchSpace from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdate @@ -93,8 +96,12 @@ def add_component(self, obj: BaseEstimator) -> None: class autoPyTorchComponent(BaseEstimator): _required_properties: Optional[List[str]] = None - def __init__(self) -> None: + def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None: super().__init__() + if random_state is None: + self.random_state = check_random_state(1) + else: + self.random_state = check_random_state(random_state) self._fit_requirements: List[FitRequirement] = list() self._cs_updates: Dict[str, HyperparameterSearchSpaceUpdate] = dict() diff --git a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/TruncatedSVD.py b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/TruncatedSVD.py index fdf6a751a..bfe4568b3 100644 --- a/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/TruncatedSVD.py +++ b/autoPyTorch/pipeline/components/preprocessing/tabular_preprocessing/feature_preprocessing/TruncatedSVD.py @@ -26,7 +26,8 @@ def __init__(self, target_dim: int = 128, def fit(self, X: Dict[str, Any], y: Any = None) -> BaseEstimator: - self.preprocessor['numerical'] = sklearn.decomposition.TruncatedSVD(self.target_dim, algorithm="randomized") + self.preprocessor['numerical'] = sklearn.decomposition.TruncatedSVD(self.target_dim, algorithm="randomized", + random_state=self.random_state) return self diff --git a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py index c20ca5ed2..aa46876fa 100644 --- a/autoPyTorch/pipeline/components/setup/network_backbone/utils.py +++ b/autoPyTorch/pipeline/components/setup/network_backbone/utils.py @@ -1,4 +1,3 @@ -import random import typing import warnings @@ -120,7 +119,8 @@ def shake_drop_get_bl( pl = 1 - ((block_index + 1) / num_blocks) * (1 - min_prob_no_shake) if not is_training: - bl = torch.tensor(1.0) if random.random() <= pl else torch.tensor(0.0) + # Move to torch.randn(1) for reproducibility + bl = torch.tensor(1.0) if torch.randn(1) <= pl else torch.tensor(0.0) if is_training: bl = torch.tensor(pl) diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py index 620c49c45..fba374a34 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/base_model.py @@ -74,6 +74,7 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: # instantiate model self.model = self.build_model(input_shape=input_shape, + logger_port=X['logger_port'], output_shape=output_shape) # train model @@ -91,7 +92,8 @@ def fit(self, X: Dict[str, Any], y: Any = None) -> autoPyTorchSetupComponent: return self @abstractmethod - def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> BaseClassifier: + def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...], + logger_port: int) -> BaseClassifier: """ This method returns a pytorch model, that is dynamically built using a self.config that is model specific, and contains the additional diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/base_classifier.py b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/base_classifier.py index 67d905aa5..63d2508c6 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/base_classifier.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/base_classifier.py @@ -1,24 +1,37 @@ import json -import logging +import logging.handlers import os as os from abc import abstractmethod from typing import Any, Dict, List, Optional import numpy as np +from sklearn.utils import check_random_state + from autoPyTorch.metrics import accuracy +from autoPyTorch.utils.logging_ import get_named_client_logger -class BaseClassifier(): +class BaseClassifier: """ Base class for classifiers. """ - def __init__(self, name: str = ''): - - self.configure_logging() + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None, name: str = ''): self.name = name + self.logger_port = logger_port + self.logger = get_named_client_logger( + name=name, + host='localhost', + port=logger_port, + ) + + if random_state is None: + self.random_state = check_random_state(1) + else: + self.random_state = check_random_state(random_state) self.config = self.get_config() self.categoricals: np.ndarray = np.array(()) @@ -28,17 +41,6 @@ def __init__(self, name: str = ''): self.metric = accuracy - def configure_logging(self) -> None: - """ - Setup self.logger - """ - self.logger = logging.getLogger(__name__) - self.logger.setLevel(logging.INFO) - - ch = logging.StreamHandler() - ch.setLevel(logging.INFO) - self.logger.addHandler(ch) - def get_config(self) -> Dict[str, Any]: """ Load the parameters for the classifier model from ../classifier_configs/modelname.json. diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/classifiers.py b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/classifiers.py index 58eed388f..af3622dfa 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/classifiers.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/classifier_models/classifiers.py @@ -1,3 +1,4 @@ +import logging.handlers import tempfile from typing import Any, Dict, List, Optional, Union @@ -49,8 +50,11 @@ def encode_categoricals(X_train: np.ndarray, class LGBModel(BaseClassifier): - def __init__(self) -> None: - super(LGBModel, self).__init__(name="lgb") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(LGBModel, self).__init__(name="lgb", + logger_port=logger_port, + random_state=random_state) def fit(self, X_train: np.ndarray, y_train: np.ndarray, @@ -73,7 +77,7 @@ def fit(self, X_train: np.ndarray, X_train = np.nan_to_num(X_train) X_val = np.nan_to_num(X_val) - self.model = LGBMClassifier(**self.config) + self.model = LGBMClassifier(**self.config, random_state=self.random_state) self.model.fit(X_train, y_train, eval_set=[(X_val, y_val)]) pred_train = self.model.predict_proba(X_train) @@ -116,8 +120,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ class CatboostModel(BaseClassifier): - def __init__(self) -> None: - super(CatboostModel, self).__init__(name="catboost") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(CatboostModel, self).__init__(name="catboost", + logger_port=logger_port, + random_state=random_state) self.config["train_dir"] = tempfile.gettempdir() def fit(self, X_train: np.ndarray, @@ -142,7 +149,8 @@ def fit(self, X_train: np.ndarray, X_train_pooled = Pool(data=X_train, label=y_train, cat_features=categoricals) X_val_pooled = Pool(data=X_val, label=y_val, cat_features=categoricals) - self.model = CatBoostClassifier(**self.config) + # CatBoost Cannot handle a random state object, just the seed + self.model = CatBoostClassifier(**self.config, random_state=self.random_state.get_state()[1][0]) self.model.fit(X_train_pooled, eval_set=X_val_pooled, use_best_model=True, early_stopping_rounds=early_stopping) pred_train = self.model.predict_proba(X_train) @@ -184,8 +192,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ class RFModel(BaseClassifier): - def __init__(self) -> None: - super(RFModel, self).__init__(name="random_forest") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(RFModel, self).__init__(name="random_forest", + logger_port=logger_port, + random_state=random_state) def fit(self, X_train: np.ndarray, y_train: np.ndarray, @@ -209,7 +220,7 @@ def fit(self, X_train: np.ndarray, self.config["n_estimators"] = 8 self.config["warm_start"] = True - self.model = RandomForestClassifier(**self.config) + self.model = RandomForestClassifier(**self.config, random_state=self.random_state) self.model.fit(X_train, y_train) if self.config["warm_start"]: @@ -250,8 +261,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ class ExtraTreesModel(BaseClassifier): - def __init__(self) -> None: - super(ExtraTreesModel, self).__init__(name="extra_trees") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(ExtraTreesModel, self).__init__(name="extra_trees", + logger_port=logger_port, + random_state=random_state) def fit(self, X_train: np.ndarray, y_train: np.ndarray, @@ -275,7 +289,7 @@ def fit(self, X_train: np.ndarray, self.config["n_estimators"] = 8 self.config["warm_start"] = True - self.model = ExtraTreesClassifier(**self.config) + self.model = ExtraTreesClassifier(**self.config, random_state=self.random_state) self.model.fit(X_train, y_train) if self.config["warm_start"]: @@ -316,8 +330,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ class KNNModel(BaseClassifier): - def __init__(self) -> None: - super(KNNModel, self).__init__(name="knn") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(KNNModel, self).__init__(name="knn", + logger_port=logger_port, + random_state=random_state) def fit(self, X_train: np.ndarray, y_train: np.ndarray, @@ -338,6 +355,7 @@ def fit(self, X_train: np.ndarray, self.num_classes = len(np.unique(y_train)) + # KNN is deterministic, no random seed needed self.model = KNeighborsClassifier(**self.config) self.model.fit(X_train, y_train) @@ -376,8 +394,11 @@ def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[ class SVMModel(BaseClassifier): - def __init__(self) -> None: - super(SVMModel, self).__init__(name="svm") + def __init__(self, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + random_state: Optional[np.random.RandomState] = None): + super(SVMModel, self).__init__(name="svm", + logger_port=logger_port, + random_state=random_state) def fit(self, X_train: np.ndarray, y_train: np.ndarray, @@ -392,7 +413,7 @@ def fit(self, X_train: np.ndarray, X_train = np.nan_to_num(X_train) X_val = np.nan_to_num(X_val) - self.model = SVC(**self.config, probability=True) + self.model = SVC(**self.config, probability=True, random_state=self.random_state) self.model.fit(X_train, y_train) diff --git a/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py b/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py index 03343d9f3..07422f229 100644 --- a/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py +++ b/autoPyTorch/pipeline/components/setup/traditional_ml/tabular_classifier.py @@ -48,7 +48,8 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict[str, str]] return cs - def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...]) -> BaseClassifier: + def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ...], + logger_port: int) -> BaseClassifier: """ This method returns a classifier, that is dynamically built using a self.config that is model specific, and contains the additional @@ -57,7 +58,7 @@ def build_model(self, input_shape: Tuple[int, ...], output_shape: Tuple[int, ... classifier_name = self.config["classifier"] Classifier = self._classifiers[classifier_name] - classifier = Classifier() + classifier = Classifier(random_state=self.random_state, logger_port=logger_port) return classifier diff --git a/autoPyTorch/pipeline/components/training/base_training.py b/autoPyTorch/pipeline/components/training/base_training.py index 3145d636b..ebf7ccbc4 100644 --- a/autoPyTorch/pipeline/components/training/base_training.py +++ b/autoPyTorch/pipeline/components/training/base_training.py @@ -10,7 +10,7 @@ class autoPyTorchTrainingComponent(autoPyTorchComponent): in Auto-Pytorch""" def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None: - super(autoPyTorchTrainingComponent, self).__init__() + super(autoPyTorchTrainingComponent, self).__init__(random_state=random_state) def transform(self, X: np.ndarray) -> np.ndarray: """The transform function calls the transform function of the diff --git a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py index afb43ea97..97d6873c7 100644 --- a/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py +++ b/autoPyTorch/pipeline/components/training/data_loader/base_data_loader.py @@ -31,8 +31,9 @@ class BaseDataLoaderComponent(autoPyTorchTrainingComponent): """ - def __init__(self, batch_size: int = 64) -> None: - super().__init__() + def __init__(self, batch_size: int = 64, + random_state: Optional[np.random.RandomState] = None) -> None: + super().__init__(random_state=random_state) self.batch_size = batch_size self.train_data_loader = None # type: Optional[torch.utils.data.DataLoader] self.val_data_loader = None # type: Optional[torch.utils.data.DataLoader] diff --git a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py index f001689ea..1a692f2de 100644 --- a/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/MixUpTrainer.py @@ -45,7 +45,7 @@ def data_preparation(self, X: np.ndarray, y: np.ndarray, np.ndarray: that processes data typing.Dict[str, np.ndarray]: arguments to the criterion function """ - lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0. else 1. + lam = self.random_state.beta(self.alpha, self.alpha) if self.alpha > 0. else 1. batch_size = X.size()[0] index = torch.randperm(batch_size).cuda() if X.is_cuda else torch.randperm(batch_size) diff --git a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py index f61d334e2..8c7e6907a 100644 --- a/autoPyTorch/pipeline/components/training/trainer/base_trainer.py +++ b/autoPyTorch/pipeline/components/training/trainer/base_trainer.py @@ -5,6 +5,8 @@ import pandas as pd +from sklearn.utils import check_random_state + import torch from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler @@ -161,9 +163,15 @@ def repr_last_epoch(self) -> str: class BaseTrainerComponent(autoPyTorchTrainingComponent): - def __init__(self, random_state: Optional[Union[np.random.RandomState, int]] = None) -> None: - super().__init__() - self.random_state = random_state + def __init__(self, random_state: Optional[np.random.RandomState] = None) -> None: + if random_state is None: + # A trainer components need a random state for + # sampling -- for example in MixUp training + self.random_state = check_random_state(1) + else: + self.random_state = random_state + super().__init__(random_state=self.random_state) + self.weighted_loss: bool = False def prepare( diff --git a/autoPyTorch/pipeline/tabular_classification.py b/autoPyTorch/pipeline/tabular_classification.py index 1ca0635b6..bb4cb10ac 100644 --- a/autoPyTorch/pipeline/tabular_classification.py +++ b/autoPyTorch/pipeline/tabular_classification.py @@ -10,6 +10,8 @@ import sklearn.preprocessing from sklearn.base import ClassifierMixin +import torch + from autoPyTorch.pipeline.base_pipeline import BasePipeline from autoPyTorch.pipeline.components.base_choice import autoPyTorchChoice from autoPyTorch.pipeline.components.base_component import autoPyTorchComponent @@ -79,6 +81,11 @@ def __init__( config, steps, dataset_properties, include, exclude, random_state, init_params, search_space_updates) + # Because a pipeline is passed to a worker, we need to honor the random seed + # in this context. A tabular classification pipeline will implement a torch + # model, so we comply with https://pytorch.org/docs/stable/notes/randomness.html + torch.manual_seed(self.random_state.get_state()[1][0]) + def _predict_proba(self, X: np.ndarray) -> np.ndarray: # Pre-process X loader = self.named_steps['data_loader'].get_loader(X=X) @@ -238,21 +245,28 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], default_dataset_properties.update(dataset_properties) steps.extend([ - ("imputer", SimpleImputer()), - ("encoder", EncoderChoice(default_dataset_properties)), - ("scaler", ScalerChoice(default_dataset_properties)), - ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties)), - ("tabular_transformer", TabularColumnTransformer()), - ("preprocessing", EarlyPreprocessing()), - ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties)), - ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), - ("network_head", NetworkHeadChoice(default_dataset_properties)), - ("network", NetworkComponent()), - ("network_init", NetworkInitializerChoice(default_dataset_properties)), - ("optimizer", OptimizerChoice(default_dataset_properties)), - ("lr_scheduler", SchedulerChoice(default_dataset_properties)), - ("data_loader", FeatureDataLoader()), - ("trainer", TrainerChoice(default_dataset_properties)), + ("imputer", SimpleImputer(random_state=self.random_state)), + ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), + ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), + ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, + random_state=self.random_state)), + ("tabular_transformer", TabularColumnTransformer(random_state=self.random_state)), + ("preprocessing", EarlyPreprocessing(random_state=self.random_state)), + ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties, + random_state=self.random_state)), + ("network_backbone", NetworkBackboneChoice(default_dataset_properties, + random_state=self.random_state)), + ("network_head", NetworkHeadChoice(default_dataset_properties, + random_state=self.random_state)), + ("network", NetworkComponent(random_state=self.random_state)), + ("network_init", NetworkInitializerChoice(default_dataset_properties, + random_state=self.random_state)), + ("optimizer", OptimizerChoice(default_dataset_properties, + random_state=self.random_state)), + ("lr_scheduler", SchedulerChoice(default_dataset_properties, + random_state=self.random_state)), + ("data_loader", FeatureDataLoader(random_state=self.random_state)), + ("trainer", TrainerChoice(default_dataset_properties, random_state=self.random_state)), ]) return steps diff --git a/autoPyTorch/pipeline/tabular_regression.py b/autoPyTorch/pipeline/tabular_regression.py index 855a025e8..af8702695 100644 --- a/autoPyTorch/pipeline/tabular_regression.py +++ b/autoPyTorch/pipeline/tabular_regression.py @@ -19,6 +19,8 @@ from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.encoding.base_encoder_choice import ( EncoderChoice ) +from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.feature_preprocessing. \ + base_feature_preprocessor_choice import FeatureProprocessorChoice from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.imputation.SimpleImputer import SimpleImputer from autoPyTorch.pipeline.components.preprocessing.tabular_preprocessing.scaling.base_scaler_choice import ScalerChoice from autoPyTorch.pipeline.components.setup.early_preprocessor.EarlyPreprocessing import EarlyPreprocessing @@ -187,20 +189,28 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]]) -> L default_dataset_properties.update(dataset_properties) steps.extend([ - ("imputer", SimpleImputer()), - ("encoder", EncoderChoice(default_dataset_properties)), - ("scaler", ScalerChoice(default_dataset_properties)), - ("tabular_transformer", TabularColumnTransformer()), - ("preprocessing", EarlyPreprocessing()), - ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties)), - ("network_backbone", NetworkBackboneChoice(default_dataset_properties)), - ("network_head", NetworkHeadChoice(default_dataset_properties)), - ("network", NetworkComponent()), - ("network_init", NetworkInitializerChoice(default_dataset_properties)), - ("optimizer", OptimizerChoice(default_dataset_properties)), - ("lr_scheduler", SchedulerChoice(default_dataset_properties)), - ("data_loader", FeatureDataLoader()), - ("trainer", TrainerChoice(default_dataset_properties)), + ("imputer", SimpleImputer(random_state=self.random_state)), + ("encoder", EncoderChoice(default_dataset_properties, random_state=self.random_state)), + ("scaler", ScalerChoice(default_dataset_properties, random_state=self.random_state)), + ("feature_preprocessor", FeatureProprocessorChoice(default_dataset_properties, + random_state=self.random_state)), + ("tabular_transformer", TabularColumnTransformer(random_state=self.random_state)), + ("preprocessing", EarlyPreprocessing(random_state=self.random_state)), + ("network_embedding", NetworkEmbeddingChoice(default_dataset_properties, + random_state=self.random_state)), + ("network_backbone", NetworkBackboneChoice(default_dataset_properties, + random_state=self.random_state)), + ("network_head", NetworkHeadChoice(default_dataset_properties, + random_state=self.random_state)), + ("network", NetworkComponent(random_state=self.random_state)), + ("network_init", NetworkInitializerChoice(default_dataset_properties, + random_state=self.random_state)), + ("optimizer", OptimizerChoice(default_dataset_properties, + random_state=self.random_state)), + ("lr_scheduler", SchedulerChoice(default_dataset_properties, + random_state=self.random_state)), + ("data_loader", FeatureDataLoader(random_state=self.random_state)), + ("trainer", TrainerChoice(default_dataset_properties, random_state=self.random_state)), ]) return steps diff --git a/autoPyTorch/pipeline/traditional_tabular_classification.py b/autoPyTorch/pipeline/traditional_tabular_classification.py index 1ac6aac50..51d8e6616 100644 --- a/autoPyTorch/pipeline/traditional_tabular_classification.py +++ b/autoPyTorch/pipeline/traditional_tabular_classification.py @@ -185,7 +185,8 @@ def _get_pipeline_steps(self, dataset_properties: Optional[Dict[str, Any]], default_dataset_properties.update(dataset_properties) steps.extend([ - ("model_trainer", ModelChoice(default_dataset_properties)), + ("model_trainer", ModelChoice(default_dataset_properties, + random_state=self.random_state)), ]) return steps diff --git a/examples/tabular/20_basics/example_tabular_classification.py b/examples/tabular/20_basics/example_tabular_classification.py index 1e5b08cac..1fbec6718 100644 --- a/examples/tabular/20_basics/example_tabular_classification.py +++ b/examples/tabular/20_basics/example_tabular_classification.py @@ -33,7 +33,7 @@ X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( X, y, - random_state=1, + random_state=42, ) ############################################################################ @@ -44,7 +44,8 @@ output_directory='./tmp/autoPyTorch_example_out_01', # To maintain logs of the run, set the next two as False delete_tmp_folder_after_terminate=True, - delete_output_folder_after_terminate=True + delete_output_folder_after_terminate=True, + seed=42, ) ############################################################################ diff --git a/setup.py b/setup.py index 1d8e47ba5..4fd732fdd 100755 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ "pyarrow", "pre-commit", "pytest-cov", + 'pytest-forked', "codecov", "pep8", "mypy", diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 7866e7674..59a6a3166 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -72,7 +72,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): # Internal dataset has expected settings assert estimator.dataset.task_type == 'tabular_classification' - expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3 + expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5 assert estimator.resampling_strategy == resampling_strategy assert estimator.dataset.resampling_strategy == resampling_strategy assert len(estimator.dataset.splits) == expected_num_splits @@ -140,7 +140,7 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): model = estimator._backend.load_cv_model_by_seed_and_id_and_budget( estimator.seed, successful_num_run, run_key.budget) assert isinstance(model, VotingClassifier) - assert len(model.estimators_) == 3 + assert len(model.estimators_) == 5 assert isinstance(model.estimators_[0].named_steps['network'].get_network(), torch.nn.Module) else: @@ -243,7 +243,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend): # Internal dataset has expected settings assert estimator.dataset.task_type == 'tabular_regression' - expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 3 + expected_num_splits = 1 if resampling_strategy == HoldoutValTypes.holdout_validation else 5 assert estimator.resampling_strategy == resampling_strategy assert estimator.dataset.resampling_strategy == resampling_strategy assert len(estimator.dataset.splits) == expected_num_splits @@ -310,7 +310,7 @@ def test_tabular_regression(openml_name, resampling_strategy, backend): model = estimator._backend.load_cv_model_by_seed_and_id_and_budget( estimator.seed, successful_num_run, run_key.budget) assert isinstance(model, VotingRegressor) - assert len(model.estimators_) == 3 + assert len(model.estimators_) == 5 assert isinstance(model.estimators_[0].named_steps['network'].get_network(), torch.nn.Module) else: @@ -430,14 +430,14 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): # directory, but in the temporary directory. assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch')) assert os.path.exists(os.path.join( - backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0', - 'predictions_ensemble_1_1_1.0.npy') + backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_50.0', + 'predictions_ensemble_1_1_50.0.npy') ) model_path = os.path.join(backend.temporary_directory, '.autoPyTorch', - 'runs', '1_1_1.0', - '1.1.1.0.model') + 'runs', '1_1_50.0', + '1.1.50.0.model') # Make sure the dummy model complies with scikit learn # get/set params diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index ec3a5b0aa..952ec7c78 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -112,7 +112,7 @@ def test_holdout(self, pipeline_mock): self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1) self.assertEqual(evaluator.file_output.call_count, 1) - self.assertEqual(result, 0.4782608695652174) + self.assertEqual(result, 0.5652173913043479) self.assertEqual(pipeline_mock.fit.call_count, 1) # 3 calls because of train, holdout and test set self.assertEqual(pipeline_mock.predict_proba.call_count, 3) @@ -150,15 +150,17 @@ def test_cv(self, pipeline_mock): self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1) self.assertEqual(evaluator.file_output.call_count, 1) - self.assertEqual(result, 0.463768115942029) - self.assertEqual(pipeline_mock.fit.call_count, 3) + self.assertEqual(result, 0.46235467431119603) + self.assertEqual(pipeline_mock.fit.call_count, 5) # 9 calls because of the training, holdout and - # test set (3 sets x 3 folds = 9) - self.assertEqual(pipeline_mock.predict_proba.call_count, 9) - # as the optimisation preds in cv is concatenation of the three folds, - # so it is 3*splits + # test set (3 sets x 5 folds = 15) + self.assertEqual(pipeline_mock.predict_proba.call_count, 15) + # as the optimisation preds in cv is concatenation of the 5 folds, + # so it is 5*splits self.assertEqual(evaluator.file_output.call_args[0][0].shape[0], - 3 * len(D.splits[0][1])) + # Notice this - 1: It is because the dataset D + # has shape ((69, )) which is not divisible by 5 + 5 * len(D.splits[0][1]) - 1, evaluator.file_output.call_args) self.assertIsNone(evaluator.file_output.call_args[0][1]) self.assertEqual(evaluator.file_output.call_args[0][2].shape[0], D.test_tensors[1].shape[0]) diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index bee8a820a..e1ccd49a3 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -157,6 +157,7 @@ def test_pipeline_predict_proba(self, fit_dictionary_tabular): assert isinstance(prediction, np.ndarray) assert prediction.shape == expected_output_shape + @flaky.flaky(max_runs=2) def test_pipeline_transform(self, fit_dictionary_tabular): """ In the context of autopytorch, transform expands a fit dictionary with