Skip to content

Commit

Permalink
better autoresume support
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCreator committed Sep 16, 2023
1 parent 3a62ea3 commit cd0620b
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion RWKV-v5/lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,15 @@ def process_auto_resume_ckpt():
# Handle auto_resume_ckpt_dir if its true or auto
if auto_resume_ckpt_dir.lower() == "true" or auto_resume_ckpt_dir.lower() == "auto":
print(f"[RWKV.lightning_trainer.py] Extracting checkpoint dir from config, for --auto-resume-ckpt-dir={auto_resume_ckpt_dir}")
auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None)
# Handle the auto resume overwrite, via CLI
if CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"] is not None:
auto_resume_ckpt_dir = CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"]
else:
# Try to get as an object, then an object in an array
auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None)
if auto_resume_ckpt_dir is None:
auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", [{}])[0].get("init_args", {}).get("dirpath", None)
# Safety check on the dir
assert auto_resume_ckpt_dir is not None, "Failed to extract checkpoint dir from config, for --auto-resume-ckpt-dir=True"

# Log the setting flag
Expand Down

0 comments on commit cd0620b

Please sign in to comment.