diff --git a/RWKV-v5/lightning_trainer.py b/RWKV-v5/lightning_trainer.py index 9652770e..d5213958 100644 --- a/RWKV-v5/lightning_trainer.py +++ b/RWKV-v5/lightning_trainer.py @@ -134,7 +134,15 @@ def process_auto_resume_ckpt(): # Handle auto_resume_ckpt_dir if its true or auto if auto_resume_ckpt_dir.lower() == "true" or auto_resume_ckpt_dir.lower() == "auto": print(f"[RWKV.lightning_trainer.py] Extracting checkpoint dir from config, for --auto-resume-ckpt-dir={auto_resume_ckpt_dir}") - auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None) + # Handle the auto resume overwrite, via CLI + if CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"] is not None: + auto_resume_ckpt_dir = CLI_ARGS_MAP["--trainer.callbacks.init_args.dirpath"] + else: + # Try to get as an object, then an object in an array + auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", {}).get("init_args", {}).get("dirpath", None) + if auto_resume_ckpt_dir is None: + auto_resume_ckpt_dir = LIGHTNING_CONFIG.get("trainer", {}).get("callbacks", [{}])[0].get("init_args", {}).get("dirpath", None) + # Safety check on the dir assert auto_resume_ckpt_dir is not None, "Failed to extract checkpoint dir from config, for --auto-resume-ckpt-dir=True" # Log the setting flag