diff --git a/autoPyTorch/evaluation/abstract_evaluator.py b/autoPyTorch/evaluation/abstract_evaluator.py index 7171177d7..63ef6bbb0 100644 --- a/autoPyTorch/evaluation/abstract_evaluator.py +++ b/autoPyTorch/evaluation/abstract_evaluator.py @@ -84,8 +84,8 @@ def __init__(self, config: str, configuration_space = self.pipeline.get_hyperparameter_search_space() default_configuration = configuration_space.get_default_configuration().get_dictionary() default_configuration['model_trainer:tabular_classifier:classifier'] = config - configuration = Configuration(configuration_space, default_configuration) - self.pipeline.set_hyperparameters(configuration) + self.configuration = Configuration(configuration_space, default_configuration) + self.pipeline.set_hyperparameters(self.configuration) def fit(self, X: Dict[str, Any], y: Any, sample_weight: Optional[np.ndarray] = None) -> object: @@ -102,8 +102,18 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame], def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201 return False - def get_additional_run_info(self) -> None: # pylint: disable=R0201 - return None + def get_additional_run_info(self) -> Dict[str, Any]: # pylint: disable=R0201 + """ + Can be used to return additional info for the run. + Returns: + Dict[str, Any]: + Currently contains + 1. pipeline_configuration: the configuration of the pipeline, i.e, the traditional model used + 2. trainer_configuration: the parameters for the traditional model used. + Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/classifier_configs + """ + return {'pipeline_configuration': self.configuration, + 'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()} def get_pipeline_representation(self) -> Dict[str, str]: return self.pipeline.get_pipeline_representation() @@ -134,7 +144,9 @@ def __init__(self, config: Configuration, random_state: Optional[Union[int, np.random.RandomState]] = None, init_params: Optional[Dict] = None ) -> None: - self.configuration = config + self.config = config + self.init_params = init_params + self.random_state = random_state if config == 1: super(DummyClassificationPipeline, self).__init__(strategy="uniform") else: @@ -163,8 +175,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame], def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201 return False - def get_additional_run_info(self) -> None: # pylint: disable=R0201 - return None + def get_additional_run_info(self) -> Dict: # pylint: disable=R0201 + return {} def get_pipeline_representation(self) -> Dict[str, str]: return { @@ -198,7 +210,9 @@ class DummyRegressionPipeline(DummyRegressor): def __init__(self, config: Configuration, random_state: Optional[Union[int, np.random.RandomState]] = None, init_params: Optional[Dict] = None) -> None: - self.configuration = config + self.config = config + self.init_params = init_params + self.random_state = random_state if config == 1: super(DummyRegressionPipeline, self).__init__(strategy='mean') else: @@ -219,8 +233,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame], def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201 return False - def get_additional_run_info(self) -> None: # pylint: disable=R0201 - return None + def get_additional_run_info(self) -> Dict: # pylint: disable=R0201 + return {} @staticmethod def get_default_pipeline_options() -> Dict[str, Any]: diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index b4f53b6e4..1ef4f552d 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -57,7 +57,7 @@ def fit_predict_try_except_decorator( def get_cost_of_crash(metric: autoPyTorchMetric) -> float: # The metric must always be defined to extract optimum/worst if not isinstance(metric, autoPyTorchMetric): - raise ValueError("The metric must be stricly be an instance of autoPyTorchMetric") + raise ValueError("The metric must be strictly be an instance of autoPyTorchMetric") # Autopytorch optimizes the err. This function translates # worst_possible_result to be a minimization problem. diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index 88b5e81da..5ffb8b8db 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -143,6 +143,8 @@ def fit_predict_and_loss(self) -> None: # weights for opt_losses. opt_fold_weights = [np.NaN] * self.num_folds + additional_run_info = {} + for i, (train_split, test_split) in enumerate(self.splits): pipeline = self.pipelines[i] @@ -178,7 +180,8 @@ def fit_predict_and_loss(self) -> None: # number of optimization data points for this fold. # Used for weighting the average. opt_fold_weights[i] = len(train_split) - + additional_run_info.update(pipeline.get_additional_run_info() if hasattr( + pipeline, 'get_additional_run_info') and pipeline.get_additional_run_info() is not None else {}) # Compute weights of each fold based on the number of samples in each # fold. train_fold_weights = [w / sum(train_fold_weights) diff --git a/autoPyTorch/utils/backend.py b/autoPyTorch/utils/backend.py index 5111c116f..50a1c4d38 100644 --- a/autoPyTorch/utils/backend.py +++ b/autoPyTorch/utils/backend.py @@ -169,6 +169,10 @@ def __init__(self, context: BackendContext): self._logger = None # type: Optional[PicklableClientLogger] self.context = context + # Track the number of configurations launched + # num_run == 1 means a dummy estimator run + self.active_num_run = 1 + # Create the temporary directory if it does not yet exist try: os.makedirs(self.temporary_directory) @@ -329,6 +333,47 @@ def get_runs_directory(self) -> str: def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str: return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget)) + def get_next_num_run(self, peek: bool = False) -> int: + """ + Every pipeline that is fitted by the estimator is stored with an + identifier called num_run. A dummy classifier will always have a num_run + equal to 1, and all other new configurations that are explored will + have a sequentially increasing identifier. + + This method returns the next num_run a configuration should take. + + Parameters + ---------- + peek: bool + By default, the next num_rum will be returned, i.e. self.active_num_run + 1 + Yet, if this bool parameter is equal to True, the value of the current + num_run is provided, i.e, self.active_num_run. + In other words, peek allows to get the current maximum identifier + of a configuration. + + Returns + ------- + num_run: int + An unique identifier for a configuration + """ + + # If there are other num_runs, their name would be runs/__ + other_num_runs = [int(os.path.basename(run_dir).split('_')[1]) + for run_dir in glob.glob(os.path.join(self.internals_directory, + 'runs', + '*'))] + if len(other_num_runs) > 0: + # We track the number of runs from two forefronts: + # The physically available num_runs (which might be deleted or a crash could happen) + # From a internally kept attribute. The later should be sufficient, but we + # want to be robust against multiple backend copies on different workers + self.active_num_run = max([self.active_num_run] + other_num_runs) + + # We are interested in the next run id + if not peek: + self.active_num_run += 1 + return self.active_num_run + def get_model_filename(self, seed: int, idx: int, budget: float) -> str: return '%s.%s.%s.model' % (seed, idx, budget) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 6af387298..7866e7674 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -12,8 +12,11 @@ import sklearn import sklearn.datasets +from sklearn.base import clone from sklearn.ensemble import VotingClassifier, VotingRegressor +from smac.runhistory.runhistory import RunHistory + import torch from autoPyTorch.api.tabular_classification import TabularClassificationTask @@ -23,6 +26,7 @@ HoldoutValTypes, ) from autoPyTorch.optimizer.smbo import AutoMLSMBO +from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy # Fixtures @@ -104,17 +108,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): # Search for an existing run key in disc. A individual model might have # a timeout and hence was not written to disc + successful_num_run = None + SUCCESS = False for i, (run_key, value) in enumerate(estimator.run_history.data.items()): - if 'SUCCESS' not in str(value.status): - continue - - run_key_model_run_dir = estimator._backend.get_numrun_directory( - estimator.seed, run_key.config_id + 1, run_key.budget) - if os.path.exists(run_key_model_run_dir): - # Runkey config id is different from the num_run - # more specifically num_run = config_id + 1(dummy) + if 'SUCCESS' in str(value.status): + run_key_model_run_dir = estimator._backend.get_numrun_directory( + estimator.seed, run_key.config_id + 1, run_key.budget) successful_num_run = run_key.config_id + 1 - break + if os.path.exists(run_key_model_run_dir): + # Runkey config id is different from the num_run + # more specifically num_run = config_id + 1(dummy) + SUCCESS = True + break + + assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}" if resampling_strategy == HoldoutValTypes.holdout_validation: model_file = os.path.join(run_key_model_run_dir, @@ -272,17 +279,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend): # Search for an existing run key in disc. A individual model might have # a timeout and hence was not written to disc + successful_num_run = None + SUCCESS = False for i, (run_key, value) in enumerate(estimator.run_history.data.items()): - if 'SUCCESS' not in str(value.status): - continue - - run_key_model_run_dir = estimator._backend.get_numrun_directory( - estimator.seed, run_key.config_id + 1, run_key.budget) - if os.path.exists(run_key_model_run_dir): - # Runkey config id is different from the num_run - # more specifically num_run = config_id + 1(dummy) + if 'SUCCESS' in str(value.status): + run_key_model_run_dir = estimator._backend.get_numrun_directory( + estimator.seed, run_key.config_id + 1, run_key.budget) successful_num_run = run_key.config_id + 1 - break + if os.path.exists(run_key_model_run_dir): + # Runkey config id is different from the num_run + # more specifically num_run = config_id + 1(dummy) + SUCCESS = True + break + + assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}" if resampling_strategy == HoldoutValTypes.holdout_validation: model_file = os.path.join(run_key_model_run_dir, @@ -384,7 +394,7 @@ def test_tabular_input_support(openml_id, backend): estimator._do_dummy_prediction = unittest.mock.MagicMock() with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock: - AutoMLSMBOMock.return_value = ({}, {}, 'epochs') + AutoMLSMBOMock.return_value = (RunHistory(), {}, 'epochs') estimator.search( X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, @@ -394,3 +404,48 @@ def test_tabular_input_support(openml_id, backend): enable_traditional_pipeline=False, load_models=False, ) + + +@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True) +def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): + backend = fit_dictionary_tabular['backend'] + estimator = TabularClassificationTask( + backend=backend, + resampling_strategy=HoldoutValTypes.holdout_validation, + ensemble_size=0, + ) + + # Setup pre-requisites normally set by search() + estimator._create_dask_client() + estimator._metric = accuracy + estimator._logger = estimator._get_logger('test') + estimator._memory_limit = 5000 + estimator._time_for_task = 60 + estimator._disable_file_output = [] + estimator._all_supported_metrics = False + + estimator._do_dummy_prediction() + + # Ensure that the dummy predictions are not in the current working + # directory, but in the temporary directory. + assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch')) + assert os.path.exists(os.path.join( + backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0', + 'predictions_ensemble_1_1_1.0.npy') + ) + + model_path = os.path.join(backend.temporary_directory, + '.autoPyTorch', + 'runs', '1_1_1.0', + '1.1.1.0.model') + + # Make sure the dummy model complies with scikit learn + # get/set params + assert os.path.exists(model_path) + with open(model_path, 'rb') as model_handler: + clone(pickle.load(model_handler)) + + estimator._close_dask_client() + estimator._clean_logger() + + del estimator diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index 67132285e..ec3a5b0aa 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -50,8 +50,8 @@ def __init__(self): def predict_proba(self, X, batch_size=None): return np.tile([0.6, 0.4], (len(X), 1)) - def get_additional_run_info(self) -> None: - return None + def get_additional_run_info(self): + return {} class TestTrainEvaluator(BaseEvaluatorTest, unittest.TestCase): diff --git a/test/test_pipeline/test_tabular_classification.py b/test/test_pipeline/test_tabular_classification.py index 4ba8572d1..bee8a820a 100644 --- a/test/test_pipeline/test_tabular_classification.py +++ b/test/test_pipeline/test_tabular_classification.py @@ -439,5 +439,5 @@ def test_constant_pipeline_iris(fit_dictionary_tabular): val_score = run_summary.performance_tracker['val_metrics'][epoch_where_best]['balanced_accuracy'] train_score = run_summary.performance_tracker['train_metrics'][epoch_where_best]['balanced_accuracy'] - assert val_score >= 0.9, run_summary.performance_tracker['val_metrics'] - assert train_score >= 0.9, run_summary.performance_tracker['train_metrics'] + assert val_score >= 0.8, run_summary.performance_tracker['val_metrics'] + assert train_score >= 0.8, run_summary.performance_tracker['train_metrics'] diff --git a/test/test_utils/test_backend.py b/test/test_utils/test_backend.py index becea67fb..9f8432884 100644 --- a/test/test_utils/test_backend.py +++ b/test/test_utils/test_backend.py @@ -1,8 +1,11 @@ # -*- encoding: utf-8 -*- import builtins +import logging.handlers import unittest import unittest.mock +import numpy as np + import pytest from autoPyTorch.utils.backend import Backend @@ -81,3 +84,23 @@ def test_loads_models_by_identifiers(exists_mock, openMock, pickleLoadMock, back assert isinstance(actual_dict, dict) assert expected_dict == actual_dict + + +def test_get_next_num_run(backend): + # Asking for a num_run increases the tracked num_run + assert backend.get_next_num_run() == 2 + assert backend.get_next_num_run() == 3 + # Then test that we are robust against new files being generated + backend.setup_logger('Test', logging.handlers.DEFAULT_TCP_LOGGING_PORT) + backend.save_numrun_to_dir( + seed=0, + idx=12, + budget=0.0, + model=dict(), + cv_model=None, + ensemble_predictions=np.zeros(10), + valid_predictions=None, + test_predictions=None, + ) + assert backend.get_next_num_run() == 13 + assert backend.get_next_num_run(peek=True) == 13