Skip to content

Commit

Permalink
fix flake and issue #299
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Dec 25, 2021
1 parent 4d16352 commit 2569602
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 69 deletions.
25 changes: 10 additions & 15 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def __init__(
self._scoring_functions: Optional[List[autoPyTorchMetric]] = None
self._logger: Optional[PicklableClientLogger] = None
self.dataset_name: Optional[str] = None
self.dataset = Optional[BaseDataset]
self.cv_models_: Dict = {}

self._results_manager = ResultsManager()
Expand Down Expand Up @@ -616,20 +617,7 @@ def _load_best_individual_model(self) -> SingleBest:
run_history=self.run_history,
backend=self._backend,
)
if self._logger is None:
warnings.warn(
"No valid ensemble was created. Please check the log"
"file for errors. Default to the best individual estimator:{}".format(
ensemble.identifiers_
)
)
else:
self._logger.exception(
"No valid ensemble was created. Please check the log"
"file for errors. Default to the best individual estimator:{}".format(
ensemble.identifiers_
)
)


return ensemble

Expand Down Expand Up @@ -1257,7 +1245,6 @@ def _search(
if proc_ensemble is not None:
self._collect_results_ensemble(proc_ensemble)


self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")
Expand All @@ -1267,6 +1254,14 @@ def _search(
self._load_models()
self._logger.info("Finished loading models...")

if isinstance(self.ensemble_, SingleBest) and ensemble_size > 0:
self._logger.exception(
"No valid ensemble was created. Please check the log"
"file for errors. Default to the best individual estimator:{}".format(
self.ensemble_.identifiers_
)
)

self._cleanup()

return self
Expand Down
100 changes: 50 additions & 50 deletions examples/40_advanced/example_posthoc_ensemble_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,59 +24,59 @@
from autoPyTorch.api.tabular_classification import TabularClassificationTask


if __name__ == '__main__':
############################################################################
# Data Loading
# ============
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X,
y,
random_state=42,
)

############################################################################
# Data Loading
# ============
X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X,
y,
random_state=42,
)
############################################################################
# Build and fit a classifier
# ==========================
api = TabularClassificationTask(
seed=42,
)

############################################################################
# Build and fit a classifier
# ==========================
api = TabularClassificationTask(
ensemble_size=0,
seed=42,
)
############################################################################
# Search for the best neural network
# ==================================
api.search(
X_train=X_train,
y_train=y_train,
X_test=X_test.copy(),
y_test=y_test.copy(),
optimize_metric='accuracy',
total_walltime_limit=100,
func_eval_time_limit_secs=50,
ensemble_size=0,
)

############################################################################
# Search for the best neural network
# ==================================
api.search(
X_train=X_train,
y_train=y_train,
X_test=X_test.copy(),
y_test=y_test.copy(),
optimize_metric='accuracy',
total_walltime_limit=100,
func_eval_time_limit_secs=50
)
############################################################################
# Print the final performance of the incumbent neural network
# ===========================================================
print(api.run_history, api.trajectory)
y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)

############################################################################
# Print the final performance of the incumbent neural network
# ===========================================================
print(api.run_history, api.trajectory)
y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)
############################################################################
# Fit an ensemble with the neural networks fitted during the search
# =================================================================

############################################################################
# Fit an ensemble with the neural networks fitted during the search
# =================================================================
api.fit_ensemble(ensemble_size=5,
# Set the enable_traditional_pipeline=True
# to also include traditional models
# in the ensemble
enable_traditional_pipeline=False)
# Print the final ensemble built by AutoPyTorch
y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)
print(api.show_models())

api.fit_ensemble(ensemble_size=5,
# Set the enable_traditional_pipeline=True
# to also include traditional models
# in the ensemble
enable_traditional_pipeline=False)
# Print the final ensemble built by AutoPyTorch
y_pred = api.predict(X_test)
score = api.score(y_pred, y_test)
print(score)
print(api.show_models())
api._cleanup()
# Print statistics from search
print(api.sprint_statistics())
7 changes: 3 additions & 4 deletions test/test_api/test_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from autoPyTorch.constants import TABULAR_CLASSIFICATION, TABULAR_REGRESSION
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline


# ====
Expand Down Expand Up @@ -167,13 +167,12 @@ def test_init_ensemble_builder(backend):
time_left_for_ensembles=60,
optimize_metric='accuracy',
ensemble_nbest=10,
ensemble_size=5
)
ensemble_size=5)

assert isinstance(proc_ensemble, EnsembleBuilderManager)
assert proc_ensemble.opt_metric == 'accuracy'
assert proc_ensemble.metrics[0] == accuracy

estimator._cleanup()

del estimator
del estimator

0 comments on commit 2569602

Please sign in to comment.