Skip to content

Commit

Permalink
Merge branch 'refactor_development' into refactor_basetask
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 authored Apr 14, 2021
2 parents 9481437 + 5fef094 commit 02991d5
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 35 deletions.
34 changes: 24 additions & 10 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(self, config: str,
configuration_space = self.pipeline.get_hyperparameter_search_space()
default_configuration = configuration_space.get_default_configuration().get_dictionary()
default_configuration['model_trainer:tabular_classifier:classifier'] = config
configuration = Configuration(configuration_space, default_configuration)
self.pipeline.set_hyperparameters(configuration)
self.configuration = Configuration(configuration_space, default_configuration)
self.pipeline.set_hyperparameters(self.configuration)

def fit(self, X: Dict[str, Any], y: Any,
sample_weight: Optional[np.ndarray] = None) -> object:
Expand All @@ -102,8 +102,18 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict[str, Any]: # pylint: disable=R0201
"""
Can be used to return additional info for the run.
Returns:
Dict[str, Any]:
Currently contains
1. pipeline_configuration: the configuration of the pipeline, i.e, the traditional model used
2. trainer_configuration: the parameters for the traditional model used.
Can be found in autoPyTorch/pipeline/components/setup/traditional_ml/classifier_configs
"""
return {'pipeline_configuration': self.configuration,
'trainer_configuration': self.pipeline.named_steps['model_trainer'].choice.model.get_config()}

def get_pipeline_representation(self) -> Dict[str, str]:
return self.pipeline.get_pipeline_representation()
Expand Down Expand Up @@ -134,7 +144,9 @@ def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None
) -> None:
self.configuration = config
self.config = config
self.init_params = init_params
self.random_state = random_state
if config == 1:
super(DummyClassificationPipeline, self).__init__(strategy="uniform")
else:
Expand Down Expand Up @@ -163,8 +175,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict: # pylint: disable=R0201
return {}

