diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index afc128625..427522beb 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -174,7 +174,7 @@ def __init__( self._logger: Optional[PicklableClientLogger] = None self.run_history: RunHistory = RunHistory() self.trajectory: Optional[List] = None - self.dataset_name: Optional[str] = None + self.dataset_name: str = "" self.cv_models_: Dict = {} self.experiment_task_name: str = 'runSearch' @@ -702,6 +702,7 @@ def _run_traditional_ml(self) -> None: self._stopwatch.start_task(traditional_task_name) elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) + assert self._func_eval_time_limit_secs is not None time_for_traditional = int( self._total_walltime_limit - elapsed_time - self._func_eval_time_limit_secs ) @@ -709,9 +710,10 @@ def _run_traditional_ml(self) -> None: self._stopwatch.stop_task(traditional_task_name) def _run_ensemble(self, dataset: BaseDataset, optimize_metric: str, - precision: int) -> EnsembleBuilderManager: + precision: int) -> Optional[EnsembleBuilderManager]: assert self._logger is not None + assert self._metric is not None elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) time_left_for_ensembles = max(0, self._total_walltime_limit - elapsed_time) @@ -788,7 +790,7 @@ def _start_smac(self, proc_smac: AutoMLSMBO) -> None: except Exception as e: self._logger.warning(f"Could not save {trajectory_filename} due to {e}...") - def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager, + def _run_smac(self, dataset: BaseDataset, proc_ensemble: Optional[EnsembleBuilderManager], budget_type: Optional[str] = None, budget: Optional[float] = None, get_smac_object_callback: Optional[Callable] = None, smac_scenario_args: Optional[Dict[str, Any]] = None) -> None: @@ -805,6 +807,9 @@ def _run_smac(self, dataset: BaseDataset, proc_ensemble: EnsembleBuilderManager, self._logger.warning(" Could not run SMAC because there is no time left") else: budget_config = self._get_budget_config(budget_type=budget_type, budget=budget) + + assert self._func_eval_time_limit_secs is not None + assert self._metric is not None proc_smac = AutoMLSMBO( config_space=self.search_space, dataset_name=dataset.dataset_name, @@ -1095,7 +1100,7 @@ def refit( Returns: self """ - if self.dataset_name is None: + if self.dataset_name == "": self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) if self._logger is None: @@ -1165,7 +1170,7 @@ def fit(self, Returns: (BasePipeline): fitted pipeline """ - if self.dataset_name is None: + if self.dataset_name == "": self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) if self._logger is None: