Skip to content

Commit

Permalink
autoPyTorch/api/
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Jul 21, 2022
1 parent 8f8dee1 commit 5f3d4b6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
3 changes: 1 addition & 2 deletions test/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,6 @@ def test_tabular_input_support(openml_id, backend):
estimator = TabularClassificationTask(
backend=backend,
resampling_strategy=HoldoutValTypes.holdout_validation,
ensemble_size=0,
)

estimator._do_dummy_prediction = unittest.mock.MagicMock()
Expand All @@ -624,6 +623,7 @@ def test_tabular_input_support(openml_id, backend):
func_eval_time_limit_secs=50,
enable_traditional_pipeline=False,
load_models=False,
ensemble_size=0,
)


Expand All @@ -633,7 +633,6 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular):
estimator = TabularClassificationTask(
backend=backend,
resampling_strategy=HoldoutValTypes.holdout_validation,
ensemble_size=0,
)

# Setup pre-requisites normally set by search()
Expand Down
5 changes: 2 additions & 3 deletions test/test_api/test_base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_set_pipeline_config():
])
def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, budget_type, expected):
BaseTask.__abstractmethods__ = set()
estimator = BaseTask(task_type='tabular_classification', ensemble_size=0)
estimator = BaseTask(task_type='tabular_classification')

# Fixture pipeline config
default_pipeline_config = {
Expand All @@ -141,7 +141,7 @@ def test_pipeline_get_budget(fit_dictionary_tabular, min_budget, max_budget, bud
smac_mock.return_value = smac
estimator._search(optimize_metric='accuracy', dataset=dataset, tae_func=pipeline_fit,
min_budget=min_budget, max_budget=max_budget, budget_type=budget_type,
enable_traditional_pipeline=False,
ensemble_size=0, enable_traditional_pipeline=False,
total_walltime_limit=20, func_eval_time_limit_secs=10,
load_models=False)
assert list(smac_mock.call_args)[1]['ta_kwargs']['pipeline_config'] == default_pipeline_config
Expand Down Expand Up @@ -210,7 +210,6 @@ def test_init_ensemble_builder(backend):
BaseTask.__abstractmethods__ = set()
estimator = BaseTask(
backend=backend,
ensemble_size=0,
)

# Setup pre-requisites normally set by search()
Expand Down

0 comments on commit 5f3d4b6

Please sign in to comment.