From cd0620b445b88792a421b52c521825f70ab23180 Mon Sep 17 00:00:00 2001 From: "@picocreator (Eugene Cheah)" Date: Sat, 16 Sep 2023 00:50:49 +0000 Subject: [PATCH] better autoresume support --- RWKV-v5/lightning_trainer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/RWKV-v5/lightning_trainer.py b/RWKV-v5/lightning_trainer.py index 9652770e..d5213958 100644 --- a/RWKV-v5/lightning_trainer.py +++ b/RWKV-v5/lightning_trainer.py @@ -134,7 +134,15 @@ def process_auto_resume_ckpt(): # Handle auto_resume_ckpt_dir if its true or auto if auto_resume_ckpt_dir.lower() == "true" or auto_resume_ckpt_dir.lower() == "auto": print(f"[RWKV.lightning_trainer.py] Extracting checkpoint dir from config, for --auto-resume-ckpt-dir={auto_resume_ckpt_dir}") - auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None) + # Handle the auto resume overwrite, via CLI + if CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"] is not None: + auto_resume_ckpt_dir = CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"] + else: + # Try to get as an object, then an object in an array + auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None) + if auto_resume_ckpt_dir is None: + auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", [{}])[0].get("init_args", {}).get("dirpath", None) + # Safety check on the dir assert auto_resume_ckpt_dir is not None, "Failed to extract checkpoint dir from config, for --auto-resume-ckpt-dir=True" # Log the setting flag