Skip to content

Commit

Permalink
Refactor task_wrapper decorator (#488)
Browse files Browse the repository at this point in the history
  • Loading branch information
ashleve authored Dec 18, 2022
1 parent ac93624 commit 0571b4f
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,39 @@ def task_wrapper(task_func: Callable) -> Callable:
Utilities:
- Calling the `utils.extras()` before the task is started
- Calling the `utils.close_loggers()` after the task is finished
- Calling the `utils.close_loggers()` after the task is finished or failed
- Logging the exception if occurs
- Logging the task total execution time
- Logging the output dir
"""

def wrap(cfg: DictConfig):

# apply extra utilities
extras(cfg)

# execute the task
try:
start_time = time.time()

# apply extra utilities
extras(cfg)

metric_dict, object_dict = task_func(cfg=cfg)

# things to do if exception occurs
except Exception as ex:
log.exception("") # save exception to `.log` file

# save exception to `.log` file
log.exception("")

# when using hydra plugins like Optuna, you might want to disable raising exception
# to avoid multirun failure
raise ex

# things to always do after either success or exception
finally:
path = Path(cfg.paths.output_dir, "exec_time.log")
content = f"'{cfg.task_name}' execution time: {time.time() - start_time} (s)"
save_file(path, content) # save task execution time (even if exception occurs)
close_loggers() # close loggers (even if exception occurs so multirun won't fail)

log.info(f"Output dir: {cfg.paths.output_dir}")
# display output dir path in terminal
log.info(f"Output dir: {cfg.paths.output_dir}")

# close loggers (even if exception occurs so multirun won't fail)
close_loggers()

return metric_dict, object_dict

Expand Down Expand Up @@ -83,13 +91,6 @@ def extras(cfg: DictConfig) -> None:
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)


@rank_zero_only
def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file:
file.write(content)


def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
"""Instantiates callbacks from config."""
callbacks: List[Callback] = []
Expand Down Expand Up @@ -204,3 +205,10 @@ def close_loggers() -> None:
if wandb.run:
log.info("Closing wandb!")
wandb.finish()


@rank_zero_only
def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file:
file.write(content)

0 comments on commit 0571b4f

Please sign in to comment.