Skip to content

Commit

Permalink
[fix] Fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
nabenabe0928 committed Apr 14, 2021
1 parent 9481437 commit b7726a8
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(
self._logger: Optional[PicklableClientLogger] = None
self.run_history: RunHistory = RunHistory()
self.trajectory: Optional[List] = None
self.dataset_name: Optional[str] = None
self.dataset_name: str = ""
self.cv_models_: Dict = {}
self.experiment_task_name: str = 'runSearch'

Expand Down Expand Up @@ -702,16 +702,18 @@ def _run_traditional_ml(self) -> None:
self._stopwatch.start_task(traditional_task_name)
elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)

assert self._func_eval_time_limit_secs is not None
time_for_traditional = int(
self._total_walltime_limit - elapsed_time - self._func_eval_time_limit_secs
)
self._do_traditional_prediction(time_left=time_for_traditional)
self._stopwatch.stop_task(traditional_task_name)

def _run_ensemble(self, dataset: BaseDataset, optimize_metric: str,
precision: int) -> EnsembleBuilderManager:
precision: int) -> Optional[EnsembleBuilderManager]:

assert self._logger is not None
assert self._metric is not None

elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name)
time_left_for_ensembles = max(0, self._total_walltime_limit - elapsed_time)
Expand Down Expand Up @@ -788,7 +790,7 @@ def _start_smac(self, proc_smac: AutoMLSMBO) -> None:
except Exception as e:
self._logger.warning(f"Could not save {trajectory_filename} due to {e}...")

def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager,
def _run_smac(self, dataset: BaseDataset, proc_ensemble: Optional[EnsembleBuilderManager],
budget_type: Optional[str] = None, budget: Optional[float] = None,
get_smac_object_callback: Optional[Callable] = None,
smac_scenario_args: Optional[Dict[str, Any]] = None) -> None:
Expand All @@ -805,6 +807,9 @@ def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager,
self._logger.warning(" Could not run SMAC because there is no time left")
else:
budget_config = self._get_budget_config(budget_type=budget_type, budget=budget)

assert self._func_eval_time_limit_secs is not None
assert self._metric is not None
proc_smac = AutoMLSMBO(
config_space=self.search_space,
dataset_name=dataset.dataset_name,
Expand Down Expand Up @@ -1095,7 +1100,7 @@ def refit(
Returns:
self
"""
if self.dataset_name is None:
if self.dataset_name == "":
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

if self._logger is None:
Expand Down Expand Up @@ -1165,7 +1170,7 @@ def fit(self,
Returns:
(BasePipeline): fitted pipeline
"""
if self.dataset_name is None:
if self.dataset_name == "":
self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

if self._logger is None:
Expand Down

0 comments on commit b7726a8

Please sign in to comment.