Skip to content

Commit

Permalink
fix bug in config
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Jul 19, 2024
1 parent 2ca1424 commit 621a542
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/fairchem/core/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,17 @@ def __init__(self, config) -> None:
def setup(self, trainer) -> None:
self.trainer = trainer

self.chkpt_path = os.path.join(
self.trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)

# if the supplied checkpoint exists, then load that
if self.config["checkpoint"] is not None:
logging.info(f"Attemping to load user specified checkpoint at {self.config['checkpoint']}")
self.trainer.load_checkpoint(checkpoint_path=self.config["checkpoint"])
# otherwise check if slurm job already exists and try to resume from latest checkpoint
elif "slurm" in self.config and "slurm_job_id" in self.config["slurm"]:
uuid = unique_job_id(self.config["timestamp_id"], self.config["slurm"]["slurm_job_id"])
checkpoint_dir = get_checkpoint_dir(uuid, self.config["run_dir"])
self.chkpt_path = os.path.join(checkpoint_dir, DEFAULT_CHECKPOINT_NAME)

# if the path already exists, then attempt to load it
if os.path.exists(self.chkpt_path):
self.trainer.load_checkpoint(checkpoint_path=self.chkpt_path)
elif os.path.exists(self.chkpt_path):
logging.info(f"Previous checkpoint found at {self.chkpt_path}, resuming job from this checkecpoint")
self.trainer.load_checkpoint(checkpoint_path=self.chkpt_path)


def run(self):
Expand Down

0 comments on commit 621a542

Please sign in to comment.