diff --git a/docs/src/architectures/soap-bpnn.rst b/docs/src/architectures/soap-bpnn.rst index 1c87eda3f..f0715081b 100644 --- a/docs/src/architectures/soap-bpnn.rst +++ b/docs/src/architectures/soap-bpnn.rst @@ -147,10 +147,15 @@ The parameters for training are :param learning_rate: learning rate :param log_interval: number of epochs that elapse between reporting new training results :param checkpoint_interval: Interval to save a checkpoint to disk. -:param per_atom_targets: Specifies whether the model should be trained on a per-atom +:param fixed_composition_weights: allows to set fixed isolated atom energies from + outside. These are per target name and per (integer) atom type. For example, + ``fixed_composition_weights: {"energy": {1: -396.0, 6: -500.0}, "mtt::U0": {1: 0.0, + 6: 0.0}}`` sets the isolated atom energies for H (``1``) and O (``6``) to different + values for the two distinct targets. +:param per_atom_targets: specifies whether the model should be trained on a per-atom loss. In that case, the logger will also output per-atom metrics for that target. In any case, the final summary will be per-structure. -:param loss_weights: Specifies the weights to be used in the loss for each target. The +:param loss_weights: specifies the weights to be used in the loss for each target. The weights should be a dictionary of floats, one for each target. All missing targets are assigned a weight of 1.0. diff --git a/src/metatrain/experimental/soap_bpnn/model.py b/src/metatrain/experimental/soap_bpnn/model.py index 54980c0b6..a2be10d10 100644 --- a/src/metatrain/experimental/soap_bpnn/model.py +++ b/src/metatrain/experimental/soap_bpnn/model.py @@ -144,7 +144,7 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: * (self.hypers["soap"]["max_angular"] + 1) ) - hypers_bpnn = self.hypers["bpnn"] + hypers_bpnn = {**self.hypers["bpnn"]} hypers_bpnn["input_size"] = soap_size if hypers_bpnn["layernorm"]: diff --git a/src/metatrain/experimental/soap_bpnn/schema-hypers.json b/src/metatrain/experimental/soap_bpnn/schema-hypers.json index 2ce8a1b4e..570931d49 100644 --- a/src/metatrain/experimental/soap_bpnn/schema-hypers.json +++ b/src/metatrain/experimental/soap_bpnn/schema-hypers.json @@ -120,8 +120,14 @@ "fixed_composition_weights": { "type": "object", "patternProperties": { - "^[0-9]+$": { - "type": "number" + "^.*$": { + "type": "object", + "propertyNames": { + "pattern": "^[0-9]+$" + }, + "additionalProperties": { + "type": "number" + } } }, "additionalProperties": false diff --git a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py index 45dae1a8d..25dd250b6 100644 --- a/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py +++ b/src/metatrain/experimental/soap_bpnn/tests/test_functionality.py @@ -1,11 +1,15 @@ import metatensor.torch +import pytest import torch +from jsonschema.exceptions import ValidationError from metatensor.torch.atomistic import ModelOutput, System +from omegaconf import OmegaConf from metatrain.experimental.soap_bpnn import SoapBpnn +from metatrain.utils.architectures import check_architecture_options from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict -from . import MODEL_HYPERS +from . import DEFAULT_HYPERS, MODEL_HYPERS def test_prediction_subset_elements(): @@ -195,3 +199,33 @@ def test_output_per_atom(): assert outputs["energy"].block().samples.names == ["system", "atom"] assert outputs["energy"].block().values.shape == (4, 1) + + +def test_fixed_composition_weights(): + """Tests the correctness of the json schema for fixed_composition_weights""" + + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["fixed_composition_weights"] = { + "energy": { + 1: 1.0, + 6: 0.0, + 7: 0.0, + 8: 0.0, + 9: 3000.0, + } + } + hypers = OmegaConf.create(hypers) + check_architecture_options( + name="experimental.soap_bpnn", options=OmegaConf.to_container(hypers) + ) + + +def test_fixed_composition_weights_error(): + """Test that only inputd of type Dict[str, Dict[int, float]] are allowed.""" + hypers = DEFAULT_HYPERS.copy() + hypers["training"]["fixed_composition_weights"] = {"energy": {"H": 300.0}} + hypers = OmegaConf.create(hypers) + with pytest.raises(ValidationError, match=r"'H' does not match '\^\[0-9\]\+\$'"): + check_architecture_options( + name="experimental.soap_bpnn", options=OmegaConf.to_container(hypers) + )