def get_pipeline_representation(self) -> Dict[str, str]:
return {
Expand Down Expand Up @@ -198,7 +210,9 @@ class DummyRegressionPipeline(DummyRegressor):
def __init__(self, config: Configuration,
random_state: Optional[Union[int, np.random.RandomState]] = None,
init_params: Optional[Dict] = None) -> None:
self.configuration = config
self.config = config
self.init_params = init_params
self.random_state = random_state
if config == 1:
super(DummyRegressionPipeline, self).__init__(strategy='mean')
else:
Expand All @@ -219,8 +233,8 @@ def predict(self, X: Union[np.ndarray, pd.DataFrame],
def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201
return False

def get_additional_run_info(self) -> None: # pylint: disable=R0201
return None
def get_additional_run_info(self) -> Dict: # pylint: disable=R0201
return {}

@staticmethod
def get_default_pipeline_options() -> Dict[str, Any]:
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def fit_predict_try_except_decorator(
def get_cost_of_crash(metric: autoPyTorchMetric) -> float:
# The metric must always be defined to extract optimum/worst
if not isinstance(metric, autoPyTorchMetric):
raise ValueError("The metric must be stricly be an instance of autoPyTorchMetric")
raise ValueError("The metric must be strictly be an instance of autoPyTorchMetric")

# Autopytorch optimizes the err. This function translates
# worst_possible_result to be a minimization problem.
Expand Down
5 changes: 4 additions & 1 deletion autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def fit_predict_and_loss(self) -> None:
# weights for opt_losses.
opt_fold_weights = [np.NaN] * self.num_folds

additional_run_info = {}

for i, (train_split, test_split) in enumerate(self.splits):

pipeline = self.pipelines[i]
Expand Down Expand Up @@ -178,7 +180,8 @@ def fit_predict_and_loss(self) -> None:
# number of optimization data points for this fold.
# Used for weighting the average.
opt_fold_weights[i] = len(train_split)

additional_run_info.update(pipeline.get_additional_run_info() if hasattr(
pipeline, 'get_additional_run_info') and pipeline.get_additional_run_info() is not None else {})
# Compute weights of each fold based on the number of samples in each
# fold.
train_fold_weights = [w / sum(train_fold_weights)
Expand Down
45 changes: 45 additions & 0 deletions autoPyTorch/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def __init__(self, context: BackendContext):
self._logger = None # type: Optional[PicklableClientLogger]
self.context = context

# Track the number of configurations launched
# num_run == 1 means a dummy estimator run
self.active_num_run = 1

# Create the temporary directory if it does not yet exist
try:
os.makedirs(self.temporary_directory)
Expand Down Expand Up @@ -329,6 +333,47 @@ def get_runs_directory(self) -> str:
def get_numrun_directory(self, seed: int, num_run: int, budget: float) -> str:
return os.path.join(self.internals_directory, 'runs', '%d_%d_%s' % (seed, num_run, budget))

def get_next_num_run(self, peek: bool = False) -> int:
"""
Every pipeline that is fitted by the estimator is stored with an
identifier called num_run. A dummy classifier will always have a num_run
equal to 1, and all other new configurations that are explored will
have a sequentially increasing identifier.
This method returns the next num_run a configuration should take.
Parameters
----------
peek: bool
By default, the next num_rum will be returned, i.e. self.active_num_run + 1
Yet, if this bool parameter is equal to True, the value of the current
num_run is provided, i.e, self.active_num_run.
In other words, peek allows to get the current maximum identifier
of a configuration.
Returns
-------
num_run: int
An unique identifier for a configuration
"""

# If there are other num_runs, their name would be runs/<seed>_<num_run>_<budget>
other_num_runs = [int(os.path.basename(run_dir).split('_')[1])
for run_dir in glob.glob(os.path.join(self.internals_directory,
'runs',
'*'))]
if len(other_num_runs) > 0:
# We track the number of runs from two forefronts:
# The physically available num_runs (which might be deleted or a crash could happen)
# From a internally kept attribute. The later should be sufficient, but we
# want to be robust against multiple backend copies on different workers
self.active_num_run = max([self.active_num_run] + other_num_runs)

# We are interested in the next run id
if not peek:
self.active_num_run += 1
return self.active_num_run

def get_model_filename(self, seed: int, idx: int, budget: float) -> str:
return '%s.%s.%s.model' % (seed, idx, budget)

Expand Down
93 changes: 74 additions & 19 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@

import sklearn
import sklearn.datasets
from sklearn.base import clone
from sklearn.ensemble import VotingClassifier, VotingRegressor

from smac.runhistory.runhistory import RunHistory

import torch

from autoPyTorch.api.tabular_classification import TabularClassificationTask
Expand All @@ -23,6 +26,7 @@
HoldoutValTypes,
)
from autoPyTorch.optimizer.smbo import AutoMLSMBO
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy


# Fixtures
Expand Down Expand Up @@ -104,17 +108,20 @@ def test_tabular_classification(openml_id, resampling_strategy, backend):

# Search for an existing run key in disc. A individual model might have
# a timeout and hence was not written to disc
successful_num_run = None
SUCCESS = False
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
if 'SUCCESS' not in str(value.status):
continue

run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
if 'SUCCESS' in str(value.status):
run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
successful_num_run = run_key.config_id + 1
break
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
SUCCESS = True
break

assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}"

if resampling_strategy == HoldoutValTypes.holdout_validation:
model_file = os.path.join(run_key_model_run_dir,
Expand Down Expand Up @@ -272,17 +279,20 @@ def test_tabular_regression(openml_name, resampling_strategy, backend):

# Search for an existing run key in disc. A individual model might have
# a timeout and hence was not written to disc
successful_num_run = None
SUCCESS = False
for i, (run_key, value) in enumerate(estimator.run_history.data.items()):
if 'SUCCESS' not in str(value.status):
continue

run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
if 'SUCCESS' in str(value.status):
run_key_model_run_dir = estimator._backend.get_numrun_directory(
estimator.seed, run_key.config_id + 1, run_key.budget)
successful_num_run = run_key.config_id + 1
break
if os.path.exists(run_key_model_run_dir):
# Runkey config id is different from the num_run
# more specifically num_run = config_id + 1(dummy)
SUCCESS = True
break

assert SUCCESS, f"Successful run was not properly saved for num_run: {successful_num_run}"

if resampling_strategy == HoldoutValTypes.holdout_validation:
model_file = os.path.join(run_key_model_run_dir,
Expand Down Expand Up @@ -384,7 +394,7 @@ def test_tabular_input_support(openml_id, backend):
estimator._do_dummy_prediction = unittest.mock.MagicMock()

with unittest.mock.patch.object(AutoMLSMBO, 'run_smbo') as AutoMLSMBOMock:
AutoMLSMBOMock.return_value = ({}, {}, 'epochs')
AutoMLSMBOMock.return_value = (RunHistory(), {}, 'epochs')
estimator.search(
X_train=X_train, y_train=y_train,
X_test=X_test, y_test=y_test,
Expand All @@ -394,3 +404,48 @@ def test_tabular_input_support(openml_id, backend):
enable_traditional_pipeline=False,
load_models=False,
)


@pytest.mark.parametrize("fit_dictionary_tabular", ['classification_categorical_only'], indirect=True)
def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
backend = fit_dictionary_tabular['backend']
estimator = TabularClassificationTask(
backend=backend,
resampling_strategy=HoldoutValTypes.holdout_validation,
ensemble_size=0,
)

# Setup pre-requisites normally set by search()
estimator._create_dask_client()
estimator._metric = accuracy
estimator._logger = estimator._get_logger('test')
estimator._memory_limit = 5000
estimator._time_for_task = 60
estimator._disable_file_output = []
estimator._all_supported_metrics = False

estimator._do_dummy_prediction()

# Ensure that the dummy predictions are not in the current working
# directory, but in the temporary directory.
assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch'))
assert os.path.exists(os.path.join(
backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0',
'predictions_ensemble_1_1_1.0.npy')
)

model_path = os.path.join(backend.temporary_directory,
'.autoPyTorch',
'runs', '1_1_1.0',
'1.1.1.0.model')

# Make sure the dummy model complies with scikit learn
# get/set params
assert os.path.exists(model_path)
with open(model_path, 'rb') as model_handler:
clone(pickle.load(model_handler))

estimator._close_dask_client()
estimator._clean_logger()

del estimator
4 changes: 2 additions & 2 deletions test/test_evaluation/test_train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __init__(self):
def predict_proba(self, X, batch_size=None):
return np.tile([0.6, 0.4], (len(X), 1))

def get_additional_run_info(self) -> None:
return None
def get_additional_run_info(self):
return {}


class TestTrainEvaluator(BaseEvaluatorTest, unittest.TestCase):
Expand Down
4 changes: 2 additions & 2 deletions test/test_pipeline/test_tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,5 +439,5 @@ def test_constant_pipeline_iris(fit_dictionary_tabular):
val_score = run_summary.performance_tracker['val_metrics'][epoch_where_best]['balanced_accuracy']
train_score = run_summary.performance_tracker['train_metrics'][epoch_where_best]['balanced_accuracy']

assert val_score >= 0.9, run_summary.performance_tracker['val_metrics']
assert train_score >= 0.9, run_summary.performance_tracker['train_metrics']
assert val_score >= 0.8, run_summary.performance_tracker['val_metrics']
assert train_score >= 0.8, run_summary.performance_tracker['train_metrics']
23 changes: 23 additions & 0 deletions test/test_utils/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# -*- encoding: utf-8 -*-
import builtins
import logging.handlers
import unittest
import unittest.mock

import numpy as np

import pytest

from autoPyTorch.utils.backend import Backend
Expand Down Expand Up @@ -81,3 +84,23 @@ def test_loads_models_by_identifiers(exists_mock, openMock, pickleLoadMock, back

assert isinstance(actual_dict, dict)
assert expected_dict == actual_dict


def test_get_next_num_run(backend):
# Asking for a num_run increases the tracked num_run
assert backend.get_next_num_run() == 2
assert backend.get_next_num_run() == 3
# Then test that we are robust against new files being generated
backend.setup_logger('Test', logging.handlers.DEFAULT_TCP_LOGGING_PORT)
backend.save_numrun_to_dir(
seed=0,
idx=12,
budget=0.0,
model=dict(),
cv_model=None,
ensemble_predictions=np.zeros(10),
valid_predictions=None,
test_predictions=None,
)
assert backend.get_next_num_run() == 13
assert backend.get_next_num_run(peek=True) == 13

0 comments on commit 02991d5

Please sign in to comment.