diff --git a/src/utils/utils.py b/src/utils/utils.py index e126078b9..0860d5fcf 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -22,31 +22,39 @@ def task_wrapper(task_func: Callable) -> Callable: Utilities: - Calling the `utils.extras()` before the task is started - - Calling the `utils.close_loggers()` after the task is finished + - Calling the `utils.close_loggers()` after the task is finished or failed - Logging the exception if occurs - - Logging the task total execution time - Logging the output dir """ def wrap(cfg: DictConfig): - # apply extra utilities - extras(cfg) - # execute the task try: - start_time = time.time() + + # apply extra utilities + extras(cfg) + metric_dict, object_dict = task_func(cfg=cfg) + + # things to do if exception occurs except Exception as ex: - log.exception("") # save exception to `.log` file + + # save exception to `.log` file + log.exception("") + + # when using hydra plugins like Optuna, you might want to disable raising exception + # to avoid multirun failure raise ex + + # things to always do after either success or exception finally: - path = Path(cfg.paths.output_dir, "exec_time.log") - content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)" - save_file(path, content) # save task execution time (even if exception occurs) - close_loggers() # close loggers (even if exception occurs so multirun won't fail) - log.info(f"Output dir: {cfg.paths.output_dir}") + # display output dir path in terminal + log.info(f"Output dir: {cfg.paths.output_dir}") + + # close loggers (even if exception occurs so multirun won't fail) + close_loggers() return metric_dict, object_dict @@ -83,13 +91,6 @@ def extras(cfg: DictConfig) -> None: rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) -@rank_zero_only -def save_file(path: str, content: str) -> None: - """Save file in rank zero mode (only on one process in multi-GPU setup).""" - with open(path, "w+") as file: - file.write(content) - - def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: """Instantiates callbacks from config.""" callbacks: List[Callback] = [] @@ -204,3 +205,10 @@ def close_loggers() -> None: if wandb.run: log.info("Closing wandb!") wandb.finish() + + +@rank_zero_only +def save_file(path: str, content: str) -> None: + """Save file in rank zero mode (only on one process in multi-GPU setup).""" + with open(path, "w+") as file: + file.write(content)