diff --git a/src/metatensor/models/soap_bpnn/__init__.py b/src/metatensor/models/soap_bpnn/__init__.py index 4007f94ad..7b066fa9c 100644 --- a/src/metatensor/models/soap_bpnn/__init__.py +++ b/src/metatensor/models/soap_bpnn/__init__.py @@ -1,12 +1,2 @@ -from .model import Model # noqa: F401 -from .train import train # noqa: F401 - -from metatensor.models import ARCHITECTURE_CONFIG_PATH -from omegaconf import OmegaConf - -DEAFAULT_HYPERS = OmegaConf.to_container( - OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") -) - -DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"] -DEFAULT_TRAIN_HYPERS = DEAFAULT_HYPERS["train"] +from .model import Model, DEFAULT_MODEL_HYPERS # noqa: F401 +from .train import train, DEFAULT_TRAINING_HYPERS # noqa: F401 diff --git a/src/metatensor/models/soap_bpnn/model.py b/src/metatensor/models/soap_bpnn/model.py index 70eedc235..a3bdf1af3 100644 --- a/src/metatensor/models/soap_bpnn/model.py +++ b/src/metatensor/models/soap_bpnn/model.py @@ -4,11 +4,18 @@ import rascaline.torch import torch from metatensor.torch import Labels, TensorBlock, TensorMap +from omegaconf import OmegaConf -from . import DEFAULT_MODEL_HYPERS +from .. import ARCHITECTURE_CONFIG_PATH from ..utils.composition import apply_composition_contribution +DEAFAULT_HYPERS = OmegaConf.to_container( + OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") +) + +DEFAULT_MODEL_HYPERS = DEAFAULT_HYPERS["model"] + ARCHITECTURE_NAME = "soap_bpnn" diff --git a/src/metatensor/models/soap_bpnn/tests/test_functionality.py b/src/metatensor/models/soap_bpnn/tests/test_functionality.py index 8aed531c1..c6c6bcc0f 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_functionality.py +++ b/src/metatensor/models/soap_bpnn/tests/test_functionality.py @@ -2,9 +2,7 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import Model - -from . import DEAFAULT_HYPERS +from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model def test_prediction_subset(): @@ -12,7 +10,7 @@ def test_prediction_subset(): of the elements it was trained on.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) soap_bpnn([rascaline.torch.systems_to_torch(structure)]) diff --git a/src/metatensor/models/soap_bpnn/tests/test_invariance.py b/src/metatensor/models/soap_bpnn/tests/test_invariance.py index f87b74644..87358a9ab 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_invariance.py +++ b/src/metatensor/models/soap_bpnn/tests/test_invariance.py @@ -4,16 +4,16 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import Model +from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model -from . import DATASET_PATH, DEAFAULT_HYPERS +from . import DATASET_PATH def test_rotational_invariance(): """Tests that the model is rotationally invariant.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) structure = ase.io.read(DATASET_PATH) original_structure = copy.deepcopy(structure) diff --git a/src/metatensor/models/soap_bpnn/tests/test_regression.py b/src/metatensor/models/soap_bpnn/tests/test_regression.py index 71d783a45..0bf946920 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_regression.py +++ b/src/metatensor/models/soap_bpnn/tests/test_regression.py @@ -2,11 +2,16 @@ import rascaline.torch import torch -from metatensor.models.soap_bpnn import Model, train +from metatensor.models.soap_bpnn import ( + DEFAULT_MODEL_HYPERS, + DEFAULT_TRAINING_HYPERS, + Model, + train, +) from metatensor.models.utils.data import Dataset from metatensor.models.utils.data.readers import read_structures, read_targets -from . import DATASET_PATH, DEAFAULT_HYPERS +from . import DATASET_PATH torch.manual_seed(0) @@ -16,7 +21,7 @@ def test_regression_init(): """Perform a regression test on the model at initialization""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) # Predict on the first fivestructures structures = ase.io.read(DATASET_PATH, ":5") @@ -42,11 +47,11 @@ def test_regression_train(): targets = read_targets(DATASET_PATH, "U0") dataset = Dataset(structures, targets) - soap_bpnn = Model(dataset.all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) + soap_bpnn = Model(dataset.all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) - hypers = DEAFAULT_HYPERS.copy() - hypers["training"]["num_epochs"] = 2 - train(soap_bpnn, dataset, hypers["training"]) + hypers = DEFAULT_TRAINING_HYPERS.copy() + hypers["num_epochs"] = 2 + train(soap_bpnn, dataset, hypers) # Predict on the first five structures output = soap_bpnn(structures[:5]) diff --git a/src/metatensor/models/soap_bpnn/tests/test_torchscript.py b/src/metatensor/models/soap_bpnn/tests/test_torchscript.py index e12fed6ef..22efa8e81 100644 --- a/src/metatensor/models/soap_bpnn/tests/test_torchscript.py +++ b/src/metatensor/models/soap_bpnn/tests/test_torchscript.py @@ -1,13 +1,11 @@ import torch -from metatensor.models.soap_bpnn import Model - -from . import DEAFAULT_HYPERS +from metatensor.models.soap_bpnn import DEFAULT_MODEL_HYPERS, Model def test_torchscript(): """Tests that the model can be jitted.""" all_species = [1, 6, 7, 8] - soap_bpnn = Model(all_species, DEAFAULT_HYPERS["model"]).to(torch.float64) + soap_bpnn = Model(all_species, DEFAULT_MODEL_HYPERS).to(torch.float64) torch.jit.script(soap_bpnn) diff --git a/src/metatensor/models/soap_bpnn/train.py b/src/metatensor/models/soap_bpnn/train.py index 91e05cae3..405de5b0d 100644 --- a/src/metatensor/models/soap_bpnn/train.py +++ b/src/metatensor/models/soap_bpnn/train.py @@ -1,18 +1,14 @@ import logging import torch -from omegaconf import OmegaConf - -from metatensor.models import ARCHITECTURE_CONFIG_PATH from ..utils.composition import calculate_composition_weights from ..utils.data import collate_fn from ..utils.model_io import save_model +from .model import DEAFAULT_HYPERS -DEFAULT_TRAINING_HYPERS = OmegaConf.to_container( - OmegaConf.load(ARCHITECTURE_CONFIG_PATH / "soap_bpnn.yaml") -)["training"] +DEFAULT_TRAINING_HYPERS = DEAFAULT_HYPERS["training"] logger = logging.getLogger(__name__)