From e413a9082ccbae7f97959845c63b208dd53cc9b3 Mon Sep 17 00:00:00 2001 From: Guillaume Fraux Date: Tue, 3 Sep 2024 14:26:20 +0200 Subject: [PATCH] Fix/silence warnings in tests --- .../experimental/alchemical_model/model.py | 2 +- .../experimental/alchemical_model/trainer.py | 2 +- src/metatrain/experimental/pet/model.py | 4 ++-- src/metatrain/experimental/pet/trainer.py | 6 +++--- src/metatrain/experimental/soap_bpnn/model.py | 2 +- .../experimental/soap_bpnn/trainer.py | 21 ++++++++++++------- tests/cli/test_train_model.py | 4 ++-- 7 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/metatrain/experimental/alchemical_model/model.py b/src/metatrain/experimental/alchemical_model/model.py index 21b622a27..52b9af5fb 100644 --- a/src/metatrain/experimental/alchemical_model/model.py +++ b/src/metatrain/experimental/alchemical_model/model.py @@ -129,7 +129,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "AlchemicalModel": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/alchemical_model/trainer.py b/src/metatrain/experimental/alchemical_model/trainer.py index 9ff96599e..2ff0693dd 100644 --- a/src/metatrain/experimental/alchemical_model/trainer.py +++ b/src/metatrain/experimental/alchemical_model/trainer.py @@ -349,7 +349,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"] diff --git a/src/metatrain/experimental/pet/model.py b/src/metatrain/experimental/pet/model.py index 0ff1ad0e3..ee5202bd4 100644 --- a/src/metatrain/experimental/pet/model.py +++ b/src/metatrain/experimental/pet/model.py @@ -114,14 +114,14 @@ def forward( @classmethod def load_checkpoint(cls, path: Union[str, Path]) -> "PET": - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) hypers = checkpoint["hypers"] dataset_info = checkpoint["dataset_info"] model = cls( model_hypers=hypers["ARCHITECTURAL_HYPERS"], dataset_info=dataset_info ) - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) state_dict = checkpoint["checkpoint"]["model_state_dict"] ARCHITECTURAL_HYPERS = Hypers(model.hypers) diff --git a/src/metatrain/experimental/pet/trainer.py b/src/metatrain/experimental/pet/trainer.py index b36551aa3..f9c6ef670 100644 --- a/src/metatrain/experimental/pet/trainer.py +++ b/src/metatrain/experimental/pet/trainer.py @@ -163,7 +163,7 @@ def train( else: load_path = self.pet_dir / "best_val_rmse_energies_model_state_dict" - state_dict = torch.load(load_path) + state_dict = torch.load(load_path, weights_only=False) ARCHITECTURAL_HYPERS = Hypers(model.hypers) raw_pet = PET(ARCHITECTURAL_HYPERS, 0.0, len(model.atomic_types)) @@ -186,7 +186,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): # together with the hypers inside a file that will act as a metatrain # checkpoint checkpoint_path = self.pet_dir / "checkpoint" # type: ignore - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, weights_only=False) torch.save( { "checkpoint": checkpoint, @@ -204,7 +204,7 @@ def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # This function loads a metatrain PET checkpoint and returns a Trainer # instance with the hypers, while also saving the checkpoint in the # class - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) trainer = cls(train_hypers) trainer.pet_checkpoint = checkpoint["checkpoint"] return trainer diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index a2be10d10..7f144d27f 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -287,7 +287,7 @@ def forward( def load_checkpoint(cls, path: Union[str, Path]) -> "SoapBpnn": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] diff --git a/src/metatrain/experimental/soap_bpnn/trainer.py b/src/metatrain/experimental/soap_bpnn/trainer.py index 2a2967fa6..4a342b158 100644 --- a/src/metatrain/experimental/soap_bpnn/trainer.py +++ b/src/metatrain/experimental/soap_bpnn/trainer.py @@ -31,13 +31,6 @@ logger = logging.getLogger(__name__) -# Filter out the second derivative and device warnings from rascaline-torch -warnings.filterwarnings("ignore", category=UserWarning, message="second derivative") -warnings.filterwarnings( - "ignore", category=UserWarning, message="Systems data is on device" -) - - class Trainer: def __init__(self, train_hypers): self.hypers = train_hypers @@ -54,6 +47,17 @@ def train( val_datasets: List[Union[Dataset, torch.utils.data.Subset]], checkpoint_dir: str, ): + # Filter out the second derivative and device warnings from rascaline + warnings.filterwarnings(action="ignore", message="Systems data is on device") + warnings.filterwarnings( + action="ignore", + message="second derivatives with respect to positions are not implemented", + ) + warnings.filterwarnings( + action="ignore", + message="second derivatives with respect to cell matrix", + ) + assert dtype in SoapBpnn.__supported_dtypes__ is_distributed = self.hypers["distributed"] @@ -290,6 +294,7 @@ def train( targets = average_by_num_atoms(targets, systems, per_structure_targets) train_loss_batch = loss_fn(predictions, targets) + train_loss_batch.backward() optimizer.step() @@ -409,7 +414,7 @@ def save_checkpoint(self, model, path: Union[str, Path]): def load_checkpoint(cls, path: Union[str, Path], train_hypers) -> "Trainer": # Load the checkpoint - checkpoint = torch.load(path) + checkpoint = torch.load(path, weights_only=False) model_hypers = checkpoint["model_hypers"] model_state_dict = checkpoint["model_state_dict"] epoch = checkpoint["epoch"] diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 854559339..acecceea1 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -449,8 +449,8 @@ def test_model_consistency_with_seed(options, monkeypatch, tmp_path, seed): train_model(options, output="model2.pt") - m1 = torch.load("model1.ckpt") - m2 = torch.load("model2.ckpt") + m1 = torch.load("model1.ckpt", weights_only=False) + m2 = torch.load("model2.ckpt", weights_only=False) for i in m1["model_state_dict"]: tensor1 = m1["model_state_dict"][i]