diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index 9c2dd1d0..056abf99 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -636,7 +636,9 @@ def save_checkpoint(self, epoch: int, mae_error: dict, save_dir: str) -> None: filename, os.path.join(save_dir, f"bestE_epoch{epoch}_{err_str}.pth.tar"), ) - if mae_error["f"] == min(self.training_history["f"]["val"]): + if "f" in self.targets and mae_error["f"] == min( + self.training_history["f"]["val"] + ): for fname in os.listdir(save_dir): if fname.startswith("bestF"): os.remove(os.path.join(save_dir, fname))